fixing merge conflicts
This commit is contained in:
commit
5ee21f26bb
@ -202,6 +202,7 @@ filegroup(
|
|||||||
"//tensorflow/contrib/boosted_trees:all_files",
|
"//tensorflow/contrib/boosted_trees:all_files",
|
||||||
"//tensorflow/contrib/boosted_trees/lib:all_files",
|
"//tensorflow/contrib/boosted_trees/lib:all_files",
|
||||||
"//tensorflow/contrib/boosted_trees/proto:all_files",
|
"//tensorflow/contrib/boosted_trees/proto:all_files",
|
||||||
|
"//tensorflow/contrib/boosted_trees/resources:all_files",
|
||||||
"//tensorflow/contrib/cloud:all_files",
|
"//tensorflow/contrib/cloud:all_files",
|
||||||
"//tensorflow/contrib/cloud/kernels:all_files",
|
"//tensorflow/contrib/cloud/kernels:all_files",
|
||||||
"//tensorflow/contrib/compiler:all_files",
|
"//tensorflow/contrib/compiler:all_files",
|
||||||
@ -256,6 +257,7 @@ filegroup(
|
|||||||
"//tensorflow/contrib/tfprof/python/tools/tfprof:all_files",
|
"//tensorflow/contrib/tfprof/python/tools/tfprof:all_files",
|
||||||
"//tensorflow/contrib/training:all_files",
|
"//tensorflow/contrib/training:all_files",
|
||||||
"//tensorflow/contrib/util:all_files",
|
"//tensorflow/contrib/util:all_files",
|
||||||
|
"//tensorflow/contrib/xla_tf_graph:all_files",
|
||||||
"//tensorflow/core:all_files",
|
"//tensorflow/core:all_files",
|
||||||
"//tensorflow/core/debug:all_files",
|
"//tensorflow/core/debug:all_files",
|
||||||
"//tensorflow/core/distributed_runtime:all_files",
|
"//tensorflow/core/distributed_runtime:all_files",
|
||||||
|
@ -51,6 +51,7 @@ genrule(
|
|||||||
"test_graph_tfgather.pb",
|
"test_graph_tfgather.pb",
|
||||||
"test_graph_tfmatmul.pb",
|
"test_graph_tfmatmul.pb",
|
||||||
"test_graph_tfmatmulandadd.pb",
|
"test_graph_tfmatmulandadd.pb",
|
||||||
|
"test_graph_tffunction.pb",
|
||||||
],
|
],
|
||||||
cmd = "$(location :make_test_graphs) --out_dir $(@D)",
|
cmd = "$(location :make_test_graphs) --out_dir $(@D)",
|
||||||
tags = ["manual"],
|
tags = ["manual"],
|
||||||
@ -114,6 +115,15 @@ tf_library(
|
|||||||
tags = ["manual"],
|
tags = ["manual"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_library(
|
||||||
|
name = "test_graph_tffunction",
|
||||||
|
testonly = 1,
|
||||||
|
config = "test_graph_tffunction.config.pbtxt",
|
||||||
|
cpp_class = "FunctionComp",
|
||||||
|
graph = "test_graph_tffunction.pb",
|
||||||
|
tags = ["manual"],
|
||||||
|
)
|
||||||
|
|
||||||
cc_test(
|
cc_test(
|
||||||
name = "tfcompile_test",
|
name = "tfcompile_test",
|
||||||
srcs = ["tfcompile_test.cc"],
|
srcs = ["tfcompile_test.cc"],
|
||||||
@ -122,6 +132,7 @@ cc_test(
|
|||||||
":test_graph_tfadd",
|
":test_graph_tfadd",
|
||||||
":test_graph_tfadd_with_ckpt",
|
":test_graph_tfadd_with_ckpt",
|
||||||
":test_graph_tfadd_with_ckpt_saver",
|
":test_graph_tfadd_with_ckpt_saver",
|
||||||
|
":test_graph_tffunction",
|
||||||
":test_graph_tfgather",
|
":test_graph_tfgather",
|
||||||
":test_graph_tfmatmul",
|
":test_graph_tfmatmul",
|
||||||
":test_graph_tfmatmulandadd",
|
":test_graph_tfmatmulandadd",
|
||||||
|
@ -25,6 +25,7 @@ from tensorflow.core.protobuf import saver_pb2
|
|||||||
from tensorflow.python.client import session
|
from tensorflow.python.client import session
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.framework import function
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
@ -95,6 +96,17 @@ def tfmatmulandadd(_):
|
|||||||
math_ops.add(x, y, name='x_y_sum')
|
math_ops.add(x, y, name='x_y_sum')
|
||||||
|
|
||||||
|
|
||||||
|
def tffunction(_):
|
||||||
|
|
||||||
|
@function.Defun(dtypes.int32, dtypes.int32)
|
||||||
|
def test_func(a, b):
|
||||||
|
return a + b
|
||||||
|
|
||||||
|
x = constant_op.constant([1], name='x_const')
|
||||||
|
y = constant_op.constant([2], name='y_const')
|
||||||
|
test_func(x, y, name='func_call') # pylint: disable=unexpected-keyword-arg
|
||||||
|
|
||||||
|
|
||||||
def write_graph(build_graph, out_dir):
|
def write_graph(build_graph, out_dir):
|
||||||
"""Build a graph using build_graph and write it out."""
|
"""Build a graph using build_graph and write it out."""
|
||||||
g = ops.Graph()
|
g = ops.Graph()
|
||||||
@ -112,6 +124,7 @@ def main(_):
|
|||||||
write_graph(tfgather, FLAGS.out_dir)
|
write_graph(tfgather, FLAGS.out_dir)
|
||||||
write_graph(tfmatmul, FLAGS.out_dir)
|
write_graph(tfmatmul, FLAGS.out_dir)
|
||||||
write_graph(tfmatmulandadd, FLAGS.out_dir)
|
write_graph(tfmatmulandadd, FLAGS.out_dir)
|
||||||
|
write_graph(tffunction, FLAGS.out_dir)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
@ -121,7 +134,6 @@ if __name__ == '__main__':
|
|||||||
'--out_dir',
|
'--out_dir',
|
||||||
type=str,
|
type=str,
|
||||||
default='',
|
default='',
|
||||||
help='Output directory for graphs, checkpoints and savers.'
|
help='Output directory for graphs, checkpoints and savers.')
|
||||||
)
|
|
||||||
FLAGS, unparsed = parser.parse_known_args()
|
FLAGS, unparsed = parser.parse_known_args()
|
||||||
app.run(main=main, argv=[sys.argv[0]] + unparsed)
|
app.run(main=main, argv=[sys.argv[0]] + unparsed)
|
||||||
|
@ -0,0 +1,16 @@
|
|||||||
|
# Text form of tensorflow.tfcompile.Config proto.
|
||||||
|
feed {
|
||||||
|
id { node_name: "x_const" }
|
||||||
|
shape {
|
||||||
|
dim { size: 1 }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
feed {
|
||||||
|
id { node_name: "y_const" }
|
||||||
|
shape {
|
||||||
|
dim { size: 1 }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fetch {
|
||||||
|
id { node_name: "func_call" }
|
||||||
|
}
|
@ -20,6 +20,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/aot/tests/test_graph_tfadd.h"
|
#include "tensorflow/compiler/aot/tests/test_graph_tfadd.h"
|
||||||
#include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt.h"
|
#include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt.h"
|
||||||
#include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt_saver.h"
|
#include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt_saver.h"
|
||||||
|
#include "tensorflow/compiler/aot/tests/test_graph_tffunction.h"
|
||||||
#include "tensorflow/compiler/aot/tests/test_graph_tfgather.h"
|
#include "tensorflow/compiler/aot/tests/test_graph_tfgather.h"
|
||||||
#include "tensorflow/compiler/aot/tests/test_graph_tfmatmul.h"
|
#include "tensorflow/compiler/aot/tests/test_graph_tfmatmul.h"
|
||||||
#include "tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd.h"
|
#include "tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd.h"
|
||||||
@ -376,6 +377,21 @@ TEST(TFCompileTest, MatMulAndAdd1) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(TFCompileTest, Function) {
|
||||||
|
// The function is equivalent to an addition
|
||||||
|
FunctionComp add_fn;
|
||||||
|
EXPECT_EQ(add_fn.arg0_data(), add_fn.args()[0]);
|
||||||
|
EXPECT_EQ(add_fn.arg1_data(), add_fn.args()[1]);
|
||||||
|
|
||||||
|
add_fn.arg0() = 1;
|
||||||
|
add_fn.arg1() = 2;
|
||||||
|
EXPECT_TRUE(add_fn.Run());
|
||||||
|
EXPECT_EQ(add_fn.error_msg(), "");
|
||||||
|
EXPECT_EQ(add_fn.result0(), 3);
|
||||||
|
EXPECT_EQ(add_fn.result0_data()[0], 3);
|
||||||
|
EXPECT_EQ(add_fn.result0_data(), add_fn.results()[0]);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace tfcompile
|
} // namespace tfcompile
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -50,7 +50,7 @@ bool HasXLAKernel(const Node& node, const DeviceType& jit_device_type) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Make sure we don't recurse infinitely on recursive functions.
|
// Make sure we don't recurse infinitely on recursive functions.
|
||||||
const int kMaxRecursionDepth = 5;
|
const int kMaxRecursionDepth = 10;
|
||||||
|
|
||||||
bool IsCompilableCall(const NodeDef& call_def, DeviceType jit_device_type,
|
bool IsCompilableCall(const NodeDef& call_def, DeviceType jit_device_type,
|
||||||
int depth, FunctionLibraryRuntime* lib_runtime);
|
int depth, FunctionLibraryRuntime* lib_runtime);
|
||||||
|
@ -2339,6 +2339,14 @@ TEST_F(OpTest, ZerosLike) {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(OpTest, OnesLike) {
|
||||||
|
Repeatedly([this]() {
|
||||||
|
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
|
||||||
|
ExpectTfAndXlaOutputsAreClose(
|
||||||
|
OpTestBuilder("OnesLike").Input(RandomTensor(type)).Attr("T", type));
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
} // anonymous namespace
|
} // anonymous namespace
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
@ -257,6 +257,11 @@ class UnaryOpsTest(XLATestCase):
|
|||||||
np.array([[4, 3], [2, 1]], dtype=dtype),
|
np.array([[4, 3], [2, 1]], dtype=dtype),
|
||||||
expected=np.array([[0, 0], [0, 0]], dtype=dtype))
|
expected=np.array([[0, 0], [0, 0]], dtype=dtype))
|
||||||
|
|
||||||
|
self._assertOpOutputMatchesExpected(
|
||||||
|
array_ops.ones_like,
|
||||||
|
np.array([[4, 3], [2, 1]], dtype=dtype),
|
||||||
|
expected=np.array([[1, 1], [1, 1]], dtype=dtype))
|
||||||
|
|
||||||
def testLogicalOps(self):
|
def testLogicalOps(self):
|
||||||
self._assertOpOutputMatchesExpected(
|
self._assertOpOutputMatchesExpected(
|
||||||
math_ops.logical_not,
|
math_ops.logical_not,
|
||||||
|
@ -241,5 +241,19 @@ class ZerosLikeOp : public XlaOpKernel {
|
|||||||
|
|
||||||
REGISTER_XLA_OP(Name("ZerosLike"), ZerosLikeOp);
|
REGISTER_XLA_OP(Name("ZerosLike"), ZerosLikeOp);
|
||||||
|
|
||||||
|
class OnesLikeOp : public XlaOpKernel {
|
||||||
|
public:
|
||||||
|
explicit OnesLikeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
|
||||||
|
|
||||||
|
void Compile(XlaOpKernelContext* ctx) override {
|
||||||
|
const TensorShape input_shape = ctx->InputShape(0);
|
||||||
|
|
||||||
|
auto one = XlaHelpers::One(ctx->builder(), input_type(0));
|
||||||
|
ctx->SetOutput(0, ctx->builder()->Broadcast(one, input_shape.dim_sizes()));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
REGISTER_XLA_OP(Name("OnesLike"), OnesLikeOp);
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -59,9 +59,15 @@ Status CheckSignature(const DataTypeVector& types,
|
|||||||
|
|
||||||
XlaCompiler::XlaCompiler(XlaCompiler::Options options)
|
XlaCompiler::XlaCompiler(XlaCompiler::Options options)
|
||||||
: options_(std::move(options)),
|
: options_(std::move(options)),
|
||||||
|
initialization_status_(Status::OK()),
|
||||||
next_step_id_(1),
|
next_step_id_(1),
|
||||||
device_(new XlaCompilationDevice(SessionOptions(), options_.device_type)),
|
device_(new XlaCompilationDevice(SessionOptions(), options_.device_type)),
|
||||||
device_mgr_({device_}) {}
|
device_mgr_({device_}) {
|
||||||
|
if (options_.populate_resource_manager) {
|
||||||
|
initialization_status_ =
|
||||||
|
(*options_.populate_resource_manager)(device_->resource_manager());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
XlaCompiler::~XlaCompiler() = default;
|
XlaCompiler::~XlaCompiler() = default;
|
||||||
|
|
||||||
@ -379,6 +385,9 @@ Status XlaCompiler::CompileGraph(string const& name,
|
|||||||
CompilationResult* result) {
|
CompilationResult* result) {
|
||||||
VLOG(1) << "Executing graph symbolically to populate ComputationBuilder.";
|
VLOG(1) << "Executing graph symbolically to populate ComputationBuilder.";
|
||||||
|
|
||||||
|
// Report the error here if initialization failed.
|
||||||
|
TF_RETURN_IF_ERROR(initialization_status_);
|
||||||
|
|
||||||
xla::ComputationBuilder builder(client(), name);
|
xla::ComputationBuilder builder(client(), name);
|
||||||
XlaContext* context =
|
XlaContext* context =
|
||||||
new XlaContext(this, &builder, options_.allow_cpu_custom_calls,
|
new XlaContext(this, &builder, options_.allow_cpu_custom_calls,
|
||||||
|
@ -214,6 +214,12 @@ class XlaCompiler {
|
|||||||
// This is useful to prune stateful operators that should not be executed
|
// This is useful to prune stateful operators that should not be executed
|
||||||
// from a function body.
|
// from a function body.
|
||||||
bool prune_unreachable_nodes = false;
|
bool prune_unreachable_nodes = false;
|
||||||
|
|
||||||
|
// If not nullptr, populate_resource_manager is called with the
|
||||||
|
// compilation device's resource manager when the compilation
|
||||||
|
// device is created, and can be used to create metadata objects
|
||||||
|
// that can be accessed by XLA op kernels.
|
||||||
|
std::function<Status(ResourceMgr*)>* populate_resource_manager = nullptr;
|
||||||
};
|
};
|
||||||
|
|
||||||
explicit XlaCompiler(Options options);
|
explicit XlaCompiler(Options options);
|
||||||
@ -247,6 +253,7 @@ class XlaCompiler {
|
|||||||
Status BuildExecutable(const CompilationResult& result,
|
Status BuildExecutable(const CompilationResult& result,
|
||||||
std::unique_ptr<xla::LocalExecutable>* executable);
|
std::unique_ptr<xla::LocalExecutable>* executable);
|
||||||
|
|
||||||
|
const Options& options() const { return options_; }
|
||||||
xla::Client* client() const { return options_.client; }
|
xla::Client* client() const { return options_.client; }
|
||||||
XlaCompilationDevice* device() const { return device_; }
|
XlaCompilationDevice* device() const { return device_; }
|
||||||
const DeviceMgr* device_mgr() const { return &device_mgr_; }
|
const DeviceMgr* device_mgr() const { return &device_mgr_; }
|
||||||
@ -260,6 +267,9 @@ class XlaCompiler {
|
|||||||
private:
|
private:
|
||||||
Options options_;
|
Options options_;
|
||||||
|
|
||||||
|
// Status set to non-OK in the constructor if initialization fails.
|
||||||
|
Status initialization_status_;
|
||||||
|
|
||||||
// Returns the next step sequence number.
|
// Returns the next step sequence number.
|
||||||
int64 NextStepId();
|
int64 NextStepId();
|
||||||
|
|
||||||
|
@ -17,12 +17,14 @@ limitations under the License.
|
|||||||
#include "tensorflow/cc/framework/ops.h"
|
#include "tensorflow/cc/framework/ops.h"
|
||||||
#include "tensorflow/cc/ops/function_ops.h"
|
#include "tensorflow/cc/ops/function_ops.h"
|
||||||
#include "tensorflow/cc/ops/standard_ops.h"
|
#include "tensorflow/cc/ops/standard_ops.h"
|
||||||
|
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||||
#include "tensorflow/compiler/xla/client/client_library.h"
|
#include "tensorflow/compiler/xla/client/client_library.h"
|
||||||
#include "tensorflow/compiler/xla/client/local_client.h"
|
#include "tensorflow/compiler/xla/client/local_client.h"
|
||||||
#include "tensorflow/compiler/xla/literal_util.h"
|
#include "tensorflow/compiler/xla/literal_util.h"
|
||||||
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
|
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
|
||||||
#include "tensorflow/core/common_runtime/function.h"
|
#include "tensorflow/core/common_runtime/function.h"
|
||||||
|
#include "tensorflow/core/framework/resource_mgr.h"
|
||||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||||
#include "tensorflow/core/graph/graph.h"
|
#include "tensorflow/core/graph/graph.h"
|
||||||
#include "tensorflow/core/graph/graph_constructor.h"
|
#include "tensorflow/core/graph/graph_constructor.h"
|
||||||
@ -33,6 +35,65 @@ limitations under the License.
|
|||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
// Helper class to test the ability to pass resources through to XLA
|
||||||
|
// compiled kernels.
|
||||||
|
class DummyResourceForTest : public ResourceBase {
|
||||||
|
public:
|
||||||
|
string DebugString() override { return "dummy"; }
|
||||||
|
void Increment() { ++value_; }
|
||||||
|
int Get() { return value_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
int value_ = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
class DummyReadResourceOp : public XlaOpKernel {
|
||||||
|
public:
|
||||||
|
explicit DummyReadResourceOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
|
||||||
|
void Compile(XlaOpKernelContext* ctx) override {
|
||||||
|
ResourceMgr* rm = ctx->op_kernel_context()->resource_manager();
|
||||||
|
OP_REQUIRES(ctx, rm, errors::Internal("No resource manager."));
|
||||||
|
DummyResourceForTest* dummy;
|
||||||
|
OP_REQUIRES_OK(ctx, rm->Lookup<DummyResourceForTest>(
|
||||||
|
rm->default_container(), "dummy", &dummy));
|
||||||
|
dummy->Increment();
|
||||||
|
dummy->Unref();
|
||||||
|
|
||||||
|
ctx->SetOutput(0, ctx->Input(0));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class DummyReadResourceCC {
|
||||||
|
public:
|
||||||
|
DummyReadResourceCC(const Scope& scope, const Input& value) {
|
||||||
|
if (!scope.ok()) return;
|
||||||
|
auto _value = ops::AsNodeOut(scope, value);
|
||||||
|
if (!scope.ok()) return;
|
||||||
|
Node* ret;
|
||||||
|
const auto unique_name = scope.GetUniqueNameForOp("DummyReadResource");
|
||||||
|
auto builder = NodeBuilder(unique_name, "DummyReadResource").Input(_value);
|
||||||
|
scope.UpdateBuilder(&builder);
|
||||||
|
scope.UpdateStatus(builder.Finalize(scope.graph(), &ret));
|
||||||
|
if (!scope.ok()) return;
|
||||||
|
this->output_ = Output(ret, 0);
|
||||||
|
}
|
||||||
|
Node* node() const { return output_.node(); }
|
||||||
|
|
||||||
|
Output output_;
|
||||||
|
};
|
||||||
|
|
||||||
|
REGISTER_OP("DummyReadResource")
|
||||||
|
.Input("input: int32")
|
||||||
|
.Output("output: int32")
|
||||||
|
.Doc(R"doc(
|
||||||
|
A dummy Op.
|
||||||
|
|
||||||
|
input: dummy input.
|
||||||
|
output: dummy output.
|
||||||
|
)doc");
|
||||||
|
|
||||||
|
REGISTER_XLA_OP(Name("DummyReadResource"), DummyReadResourceOp);
|
||||||
|
|
||||||
class XlaCompilerTest : public ::testing::Test {
|
class XlaCompilerTest : public ::testing::Test {
|
||||||
protected:
|
protected:
|
||||||
void SetUp() override {
|
void SetUp() override {
|
||||||
@ -224,5 +285,45 @@ TEST_F(XlaCompilerTest, ConstantOutputs) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Tests compilation and execution of a graph that adds two tensors.
|
||||||
|
TEST_F(XlaCompilerTest, ResourceManager) {
|
||||||
|
// Builds a graph that calls the dummy resource Op.
|
||||||
|
Scope scope = Scope::NewRootScope().ExitOnError();
|
||||||
|
auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
|
||||||
|
auto b = DummyReadResourceCC(scope.WithOpName("B"), a);
|
||||||
|
auto c = ops::_Retval(scope.WithOpName("C"), b.output_, 0);
|
||||||
|
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
||||||
|
TF_ASSERT_OK(scope.ToGraph(graph.get()));
|
||||||
|
|
||||||
|
// Builds a description of the argument.
|
||||||
|
std::vector<XlaCompiler::Argument> args(1);
|
||||||
|
args[0].kind = XlaCompiler::Argument::kParameter;
|
||||||
|
args[0].type = DT_INT32;
|
||||||
|
args[0].shape = TensorShape({2});
|
||||||
|
|
||||||
|
DummyResourceForTest* resource = new DummyResourceForTest();
|
||||||
|
|
||||||
|
// Compiles the graph.
|
||||||
|
auto options = DefaultOptions();
|
||||||
|
std::function<Status(ResourceMgr*)> populate_function =
|
||||||
|
[resource](ResourceMgr* rm) {
|
||||||
|
resource->Ref();
|
||||||
|
return rm->Create(rm->default_container(), "dummy", resource);
|
||||||
|
};
|
||||||
|
options.populate_resource_manager = &populate_function;
|
||||||
|
XlaCompiler compiler(options);
|
||||||
|
auto flr = BuildFunctionLibraryRuntime(compiler);
|
||||||
|
|
||||||
|
EXPECT_EQ(0, resource->Get());
|
||||||
|
|
||||||
|
XlaCompiler::CompilationResult result;
|
||||||
|
TF_ASSERT_OK(compiler.CompileGraph("dummy", std::move(graph), flr.get(), args,
|
||||||
|
&result));
|
||||||
|
|
||||||
|
EXPECT_EQ(1, resource->Get());
|
||||||
|
|
||||||
|
resource->Unref();
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -354,6 +354,10 @@ void XlaOpKernelContext::SetOpHasSideEffects() {
|
|||||||
XlaContext::Get(context_).AddSideEffects();
|
XlaContext::Get(context_).AddSideEffects();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const XlaCompiler::Options& XlaOpKernelContext::GetCompilerOptions() const {
|
||||||
|
return XlaContext::Get(context_).compiler()->options();
|
||||||
|
}
|
||||||
|
|
||||||
void XlaOpKernelContext::CtxFailure(Status s) { context_->CtxFailure(s); }
|
void XlaOpKernelContext::CtxFailure(Status s) { context_->CtxFailure(s); }
|
||||||
void XlaOpKernelContext::CtxFailureWithWarning(Status s) {
|
void XlaOpKernelContext::CtxFailureWithWarning(Status s) {
|
||||||
context_->CtxFailureWithWarning(s);
|
context_->CtxFailureWithWarning(s);
|
||||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_OP_KERNEL_H_
|
#ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_OP_KERNEL_H_
|
||||||
#define TENSORFLOW_COMPILER_TF2XLA_XLA_OP_KERNEL_H_
|
#define TENSORFLOW_COMPILER_TF2XLA_XLA_OP_KERNEL_H_
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
|
||||||
#include "tensorflow/compiler/xla/client/computation_builder.h"
|
#include "tensorflow/compiler/xla/client/computation_builder.h"
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
#include "tensorflow/core/platform/macros.h"
|
#include "tensorflow/core/platform/macros.h"
|
||||||
@ -182,6 +183,11 @@ class XlaOpKernelContext {
|
|||||||
// Returns the underlying OpKernelContext. Use rarely.
|
// Returns the underlying OpKernelContext. Use rarely.
|
||||||
OpKernelContext* op_kernel_context() const { return context_; }
|
OpKernelContext* op_kernel_context() const { return context_; }
|
||||||
|
|
||||||
|
// Returns the options passed to the XlaCompiler that is being
|
||||||
|
// run. Used for, e.g., While to inherit options needed for nested
|
||||||
|
// computation.
|
||||||
|
const XlaCompiler::Options& GetCompilerOptions() const;
|
||||||
|
|
||||||
// TODO(phawkins): find a better home for these helpers.
|
// TODO(phawkins): find a better home for these helpers.
|
||||||
|
|
||||||
// Get an XLA lambda to compute Max. This is cached in the
|
// Get an XLA lambda to compute Max. This is cached in the
|
||||||
|
@ -167,6 +167,8 @@ void XlaOpRegistry::RegisterCompilationKernels() {
|
|||||||
!backend.second.op_filter(kdef.get())) {
|
!backend.second.op_filter(kdef.get())) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
VLOG(2) << "XLA op registration: device: " << backend.first
|
||||||
|
<< " op: " << op.first;
|
||||||
registry.kernel_registrars_.emplace_back(
|
registry.kernel_registrars_.emplace_back(
|
||||||
new kernel_factory::OpKernelRegistrar(
|
new kernel_factory::OpKernelRegistrar(
|
||||||
new KernelDef(*kdef), "XlaJitOp", op.second->factory));
|
new KernelDef(*kdef), "XlaJitOp", op.second->factory));
|
||||||
|
@ -6,6 +6,7 @@ package_group(
|
|||||||
name = "friends",
|
name = "friends",
|
||||||
packages = [
|
packages = [
|
||||||
"//tensorflow/compiler/...",
|
"//tensorflow/compiler/...",
|
||||||
|
"//tensorflow/contrib/xla_tf_graph/...",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1229,8 +1229,7 @@ StatusOr<bool> ComputationBuilder::IsConstant(
|
|||||||
VLOG(2) << "done with request";
|
VLOG(2) << "done with request";
|
||||||
|
|
||||||
if (!s.ok()) {
|
if (!s.ok()) {
|
||||||
NoteError(s);
|
return s;
|
||||||
return first_error_;
|
|
||||||
}
|
}
|
||||||
return response.is_constant();
|
return response.is_constant();
|
||||||
}
|
}
|
||||||
@ -1255,8 +1254,7 @@ StatusOr<std::unique_ptr<GlobalData>> ComputationBuilder::ComputeConstant(
|
|||||||
VLOG(2) << "done with request";
|
VLOG(2) << "done with request";
|
||||||
|
|
||||||
if (!s.ok()) {
|
if (!s.ok()) {
|
||||||
NoteError(s);
|
return s;
|
||||||
return first_error_;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TF_RET_CHECK(response.output().handle() != 0);
|
TF_RET_CHECK(response.output().handle() != 0);
|
||||||
|
@ -120,6 +120,7 @@ class HloComputation {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const string& name() const { return name_; }
|
const string& name() const { return name_; }
|
||||||
|
void set_name(const string& name) { name_ = name; }
|
||||||
|
|
||||||
// Return a string representation of the computation.
|
// Return a string representation of the computation.
|
||||||
string ToString() const;
|
string ToString() const;
|
||||||
@ -257,7 +258,7 @@ class HloComputation {
|
|||||||
// Internal helper to collect unreachable roots.
|
// Internal helper to collect unreachable roots.
|
||||||
std::vector<HloInstruction*> CollectUnreachableRoots() const;
|
std::vector<HloInstruction*> CollectUnreachableRoots() const;
|
||||||
|
|
||||||
const string name_;
|
string name_;
|
||||||
HloInstruction* root_instruction_;
|
HloInstruction* root_instruction_;
|
||||||
|
|
||||||
// Module containing this computation.
|
// Module containing this computation.
|
||||||
|
@ -357,7 +357,9 @@ Status HloCostAnalysis::HandleRng(HloInstruction* random,
|
|||||||
Status HloCostAnalysis::HandleFusion(HloInstruction* fusion) {
|
Status HloCostAnalysis::HandleFusion(HloInstruction* fusion) {
|
||||||
// Compute the cost of the fused expression.
|
// Compute the cost of the fused expression.
|
||||||
HloInstruction* fused_expression_root = fusion->fused_expression_root();
|
HloInstruction* fused_expression_root = fusion->fused_expression_root();
|
||||||
HloCostAnalysis visitor(shape_size_);
|
// Don't compute sizes inside of fused ops. We don't use the size here and the
|
||||||
|
// operations inside might not have a layout.
|
||||||
|
HloCostAnalysis visitor([](const Shape&) { return 0; });
|
||||||
TF_RETURN_IF_ERROR(fused_expression_root->Accept(&visitor));
|
TF_RETURN_IF_ERROR(fused_expression_root->Accept(&visitor));
|
||||||
|
|
||||||
// Attribute the cost of the fused expression to the fusion node.
|
// Attribute the cost of the fused expression to the fusion node.
|
||||||
|
@ -375,6 +375,33 @@ TEST_F(FusionCostAnalysis, LoopFusion) {
|
|||||||
EXPECT_EQ(fusion_analysis.transcendental_count(), 4);
|
EXPECT_EQ(fusion_analysis.transcendental_count(), 4);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(FusionCostAnalysis, NoLayout) {
|
||||||
|
Shape shape_with_layout = ShapeUtil::MakeShape(F32, {2, 3, 4, 5});
|
||||||
|
// Instructions within a fused op may have no layout.
|
||||||
|
Shape shape_without_layout = shape_with_layout;
|
||||||
|
shape_without_layout.clear_layout();
|
||||||
|
|
||||||
|
auto c1 = HloInstruction::CreateConstant(
|
||||||
|
LiteralUtil::CreateR4FromArray4D(Array4D<float>(2, 3, 4, 5)));
|
||||||
|
auto c2 =
|
||||||
|
HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({1, 2, 3}));
|
||||||
|
|
||||||
|
auto broadcast =
|
||||||
|
HloInstruction::CreateBroadcast(shape_without_layout, c2.get(), {1});
|
||||||
|
auto add = HloInstruction::CreateBinary(shape_with_layout, HloOpcode::kAdd,
|
||||||
|
c1.get(), broadcast.get());
|
||||||
|
|
||||||
|
auto fusion = HloInstruction::CreateFusion(
|
||||||
|
shape_with_layout, HloInstruction::FusionKind::kLoop, add.get());
|
||||||
|
fusion->FuseInstruction(broadcast.get());
|
||||||
|
|
||||||
|
HloCostAnalysis fusion_analysis(ShapeSize);
|
||||||
|
ASSERT_IS_OK(fusion->Accept(&fusion_analysis));
|
||||||
|
|
||||||
|
EXPECT_EQ(fusion_analysis.flop_count(), 120);
|
||||||
|
EXPECT_EQ(fusion_analysis.transcendental_count(), 0);
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(HloCostAnalysisTest, TupleCost) {
|
TEST_F(HloCostAnalysisTest, TupleCost) {
|
||||||
HloCostAnalysis analysis(ShapeSize);
|
HloCostAnalysis analysis(ShapeSize);
|
||||||
{
|
{
|
||||||
|
@ -31,20 +31,38 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
||||||
HloComputation* HloModule::AddEntryComputation(
|
HloModule::HloModule(const string& name,
|
||||||
|
const VersionedComputationHandle& entry_computation_handle)
|
||||||
|
: name_(name),
|
||||||
|
entry_computation_(nullptr),
|
||||||
|
has_entry_computation_handle_(true),
|
||||||
|
entry_computation_handle_(entry_computation_handle),
|
||||||
|
computation_name_uniquer_(/*separator=*/".") {}
|
||||||
|
|
||||||
|
HloModule::HloModule(const string& name)
|
||||||
|
: name_(name),
|
||||||
|
entry_computation_(nullptr),
|
||||||
|
computation_name_uniquer_(/*separator=*/".") {}
|
||||||
|
|
||||||
|
HloComputation* HloModule::AddComputationInternal(
|
||||||
std::unique_ptr<HloComputation> computation) {
|
std::unique_ptr<HloComputation> computation) {
|
||||||
CHECK_EQ(nullptr, entry_computation_);
|
computation->set_name(
|
||||||
entry_computation_ = computation.get();
|
computation_name_uniquer_.GetUniqueName(computation->name()));
|
||||||
computation->set_parent(this);
|
computation->set_parent(this);
|
||||||
computations_.push_back(std::move(computation));
|
computations_.push_back(std::move(computation));
|
||||||
return computations_.back().get();
|
return computations_.back().get();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
HloComputation* HloModule::AddEntryComputation(
|
||||||
|
std::unique_ptr<HloComputation> computation) {
|
||||||
|
CHECK_EQ(nullptr, entry_computation_);
|
||||||
|
entry_computation_ = computation.get();
|
||||||
|
return AddComputationInternal(std::move(computation));
|
||||||
|
}
|
||||||
|
|
||||||
HloComputation* HloModule::AddEmbeddedComputation(
|
HloComputation* HloModule::AddEmbeddedComputation(
|
||||||
std::unique_ptr<HloComputation> computation) {
|
std::unique_ptr<HloComputation> computation) {
|
||||||
computation->set_parent(this);
|
return AddComputationInternal(std::move(computation));
|
||||||
computations_.push_back(std::move(computation));
|
|
||||||
return computations_.back().get();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void HloModule::ReplaceComputations(
|
void HloModule::ReplaceComputations(
|
||||||
|
@ -25,6 +25,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/name_uniquer.h"
|
||||||
#include "tensorflow/compiler/xla/service/versioned_computation_handle.h"
|
#include "tensorflow/compiler/xla/service/versioned_computation_handle.h"
|
||||||
#include "tensorflow/compiler/xla/types.h"
|
#include "tensorflow/compiler/xla/types.h"
|
||||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||||
@ -41,19 +42,14 @@ namespace xla {
|
|||||||
// computations are owned by the module.
|
// computations are owned by the module.
|
||||||
class HloModule {
|
class HloModule {
|
||||||
public:
|
public:
|
||||||
explicit HloModule(const string& name,
|
HloModule(const string& name,
|
||||||
const VersionedComputationHandle& entry_computation_handle)
|
const VersionedComputationHandle& entry_computation_handle);
|
||||||
: name_(name),
|
|
||||||
entry_computation_(nullptr),
|
|
||||||
has_entry_computation_handle_(true),
|
|
||||||
entry_computation_handle_(entry_computation_handle) {}
|
|
||||||
|
|
||||||
// Constructor without a versioned computation handle. This constructor should
|
// Constructor without a versioned computation handle. This constructor should
|
||||||
// only be used for HloModules used outside of the XLA service (eg
|
// only be used for HloModules used outside of the XLA service (eg
|
||||||
// tests). The versioned handle is used by the service in the compilation
|
// tests). The versioned handle is used by the service in the compilation
|
||||||
// cache.
|
// cache.
|
||||||
explicit HloModule(const string& name)
|
explicit HloModule(const string& name);
|
||||||
: name_(name), entry_computation_(nullptr) {}
|
|
||||||
|
|
||||||
// Adds an entry computation to the module. A module can only have one entry
|
// Adds an entry computation to the module. A module can only have one entry
|
||||||
// computation. Returns a pointer to the newly added computation.
|
// computation. Returns a pointer to the newly added computation.
|
||||||
@ -111,6 +107,9 @@ class HloModule {
|
|||||||
uint64 RandomNew64() const;
|
uint64 RandomNew64() const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
HloComputation* AddComputationInternal(
|
||||||
|
std::unique_ptr<HloComputation> computation);
|
||||||
|
|
||||||
const string name_;
|
const string name_;
|
||||||
HloComputation* entry_computation_;
|
HloComputation* entry_computation_;
|
||||||
std::vector<std::unique_ptr<HloComputation>> computations_;
|
std::vector<std::unique_ptr<HloComputation>> computations_;
|
||||||
@ -125,6 +124,9 @@ class HloModule {
|
|||||||
// Versioned handle of the entry computation of the module.
|
// Versioned handle of the entry computation of the module.
|
||||||
bool has_entry_computation_handle_ = false;
|
bool has_entry_computation_handle_ = false;
|
||||||
VersionedComputationHandle entry_computation_handle_;
|
VersionedComputationHandle entry_computation_handle_;
|
||||||
|
|
||||||
|
// Unique name generator for computation names, which are unique per module.
|
||||||
|
NameUniquer computation_name_uniquer_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
@ -74,6 +74,11 @@ TEST_F(HloModuleTest, TwoComputationsPostOrder) {
|
|||||||
EXPECT_MATCH(
|
EXPECT_MATCH(
|
||||||
testing::ListToVec<HloComputation*>(module->MakeComputationPostOrder()),
|
testing::ListToVec<HloComputation*>(module->MakeComputationPostOrder()),
|
||||||
testing::UnorderedMatcher<HloComputation*>(computation1, computation2));
|
testing::UnorderedMatcher<HloComputation*>(computation1, computation2));
|
||||||
|
|
||||||
|
// We specified the same name for both computations, but the HloModule should
|
||||||
|
// have made the names unique.
|
||||||
|
EXPECT_EQ(computation1->name(), "Constant");
|
||||||
|
EXPECT_EQ(computation2->name(), "Constant.1");
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(HloModuleTest, DiamondComputationsPostOrder) {
|
TEST_F(HloModuleTest, DiamondComputationsPostOrder) {
|
||||||
|
@ -633,26 +633,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
|
|||||||
TF_DCHECK_OK(ShapeUtil::ValidateShape(ehs));
|
TF_DCHECK_OK(ShapeUtil::ValidateShape(ehs));
|
||||||
switch (operation) {
|
switch (operation) {
|
||||||
case TRIOP_CLAMP:
|
case TRIOP_CLAMP:
|
||||||
TF_RETURN_IF_ERROR(
|
return InferClampShape(lhs, rhs, ehs);
|
||||||
ExpectNotTupleOrOpaque(lhs, "lhs of ternary operation"));
|
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
ExpectNotTupleOrOpaque(rhs, "rhs of ternary operation"));
|
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
ExpectNotTupleOrOpaque(ehs, "ehs of ternary operation"));
|
|
||||||
if (((ShapeUtil::Compatible(lhs, rhs) || ShapeUtil::Rank(lhs) == 0) &&
|
|
||||||
(ShapeUtil::Compatible(rhs, ehs) || ShapeUtil::Rank(ehs) == 0))) {
|
|
||||||
return rhs;
|
|
||||||
}
|
|
||||||
if (ShapeUtil::Rank(rhs) == 0) {
|
|
||||||
if (ShapeUtil::Compatible(lhs, ehs)) {
|
|
||||||
return lhs;
|
|
||||||
}
|
|
||||||
return ShapeUtil::Rank(ehs) == 0 ? lhs : ehs;
|
|
||||||
}
|
|
||||||
return Unimplemented("not yet implemented: %s, %s <clamp> %s",
|
|
||||||
lhs.ShortDebugString().c_str(),
|
|
||||||
ehs.ShortDebugString().c_str(),
|
|
||||||
rhs.ShortDebugString().c_str());
|
|
||||||
case TRIOP_SELECT:
|
case TRIOP_SELECT:
|
||||||
return InferSelectShape(lhs, rhs, ehs);
|
return InferSelectShape(lhs, rhs, ehs);
|
||||||
case TRIOP_UPDATE:
|
case TRIOP_UPDATE:
|
||||||
@ -1332,6 +1313,41 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
|
|||||||
return ShapeUtil::PermuteDimensions(InversePermutation(dimensions), operand);
|
return ShapeUtil::PermuteDimensions(InversePermutation(dimensions), operand);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO(b/36794510): Make broadcast semantics more consistent, by supporting
|
||||||
|
// "degenerate" cases, as with binary elementwise ops.
|
||||||
|
/* static */ StatusOr<Shape> ShapeInference::InferClampShape(
|
||||||
|
const Shape& min, const Shape& operand, const Shape& max) {
|
||||||
|
TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(min, "clamp min"));
|
||||||
|
TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(operand, "clamp operand"));
|
||||||
|
TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(max, "clamp max"));
|
||||||
|
if (!ShapeUtil::SameElementType(min, operand) ||
|
||||||
|
!ShapeUtil::SameElementType(max, operand)) {
|
||||||
|
return InvalidArgument("clamp op with different operand types: %s, %s, %s",
|
||||||
|
ShapeUtil::HumanString(min).c_str(),
|
||||||
|
ShapeUtil::HumanString(operand).c_str(),
|
||||||
|
ShapeUtil::HumanString(max).c_str());
|
||||||
|
}
|
||||||
|
if (((ShapeUtil::Compatible(min, operand) || ShapeUtil::IsScalar(min)) &&
|
||||||
|
(ShapeUtil::Compatible(max, operand) || ShapeUtil::IsScalar(max)))) {
|
||||||
|
return operand;
|
||||||
|
}
|
||||||
|
if (ShapeUtil::IsScalar(operand)) {
|
||||||
|
if (ShapeUtil::Compatible(min, max)) {
|
||||||
|
return min;
|
||||||
|
} else if (ShapeUtil::IsScalar(min)) {
|
||||||
|
return max;
|
||||||
|
} else if (ShapeUtil::IsScalar(max)) {
|
||||||
|
return min;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Unimplemented(
|
||||||
|
"not yet implemented: %s, %s <clamp> %s", min.ShortDebugString().c_str(),
|
||||||
|
max.ShortDebugString().c_str(), operand.ShortDebugString().c_str());
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(b/36794510): Make broadcast semantics more consistent, by supporting
|
||||||
|
// "degenerate" cases, as with binary elementwise ops, as well as scalar
|
||||||
|
// broadcast from all operands, not just the predicate.
|
||||||
/* static */ StatusOr<Shape> ShapeInference::InferSelectShape(
|
/* static */ StatusOr<Shape> ShapeInference::InferSelectShape(
|
||||||
const Shape& pred, const Shape& on_true, const Shape& on_false) {
|
const Shape& pred, const Shape& on_true, const Shape& on_false) {
|
||||||
if (!ShapeUtil::Compatible(on_true, on_false)) {
|
if (!ShapeUtil::Compatible(on_true, on_false)) {
|
||||||
|
@ -190,6 +190,10 @@ class ShapeInference {
|
|||||||
BinaryOperation operation, const Shape& lhs, const Shape& rhs,
|
BinaryOperation operation, const Shape& lhs, const Shape& rhs,
|
||||||
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
|
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
|
||||||
|
|
||||||
|
// Helper for inferring the shape of Clamp ops.
|
||||||
|
static StatusOr<Shape> InferClampShape(const Shape& min, const Shape& operand,
|
||||||
|
const Shape& max);
|
||||||
|
|
||||||
// Helper for inferring the shape of Select ops.
|
// Helper for inferring the shape of Select ops.
|
||||||
static StatusOr<Shape> InferSelectShape(const Shape& pred,
|
static StatusOr<Shape> InferSelectShape(const Shape& pred,
|
||||||
const Shape& on_true,
|
const Shape& on_true,
|
||||||
|
@ -157,6 +157,99 @@ TEST_F(ShapeInferenceTest, SelectBadShapes) {
|
|||||||
testing::ContainsRegex("pred operand must have PRED element type"));
|
testing::ContainsRegex("pred operand must have PRED element type"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(ShapeInferenceTest, ClampAllMatrix) {
|
||||||
|
auto inferred_status = ShapeInference::InferTernaryOpShape(
|
||||||
|
TernaryOperation::TRIOP_CLAMP, matrix_64_48_, matrix_64_48_,
|
||||||
|
matrix_64_48_);
|
||||||
|
ASSERT_IS_OK(inferred_status.status());
|
||||||
|
ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie()));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ShapeInferenceTest, ClampAllScalar) {
|
||||||
|
auto inferred_status = ShapeInference::InferTernaryOpShape(
|
||||||
|
TernaryOperation::TRIOP_CLAMP, f32_, f32_, f32_);
|
||||||
|
ASSERT_IS_OK(inferred_status.status());
|
||||||
|
ASSERT_TRUE(ShapeUtil::Equal(f32_, inferred_status.ValueOrDie()));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ShapeInferenceTest, ClampMinScalar) {
|
||||||
|
auto inferred_status = ShapeInference::InferTernaryOpShape(
|
||||||
|
TernaryOperation::TRIOP_CLAMP, f32_, matrix_64_48_, matrix_64_48_);
|
||||||
|
ASSERT_IS_OK(inferred_status.status());
|
||||||
|
ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie()));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ShapeInferenceTest, ClampMaxScalar) {
|
||||||
|
auto inferred_status = ShapeInference::InferTernaryOpShape(
|
||||||
|
TernaryOperation::TRIOP_CLAMP, matrix_64_48_, matrix_64_48_, f32_);
|
||||||
|
ASSERT_IS_OK(inferred_status.status());
|
||||||
|
ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie()));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ShapeInferenceTest, ClampOperandScalar) {
|
||||||
|
auto inferred_status = ShapeInference::InferTernaryOpShape(
|
||||||
|
TernaryOperation::TRIOP_CLAMP, matrix_64_48_, f32_, matrix_64_48_);
|
||||||
|
ASSERT_IS_OK(inferred_status.status());
|
||||||
|
ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie()));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ShapeInferenceTest, ClampMinMatrix) {
|
||||||
|
auto inferred_status = ShapeInference::InferTernaryOpShape(
|
||||||
|
TernaryOperation::TRIOP_CLAMP, matrix_64_48_, f32_, f32_);
|
||||||
|
ASSERT_IS_OK(inferred_status.status());
|
||||||
|
ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie()));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ShapeInferenceTest, ClampMaxMatrix) {
|
||||||
|
auto inferred_status = ShapeInference::InferTernaryOpShape(
|
||||||
|
TernaryOperation::TRIOP_CLAMP, f32_, f32_, matrix_64_48_);
|
||||||
|
ASSERT_IS_OK(inferred_status.status());
|
||||||
|
ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie()));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ShapeInferenceTest, ClampOperandMatrix) {
|
||||||
|
auto inferred_status = ShapeInference::InferTernaryOpShape(
|
||||||
|
TernaryOperation::TRIOP_CLAMP, f32_, matrix_64_48_, f32_);
|
||||||
|
ASSERT_IS_OK(inferred_status.status());
|
||||||
|
ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie()));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ShapeInferenceTest, ClampBadShapes) {
|
||||||
|
// Type mismatch
|
||||||
|
ASSERT_FALSE(ShapeInference::InferTernaryOpShape(
|
||||||
|
TernaryOperation::TRIOP_CLAMP, s32_, f32_, f32_)
|
||||||
|
.ok());
|
||||||
|
ASSERT_FALSE(ShapeInference::InferTernaryOpShape(
|
||||||
|
TernaryOperation::TRIOP_CLAMP, f32_, s32_, f32_)
|
||||||
|
.ok());
|
||||||
|
ASSERT_FALSE(ShapeInference::InferTernaryOpShape(
|
||||||
|
TernaryOperation::TRIOP_CLAMP, f32_, f32_, s32_)
|
||||||
|
.ok());
|
||||||
|
// Dimension mismatch
|
||||||
|
ASSERT_FALSE(
|
||||||
|
ShapeInference::InferTernaryOpShape(TernaryOperation::TRIOP_CLAMP,
|
||||||
|
vector_64_, vector_32_, vector_32_)
|
||||||
|
.ok());
|
||||||
|
ASSERT_FALSE(
|
||||||
|
ShapeInference::InferTernaryOpShape(TernaryOperation::TRIOP_CLAMP,
|
||||||
|
vector_32_, vector_64_, vector_32_)
|
||||||
|
.ok());
|
||||||
|
ASSERT_FALSE(
|
||||||
|
ShapeInference::InferTernaryOpShape(TernaryOperation::TRIOP_CLAMP,
|
||||||
|
vector_32_, vector_32_, vector_64_)
|
||||||
|
.ok());
|
||||||
|
// Dimension mismatch, where one operand is a scalar
|
||||||
|
ASSERT_FALSE(ShapeInference::InferTernaryOpShape(
|
||||||
|
TernaryOperation::TRIOP_CLAMP, vector_64_, vector_32_, f32_)
|
||||||
|
.ok());
|
||||||
|
ASSERT_FALSE(ShapeInference::InferTernaryOpShape(
|
||||||
|
TernaryOperation::TRIOP_CLAMP, vector_64_, f32_, vector_32_)
|
||||||
|
.ok());
|
||||||
|
ASSERT_FALSE(ShapeInference::InferTernaryOpShape(
|
||||||
|
TernaryOperation::TRIOP_CLAMP, f32_, vector_64_, vector_32_)
|
||||||
|
.ok());
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(ShapeInferenceTest, VariadicOpTuplify) {
|
TEST_F(ShapeInferenceTest, VariadicOpTuplify) {
|
||||||
StatusOr<Shape> result = ShapeInference::InferVariadicOpShape(
|
StatusOr<Shape> result = ShapeInference::InferVariadicOpShape(
|
||||||
VariadicOperation::VAROP_TUPLE, {&s32_, &f32_});
|
VariadicOperation::VAROP_TUPLE, {&s32_, &f32_});
|
||||||
|
@ -30,6 +30,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/tests/test_macros.h"
|
#include "tensorflow/compiler/xla/tests/test_macros.h"
|
||||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||||
|
#include "tensorflow/core/lib/strings/strcat.h"
|
||||||
#include "tensorflow/core/platform/test.h"
|
#include "tensorflow/core/platform/test.h"
|
||||||
#include "tensorflow/core/platform/types.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
|
||||||
@ -245,37 +246,69 @@ XLA_TEST_F(ScalarComputationsTest, RemTwoScalarsF32) {
|
|||||||
ComputeAndCompareR0<float>(&builder, 2.5f, {}, error_spec_);
|
ComputeAndCompareR0<float>(&builder, 2.5f, {}, error_spec_);
|
||||||
}
|
}
|
||||||
|
|
||||||
XLA_TEST_F(ScalarComputationsTest, DivideTwoScalarsS32) {
|
struct DivS32Params {
|
||||||
ComputationBuilder builder(client_, TestName());
|
int32 dividend;
|
||||||
builder.Div(builder.ConstantR0<int32>(-5), builder.ConstantR0<int32>(2));
|
int32 divisor;
|
||||||
|
int32 quotient;
|
||||||
|
int32 remainder;
|
||||||
|
};
|
||||||
|
|
||||||
ComputeAndCompareR0<int32>(&builder, -2, {});
|
void PrintTo(const DivS32Params& p, std::ostream* os) {
|
||||||
|
*os << "{" << p.dividend << ", " << p.divisor << ", " << p.quotient << ", "
|
||||||
|
<< p.remainder << "}";
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ScalarComputationsTest, RemainderTwoScalarsNegativeResultS32) {
|
class DivS32Test : public ClientLibraryTestBase,
|
||||||
ComputationBuilder builder(client_, TestName());
|
public ::testing::WithParamInterface<DivS32Params> {};
|
||||||
builder.Rem(builder.ConstantR0<int32>(-5), builder.ConstantR0<int32>(2));
|
|
||||||
|
|
||||||
ComputeAndCompareR0<int32>(&builder, -1, {});
|
XLA_TEST_P(DivS32Test, DivideTwoScalarsS32) {
|
||||||
|
DivS32Params p = GetParam();
|
||||||
|
ComputationBuilder builder(client_, TestName());
|
||||||
|
builder.Div(builder.ConstantR0<int32>(p.dividend),
|
||||||
|
builder.ConstantR0<int32>(p.divisor));
|
||||||
|
|
||||||
|
ComputeAndCompareR0<int32>(&builder, p.quotient, {});
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ScalarComputationsTest, RemainderTwoScalarsIntMinS32) {
|
XLA_TEST_P(DivS32Test, RemainderTwoScalarsS32) {
|
||||||
|
DivS32Params p = GetParam();
|
||||||
ComputationBuilder builder(client_, TestName());
|
ComputationBuilder builder(client_, TestName());
|
||||||
builder.Rem(builder.ConstantR0<int32>(INT_MIN),
|
builder.Rem(builder.ConstantR0<int32>(p.dividend),
|
||||||
builder.ConstantR0<int32>(7919));
|
builder.ConstantR0<int32>(p.divisor));
|
||||||
|
|
||||||
ComputeAndCompareR0<int32>(&builder, -1309, {});
|
ComputeAndCompareR0<int32>(&builder, p.remainder, {});
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ScalarComputationsTest, RemainderTwoScalarsIntMinVsIntMaxS32) {
|
INSTANTIATE_TEST_CASE_P(
|
||||||
ComputationBuilder builder(client_, TestName());
|
DivS32Test_Instantiation, DivS32Test,
|
||||||
builder.Rem(builder.ConstantR0<int32>(INT_MIN),
|
::testing::Values(
|
||||||
builder.ConstantR0<int32>(INT_MAX));
|
// Positive divisors.
|
||||||
|
DivS32Params{5, 2, 2, 1}, //
|
||||||
|
DivS32Params{-5, 2, -2, -1}, //
|
||||||
|
DivS32Params{17, 3, 5, 2}, //
|
||||||
|
DivS32Params{-17, 3, -5, -2}, //
|
||||||
|
// Negative divisors.
|
||||||
|
DivS32Params{5, -2, -2, 1}, //
|
||||||
|
DivS32Params{-5, -2, 2, -1}, //
|
||||||
|
DivS32Params{17, -3, -5, 2}, //
|
||||||
|
DivS32Params{-17, -3, 5, -2}, //
|
||||||
|
// Large positive divisors.
|
||||||
|
DivS32Params{INT32_MIN, 7919, -271181, -1309}, //
|
||||||
|
DivS32Params{INT32_MIN, INT32_MAX, -1, -1}, //
|
||||||
|
DivS32Params{INT32_MIN + 1, INT32_MAX, -1, 0}, //
|
||||||
|
DivS32Params{INT32_MIN + 2, INT32_MAX, 0, INT32_MIN + 2}, //
|
||||||
|
DivS32Params{INT32_MIN, 0x40000000, -2, 0}, //
|
||||||
|
DivS32Params{INT32_MIN + 1, 0x40000000, -1, -0x3fffffff}, //
|
||||||
|
// Large negative divisors.
|
||||||
|
DivS32Params{INT32_MIN, INT32_MIN, 1, 0}, //
|
||||||
|
DivS32Params{INT32_MIN, INT32_MIN + 1, 1, -1}, //
|
||||||
|
DivS32Params{INT32_MIN + 1, INT32_MIN, 0, INT32_MIN + 1}, //
|
||||||
|
DivS32Params{INT32_MAX, INT32_MIN, 0, INT32_MAX}, //
|
||||||
|
DivS32Params{INT32_MAX, INT32_MIN + 1, -1, 0}, //
|
||||||
|
DivS32Params{INT32_MIN, -0x40000000, 2, 0}, //
|
||||||
|
DivS32Params{INT32_MIN + 1, -0x40000000, 1, -0x3fffffff}));
|
||||||
|
|
||||||
ComputeAndCompareR0<int32>(&builder, -1, {});
|
TEST_F(ScalarComputationsTest, RemainderTwoScalarsNonConstDividendS32) {
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(ScalarComputationsTest, RemainderTwoScalarsPositiveResultS32) {
|
|
||||||
ComputationBuilder builder(client_, TestName());
|
ComputationBuilder builder(client_, TestName());
|
||||||
auto x = builder.Parameter(0, ShapeUtil::MakeShape(S32, {}), "x");
|
auto x = builder.Parameter(0, ShapeUtil::MakeShape(S32, {}), "x");
|
||||||
builder.Rem(x, builder.ConstantR0<int32>(80000));
|
builder.Rem(x, builder.ConstantR0<int32>(80000));
|
||||||
|
@ -7,8 +7,6 @@ exports_files(["LICENSE"])
|
|||||||
|
|
||||||
package(default_visibility = ["//tensorflow:__subpackages__"])
|
package(default_visibility = ["//tensorflow:__subpackages__"])
|
||||||
|
|
||||||
load("//tensorflow:tensorflow.bzl", "if_not_windows")
|
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
name = "contrib_py",
|
name = "contrib_py",
|
||||||
srcs = glob(["**/*.py"]),
|
srcs = glob(["**/*.py"]),
|
||||||
@ -46,6 +44,7 @@ py_library(
|
|||||||
"//tensorflow/contrib/losses:losses_py",
|
"//tensorflow/contrib/losses:losses_py",
|
||||||
"//tensorflow/contrib/memory_stats:memory_stats_py",
|
"//tensorflow/contrib/memory_stats:memory_stats_py",
|
||||||
"//tensorflow/contrib/metrics:metrics_py",
|
"//tensorflow/contrib/metrics:metrics_py",
|
||||||
|
"//tensorflow/contrib/nccl:nccl_py",
|
||||||
"//tensorflow/contrib/ndlstm",
|
"//tensorflow/contrib/ndlstm",
|
||||||
"//tensorflow/contrib/nn:nn_py",
|
"//tensorflow/contrib/nn:nn_py",
|
||||||
"//tensorflow/contrib/opt:opt_py",
|
"//tensorflow/contrib/opt:opt_py",
|
||||||
@ -65,9 +64,7 @@ py_library(
|
|||||||
"//tensorflow/contrib/tfprof",
|
"//tensorflow/contrib/tfprof",
|
||||||
"//tensorflow/contrib/training:training_py",
|
"//tensorflow/contrib/training:training_py",
|
||||||
"//tensorflow/contrib/util:util_py",
|
"//tensorflow/contrib/util:util_py",
|
||||||
] + if_not_windows([
|
],
|
||||||
"//tensorflow/contrib/nccl:nccl_py",
|
|
||||||
]),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
|
@ -35,6 +35,7 @@ from tensorflow.contrib import image
|
|||||||
from tensorflow.contrib import input_pipeline
|
from tensorflow.contrib import input_pipeline
|
||||||
from tensorflow.contrib import integrate
|
from tensorflow.contrib import integrate
|
||||||
from tensorflow.contrib import keras
|
from tensorflow.contrib import keras
|
||||||
|
from tensorflow.contrib import kernel_methods
|
||||||
from tensorflow.contrib import labeled_tensor
|
from tensorflow.contrib import labeled_tensor
|
||||||
from tensorflow.contrib import layers
|
from tensorflow.contrib import layers
|
||||||
from tensorflow.contrib import learn
|
from tensorflow.contrib import learn
|
||||||
@ -45,6 +46,7 @@ from tensorflow.contrib import lookup
|
|||||||
from tensorflow.contrib import losses
|
from tensorflow.contrib import losses
|
||||||
from tensorflow.contrib import memory_stats
|
from tensorflow.contrib import memory_stats
|
||||||
from tensorflow.contrib import metrics
|
from tensorflow.contrib import metrics
|
||||||
|
from tensorflow.contrib import nccl
|
||||||
from tensorflow.contrib import nn
|
from tensorflow.contrib import nn
|
||||||
from tensorflow.contrib import opt
|
from tensorflow.contrib import opt
|
||||||
from tensorflow.contrib import quantization
|
from tensorflow.contrib import quantization
|
||||||
|
@ -160,3 +160,90 @@ cc_test(
|
|||||||
"//tensorflow/core:test_main",
|
"//tensorflow/core:test_main",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "models",
|
||||||
|
srcs = ["models/multiple_additive_trees.cc"],
|
||||||
|
hdrs = ["models/multiple_additive_trees.h"],
|
||||||
|
deps = [
|
||||||
|
":trees",
|
||||||
|
":utils",
|
||||||
|
"//tensorflow/contrib/boosted_trees/proto:tree_config_proto_cc",
|
||||||
|
"//tensorflow/core:framework_headers_lib",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "multiple_additive_trees_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = ["models/multiple_additive_trees_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":batch_features_testutil",
|
||||||
|
":models",
|
||||||
|
":random_tree_gen",
|
||||||
|
"//tensorflow/contrib/boosted_trees/resources:decision_tree_ensemble_resource",
|
||||||
|
"//tensorflow/core:framework_headers_lib",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:tensor_testutil",
|
||||||
|
"//tensorflow/core:test",
|
||||||
|
"//tensorflow/core:test_main",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "trees",
|
||||||
|
srcs = ["trees/decision_tree.cc"],
|
||||||
|
hdrs = ["trees/decision_tree.h"],
|
||||||
|
deps = [
|
||||||
|
":utils",
|
||||||
|
"//tensorflow/contrib/boosted_trees/proto:tree_config_proto_cc",
|
||||||
|
"//tensorflow/core:framework_headers_lib",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "trees_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = ["trees/decision_tree_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":trees",
|
||||||
|
":utils",
|
||||||
|
"//tensorflow/core:tensor_testutil",
|
||||||
|
"//tensorflow/core:test",
|
||||||
|
"//tensorflow/core:test_main",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "batch_features_testutil",
|
||||||
|
testonly = 1,
|
||||||
|
srcs = ["testutil/batch_features_testutil.cc"],
|
||||||
|
hdrs = ["testutil/batch_features_testutil.h"],
|
||||||
|
deps = [
|
||||||
|
":utils",
|
||||||
|
"//tensorflow/core:framework_headers_lib",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:test",
|
||||||
|
"//tensorflow/core:testlib",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "random_tree_gen",
|
||||||
|
srcs = ["testutil/random_tree_gen.cc"],
|
||||||
|
hdrs = ["testutil/random_tree_gen.h"],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/contrib/boosted_trees/proto:tree_config_proto_cc",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_binary(
|
||||||
|
name = "random_tree_gen_main",
|
||||||
|
srcs = ["testutil/random_tree_gen_main.cc"],
|
||||||
|
deps = [
|
||||||
|
":random_tree_gen",
|
||||||
|
"//tensorflow/core:framework_internal",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@ -0,0 +1,140 @@
|
|||||||
|
// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
// =============================================================================
|
||||||
|
#include "tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.h"
|
||||||
|
#include "tensorflow/contrib/boosted_trees/lib/trees/decision_tree.h"
|
||||||
|
#include "tensorflow/contrib/boosted_trees/lib/utils/batch_features.h"
|
||||||
|
#include "tensorflow/contrib/boosted_trees/lib/utils/parallel_for.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace boosted_trees {
|
||||||
|
namespace models {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
void CalculateTreesToKeep(
|
||||||
|
const boosted_trees::trees::DecisionTreeEnsembleConfig& config,
|
||||||
|
const std::vector<int32>& trees_to_drop, const int32 num_trees,
|
||||||
|
const bool only_finalized, std::vector<int32>* trees_to_keep) {
|
||||||
|
trees_to_keep->reserve(num_trees - trees_to_drop.size());
|
||||||
|
|
||||||
|
int32 index = 0;
|
||||||
|
// This assumes that trees_to_drop is a sorted list of tree ids.
|
||||||
|
for (int32 tree = 0; tree < num_trees; ++tree) {
|
||||||
|
if ((!trees_to_drop.empty() && index < trees_to_drop.size() &&
|
||||||
|
trees_to_drop[index] == tree) ||
|
||||||
|
(only_finalized && config.tree_metadata_size() > 0 &&
|
||||||
|
!config.tree_metadata(tree).is_finalized())) {
|
||||||
|
++index;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
trees_to_keep->push_back(tree);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void UpdatePredictions(
|
||||||
|
const int32 index_1, const int32 index_2, const float value,
|
||||||
|
tensorflow::TTypes<float>::Matrix* output_predictions,
|
||||||
|
tensorflow::TTypes<float>::Matrix* additional_output_predictions) {
|
||||||
|
(*output_predictions)(index_1, index_2) += value;
|
||||||
|
|
||||||
|
if (additional_output_predictions != nullptr) {
|
||||||
|
(*additional_output_predictions)(index_1, index_2) += value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void UpdatePredictionsBasedOnTree(
|
||||||
|
const boosted_trees::trees::DecisionTreeEnsembleConfig& config,
|
||||||
|
const int32 tree_idx, const boosted_trees::utils::Example& example,
|
||||||
|
tensorflow::TTypes<float>::Matrix* output_predictions,
|
||||||
|
tensorflow::TTypes<float>::Matrix* additional_output_predictions) {
|
||||||
|
const boosted_trees::trees::DecisionTreeConfig& tree = config.trees(tree_idx);
|
||||||
|
const float tree_weight = config.tree_weights(tree_idx);
|
||||||
|
const int leaf_idx = trees::DecisionTree::Traverse(tree, 0, example);
|
||||||
|
QCHECK(leaf_idx >= 0) << "Invalid tree: " << tree.DebugString();
|
||||||
|
const auto& leaf_node = tree.nodes(leaf_idx);
|
||||||
|
QCHECK(leaf_node.has_leaf())
|
||||||
|
<< "Invalid leaf node: " << leaf_node.DebugString();
|
||||||
|
if (leaf_node.leaf().has_sparse_vector()) {
|
||||||
|
const auto& leaf = leaf_node.leaf().sparse_vector();
|
||||||
|
QCHECK_EQ(leaf.index_size(), leaf.value_size());
|
||||||
|
for (size_t class_idx = 0; class_idx < leaf.index_size(); ++class_idx) {
|
||||||
|
const float value = tree_weight * leaf.value(class_idx);
|
||||||
|
|
||||||
|
UpdatePredictions(example.example_idx, leaf.index(class_idx), value,
|
||||||
|
output_predictions, additional_output_predictions);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
QCHECK(leaf_node.leaf().has_vector()) << "Unknown leaf type";
|
||||||
|
const auto& leaf = leaf_node.leaf().vector();
|
||||||
|
for (size_t i = 0; i < leaf.value_size(); ++i) {
|
||||||
|
const float value = tree_weight * leaf.value(i);
|
||||||
|
UpdatePredictions(example.example_idx, i, value, output_predictions,
|
||||||
|
additional_output_predictions);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void MultipleAdditiveTrees::Predict(
|
||||||
|
const boosted_trees::trees::DecisionTreeEnsembleConfig& config,
|
||||||
|
const bool only_finalized_trees, const std::vector<int32>& trees_to_drop,
|
||||||
|
const boosted_trees::utils::BatchFeatures& features,
|
||||||
|
tensorflow::thread::ThreadPool* worker_threads,
|
||||||
|
tensorflow::TTypes<float>::Matrix output_predictions,
|
||||||
|
tensorflow::TTypes<float>::Matrix no_dropout_predictions) {
|
||||||
|
// Zero out predictions as the model is additive.
|
||||||
|
output_predictions.setZero();
|
||||||
|
no_dropout_predictions.setZero();
|
||||||
|
|
||||||
|
// Get batch size.
|
||||||
|
const int64 batch_size = features.batch_size();
|
||||||
|
if (batch_size <= 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prepare the list of trees to keep.
|
||||||
|
std::vector<int32> trees_to_keep;
|
||||||
|
CalculateTreesToKeep(config, trees_to_drop, config.trees_size(),
|
||||||
|
only_finalized_trees, &trees_to_keep);
|
||||||
|
|
||||||
|
// Lambda for doing a block of work.
|
||||||
|
auto update_predictions = [&config, &features, &trees_to_keep, &trees_to_drop,
|
||||||
|
&output_predictions,
|
||||||
|
&no_dropout_predictions](int64 start, int64 end) {
|
||||||
|
auto examples_iterable = features.examples_iterable(start, end);
|
||||||
|
for (const auto& example : examples_iterable) {
|
||||||
|
for (const int32 tree_idx : trees_to_keep) {
|
||||||
|
UpdatePredictionsBasedOnTree(config, tree_idx, example,
|
||||||
|
&output_predictions,
|
||||||
|
&no_dropout_predictions);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now do predictions for dropped trees
|
||||||
|
for (const int32 tree_idx : trees_to_drop) {
|
||||||
|
UpdatePredictionsBasedOnTree(config, tree_idx, example,
|
||||||
|
&no_dropout_predictions, nullptr);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// TODO(salehay): parallelize this for low latency in serving path where
|
||||||
|
// batch size tends to be small but ensemble size tends to be large.
|
||||||
|
boosted_trees::utils::ParallelFor(batch_size, worker_threads->NumThreads(),
|
||||||
|
worker_threads, update_predictions);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace models
|
||||||
|
} // namespace boosted_trees
|
||||||
|
} // namespace tensorflow
|
@ -0,0 +1,50 @@
|
|||||||
|
// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
// =============================================================================
|
||||||
|
#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_MODELS_MULTIPLE_ADDITIVE_TREES_H_
|
||||||
|
#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_MODELS_MULTIPLE_ADDITIVE_TREES_H_
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tensorflow/contrib/boosted_trees/lib/utils/batch_features.h"
|
||||||
|
#include "tensorflow/contrib/boosted_trees/proto/tree_config.pb.h" // NOLINT
|
||||||
|
#include "tensorflow/core/framework/tensor_types.h"
|
||||||
|
#include "tensorflow/core/lib/core/threadpool.h"
|
||||||
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace boosted_trees {
|
||||||
|
namespace models {
|
||||||
|
|
||||||
|
// Multiple additive trees prediction model.
|
||||||
|
// This class does not hold state and is thread safe.
|
||||||
|
class MultipleAdditiveTrees {
|
||||||
|
public:
|
||||||
|
// Predict runs tree ensemble on the given batch and updates
|
||||||
|
// output predictions accordingly. The method also returns predictions that
|
||||||
|
// we would get if no dropout was applied.
|
||||||
|
static void Predict(
|
||||||
|
const boosted_trees::trees::DecisionTreeEnsembleConfig& config,
|
||||||
|
const bool only_finalized_trees, const std::vector<int32>& trees_to_drop,
|
||||||
|
const boosted_trees::utils::BatchFeatures& features,
|
||||||
|
thread::ThreadPool* const thread_pool,
|
||||||
|
TTypes<float>::Matrix output_predictions,
|
||||||
|
TTypes<float>::Matrix no_dropout_predictions);
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace models
|
||||||
|
} // namespace boosted_trees
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_MODELS_MULTIPLE_ADDITIVE_TREES_H_
|
@ -0,0 +1,381 @@
|
|||||||
|
// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
// =============================================================================
|
||||||
|
#include "tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.h"
|
||||||
|
|
||||||
|
#include "tensorflow/contrib/boosted_trees/lib/testutil/batch_features_testutil.h"
|
||||||
|
#include "tensorflow/contrib/boosted_trees/lib/testutil/random_tree_gen.h"
|
||||||
|
#include "tensorflow/contrib/boosted_trees/resources/decision_tree_ensemble_resource.h"
|
||||||
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||||
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
|
#include "tensorflow/core/lib/random/philox_random.h"
|
||||||
|
#include "tensorflow/core/lib/random/simple_philox.h"
|
||||||
|
#include "tensorflow/core/platform/env.h"
|
||||||
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
#include "tensorflow/core/platform/test_benchmark.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
using boosted_trees::trees::DecisionTreeEnsembleConfig;
|
||||||
|
using test::AsTensor;
|
||||||
|
|
||||||
|
namespace boosted_trees {
|
||||||
|
namespace models {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
const int32 kNumThreadsMultiThreaded = 6;
|
||||||
|
const int32 kNumThreadsSingleThreaded = 1;
|
||||||
|
|
||||||
|
class MultipleAdditiveTreesTest : public ::testing::Test {
|
||||||
|
protected:
|
||||||
|
MultipleAdditiveTreesTest() : batch_features_(2) {
|
||||||
|
// Create a batch of two examples having one dense feature each.
|
||||||
|
// The shape of the dense matrix is therefore 2x1 as in one row per example
|
||||||
|
// and one column per feature per example.
|
||||||
|
auto dense_matrix = test::AsTensor<float>({7.0f, -2.0f}, {2, 1});
|
||||||
|
TF_EXPECT_OK(
|
||||||
|
batch_features_.Initialize({dense_matrix}, {}, {}, {}, {}, {}, {}));
|
||||||
|
}
|
||||||
|
|
||||||
|
boosted_trees::utils::BatchFeatures batch_features_;
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(MultipleAdditiveTreesTest, Empty) {
|
||||||
|
// Create empty tree ensemble.
|
||||||
|
DecisionTreeEnsembleConfig tree_ensemble_config;
|
||||||
|
auto output_tensor = AsTensor<float>({9.0f, 23.0f}, {2, 1});
|
||||||
|
auto output_matrix = output_tensor.matrix<float>();
|
||||||
|
auto no_dropout_output_matrix = output_tensor.matrix<float>();
|
||||||
|
|
||||||
|
// Predict for both instances.
|
||||||
|
tensorflow::thread::ThreadPool threads(tensorflow::Env::Default(), "test",
|
||||||
|
kNumThreadsSingleThreaded);
|
||||||
|
MultipleAdditiveTrees::Predict(tree_ensemble_config,
|
||||||
|
false, // include non-finalized trees
|
||||||
|
{}, batch_features_, &threads, output_matrix,
|
||||||
|
no_dropout_output_matrix);
|
||||||
|
EXPECT_EQ(0, output_matrix(0, 0));
|
||||||
|
EXPECT_EQ(0, output_matrix(1, 0));
|
||||||
|
|
||||||
|
// There was no dropout
|
||||||
|
for (int i = 0; i < 2; ++i) {
|
||||||
|
EXPECT_EQ(output_matrix(i, 0), no_dropout_output_matrix(i, 0));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(MultipleAdditiveTreesTest, SingleClass) {
|
||||||
|
// Add one bias and one stump to ensemble for a single class.
|
||||||
|
DecisionTreeEnsembleConfig tree_ensemble_config;
|
||||||
|
auto* tree1 = tree_ensemble_config.add_trees();
|
||||||
|
auto* bias_leaf = tree1->add_nodes()->mutable_leaf()->mutable_sparse_vector();
|
||||||
|
bias_leaf->add_index(0);
|
||||||
|
bias_leaf->add_value(-0.4f);
|
||||||
|
auto* tree2 = tree_ensemble_config.add_trees();
|
||||||
|
auto* dense_split = tree2->add_nodes()->mutable_dense_float_binary_split();
|
||||||
|
dense_split->set_feature_column(0);
|
||||||
|
dense_split->set_threshold(5.0f);
|
||||||
|
dense_split->set_left_id(1);
|
||||||
|
dense_split->set_right_id(2);
|
||||||
|
auto* leaf1 = tree2->add_nodes()->mutable_leaf()->mutable_sparse_vector();
|
||||||
|
leaf1->add_index(0);
|
||||||
|
leaf1->add_value(0.9f);
|
||||||
|
auto* leaf2 = tree2->add_nodes()->mutable_leaf()->mutable_sparse_vector();
|
||||||
|
leaf2->add_index(0);
|
||||||
|
leaf2->add_value(0.2f);
|
||||||
|
|
||||||
|
tree_ensemble_config.add_tree_weights(1.0);
|
||||||
|
tree_ensemble_config.add_tree_weights(1.0);
|
||||||
|
|
||||||
|
auto output_tensor = AsTensor<float>({0.0f, 0.0f}, {2, 1});
|
||||||
|
auto output_matrix = output_tensor.matrix<float>();
|
||||||
|
|
||||||
|
auto no_dropout_output_tensor = AsTensor<float>({0.0f, 0.0f}, {2, 1});
|
||||||
|
auto no_dropout_output_matrix = no_dropout_output_tensor.matrix<float>();
|
||||||
|
|
||||||
|
tensorflow::thread::ThreadPool threads(tensorflow::Env::Default(), "test",
|
||||||
|
kNumThreadsSingleThreaded);
|
||||||
|
|
||||||
|
// Normal case.
|
||||||
|
{
|
||||||
|
MultipleAdditiveTrees::Predict(tree_ensemble_config,
|
||||||
|
false, // include non-finalized trees
|
||||||
|
{}, batch_features_, &threads, output_matrix,
|
||||||
|
no_dropout_output_matrix);
|
||||||
|
EXPECT_FLOAT_EQ(-0.2f, output_matrix(0, 0)); // -0.4 (bias) + 0.2 (leaf 2).
|
||||||
|
EXPECT_FLOAT_EQ(0.5f, output_matrix(1, 0)); // -0.4 (bias) + 0.9 (leaf 1).
|
||||||
|
|
||||||
|
// No dropout predictions are the same.
|
||||||
|
for (int i = 0; i < 2; ++i) {
|
||||||
|
EXPECT_EQ(output_matrix(i, 0), no_dropout_output_matrix(i, 0));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Weighted case
|
||||||
|
{
|
||||||
|
DecisionTreeEnsembleConfig weighted = tree_ensemble_config;
|
||||||
|
weighted.set_tree_weights(0, 6.0);
|
||||||
|
weighted.set_tree_weights(1, 3.2);
|
||||||
|
MultipleAdditiveTrees::Predict(weighted,
|
||||||
|
false, // include non-finalized trees
|
||||||
|
{}, batch_features_, &threads, output_matrix,
|
||||||
|
no_dropout_output_matrix);
|
||||||
|
// -0.4 (bias) + 0.2 (leaf 2).
|
||||||
|
EXPECT_FLOAT_EQ(-0.4f * 6 + 0.2 * 3.2, output_matrix(0, 0));
|
||||||
|
// -0.4 (bias) + 0.9 (leaf 1).
|
||||||
|
EXPECT_FLOAT_EQ(-0.4f * 6 + 0.9 * 3.2, output_matrix(1, 0));
|
||||||
|
|
||||||
|
// No dropout predictions are the same.
|
||||||
|
for (int i = 0; i < 2; ++i) {
|
||||||
|
EXPECT_EQ(output_matrix(i, 0), no_dropout_output_matrix(i, 0));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Drop first tree.
|
||||||
|
{
|
||||||
|
MultipleAdditiveTrees::Predict(tree_ensemble_config,
|
||||||
|
false, // include non-finalized trees
|
||||||
|
{0}, batch_features_, &threads,
|
||||||
|
output_matrix, no_dropout_output_matrix);
|
||||||
|
EXPECT_FLOAT_EQ(0.2f, output_matrix(0, 0)); // 0.2 (leaf 2).
|
||||||
|
EXPECT_FLOAT_EQ(0.9f, output_matrix(1, 0)); // 0.9 (leaf 1).
|
||||||
|
|
||||||
|
// No dropout predictions
|
||||||
|
EXPECT_FLOAT_EQ(
|
||||||
|
-0.2f, no_dropout_output_matrix(0, 0)); // -0.4 (bias) + 0.2 (leaf 2).
|
||||||
|
EXPECT_FLOAT_EQ(
|
||||||
|
0.5f, no_dropout_output_matrix(1, 0)); // -0.4 (bias) + 0.9 (leaf 1).
|
||||||
|
}
|
||||||
|
// Drop second tree.
|
||||||
|
{
|
||||||
|
MultipleAdditiveTrees::Predict(tree_ensemble_config,
|
||||||
|
false, // include non-finalized trees
|
||||||
|
{1}, batch_features_, &threads,
|
||||||
|
output_matrix, no_dropout_output_matrix);
|
||||||
|
EXPECT_FLOAT_EQ(-0.4f, output_matrix(0, 0)); // -0.4 (bias).
|
||||||
|
EXPECT_FLOAT_EQ(-0.4f, output_matrix(1, 0)); // -0.4 (bias).
|
||||||
|
|
||||||
|
// No dropout predictions
|
||||||
|
EXPECT_FLOAT_EQ(
|
||||||
|
-0.2f, no_dropout_output_matrix(0, 0)); // -0.4 (bias) + 0.2 (leaf 2).
|
||||||
|
EXPECT_FLOAT_EQ(
|
||||||
|
0.5f, no_dropout_output_matrix(1, 0)); // -0.4 (bias) + 0.9 (leaf 1).
|
||||||
|
}
|
||||||
|
// Drop all trees.
|
||||||
|
{
|
||||||
|
MultipleAdditiveTrees::Predict(tree_ensemble_config,
|
||||||
|
false, // include non-finalized trees
|
||||||
|
{0, 1}, batch_features_, &threads,
|
||||||
|
output_matrix, no_dropout_output_matrix);
|
||||||
|
EXPECT_FLOAT_EQ(0.0, output_matrix(0, 0));
|
||||||
|
EXPECT_FLOAT_EQ(0.0, output_matrix(1, 0));
|
||||||
|
|
||||||
|
// No dropout predictions
|
||||||
|
EXPECT_FLOAT_EQ(
|
||||||
|
-0.2f, no_dropout_output_matrix(0, 0)); // -0.4 (bias) + 0.2 (leaf 2).
|
||||||
|
EXPECT_FLOAT_EQ(
|
||||||
|
0.5f, no_dropout_output_matrix(1, 0)); // -0.4 (bias) + 0.9 (leaf 1).
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(MultipleAdditiveTreesTest, MultiClass) {
|
||||||
|
// Add one bias and one stump to ensemble for two classes.
|
||||||
|
DecisionTreeEnsembleConfig tree_ensemble_config;
|
||||||
|
auto* tree1 = tree_ensemble_config.add_trees();
|
||||||
|
auto* bias_leaf = tree1->add_nodes()->mutable_leaf()->mutable_sparse_vector();
|
||||||
|
bias_leaf->add_index(0);
|
||||||
|
bias_leaf->add_value(-0.4f);
|
||||||
|
bias_leaf->add_index(1);
|
||||||
|
bias_leaf->add_value(-0.7f);
|
||||||
|
auto* tree2 = tree_ensemble_config.add_trees();
|
||||||
|
auto* dense_split = tree2->add_nodes()->mutable_dense_float_binary_split();
|
||||||
|
dense_split->set_feature_column(0);
|
||||||
|
dense_split->set_threshold(5.0f);
|
||||||
|
dense_split->set_left_id(1);
|
||||||
|
dense_split->set_right_id(2);
|
||||||
|
auto* leaf1 = tree2->add_nodes()->mutable_leaf()->mutable_sparse_vector();
|
||||||
|
leaf1->add_index(0);
|
||||||
|
leaf1->add_value(0.9f);
|
||||||
|
auto* leaf2 = tree2->add_nodes()->mutable_leaf()->mutable_sparse_vector();
|
||||||
|
leaf2->add_index(1);
|
||||||
|
leaf2->add_value(0.2f);
|
||||||
|
|
||||||
|
tree_ensemble_config.add_tree_weights(1.0);
|
||||||
|
tree_ensemble_config.add_tree_weights(1.0);
|
||||||
|
|
||||||
|
// Predict for both instances.
|
||||||
|
tensorflow::thread::ThreadPool threads(tensorflow::Env::Default(), "test",
|
||||||
|
kNumThreadsSingleThreaded);
|
||||||
|
auto output_tensor = AsTensor<float>({0.0f, 0.0f, 0.0f, 0.0f}, {2, 2});
|
||||||
|
auto output_matrix = output_tensor.matrix<float>();
|
||||||
|
|
||||||
|
auto no_dropout_output_tensor =
|
||||||
|
AsTensor<float>({0.0f, 0.0f, 0.0f, 0.0f}, {2, 2});
|
||||||
|
auto no_dropout_output_matrix = no_dropout_output_tensor.matrix<float>();
|
||||||
|
|
||||||
|
// Normal case.
|
||||||
|
{
|
||||||
|
MultipleAdditiveTrees::Predict(tree_ensemble_config,
|
||||||
|
false, // include non-finalized trees
|
||||||
|
{}, batch_features_, &threads, output_matrix,
|
||||||
|
no_dropout_output_matrix);
|
||||||
|
EXPECT_FLOAT_EQ(-0.4f, output_matrix(0, 0)); // -0.4 (bias)
|
||||||
|
EXPECT_FLOAT_EQ(-0.5f, output_matrix(0, 1)); // -0.7 (bias) + 0.2 (leaf 2)
|
||||||
|
EXPECT_FLOAT_EQ(0.5f, output_matrix(1, 0)); // -0.4 (bias) + 0.9 (leaf 1)
|
||||||
|
EXPECT_FLOAT_EQ(-0.7f, output_matrix(1, 1)); // -0.7 (bias)
|
||||||
|
|
||||||
|
// No dropout predictions are the same.
|
||||||
|
for (int i = 0; i < 2; ++i) {
|
||||||
|
for (int j = 0; j < 2; ++j) {
|
||||||
|
EXPECT_EQ(output_matrix(i, j), no_dropout_output_matrix(i, j));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Weighted case.
|
||||||
|
{
|
||||||
|
DecisionTreeEnsembleConfig weighted = tree_ensemble_config;
|
||||||
|
weighted.set_tree_weights(0, 6.0);
|
||||||
|
weighted.set_tree_weights(1, 3.2);
|
||||||
|
MultipleAdditiveTrees::Predict(weighted,
|
||||||
|
false, // include non-finalized trees
|
||||||
|
{}, batch_features_, &threads, output_matrix,
|
||||||
|
no_dropout_output_matrix);
|
||||||
|
// bias
|
||||||
|
EXPECT_FLOAT_EQ(-0.4f * 6, output_matrix(0, 0));
|
||||||
|
// bias + leaf 2
|
||||||
|
EXPECT_FLOAT_EQ(-0.7f * 6 + 0.2f * 3.2, output_matrix(0, 1));
|
||||||
|
// bias + leaf 2
|
||||||
|
EXPECT_FLOAT_EQ(-0.4f * 6 + 0.9f * 3.2f, output_matrix(1, 0));
|
||||||
|
// bias
|
||||||
|
EXPECT_FLOAT_EQ(-0.7f * 6, output_matrix(1, 1));
|
||||||
|
}
|
||||||
|
// Dropout first tree.
|
||||||
|
{
|
||||||
|
MultipleAdditiveTrees::Predict(tree_ensemble_config,
|
||||||
|
false, // include non-finalized trees
|
||||||
|
{0}, batch_features_, &threads,
|
||||||
|
output_matrix, no_dropout_output_matrix);
|
||||||
|
EXPECT_FLOAT_EQ(0.0, output_matrix(0, 0));
|
||||||
|
EXPECT_FLOAT_EQ(0.2f, output_matrix(0, 1)); // 0.2 (leaf 2)
|
||||||
|
EXPECT_FLOAT_EQ(0.9f, output_matrix(1, 0)); // 0.9 (leaf 2)
|
||||||
|
EXPECT_FLOAT_EQ(0.0f, output_matrix(1, 1));
|
||||||
|
|
||||||
|
// No dropout predictions
|
||||||
|
EXPECT_FLOAT_EQ(-0.4f, no_dropout_output_matrix(0, 0)); // -0.4 (bias)
|
||||||
|
EXPECT_FLOAT_EQ(
|
||||||
|
-0.5f, no_dropout_output_matrix(0, 1)); // -0.7 (bias) + 0.2 (leaf 2)
|
||||||
|
EXPECT_FLOAT_EQ(
|
||||||
|
0.5f, no_dropout_output_matrix(1, 0)); // -0.4 (bias) + 0.9 (leaf 2)
|
||||||
|
EXPECT_FLOAT_EQ(-0.7f, no_dropout_output_matrix(1, 1)); // -0.7 (bias)
|
||||||
|
}
|
||||||
|
// Dropout second tree.
|
||||||
|
{
|
||||||
|
MultipleAdditiveTrees::Predict(tree_ensemble_config,
|
||||||
|
false, // include non-finalized trees
|
||||||
|
{1}, batch_features_, &threads,
|
||||||
|
output_matrix, no_dropout_output_matrix);
|
||||||
|
EXPECT_FLOAT_EQ(-0.4f, output_matrix(0, 0)); // -0.4 (bias)
|
||||||
|
EXPECT_FLOAT_EQ(-0.7f, output_matrix(0, 1)); // -0.7 (bias)
|
||||||
|
EXPECT_FLOAT_EQ(-0.4f, output_matrix(1, 0)); // -0.4 (bias)
|
||||||
|
EXPECT_FLOAT_EQ(-0.7f, output_matrix(1, 1)); // -0.7 (bias)
|
||||||
|
|
||||||
|
// No dropout predictions
|
||||||
|
EXPECT_FLOAT_EQ(-0.4f, no_dropout_output_matrix(0, 0)); // -0.4 (bias)
|
||||||
|
EXPECT_FLOAT_EQ(
|
||||||
|
-0.5f, no_dropout_output_matrix(0, 1)); // -0.7 (bias) + 0.2 (leaf 2)
|
||||||
|
EXPECT_FLOAT_EQ(
|
||||||
|
0.5f, no_dropout_output_matrix(1, 0)); // -0.4 (bias) + 0.9 (leaf 2)
|
||||||
|
EXPECT_FLOAT_EQ(-0.7f, no_dropout_output_matrix(1, 1)); // -0.7 (bias)
|
||||||
|
}
|
||||||
|
// Drop both trees.
|
||||||
|
{
|
||||||
|
MultipleAdditiveTrees::Predict(tree_ensemble_config,
|
||||||
|
false, // include non-finalized trees
|
||||||
|
{0, 1}, batch_features_, &threads,
|
||||||
|
output_matrix, no_dropout_output_matrix);
|
||||||
|
EXPECT_FLOAT_EQ(0.0f, output_matrix(0, 0));
|
||||||
|
EXPECT_FLOAT_EQ(0.0f, output_matrix(0, 1));
|
||||||
|
EXPECT_FLOAT_EQ(0.0f, output_matrix(1, 0));
|
||||||
|
EXPECT_FLOAT_EQ(0.0f, output_matrix(1, 1));
|
||||||
|
|
||||||
|
// No dropout predictions
|
||||||
|
EXPECT_FLOAT_EQ(-0.4f, no_dropout_output_matrix(0, 0)); // -0.4 (bias)
|
||||||
|
EXPECT_FLOAT_EQ(
|
||||||
|
-0.5f, no_dropout_output_matrix(0, 1)); // -0.7 (bias) + 0.2 (leaf 2)
|
||||||
|
EXPECT_FLOAT_EQ(
|
||||||
|
0.5f, no_dropout_output_matrix(1, 0)); // -0.4 (bias) + 0.9 (leaf 2)
|
||||||
|
EXPECT_FLOAT_EQ(-0.7f, no_dropout_output_matrix(1, 1)); // -0.7 (bias)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(MultipleAdditiveTreesTest, DenseLeaves) {
|
||||||
|
DecisionTreeEnsembleConfig tree_ensemble_config;
|
||||||
|
auto* tree1 = tree_ensemble_config.add_trees();
|
||||||
|
auto* bias_leaf = tree1->add_nodes()->mutable_leaf()->mutable_vector();
|
||||||
|
bias_leaf->add_value(-0.4f);
|
||||||
|
bias_leaf->add_value(-0.7f);
|
||||||
|
bias_leaf->add_value(3.0f);
|
||||||
|
auto* tree2 = tree_ensemble_config.add_trees();
|
||||||
|
auto* dense_split = tree2->add_nodes()->mutable_dense_float_binary_split();
|
||||||
|
dense_split->set_feature_column(0);
|
||||||
|
dense_split->set_threshold(5.0f);
|
||||||
|
dense_split->set_left_id(1);
|
||||||
|
dense_split->set_right_id(2);
|
||||||
|
auto* leaf1 = tree2->add_nodes()->mutable_leaf()->mutable_vector();
|
||||||
|
leaf1->add_value(0.9f);
|
||||||
|
leaf1->add_value(0.8f);
|
||||||
|
leaf1->add_value(0.7f);
|
||||||
|
auto* leaf2 = tree2->add_nodes()->mutable_leaf()->mutable_vector();
|
||||||
|
leaf2->add_value(0.2f);
|
||||||
|
leaf2->add_value(0.3f);
|
||||||
|
leaf2->add_value(0.4f);
|
||||||
|
|
||||||
|
tree_ensemble_config.add_tree_weights(1.0);
|
||||||
|
tree_ensemble_config.add_tree_weights(1.0);
|
||||||
|
|
||||||
|
// Predict for both instances.
|
||||||
|
tensorflow::thread::ThreadPool threads(tensorflow::Env::Default(), "test",
|
||||||
|
kNumThreadsSingleThreaded);
|
||||||
|
auto output_tensor =
|
||||||
|
AsTensor<float>({0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f}, {2, 3});
|
||||||
|
auto output_matrix = output_tensor.matrix<float>();
|
||||||
|
|
||||||
|
auto no_dropout_output_tensor =
|
||||||
|
AsTensor<float>({0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f}, {2, 3});
|
||||||
|
auto no_dropout_output_matrix = no_dropout_output_tensor.matrix<float>();
|
||||||
|
|
||||||
|
// Normal case.
|
||||||
|
{
|
||||||
|
MultipleAdditiveTrees::Predict(tree_ensemble_config,
|
||||||
|
false, // include non-finalized trees
|
||||||
|
{}, batch_features_, &threads, output_matrix,
|
||||||
|
no_dropout_output_matrix);
|
||||||
|
EXPECT_FLOAT_EQ(-0.2f, output_matrix(0, 0)); // -0.4 (tree1) + 0.2 (leaf 2)
|
||||||
|
EXPECT_FLOAT_EQ(-0.4f, output_matrix(0, 1)); // -0.7 (tree1) + 0.3 (leaf 2)
|
||||||
|
EXPECT_FLOAT_EQ(3.4f, output_matrix(0, 2)); // 3.0 -(tree1) + 0.4 (leaf 2)
|
||||||
|
EXPECT_FLOAT_EQ(0.5f, output_matrix(1, 0)); // -0.4 (tree1) + 0.9 (leaf 1)
|
||||||
|
EXPECT_FLOAT_EQ(0.1f, output_matrix(1, 1)); // -0.7 (tree1) + 0.8 (leaf 1)
|
||||||
|
EXPECT_FLOAT_EQ(3.7f, output_matrix(1, 2)); // 3.0 (tree1) + 0.7 (leaf 1)
|
||||||
|
|
||||||
|
// No dropout predictions are the same.
|
||||||
|
for (int i = 0; i < 2; ++i) {
|
||||||
|
for (int j = 0; j < 3; ++j) {
|
||||||
|
EXPECT_EQ(output_matrix(i, j), no_dropout_output_matrix(i, j));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace models
|
||||||
|
} // namespace boosted_trees
|
||||||
|
} // namespace tensorflow
|
@ -0,0 +1,88 @@
|
|||||||
|
// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
// =============================================================================
|
||||||
|
#include "tensorflow/contrib/boosted_trees/lib/testutil/batch_features_testutil.h"
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||||
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace boosted_trees {
|
||||||
|
namespace testutil {
|
||||||
|
|
||||||
|
using tensorflow::Tensor;
|
||||||
|
|
||||||
|
void RandomlyInitializeBatchFeatures(
|
||||||
|
tensorflow::random::SimplePhilox* rng, uint32 num_dense_float_features,
|
||||||
|
uint32 num_sparse_float_features, double sparsity_lo, double sparsity_hi,
|
||||||
|
boosted_trees::utils::BatchFeatures* batch_features) {
|
||||||
|
const int64 batch_size = static_cast<int64>(batch_features->batch_size());
|
||||||
|
|
||||||
|
// Populate dense features.
|
||||||
|
std::vector<tensorflow::Tensor> dense_float_features_list;
|
||||||
|
for (int i = 0; i < num_dense_float_features; ++i) {
|
||||||
|
std::vector<float> values;
|
||||||
|
for (int64 j = 0; j < batch_size; ++j) {
|
||||||
|
values.push_back(rng->RandFloat());
|
||||||
|
}
|
||||||
|
auto dense_tensor = Tensor(tensorflow::DT_FLOAT, {batch_size, 1});
|
||||||
|
tensorflow::test::FillValues<float>(&dense_tensor, values);
|
||||||
|
dense_float_features_list.push_back(dense_tensor);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Populate sparse features.
|
||||||
|
std::vector<tensorflow::Tensor> sparse_float_feature_indices_list;
|
||||||
|
std::vector<tensorflow::Tensor> sparse_float_feature_values_list;
|
||||||
|
std::vector<tensorflow::Tensor> sparse_float_feature_shapes_list;
|
||||||
|
for (int i = 0; i < num_sparse_float_features; ++i) {
|
||||||
|
std::set<uint64> indices;
|
||||||
|
const double sparsity =
|
||||||
|
sparsity_lo + rng->RandDouble() * (sparsity_hi - sparsity_lo);
|
||||||
|
const double density = 1 - sparsity;
|
||||||
|
for (int64 k = 0; k < static_cast<int64>(density * batch_size) + 1; ++k) {
|
||||||
|
indices.insert(rng->Uniform64(batch_size));
|
||||||
|
}
|
||||||
|
const int64 sparse_values_size = indices.size();
|
||||||
|
std::vector<int64> indices_vector;
|
||||||
|
for (auto idx : indices) {
|
||||||
|
indices_vector.push_back(idx);
|
||||||
|
indices_vector.push_back(0);
|
||||||
|
}
|
||||||
|
auto indices_tensor = Tensor(tensorflow::DT_INT64, {sparse_values_size, 2});
|
||||||
|
tensorflow::test::FillValues<int64>(&indices_tensor, indices_vector);
|
||||||
|
sparse_float_feature_indices_list.push_back(indices_tensor);
|
||||||
|
|
||||||
|
std::vector<float> values;
|
||||||
|
for (int64 j = 0; j < sparse_values_size; ++j) {
|
||||||
|
values.push_back(rng->RandFloat());
|
||||||
|
}
|
||||||
|
auto values_tensor = Tensor(tensorflow::DT_FLOAT, {sparse_values_size});
|
||||||
|
tensorflow::test::FillValues<float>(&values_tensor, values);
|
||||||
|
sparse_float_feature_values_list.push_back(values_tensor);
|
||||||
|
|
||||||
|
auto shape_tensor = Tensor(tensorflow::DT_INT64, {2});
|
||||||
|
tensorflow::test::FillValues<int64>(&shape_tensor, {batch_size, 1});
|
||||||
|
sparse_float_feature_shapes_list.push_back(shape_tensor);
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(salehay): Add categorical feature generation support.
|
||||||
|
TF_EXPECT_OK(batch_features->Initialize(
|
||||||
|
dense_float_features_list, sparse_float_feature_indices_list,
|
||||||
|
sparse_float_feature_values_list, sparse_float_feature_shapes_list, {},
|
||||||
|
{}, {}));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace testutil
|
||||||
|
} // namespace boosted_trees
|
||||||
|
} // namespace tensorflow
|
@ -0,0 +1,45 @@
|
|||||||
|
// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
// =============================================================================
|
||||||
|
#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_TESTUTIL_BATCH_FEATURES_TESTUTIL_H_
|
||||||
|
#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_TESTUTIL_BATCH_FEATURES_TESTUTIL_H_
|
||||||
|
|
||||||
|
#include "tensorflow/contrib/boosted_trees/lib/utils/batch_features.h"
|
||||||
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
|
#include "tensorflow/core/lib/random/simple_philox.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace boosted_trees {
|
||||||
|
namespace testutil {
|
||||||
|
|
||||||
|
// This method calls Initialize on the given 'batch_features', which will be
|
||||||
|
// populated with randomly generated feature values when the call returns.
|
||||||
|
// 'tensors' returns a vector of all tensors used in the initialization,
|
||||||
|
// because they must outlive 'batch_features'.
|
||||||
|
//
|
||||||
|
// All float features will be either missing or uniformly randomly chosen
|
||||||
|
// from [0, 1). For sparse (float) features, a sparsity is uniformly randomly
|
||||||
|
// chosen from ['sparsity_lo', 'sparsity_hi') per feature, and each instance
|
||||||
|
// will have a probability of sparsity of missing that feature, in other words,
|
||||||
|
// sparsity = 1 - density.
|
||||||
|
void RandomlyInitializeBatchFeatures(
|
||||||
|
tensorflow::random::SimplePhilox* rng, uint32 num_dense_float_features,
|
||||||
|
uint32 num_sparse_float_features, double sparsity_lo, double sparsity_hi,
|
||||||
|
boosted_trees::utils::BatchFeatures* batch_features);
|
||||||
|
|
||||||
|
} // namespace testutil
|
||||||
|
} // namespace boosted_trees
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_TESTUTIL_BATCH_FEATURES_TESTUTIL_H_
|
211
tensorflow/contrib/boosted_trees/lib/testutil/random_tree_gen.cc
Normal file
211
tensorflow/contrib/boosted_trees/lib/testutil/random_tree_gen.cc
Normal file
@ -0,0 +1,211 @@
|
|||||||
|
// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
// =============================================================================
|
||||||
|
#include "tensorflow/contrib/boosted_trees/lib/testutil/random_tree_gen.h"
|
||||||
|
|
||||||
|
#include "tensorflow/core/lib/random/philox_random.h"
|
||||||
|
#include "tensorflow/core/lib/random/simple_philox.h"
|
||||||
|
#include "tensorflow/core/platform/logging.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace boosted_trees {
|
||||||
|
namespace testutil {
|
||||||
|
|
||||||
|
using tensorflow::boosted_trees::trees::DecisionTreeConfig;
|
||||||
|
using tensorflow::boosted_trees::trees::TreeNode;
|
||||||
|
using boosted_trees::trees::DenseFloatBinarySplit;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
// Append the given nodes to tree with transfer of pointer ownership.
|
||||||
|
// nodes will not be usable upon return.
|
||||||
|
template <typename T>
|
||||||
|
void AppendNodes(DecisionTreeConfig* tree, T* nodes) {
|
||||||
|
std::reverse(nodes->pointer_begin(), nodes->pointer_end());
|
||||||
|
while (!nodes->empty()) {
|
||||||
|
tree->mutable_nodes()->AddAllocated(nodes->ReleaseLast());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
DenseFloatBinarySplit* GetSplit(TreeNode* node) {
|
||||||
|
switch (node->node_case()) {
|
||||||
|
case TreeNode::kSparseFloatBinarySplitDefaultLeft:
|
||||||
|
return node->mutable_sparse_float_binary_split_default_left()
|
||||||
|
->mutable_split();
|
||||||
|
case TreeNode::kSparseFloatBinarySplitDefaultRight:
|
||||||
|
return node->mutable_sparse_float_binary_split_default_right()
|
||||||
|
->mutable_split();
|
||||||
|
case TreeNode::kDenseFloatBinarySplit:
|
||||||
|
return node->mutable_dense_float_binary_split();
|
||||||
|
default:
|
||||||
|
LOG(FATAL) << "Unknown node type encountered.";
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
RandomTreeGen::RandomTreeGen(tensorflow::random::SimplePhilox* rng,
|
||||||
|
int dense_feature_size, int sparse_feature_size)
|
||||||
|
: rng_(rng),
|
||||||
|
dense_feature_size_(dense_feature_size),
|
||||||
|
sparse_feature_size_(sparse_feature_size) {}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
void AddWeightAndMetadata(
|
||||||
|
boosted_trees::trees::DecisionTreeEnsembleConfig* ret) {
|
||||||
|
// Assign the weight of the tree to 1 and say that this weight was updated
|
||||||
|
// only once.
|
||||||
|
ret->add_tree_weights(1.0);
|
||||||
|
auto* meta = ret->add_tree_metadata();
|
||||||
|
meta->set_num_tree_weight_updates(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
boosted_trees::trees::DecisionTreeEnsembleConfig
|
||||||
|
RandomTreeGen::GenerateEnsemble(int depth, int tree_count) {
|
||||||
|
boosted_trees::trees::DecisionTreeEnsembleConfig ret;
|
||||||
|
*(ret.add_trees()) = Generate(depth);
|
||||||
|
AddWeightAndMetadata(&ret);
|
||||||
|
for (int i = 1; i < tree_count; ++i) {
|
||||||
|
*(ret.add_trees()) = Generate(ret.trees(0));
|
||||||
|
AddWeightAndMetadata(&ret);
|
||||||
|
}
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
DecisionTreeConfig RandomTreeGen::Generate(const DecisionTreeConfig& tree) {
|
||||||
|
DecisionTreeConfig ret = tree;
|
||||||
|
for (auto& node : *ret.mutable_nodes()) {
|
||||||
|
if (node.node_case() == TreeNode::kLeaf) {
|
||||||
|
node.mutable_leaf()->mutable_sparse_vector()->set_value(
|
||||||
|
0, rng_->RandFloat());
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
// Original node is a split. Re-generate it's type but retain the split node
|
||||||
|
// indices.
|
||||||
|
DenseFloatBinarySplit* split = GetSplit(&node);
|
||||||
|
const int left_id = split->left_id();
|
||||||
|
const int right_id = split->right_id();
|
||||||
|
GenerateSplit(&node, left_id, right_id);
|
||||||
|
}
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
DecisionTreeConfig RandomTreeGen::Generate(int depth) {
|
||||||
|
DecisionTreeConfig ret;
|
||||||
|
// Add root,
|
||||||
|
TreeNode* node = ret.add_nodes();
|
||||||
|
GenerateSplit(node, 1, 2);
|
||||||
|
if (depth == 1) {
|
||||||
|
// Add left and right leaves.
|
||||||
|
TreeNode* left = ret.add_nodes();
|
||||||
|
left->mutable_leaf()->mutable_sparse_vector()->add_index(0);
|
||||||
|
left->mutable_leaf()->mutable_sparse_vector()->add_value(rng_->RandFloat());
|
||||||
|
TreeNode* right = ret.add_nodes();
|
||||||
|
right->mutable_leaf()->mutable_sparse_vector()->add_index(0);
|
||||||
|
right->mutable_leaf()->mutable_sparse_vector()->add_value(
|
||||||
|
rng_->RandFloat());
|
||||||
|
return ret;
|
||||||
|
} else {
|
||||||
|
DecisionTreeConfig left_branch = Generate(depth - 1);
|
||||||
|
DecisionTreeConfig right_branch = Generate(depth - 1);
|
||||||
|
Combine(&ret, &left_branch, &right_branch);
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void RandomTreeGen::Combine(DecisionTreeConfig* root,
|
||||||
|
DecisionTreeConfig* left_branch,
|
||||||
|
DecisionTreeConfig* right_branch) {
|
||||||
|
const int left_branch_size = left_branch->nodes_size();
|
||||||
|
CHECK_EQ(1, root->nodes_size());
|
||||||
|
// left_branch starts its index at 1. right_branch starts its index at
|
||||||
|
// (left_branch_size + 1).
|
||||||
|
auto* root_node = root->mutable_nodes(0);
|
||||||
|
DenseFloatBinarySplit* root_split = GetSplit(root_node);
|
||||||
|
root_split->set_left_id(1);
|
||||||
|
root_split->set_right_id(left_branch_size + 1);
|
||||||
|
// Shift left/right branch's indices internally so that everything is
|
||||||
|
// consistent.
|
||||||
|
ShiftNodeIndex(left_branch, 1);
|
||||||
|
ShiftNodeIndex(right_branch, left_branch_size + 1);
|
||||||
|
|
||||||
|
// Complexity O(branch node size). No proto copying though.
|
||||||
|
AppendNodes(root, left_branch->mutable_nodes());
|
||||||
|
AppendNodes(root, right_branch->mutable_nodes());
|
||||||
|
}
|
||||||
|
|
||||||
|
void RandomTreeGen::ShiftNodeIndex(DecisionTreeConfig* tree, int shift) {
|
||||||
|
for (TreeNode& node : *(tree->mutable_nodes())) {
|
||||||
|
DenseFloatBinarySplit* split = nullptr;
|
||||||
|
switch (node.node_case()) {
|
||||||
|
case TreeNode::kLeaf:
|
||||||
|
break;
|
||||||
|
case TreeNode::kSparseFloatBinarySplitDefaultLeft:
|
||||||
|
split = node.mutable_sparse_float_binary_split_default_left()
|
||||||
|
->mutable_split();
|
||||||
|
break;
|
||||||
|
case TreeNode::kSparseFloatBinarySplitDefaultRight:
|
||||||
|
split = node.mutable_sparse_float_binary_split_default_right()
|
||||||
|
->mutable_split();
|
||||||
|
break;
|
||||||
|
case TreeNode::kDenseFloatBinarySplit:
|
||||||
|
split = node.mutable_dense_float_binary_split();
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
LOG(FATAL) << "Unknown node type encountered.";
|
||||||
|
}
|
||||||
|
if (split) {
|
||||||
|
split->set_left_id(shift + split->left_id());
|
||||||
|
split->set_right_id(shift + split->right_id());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void RandomTreeGen::GenerateSplit(TreeNode* node, int left_id, int right_id) {
|
||||||
|
const double denseSplitProb =
|
||||||
|
sparse_feature_size_ == 0
|
||||||
|
? 1.0
|
||||||
|
: static_cast<double>(dense_feature_size_) /
|
||||||
|
(dense_feature_size_ + sparse_feature_size_);
|
||||||
|
// Generate the tree such that it has equal probability of going left and
|
||||||
|
// right when the feature is missing.
|
||||||
|
static constexpr float kLeftProb = 0.5;
|
||||||
|
|
||||||
|
DenseFloatBinarySplit* split;
|
||||||
|
int feature_size;
|
||||||
|
if (rng_->RandFloat() < denseSplitProb) {
|
||||||
|
feature_size = dense_feature_size_;
|
||||||
|
split = node->mutable_dense_float_binary_split();
|
||||||
|
} else {
|
||||||
|
feature_size = sparse_feature_size_;
|
||||||
|
if (rng_->RandFloat() < kLeftProb) {
|
||||||
|
split = node->mutable_sparse_float_binary_split_default_left()
|
||||||
|
->mutable_split();
|
||||||
|
} else {
|
||||||
|
split = node->mutable_sparse_float_binary_split_default_right()
|
||||||
|
->mutable_split();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
split->set_threshold(rng_->RandFloat());
|
||||||
|
split->set_feature_column(rng_->Uniform(feature_size));
|
||||||
|
split->set_left_id(left_id);
|
||||||
|
split->set_right_id(right_id);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace testutil
|
||||||
|
} // namespace boosted_trees
|
||||||
|
} // namespace tensorflow
|
@ -0,0 +1,75 @@
|
|||||||
|
// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
// =============================================================================
|
||||||
|
#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_TESTUTIL_RANDOM_TREE_GEN_H_
|
||||||
|
#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_TESTUTIL_RANDOM_TREE_GEN_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include "tensorflow/contrib/boosted_trees/proto/tree_config.pb.h" // NOLINT
|
||||||
|
#include "tensorflow/core/lib/random/simple_philox.h"
|
||||||
|
#include "tensorflow/core/platform/macros.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace boosted_trees {
|
||||||
|
namespace testutil {
|
||||||
|
|
||||||
|
// Randomly generate a balanced tree, for performance benchmarking purposes,
|
||||||
|
// that assume all features are sparse float features, for now.
|
||||||
|
class RandomTreeGen {
|
||||||
|
public:
|
||||||
|
RandomTreeGen(tensorflow::random::SimplePhilox* rng, int dense_feature_size,
|
||||||
|
int sparse_feature_size);
|
||||||
|
|
||||||
|
// Required: depth must be >= 1.
|
||||||
|
// If one wants to generate multiple trees with the same depth, see also the
|
||||||
|
// overload below.
|
||||||
|
boosted_trees::trees::DecisionTreeConfig Generate(int depth);
|
||||||
|
|
||||||
|
// Randomly generate a new tree with the same depth (and tree structure)
|
||||||
|
// as the given tree. This is faster.
|
||||||
|
boosted_trees::trees::DecisionTreeConfig Generate(
|
||||||
|
const boosted_trees::trees::DecisionTreeConfig& tree);
|
||||||
|
|
||||||
|
// Requried: depth >= 1; tree_count >= 1.
|
||||||
|
boosted_trees::trees::DecisionTreeEnsembleConfig GenerateEnsemble(
|
||||||
|
int dept, int tree_count);
|
||||||
|
|
||||||
|
private:
|
||||||
|
tensorflow::random::SimplePhilox* rng_;
|
||||||
|
const int dense_feature_size_;
|
||||||
|
const int sparse_feature_size_;
|
||||||
|
|
||||||
|
// Put together a deeper tree by combining two trees.
|
||||||
|
void Combine(boosted_trees::trees::DecisionTreeConfig* root,
|
||||||
|
boosted_trees::trees::DecisionTreeConfig* left_branch,
|
||||||
|
boosted_trees::trees::DecisionTreeConfig* right_branch);
|
||||||
|
|
||||||
|
// For each node in the provided tree, shift its referenced left/right index
|
||||||
|
// by shift.
|
||||||
|
void ShiftNodeIndex(boosted_trees::trees::DecisionTreeConfig* tree,
|
||||||
|
int shift);
|
||||||
|
|
||||||
|
// Generate a sparse split in the node.
|
||||||
|
void GenerateSplit(boosted_trees::trees::TreeNode* node, int left_id,
|
||||||
|
int right_id);
|
||||||
|
|
||||||
|
TF_DISALLOW_COPY_AND_ASSIGN(RandomTreeGen);
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace testutil
|
||||||
|
} // namespace boosted_trees
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_TESTUTIL_RANDOM_TREE_GEN_H_
|
@ -0,0 +1,67 @@
|
|||||||
|
// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
// =============================================================================
|
||||||
|
// Randomly generate a tree ensemble and write to file.
|
||||||
|
|
||||||
|
#include "tensorflow/contrib/boosted_trees/lib/testutil/random_tree_gen.h"
|
||||||
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
|
#include "tensorflow/core/lib/random/philox_random.h"
|
||||||
|
#include "tensorflow/core/lib/random/simple_philox.h"
|
||||||
|
#include "tensorflow/core/platform/env.h"
|
||||||
|
#include "tensorflow/core/platform/init_main.h"
|
||||||
|
#include "tensorflow/core/platform/logging.h"
|
||||||
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
#include "tensorflow/core/util/command_line_flags.h"
|
||||||
|
|
||||||
|
using tensorflow::Flag;
|
||||||
|
using tensorflow::Flags;
|
||||||
|
using tensorflow::int32;
|
||||||
|
using tensorflow::string;
|
||||||
|
|
||||||
|
int main(int argc, char* argv[]) {
|
||||||
|
int32 dense_feature_size = 100;
|
||||||
|
int32 sparse_feature_size = 100;
|
||||||
|
int32 depth = 8;
|
||||||
|
int32 tree_count = 10;
|
||||||
|
string filename = "/tmp/trees.pb";
|
||||||
|
std::vector<Flag> flag_list = {
|
||||||
|
Flag("dense_feature_size", &dense_feature_size, "dense feature size"),
|
||||||
|
Flag("sparse_feature_size", &sparse_feature_size, "sparse_feature_size"),
|
||||||
|
Flag("depth", &depth, "tree depth"),
|
||||||
|
Flag("tree_count", &tree_count, "tree count"),
|
||||||
|
Flag("filename", &filename, "Output filename."),
|
||||||
|
};
|
||||||
|
string usage = Flags::Usage(argv[0], flag_list);
|
||||||
|
const bool parse_result = Flags::Parse(&argc, argv, flag_list);
|
||||||
|
// We need to call this to set up global state for TensorFlow.
|
||||||
|
tensorflow::port::InitMain(usage.c_str(), &argc, &argv);
|
||||||
|
if (!parse_result) {
|
||||||
|
LOG(ERROR) << "\n" << usage;
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
tensorflow::random::PhiloxRandom philox(1);
|
||||||
|
tensorflow::random::SimplePhilox rng(&philox);
|
||||||
|
tensorflow::boosted_trees::testutil::RandomTreeGen tree_gen(
|
||||||
|
&rng, dense_feature_size, sparse_feature_size);
|
||||||
|
const auto& trees = tree_gen.GenerateEnsemble(depth, tree_count);
|
||||||
|
tensorflow::Status status =
|
||||||
|
tensorflow::WriteBinaryProto(tensorflow::Env::Default(), filename, trees);
|
||||||
|
if (!status.ok()) {
|
||||||
|
LOG(WARNING) << "Failed to write: " << filename << " : " << status;
|
||||||
|
} else {
|
||||||
|
LOG(INFO) << "Tree ensemble written to: " << filename;
|
||||||
|
}
|
||||||
|
return 0;
|
||||||
|
}
|
170
tensorflow/contrib/boosted_trees/lib/trees/decision_tree.cc
Normal file
170
tensorflow/contrib/boosted_trees/lib/trees/decision_tree.cc
Normal file
@ -0,0 +1,170 @@
|
|||||||
|
// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
// =============================================================================
|
||||||
|
#include "tensorflow/contrib/boosted_trees/lib/trees/decision_tree.h"
|
||||||
|
#include "tensorflow/core/platform/macros.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace boosted_trees {
|
||||||
|
namespace trees {
|
||||||
|
|
||||||
|
constexpr int kInvalidLeaf = -1;
|
||||||
|
int DecisionTree::Traverse(const DecisionTreeConfig& config,
|
||||||
|
const int32 sub_root_id,
|
||||||
|
const utils::Example& example) {
|
||||||
|
if (TF_PREDICT_FALSE(config.nodes_size() <= sub_root_id)) {
|
||||||
|
return kInvalidLeaf;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Traverse tree starting at the provided sub-root.
|
||||||
|
int32 node_id = sub_root_id;
|
||||||
|
while (true) {
|
||||||
|
const auto& current_node = config.nodes(node_id);
|
||||||
|
switch (current_node.node_case()) {
|
||||||
|
case TreeNode::kLeaf: {
|
||||||
|
return node_id;
|
||||||
|
}
|
||||||
|
case TreeNode::kDenseFloatBinarySplit: {
|
||||||
|
const auto& split = current_node.dense_float_binary_split();
|
||||||
|
node_id = example.dense_float_features[split.feature_column()] <=
|
||||||
|
split.threshold()
|
||||||
|
? split.left_id()
|
||||||
|
: split.right_id();
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case TreeNode::kSparseFloatBinarySplitDefaultLeft: {
|
||||||
|
const auto& split =
|
||||||
|
current_node.sparse_float_binary_split_default_left().split();
|
||||||
|
auto sparse_feature =
|
||||||
|
example.sparse_float_features[split.feature_column()];
|
||||||
|
node_id = !sparse_feature.has_value() ||
|
||||||
|
sparse_feature.get_value() <= split.threshold()
|
||||||
|
? split.left_id()
|
||||||
|
: split.right_id();
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case TreeNode::kSparseFloatBinarySplitDefaultRight: {
|
||||||
|
const auto& split =
|
||||||
|
current_node.sparse_float_binary_split_default_right().split();
|
||||||
|
auto sparse_feature =
|
||||||
|
example.sparse_float_features[split.feature_column()];
|
||||||
|
node_id = sparse_feature.has_value() &&
|
||||||
|
sparse_feature.get_value() <= split.threshold()
|
||||||
|
? split.left_id()
|
||||||
|
: split.right_id();
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case TreeNode::kCategoricalIdBinarySplit: {
|
||||||
|
const auto& split = current_node.categorical_id_binary_split();
|
||||||
|
node_id = example.sparse_int_features[split.feature_column()].count(
|
||||||
|
split.feature_id()) > 0
|
||||||
|
? split.left_id()
|
||||||
|
: split.right_id();
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case TreeNode::NODE_NOT_SET: {
|
||||||
|
QCHECK(false) << "Invalid node in tree: " << current_node.DebugString();
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
DCHECK_NE(node_id, 0) << "Malformed tree, cycles found to root:"
|
||||||
|
<< current_node.DebugString();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void DecisionTree::LinkChildren(const std::vector<int32>& children,
|
||||||
|
TreeNode* parent_node) {
|
||||||
|
// Decide how to link children depending on the parent node's type.
|
||||||
|
auto children_it = children.begin();
|
||||||
|
switch (parent_node->node_case()) {
|
||||||
|
case TreeNode::kLeaf: {
|
||||||
|
// Essentially no-op.
|
||||||
|
QCHECK(children.empty()) << "A leaf node cannot have children.";
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case TreeNode::kDenseFloatBinarySplit: {
|
||||||
|
QCHECK(children.size() == 2)
|
||||||
|
<< "A binary split node must have exactly two children.";
|
||||||
|
auto* split = parent_node->mutable_dense_float_binary_split();
|
||||||
|
split->set_left_id(*children_it);
|
||||||
|
split->set_right_id(*++children_it);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case TreeNode::kSparseFloatBinarySplitDefaultLeft: {
|
||||||
|
QCHECK(children.size() == 2)
|
||||||
|
<< "A binary split node must have exactly two children.";
|
||||||
|
auto* split =
|
||||||
|
parent_node->mutable_sparse_float_binary_split_default_left()
|
||||||
|
->mutable_split();
|
||||||
|
split->set_left_id(*children_it);
|
||||||
|
split->set_right_id(*++children_it);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case TreeNode::kSparseFloatBinarySplitDefaultRight: {
|
||||||
|
QCHECK(children.size() == 2)
|
||||||
|
<< "A binary split node must have exactly two children.";
|
||||||
|
auto* split =
|
||||||
|
parent_node->mutable_sparse_float_binary_split_default_right()
|
||||||
|
->mutable_split();
|
||||||
|
split->set_left_id(*children_it);
|
||||||
|
split->set_right_id(*++children_it);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case TreeNode::kCategoricalIdBinarySplit: {
|
||||||
|
QCHECK(children.size() == 2)
|
||||||
|
<< "A binary split node must have exactly two children.";
|
||||||
|
auto* split = parent_node->mutable_categorical_id_binary_split();
|
||||||
|
split->set_left_id(*children_it);
|
||||||
|
split->set_right_id(*++children_it);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case TreeNode::NODE_NOT_SET: {
|
||||||
|
QCHECK(false) << "A non-set node cannot have children.";
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int32> DecisionTree::GetChildren(const TreeNode& node) {
|
||||||
|
// A node's children depend on its type.
|
||||||
|
switch (node.node_case()) {
|
||||||
|
case TreeNode::kLeaf: {
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
case TreeNode::kDenseFloatBinarySplit: {
|
||||||
|
const auto& split = node.dense_float_binary_split();
|
||||||
|
return {split.left_id(), split.right_id()};
|
||||||
|
}
|
||||||
|
case TreeNode::kSparseFloatBinarySplitDefaultLeft: {
|
||||||
|
const auto& split = node.sparse_float_binary_split_default_left().split();
|
||||||
|
return {split.left_id(), split.right_id()};
|
||||||
|
}
|
||||||
|
case TreeNode::kSparseFloatBinarySplitDefaultRight: {
|
||||||
|
const auto& split =
|
||||||
|
node.sparse_float_binary_split_default_right().split();
|
||||||
|
return {split.left_id(), split.right_id()};
|
||||||
|
}
|
||||||
|
case TreeNode::kCategoricalIdBinarySplit: {
|
||||||
|
const auto& split = node.categorical_id_binary_split();
|
||||||
|
return {split.left_id(), split.right_id()};
|
||||||
|
}
|
||||||
|
case TreeNode::NODE_NOT_SET: {
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace trees
|
||||||
|
} // namespace boosted_trees
|
||||||
|
} // namespace tensorflow
|
49
tensorflow/contrib/boosted_trees/lib/trees/decision_tree.h
Normal file
49
tensorflow/contrib/boosted_trees/lib/trees/decision_tree.h
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
// =============================================================================
|
||||||
|
#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_TREES_DECISION_TREE_H_
|
||||||
|
#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_TREES_DECISION_TREE_H_
|
||||||
|
|
||||||
|
#include "tensorflow/contrib/boosted_trees/lib/utils/example.h"
|
||||||
|
#include "tensorflow/contrib/boosted_trees/proto/tree_config.pb.h" // NOLINT
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace boosted_trees {
|
||||||
|
namespace trees {
|
||||||
|
|
||||||
|
// Decision tree class to encapsulate tree traversal and mutation logic.
|
||||||
|
// This class does not hold state and is thread safe.
|
||||||
|
class DecisionTree {
|
||||||
|
public:
|
||||||
|
// Traverse given an instance, a sub-root and its set of features
|
||||||
|
// and return the leaf index or -1 if the tree is empty or
|
||||||
|
// the sub-root is invalid.
|
||||||
|
static int Traverse(const DecisionTreeConfig& config, int32 sub_root_id,
|
||||||
|
const utils::Example& example);
|
||||||
|
|
||||||
|
// Links the specified children to the parent, the children must
|
||||||
|
// already be added to the decision tree config so this method
|
||||||
|
// just ensures nodes are re-linked.
|
||||||
|
static void LinkChildren(const std::vector<int32>& children,
|
||||||
|
TreeNode* parent_node);
|
||||||
|
|
||||||
|
// Retrieves node children indices if any.
|
||||||
|
static std::vector<int32> GetChildren(const TreeNode& node);
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace trees
|
||||||
|
} // namespace boosted_trees
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_TREES_DECISION_TREE_H_
|
326
tensorflow/contrib/boosted_trees/lib/trees/decision_tree_test.cc
Normal file
326
tensorflow/contrib/boosted_trees/lib/trees/decision_tree_test.cc
Normal file
@ -0,0 +1,326 @@
|
|||||||
|
// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
// =============================================================================
|
||||||
|
#include "tensorflow/contrib/boosted_trees/lib/trees/decision_tree.h"
|
||||||
|
#include "tensorflow/contrib/boosted_trees/lib/utils/batch_features.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||||
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace boosted_trees {
|
||||||
|
namespace trees {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
class DecisionTreeTest : public ::testing::Test {
|
||||||
|
protected:
|
||||||
|
DecisionTreeTest() : batch_features_(2) {
|
||||||
|
// Create a batch of two examples having one dense float, two sparse float
|
||||||
|
// and one sparse int features.
|
||||||
|
// The first example is missing the second sparse feature column and the
|
||||||
|
// second example is missing the first sparse feature column.
|
||||||
|
// This looks like the following:
|
||||||
|
// Instance | DenseF1 | SparseF1 | SparseF2 | SparseI1 |
|
||||||
|
// 0 | 7 | -3 | | 3 |
|
||||||
|
// 1 | -2 | | 4 | |
|
||||||
|
auto dense_float_matrix = test::AsTensor<float>({7.0f, -2.0f}, {2, 1});
|
||||||
|
auto sparse_float_indices1 = test::AsTensor<int64>({0, 0}, {1, 2});
|
||||||
|
auto sparse_float_values1 = test::AsTensor<float>({-3.0f});
|
||||||
|
auto sparse_float_shape1 = test::AsTensor<int64>({2, 1});
|
||||||
|
auto sparse_float_indices2 = test::AsTensor<int64>({1, 0}, {1, 2});
|
||||||
|
auto sparse_float_values2 = test::AsTensor<float>({4.0f});
|
||||||
|
auto sparse_float_shape2 = test::AsTensor<int64>({2, 1});
|
||||||
|
auto sparse_int_indices1 = test::AsTensor<int64>({0, 0}, {1, 2});
|
||||||
|
auto sparse_int_values1 = test::AsTensor<int64>({3});
|
||||||
|
auto sparse_int_shape1 = test::AsTensor<int64>({2, 1});
|
||||||
|
TF_EXPECT_OK(batch_features_.Initialize(
|
||||||
|
{dense_float_matrix}, {sparse_float_indices1, sparse_float_indices2},
|
||||||
|
{sparse_float_values1, sparse_float_values2},
|
||||||
|
{sparse_float_shape1, sparse_float_shape2}, {sparse_int_indices1},
|
||||||
|
{sparse_int_values1}, {sparse_int_shape1}));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename SplitType>
|
||||||
|
void TestLinkChildrenBinary(TreeNode* node, SplitType* split) {
|
||||||
|
// Verify children were linked.
|
||||||
|
DecisionTree::LinkChildren({3, 8}, node);
|
||||||
|
EXPECT_EQ(3, split->left_id());
|
||||||
|
EXPECT_EQ(8, split->right_id());
|
||||||
|
|
||||||
|
// Invalid cases.
|
||||||
|
EXPECT_DEATH(DecisionTree::LinkChildren({}, node),
|
||||||
|
"A binary split node must have exactly two children.");
|
||||||
|
EXPECT_DEATH(DecisionTree::LinkChildren({3}, node),
|
||||||
|
"A binary split node must have exactly two children.");
|
||||||
|
EXPECT_DEATH(DecisionTree::LinkChildren({1, 2, 3}, node),
|
||||||
|
"A binary split node must have exactly two children.");
|
||||||
|
}
|
||||||
|
|
||||||
|
void TestGetChildren(const TreeNode& node,
|
||||||
|
const std::vector<uint32>& expected_children) {
|
||||||
|
// Verify children were linked.
|
||||||
|
auto children = DecisionTree::GetChildren(node);
|
||||||
|
EXPECT_EQ(children.size(), expected_children.size());
|
||||||
|
for (size_t idx = 0; idx < children.size(); ++idx) {
|
||||||
|
EXPECT_EQ(children[idx], expected_children[idx]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
utils::BatchFeatures batch_features_;
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(DecisionTreeTest, TraverseEmpty) {
|
||||||
|
DecisionTreeConfig tree_config;
|
||||||
|
auto example = (*batch_features_.examples_iterable(0, 1).begin());
|
||||||
|
EXPECT_EQ(-1, DecisionTree::Traverse(tree_config, 0, example));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DecisionTreeTest, TraverseBias) {
|
||||||
|
DecisionTreeConfig tree_config;
|
||||||
|
tree_config.add_nodes()->mutable_leaf();
|
||||||
|
auto example = (*batch_features_.examples_iterable(0, 1).begin());
|
||||||
|
EXPECT_EQ(0, DecisionTree::Traverse(tree_config, 0, example));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DecisionTreeTest, TraverseInvalidSubRoot) {
|
||||||
|
DecisionTreeConfig tree_config;
|
||||||
|
tree_config.add_nodes()->mutable_leaf();
|
||||||
|
auto example = (*batch_features_.examples_iterable(0, 1).begin());
|
||||||
|
EXPECT_EQ(-1, DecisionTree::Traverse(tree_config, 10, example));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DecisionTreeTest, TraverseDenseBinarySplit) {
|
||||||
|
DecisionTreeConfig tree_config;
|
||||||
|
auto* split_node =
|
||||||
|
tree_config.add_nodes()->mutable_dense_float_binary_split();
|
||||||
|
split_node->set_feature_column(0);
|
||||||
|
split_node->set_threshold(0.0f);
|
||||||
|
split_node->set_left_id(1);
|
||||||
|
split_node->set_right_id(2);
|
||||||
|
tree_config.add_nodes()->mutable_leaf();
|
||||||
|
tree_config.add_nodes()->mutable_leaf();
|
||||||
|
auto example_iterable = batch_features_.examples_iterable(0, 2);
|
||||||
|
|
||||||
|
// Expect right child to be picked as !(7 <= 0);
|
||||||
|
auto example_it = example_iterable.begin();
|
||||||
|
EXPECT_EQ(2, DecisionTree::Traverse(tree_config, 0, *example_it));
|
||||||
|
|
||||||
|
// Expect left child to be picked as (-2 <= 0);
|
||||||
|
EXPECT_EQ(1, DecisionTree::Traverse(tree_config, 0, *++example_it));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DecisionTreeTest, TraverseSparseBinarySplit) {
|
||||||
|
// Test first sparse feature which is missing for the second example.
|
||||||
|
DecisionTreeConfig tree_config1;
|
||||||
|
auto* split_node1 = tree_config1.add_nodes()
|
||||||
|
->mutable_sparse_float_binary_split_default_left()
|
||||||
|
->mutable_split();
|
||||||
|
split_node1->set_feature_column(0);
|
||||||
|
split_node1->set_threshold(-20.0f);
|
||||||
|
split_node1->set_left_id(1);
|
||||||
|
split_node1->set_right_id(2);
|
||||||
|
tree_config1.add_nodes()->mutable_leaf();
|
||||||
|
tree_config1.add_nodes()->mutable_leaf();
|
||||||
|
auto example_iterable = batch_features_.examples_iterable(0, 2);
|
||||||
|
|
||||||
|
// Expect right child to be picked as !(-3 <= -20).
|
||||||
|
auto example_it = example_iterable.begin();
|
||||||
|
EXPECT_EQ(2, DecisionTree::Traverse(tree_config1, 0, *example_it));
|
||||||
|
|
||||||
|
// Expect left child to be picked as default direction.
|
||||||
|
EXPECT_EQ(1, DecisionTree::Traverse(tree_config1, 0, *++example_it));
|
||||||
|
|
||||||
|
// Test second sparse feature which is missing for the first example.
|
||||||
|
DecisionTreeConfig tree_config2;
|
||||||
|
auto* split_node2 = tree_config2.add_nodes()
|
||||||
|
->mutable_sparse_float_binary_split_default_right()
|
||||||
|
->mutable_split();
|
||||||
|
split_node2->set_feature_column(1);
|
||||||
|
split_node2->set_threshold(4.0f);
|
||||||
|
split_node2->set_left_id(1);
|
||||||
|
split_node2->set_right_id(2);
|
||||||
|
tree_config2.add_nodes()->mutable_leaf();
|
||||||
|
tree_config2.add_nodes()->mutable_leaf();
|
||||||
|
|
||||||
|
// Expect right child to be picked as default direction.
|
||||||
|
example_it = example_iterable.begin();
|
||||||
|
EXPECT_EQ(2, DecisionTree::Traverse(tree_config2, 0, *example_it));
|
||||||
|
|
||||||
|
// Expect left child to be picked as (4 <= 4).
|
||||||
|
EXPECT_EQ(1, DecisionTree::Traverse(tree_config2, 0, *++example_it));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DecisionTreeTest, TraverseCategoricalIdBinarySplit) {
|
||||||
|
DecisionTreeConfig tree_config;
|
||||||
|
auto* split_node =
|
||||||
|
tree_config.add_nodes()->mutable_categorical_id_binary_split();
|
||||||
|
split_node->set_feature_column(0);
|
||||||
|
split_node->set_feature_id(3);
|
||||||
|
split_node->set_left_id(1);
|
||||||
|
split_node->set_right_id(2);
|
||||||
|
tree_config.add_nodes()->mutable_leaf();
|
||||||
|
tree_config.add_nodes()->mutable_leaf();
|
||||||
|
auto example_iterable = batch_features_.examples_iterable(0, 2);
|
||||||
|
|
||||||
|
// Expect left child to be picked as 3 == 3;
|
||||||
|
auto example_it = example_iterable.begin();
|
||||||
|
EXPECT_EQ(1, DecisionTree::Traverse(tree_config, 0, *example_it));
|
||||||
|
|
||||||
|
// Expect right child to be picked as the feature is missing;
|
||||||
|
EXPECT_EQ(2, DecisionTree::Traverse(tree_config, 0, *++example_it));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DecisionTreeTest, TraverseHybridSplits) {
|
||||||
|
DecisionTreeConfig tree_config;
|
||||||
|
auto* split_node1 =
|
||||||
|
tree_config.add_nodes()->mutable_dense_float_binary_split();
|
||||||
|
split_node1->set_feature_column(0);
|
||||||
|
split_node1->set_threshold(9.0f);
|
||||||
|
split_node1->set_left_id(1); // sparse split.
|
||||||
|
split_node1->set_right_id(2); // leaf
|
||||||
|
auto* split_node2 = tree_config.add_nodes()
|
||||||
|
->mutable_sparse_float_binary_split_default_left()
|
||||||
|
->mutable_split();
|
||||||
|
tree_config.add_nodes()->mutable_leaf();
|
||||||
|
split_node2->set_feature_column(0);
|
||||||
|
split_node2->set_threshold(-20.0f);
|
||||||
|
split_node2->set_left_id(3);
|
||||||
|
split_node2->set_right_id(4);
|
||||||
|
auto* split_node3 =
|
||||||
|
tree_config.add_nodes()->mutable_categorical_id_binary_split();
|
||||||
|
split_node3->set_feature_column(0);
|
||||||
|
split_node3->set_feature_id(2);
|
||||||
|
split_node3->set_left_id(5);
|
||||||
|
split_node3->set_right_id(6);
|
||||||
|
tree_config.add_nodes()->mutable_leaf();
|
||||||
|
tree_config.add_nodes()->mutable_leaf();
|
||||||
|
tree_config.add_nodes()->mutable_leaf();
|
||||||
|
auto example_iterable = batch_features_.examples_iterable(0, 2);
|
||||||
|
|
||||||
|
// Expect will go left through the first dense split as (7.0f <= 9.0f),
|
||||||
|
// then will go right through the sparse split as !(-3 <= -20).
|
||||||
|
auto example_it = example_iterable.begin();
|
||||||
|
EXPECT_EQ(4, DecisionTree::Traverse(tree_config, 0, *example_it));
|
||||||
|
|
||||||
|
// Expect will go left through the first dense split as (-2.0f <= 9.0f),
|
||||||
|
// then will go left the default direction as the sparse feature is missing,
|
||||||
|
// then will go right as 2 != 3 on the categorical split.
|
||||||
|
EXPECT_EQ(6, DecisionTree::Traverse(tree_config, 0, *++example_it));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DecisionTreeTest, LinkChildrenLeaf) {
|
||||||
|
// Create leaf node.
|
||||||
|
TreeNode node;
|
||||||
|
node.mutable_leaf();
|
||||||
|
|
||||||
|
// No-op.
|
||||||
|
DecisionTree::LinkChildren({}, &node);
|
||||||
|
|
||||||
|
// Invalid case.
|
||||||
|
EXPECT_DEATH(DecisionTree::LinkChildren({1}, &node),
|
||||||
|
"A leaf node cannot have children.");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DecisionTreeTest, LinkChildrenDenseFloatBinarySplit) {
|
||||||
|
TreeNode node;
|
||||||
|
auto* split = node.mutable_dense_float_binary_split();
|
||||||
|
split->set_left_id(-1);
|
||||||
|
split->set_right_id(-1);
|
||||||
|
TestLinkChildrenBinary(&node, split);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DecisionTreeTest, LinkChildrenSparseFloatBinarySplitDefaultLeft) {
|
||||||
|
TreeNode node;
|
||||||
|
auto* split =
|
||||||
|
node.mutable_sparse_float_binary_split_default_left()->mutable_split();
|
||||||
|
split->set_left_id(-1);
|
||||||
|
split->set_right_id(-1);
|
||||||
|
TestLinkChildrenBinary(&node, split);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DecisionTreeTest, LinkChildrenSparseFloatBinarySplitDefaultRight) {
|
||||||
|
TreeNode node;
|
||||||
|
auto* split =
|
||||||
|
node.mutable_sparse_float_binary_split_default_right()->mutable_split();
|
||||||
|
split->set_left_id(-1);
|
||||||
|
split->set_right_id(-1);
|
||||||
|
TestLinkChildrenBinary(&node, split);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DecisionTreeTest, LinkChildrenCategoricalSingleIdBinarySplit) {
|
||||||
|
TreeNode node;
|
||||||
|
auto* split = node.mutable_categorical_id_binary_split();
|
||||||
|
split->set_left_id(-1);
|
||||||
|
split->set_right_id(-1);
|
||||||
|
TestLinkChildrenBinary(&node, split);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DecisionTreeTest, LinkChildrenNodeNotSet) {
|
||||||
|
// Create unset node.
|
||||||
|
TreeNode node;
|
||||||
|
|
||||||
|
// Invalid case.
|
||||||
|
EXPECT_DEATH(DecisionTree::LinkChildren({1}, &node),
|
||||||
|
"A non-set node cannot have children.");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DecisionTreeTest, GetChildrenLeaf) {
|
||||||
|
TreeNode node;
|
||||||
|
node.mutable_leaf();
|
||||||
|
TestGetChildren(node, {});
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DecisionTreeTest, GetChildrenDenseFloatBinarySplit) {
|
||||||
|
TreeNode node;
|
||||||
|
auto* split = node.mutable_dense_float_binary_split();
|
||||||
|
split->set_left_id(23);
|
||||||
|
split->set_right_id(24);
|
||||||
|
TestGetChildren(node, {23, 24});
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DecisionTreeTest, GetChildrenSparseFloatBinarySplitDefaultLeft) {
|
||||||
|
TreeNode node;
|
||||||
|
auto* split =
|
||||||
|
node.mutable_sparse_float_binary_split_default_left()->mutable_split();
|
||||||
|
split->set_left_id(12);
|
||||||
|
split->set_right_id(13);
|
||||||
|
TestGetChildren(node, {12, 13});
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DecisionTreeTest, GetChildrenSparseFloatBinarySplitDefaultRight) {
|
||||||
|
TreeNode node;
|
||||||
|
auto* split =
|
||||||
|
node.mutable_sparse_float_binary_split_default_right()->mutable_split();
|
||||||
|
split->set_left_id(1);
|
||||||
|
split->set_right_id(2);
|
||||||
|
TestGetChildren(node, {1, 2});
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DecisionTreeTest, GetChildrenCategoricalSingleIdBinarySplit) {
|
||||||
|
TreeNode node;
|
||||||
|
auto* split = node.mutable_categorical_id_binary_split();
|
||||||
|
split->set_left_id(7);
|
||||||
|
split->set_right_id(8);
|
||||||
|
TestGetChildren(node, {7, 8});
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DecisionTreeTest, GetChildrenNodeNotSet) {
|
||||||
|
TreeNode node;
|
||||||
|
TestGetChildren(node, {});
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace trees
|
||||||
|
} // namespace boosted_trees
|
||||||
|
} // namespace tensorflow
|
@ -24,6 +24,15 @@ tf_proto_library(
|
|||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_proto_library(
|
||||||
|
name = "quantiles_proto",
|
||||||
|
srcs = [
|
||||||
|
"quantiles.proto",
|
||||||
|
],
|
||||||
|
cc_api_version = 2,
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
)
|
||||||
|
|
||||||
tf_proto_library(
|
tf_proto_library(
|
||||||
name = "tree_config_proto",
|
name = "tree_config_proto",
|
||||||
srcs = ["tree_config.proto"],
|
srcs = ["tree_config.proto"],
|
||||||
|
32
tensorflow/contrib/boosted_trees/proto/quantiles.proto
Normal file
32
tensorflow/contrib/boosted_trees/proto/quantiles.proto
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
syntax = "proto3";
|
||||||
|
|
||||||
|
option cc_enable_arenas = true;
|
||||||
|
|
||||||
|
package boosted_trees;
|
||||||
|
|
||||||
|
message QuantileConfig {
|
||||||
|
// Maximum eps error when computing quantile summaries.
|
||||||
|
double eps = 1;
|
||||||
|
// Number of quantiles to generate.
|
||||||
|
int64 num_quantiles = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
message QuantileEntry {
|
||||||
|
// Value for the entry.
|
||||||
|
float value = 1;
|
||||||
|
// Weight for the entry.
|
||||||
|
float weight = 2;
|
||||||
|
// We need the minimum and maximum rank possible for this entry.
|
||||||
|
// Rank is 0.0 for the absolute minimum and sum of the weights for the maximum
|
||||||
|
// value in the input.
|
||||||
|
float min_rank = 3;
|
||||||
|
float max_rank = 4;
|
||||||
|
}
|
||||||
|
|
||||||
|
message QuantileSummaryState {
|
||||||
|
repeated QuantileEntry entries = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
message QuantileStreamState {
|
||||||
|
repeated QuantileSummaryState summaries = 1;
|
||||||
|
}
|
53
tensorflow/contrib/boosted_trees/resources/BUILD
Normal file
53
tensorflow/contrib/boosted_trees/resources/BUILD
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
licenses(["notice"]) # Apache 2.0
|
||||||
|
|
||||||
|
exports_files(["LICENSE"])
|
||||||
|
|
||||||
|
package(
|
||||||
|
default_visibility = [
|
||||||
|
"//tensorflow/contrib/boosted_trees:__subpackages__",
|
||||||
|
"//tensorflow/contrib/boosted_trees:friends",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
filegroup(
|
||||||
|
name = "all_files",
|
||||||
|
srcs = glob(
|
||||||
|
["**/*"],
|
||||||
|
exclude = [
|
||||||
|
"**/OWNERS",
|
||||||
|
],
|
||||||
|
),
|
||||||
|
visibility = ["//tensorflow:__subpackages__"],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "stamped_resource",
|
||||||
|
hdrs = ["stamped_resource.h"],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/core:framework_headers_lib",
|
||||||
|
"//third_party/eigen3",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "quantile_stream_resource",
|
||||||
|
hdrs = ["quantile_stream_resource.h"],
|
||||||
|
deps = [
|
||||||
|
":stamped_resource",
|
||||||
|
"//tensorflow/contrib/boosted_trees/lib:weighted_quantiles",
|
||||||
|
"//tensorflow/contrib/boosted_trees/proto:quantiles_proto_cc",
|
||||||
|
"//tensorflow/core:framework_headers_lib",
|
||||||
|
"//third_party/eigen3",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "decision_tree_ensemble_resource",
|
||||||
|
hdrs = ["decision_tree_ensemble_resource.h"],
|
||||||
|
deps = [
|
||||||
|
":stamped_resource",
|
||||||
|
"//tensorflow/contrib/boosted_trees/lib:trees",
|
||||||
|
"//tensorflow/core:framework_headers_lib",
|
||||||
|
],
|
||||||
|
alwayslink = 1,
|
||||||
|
)
|
@ -0,0 +1,77 @@
|
|||||||
|
// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
// =============================================================================
|
||||||
|
#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_RESOURCES_DECISION_TREE_ENSEMBLE_RESOURCE_H_
|
||||||
|
#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_RESOURCES_DECISION_TREE_ENSEMBLE_RESOURCE_H_
|
||||||
|
|
||||||
|
#include "tensorflow/contrib/boosted_trees/lib/trees/decision_tree.h"
|
||||||
|
#include "tensorflow/contrib/boosted_trees/resources/stamped_resource.h"
|
||||||
|
#include "tensorflow/core/framework/resource_mgr.h"
|
||||||
|
#include "tensorflow/core/platform/mutex.h"
|
||||||
|
#include "tensorflow/core/platform/protobuf.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace boosted_trees {
|
||||||
|
namespace models {
|
||||||
|
|
||||||
|
// Keep a tree ensemble in memory for efficient evaluation and mutation.
|
||||||
|
class DecisionTreeEnsembleResource : public StampedResource {
|
||||||
|
public:
|
||||||
|
// Constructor.
|
||||||
|
explicit DecisionTreeEnsembleResource()
|
||||||
|
: decision_tree_ensemble_(
|
||||||
|
protobuf::Arena::CreateMessage<
|
||||||
|
boosted_trees::trees::DecisionTreeEnsembleConfig>(&arena_)) {}
|
||||||
|
|
||||||
|
string DebugString() override {
|
||||||
|
return strings::StrCat("GTFlowDecisionTreeEnsemble[size=",
|
||||||
|
decision_tree_ensemble_->trees_size(), "]");
|
||||||
|
}
|
||||||
|
|
||||||
|
const boosted_trees::trees::DecisionTreeEnsembleConfig&
|
||||||
|
decision_tree_ensemble() const {
|
||||||
|
return *decision_tree_ensemble_;
|
||||||
|
}
|
||||||
|
|
||||||
|
boosted_trees::trees::DecisionTreeEnsembleConfig*
|
||||||
|
mutable_decision_tree_ensemble() {
|
||||||
|
return decision_tree_ensemble_;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Resets the resource and frees the protos in arena.
|
||||||
|
// Caller needs to hold the mutex lock while calling this.
|
||||||
|
void Reset() {
|
||||||
|
// Reset stamp.
|
||||||
|
set_stamp(-1);
|
||||||
|
|
||||||
|
// Clear tree ensemle.
|
||||||
|
arena_.Reset();
|
||||||
|
CHECK_EQ(0, arena_.SpaceAllocated());
|
||||||
|
decision_tree_ensemble_ = protobuf::Arena::CreateMessage<
|
||||||
|
boosted_trees::trees::DecisionTreeEnsembleConfig>(&arena_);
|
||||||
|
}
|
||||||
|
|
||||||
|
mutex* get_mutex() { return &mu_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
protobuf::Arena arena_;
|
||||||
|
mutex mu_;
|
||||||
|
boosted_trees::trees::DecisionTreeEnsembleConfig* decision_tree_ensemble_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace models
|
||||||
|
} // namespace boosted_trees
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_RESOURCES_DECISION_TREE_ENSEMBLE_RESOURCE_H_
|
@ -0,0 +1,104 @@
|
|||||||
|
// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
// =============================================================================
|
||||||
|
#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_RESOURCES_QUANTILE_STREAM_RESOURCE_H_
|
||||||
|
#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_RESOURCES_QUANTILE_STREAM_RESOURCE_H_
|
||||||
|
|
||||||
|
#include "tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_stream.h"
|
||||||
|
#include "tensorflow/contrib/boosted_trees/proto/quantiles.pb.h" // NOLINT
|
||||||
|
#include "tensorflow/contrib/boosted_trees/resources/stamped_resource.h"
|
||||||
|
#include "tensorflow/core/framework/resource_mgr.h"
|
||||||
|
#include "tensorflow/core/platform/macros.h"
|
||||||
|
#include "tensorflow/core/platform/mutex.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace boosted_trees {
|
||||||
|
|
||||||
|
using QuantileStream =
|
||||||
|
boosted_trees::quantiles::WeightedQuantilesStream<float, float>;
|
||||||
|
|
||||||
|
// Resource for accumulating summaries for multiple columns.
|
||||||
|
class QuantileStreamResource : public StampedResource {
|
||||||
|
public:
|
||||||
|
QuantileStreamResource(const float epsilon, const int32 num_quantiles,
|
||||||
|
const int64 max_elements, int64 stamp_token)
|
||||||
|
: stream_(epsilon, max_elements),
|
||||||
|
are_buckets_ready_(false),
|
||||||
|
epsilon_(epsilon),
|
||||||
|
num_quantiles_(num_quantiles),
|
||||||
|
max_elements_(max_elements) {
|
||||||
|
set_stamp(stamp_token);
|
||||||
|
}
|
||||||
|
|
||||||
|
string DebugString() override { return "QuantileStreamResource"; }
|
||||||
|
|
||||||
|
tensorflow::mutex* mutex() { return &mu_; }
|
||||||
|
|
||||||
|
QuantileStream* stream(int64 stamp) {
|
||||||
|
CHECK(is_stamp_valid(stamp));
|
||||||
|
return &stream_;
|
||||||
|
}
|
||||||
|
|
||||||
|
const std::vector<float>& boundaries(int64 stamp) {
|
||||||
|
CHECK(is_stamp_valid(stamp));
|
||||||
|
return boundaries_;
|
||||||
|
}
|
||||||
|
|
||||||
|
void set_boundaries(int64 stamp, const std::vector<float>& boundaries) {
|
||||||
|
CHECK(is_stamp_valid(stamp));
|
||||||
|
are_buckets_ready_ = true;
|
||||||
|
boundaries_ = boundaries;
|
||||||
|
}
|
||||||
|
|
||||||
|
float epsilon() const { return epsilon_; }
|
||||||
|
int32 num_quantiles() const { return num_quantiles_; }
|
||||||
|
|
||||||
|
void Reset(int64 stamp) {
|
||||||
|
set_stamp(stamp);
|
||||||
|
stream_ = QuantileStream(epsilon_, max_elements_);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool are_buckets_ready() const { return are_buckets_ready_; }
|
||||||
|
void set_buckets_ready(bool are_buckets_ready) {
|
||||||
|
are_buckets_ready_ = are_buckets_ready;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
~QuantileStreamResource() override {}
|
||||||
|
|
||||||
|
// Mutex for the whole resource.
|
||||||
|
tensorflow::mutex mu_;
|
||||||
|
|
||||||
|
// Quantile stream.
|
||||||
|
QuantileStream stream_;
|
||||||
|
|
||||||
|
// Stores the boundaries from the previous iteration. Empty during the first
|
||||||
|
// iteration.
|
||||||
|
std::vector<float> boundaries_;
|
||||||
|
|
||||||
|
// Whether boundaries are created. Initially boundaries are empty until
|
||||||
|
// set_boundaries are called.
|
||||||
|
bool are_buckets_ready_;
|
||||||
|
|
||||||
|
const float epsilon_;
|
||||||
|
const int32 num_quantiles_;
|
||||||
|
// An upper-bound for the number of elements.
|
||||||
|
int64 max_elements_;
|
||||||
|
TF_DISALLOW_COPY_AND_ASSIGN(QuantileStreamResource);
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace boosted_trees
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_RESOURCES_QUANTILE_STREAM_RESOURCE_H_
|
@ -0,0 +1,42 @@
|
|||||||
|
// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
// =============================================================================
|
||||||
|
#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_RESOURCES_STAMPED_RESOURCE_H_
|
||||||
|
#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_RESOURCES_STAMPED_RESOURCE_H_
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/resource_mgr.h"
|
||||||
|
#include "tensorflow/core/platform/mutex.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace boosted_trees {
|
||||||
|
|
||||||
|
// A StampedResource is a resource that has a stamp token associated with it.
|
||||||
|
// Before reading from or applying updates to the resource, the stamp should
|
||||||
|
// be checked to verify that the update is not stale.
|
||||||
|
class StampedResource : public ResourceBase {
|
||||||
|
public:
|
||||||
|
StampedResource() : stamp_(-1) {}
|
||||||
|
|
||||||
|
bool is_stamp_valid(int64 stamp) const { return stamp_ == stamp; }
|
||||||
|
|
||||||
|
int64 stamp() const { return stamp_; }
|
||||||
|
void set_stamp(int64 stamp) { stamp_ = stamp; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
int64 stamp_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace boosted_trees
|
||||||
|
} // namespace tensorflow
|
||||||
|
#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_RESOURCES_STAMPED_RESOURCE_H_
|
@ -37,6 +37,9 @@ if(tensorflow_BUILD_CONTRIB_KERNELS)
|
|||||||
"${tensorflow_source_dir}/tensorflow/contrib/layers/kernels/sparse_feature_cross_kernel.cc"
|
"${tensorflow_source_dir}/tensorflow/contrib/layers/kernels/sparse_feature_cross_kernel.cc"
|
||||||
"${tensorflow_source_dir}/tensorflow/contrib/layers/ops/bucketization_op.cc"
|
"${tensorflow_source_dir}/tensorflow/contrib/layers/ops/bucketization_op.cc"
|
||||||
"${tensorflow_source_dir}/tensorflow/contrib/layers/ops/sparse_feature_cross_op.cc"
|
"${tensorflow_source_dir}/tensorflow/contrib/layers/ops/sparse_feature_cross_op.cc"
|
||||||
|
"${tensorflow_source_dir}/tensorflow/contrib/nccl/kernels/nccl_manager.cc"
|
||||||
|
"${tensorflow_source_dir}/tensorflow/contrib/nccl/kernels/nccl_ops.cc"
|
||||||
|
"${tensorflow_source_dir}/tensorflow/contrib/nccl/ops/nccl_ops.cc"
|
||||||
"${tensorflow_source_dir}/tensorflow/contrib/rnn/kernels/blas_gemm.cc"
|
"${tensorflow_source_dir}/tensorflow/contrib/rnn/kernels/blas_gemm.cc"
|
||||||
"${tensorflow_source_dir}/tensorflow/contrib/rnn/kernels/gru_ops.cc"
|
"${tensorflow_source_dir}/tensorflow/contrib/rnn/kernels/gru_ops.cc"
|
||||||
"${tensorflow_source_dir}/tensorflow/contrib/rnn/kernels/lstm_ops.cc"
|
"${tensorflow_source_dir}/tensorflow/contrib/rnn/kernels/lstm_ops.cc"
|
||||||
|
@ -58,6 +58,7 @@ GENERATE_CONTRIB_OP_LIBRARY(image "${tensorflow_source_dir}/tensorflow/contrib/i
|
|||||||
GENERATE_CONTRIB_OP_LIBRARY(layers_bucketization "${tensorflow_source_dir}/tensorflow/contrib/layers/ops/bucketization_op.cc")
|
GENERATE_CONTRIB_OP_LIBRARY(layers_bucketization "${tensorflow_source_dir}/tensorflow/contrib/layers/ops/bucketization_op.cc")
|
||||||
GENERATE_CONTRIB_OP_LIBRARY(layers_sparse_feature_cross "${tensorflow_source_dir}/tensorflow/contrib/layers/ops/sparse_feature_cross_op.cc")
|
GENERATE_CONTRIB_OP_LIBRARY(layers_sparse_feature_cross "${tensorflow_source_dir}/tensorflow/contrib/layers/ops/sparse_feature_cross_op.cc")
|
||||||
GENERATE_CONTRIB_OP_LIBRARY(memory_stats "${tensorflow_source_dir}/tensorflow/contrib/memory_stats/ops/memory_stats_ops.cc")
|
GENERATE_CONTRIB_OP_LIBRARY(memory_stats "${tensorflow_source_dir}/tensorflow/contrib/memory_stats/ops/memory_stats_ops.cc")
|
||||||
|
GENERATE_CONTRIB_OP_LIBRARY(nccl "${tensorflow_source_dir}/tensorflow/contrib/nccl/ops/nccl_ops.cc")
|
||||||
GENERATE_CONTRIB_OP_LIBRARY(rnn_gru "${tensorflow_source_dir}/tensorflow/contrib/rnn/ops/gru_ops.cc")
|
GENERATE_CONTRIB_OP_LIBRARY(rnn_gru "${tensorflow_source_dir}/tensorflow/contrib/rnn/ops/gru_ops.cc")
|
||||||
GENERATE_CONTRIB_OP_LIBRARY(rnn_lstm "${tensorflow_source_dir}/tensorflow/contrib/rnn/ops/lstm_ops.cc")
|
GENERATE_CONTRIB_OP_LIBRARY(rnn_lstm "${tensorflow_source_dir}/tensorflow/contrib/rnn/ops/lstm_ops.cc")
|
||||||
GENERATE_CONTRIB_OP_LIBRARY(tensor_forest "${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/ops/tensor_forest_ops.cc")
|
GENERATE_CONTRIB_OP_LIBRARY(tensor_forest "${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/ops/tensor_forest_ops.cc")
|
||||||
|
@ -111,6 +111,7 @@ file(GLOB_RECURSE tf_protos_python_srcs RELATIVE ${tensorflow_source_dir}
|
|||||||
"${tensorflow_source_dir}/tensorflow/core/*.proto"
|
"${tensorflow_source_dir}/tensorflow/core/*.proto"
|
||||||
"${tensorflow_source_dir}/tensorflow/python/*.proto"
|
"${tensorflow_source_dir}/tensorflow/python/*.proto"
|
||||||
"${tensorflow_source_dir}/tensorflow/contrib/session_bundle/*.proto"
|
"${tensorflow_source_dir}/tensorflow/contrib/session_bundle/*.proto"
|
||||||
|
"${tensorflow_source_dir}/tensorflow/tensorboard/*.proto"
|
||||||
"${tensorflow_source_dir}/tensorflow/contrib/tensorboard/*.proto"
|
"${tensorflow_source_dir}/tensorflow/contrib/tensorboard/*.proto"
|
||||||
"${tensorflow_source_dir}/tensorflow/contrib/training/*.proto"
|
"${tensorflow_source_dir}/tensorflow/contrib/training/*.proto"
|
||||||
)
|
)
|
||||||
@ -124,6 +125,7 @@ RELATIVE_PROTOBUF_GENERATE_PYTHON(
|
|||||||
file(GLOB_RECURSE tf_python_protos_cc_srcs RELATIVE ${tensorflow_source_dir}
|
file(GLOB_RECURSE tf_python_protos_cc_srcs RELATIVE ${tensorflow_source_dir}
|
||||||
"${tensorflow_source_dir}/tensorflow/python/*.proto"
|
"${tensorflow_source_dir}/tensorflow/python/*.proto"
|
||||||
"${tensorflow_source_dir}/tensorflow/contrib/session_bundle/*.proto"
|
"${tensorflow_source_dir}/tensorflow/contrib/session_bundle/*.proto"
|
||||||
|
"${tensorflow_source_dir}/tensorflow/tensorboard/*.proto"
|
||||||
"${tensorflow_source_dir}/tensorflow/contrib/tensorboard/*.proto"
|
"${tensorflow_source_dir}/tensorflow/contrib/tensorboard/*.proto"
|
||||||
"${tensorflow_source_dir}/tensorflow/contrib/training/*.proto"
|
"${tensorflow_source_dir}/tensorflow/contrib/training/*.proto"
|
||||||
)
|
)
|
||||||
@ -342,6 +344,9 @@ add_python_module("tensorflow/contrib/keras/python/keras/layers")
|
|||||||
add_python_module("tensorflow/contrib/keras/python/keras/preprocessing")
|
add_python_module("tensorflow/contrib/keras/python/keras/preprocessing")
|
||||||
add_python_module("tensorflow/contrib/keras/python/keras/utils")
|
add_python_module("tensorflow/contrib/keras/python/keras/utils")
|
||||||
add_python_module("tensorflow/contrib/keras/python/keras/wrappers")
|
add_python_module("tensorflow/contrib/keras/python/keras/wrappers")
|
||||||
|
add_python_module("tensorflow/contrib/kernel_methods")
|
||||||
|
add_python_module("tensorflow/contrib/kernel_methods/python")
|
||||||
|
add_python_module("tensorflow/contrib/kernel_methods/python/mappers")
|
||||||
add_python_module("tensorflow/contrib/labeled_tensor")
|
add_python_module("tensorflow/contrib/labeled_tensor")
|
||||||
add_python_module("tensorflow/contrib/labeled_tensor/python")
|
add_python_module("tensorflow/contrib/labeled_tensor/python")
|
||||||
add_python_module("tensorflow/contrib/labeled_tensor/python/ops")
|
add_python_module("tensorflow/contrib/labeled_tensor/python/ops")
|
||||||
@ -405,6 +410,11 @@ add_python_module("tensorflow/contrib/ndlstm/python")
|
|||||||
add_python_module("tensorflow/contrib/nn")
|
add_python_module("tensorflow/contrib/nn")
|
||||||
add_python_module("tensorflow/contrib/nn/python")
|
add_python_module("tensorflow/contrib/nn/python")
|
||||||
add_python_module("tensorflow/contrib/nn/python/ops")
|
add_python_module("tensorflow/contrib/nn/python/ops")
|
||||||
|
add_python_module("tensorflow/contrib/nccl")
|
||||||
|
add_python_module("tensorflow/contrib/nccl/kernels")
|
||||||
|
add_python_module("tensorflow/contrib/nccl/ops")
|
||||||
|
add_python_module("tensorflow/contrib/nccl/python")
|
||||||
|
add_python_module("tensorflow/contrib/nccl/python/ops")
|
||||||
add_python_module("tensorflow/contrib/opt")
|
add_python_module("tensorflow/contrib/opt")
|
||||||
add_python_module("tensorflow/contrib/opt/python")
|
add_python_module("tensorflow/contrib/opt/python")
|
||||||
add_python_module("tensorflow/contrib/opt/python/training")
|
add_python_module("tensorflow/contrib/opt/python/training")
|
||||||
@ -599,6 +609,8 @@ GENERATE_PYTHON_OP_LIB("contrib_layers_sparse_feature_cross_ops"
|
|||||||
DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/layers/ops/gen_sparse_feature_cross_op.py)
|
DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/layers/ops/gen_sparse_feature_cross_op.py)
|
||||||
GENERATE_PYTHON_OP_LIB("contrib_memory_stats_ops"
|
GENERATE_PYTHON_OP_LIB("contrib_memory_stats_ops"
|
||||||
DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/memory_stats/ops/gen_memory_stats_ops.py)
|
DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/memory_stats/ops/gen_memory_stats_ops.py)
|
||||||
|
GENERATE_PYTHON_OP_LIB("contrib_nccl_ops"
|
||||||
|
DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/nccl/ops/gen_nccl_ops.py)
|
||||||
GENERATE_PYTHON_OP_LIB("contrib_rnn_gru_ops"
|
GENERATE_PYTHON_OP_LIB("contrib_rnn_gru_ops"
|
||||||
DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/rnn/ops/gen_gru_ops.py)
|
DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/rnn/ops/gen_gru_ops.py)
|
||||||
GENERATE_PYTHON_OP_LIB("contrib_rnn_lstm_ops"
|
GENERATE_PYTHON_OP_LIB("contrib_rnn_lstm_ops"
|
||||||
|
@ -69,19 +69,25 @@ class BinomialTest(test.TestCase):
|
|||||||
self.assertEqual((1, 3), binom.logits.get_shape())
|
self.assertEqual((1, 3), binom.logits.get_shape())
|
||||||
self.assertAllClose(logits, binom.logits.eval())
|
self.assertAllClose(logits, binom.logits.eval())
|
||||||
|
|
||||||
def testPmfNandCountsAgree(self):
|
def testPmfAndCdfNandCountsAgree(self):
|
||||||
p = [[0.1, 0.2, 0.7]]
|
p = [[0.1, 0.2, 0.7]]
|
||||||
n = [[5.]]
|
n = [[5.]]
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
binom = binomial.Binomial(total_count=n, probs=p, validate_args=True)
|
binom = binomial.Binomial(total_count=n, probs=p, validate_args=True)
|
||||||
binom.prob([2., 3, 2]).eval()
|
binom.prob([2., 3, 2]).eval()
|
||||||
binom.prob([3., 1, 2]).eval()
|
binom.prob([3., 1, 2]).eval()
|
||||||
|
binom.cdf([2., 3, 2]).eval()
|
||||||
|
binom.cdf([3., 1, 2]).eval()
|
||||||
with self.assertRaisesOpError("Condition x >= 0.*"):
|
with self.assertRaisesOpError("Condition x >= 0.*"):
|
||||||
binom.prob([-1., 4, 2]).eval()
|
binom.prob([-1., 4, 2]).eval()
|
||||||
with self.assertRaisesOpError("Condition x <= y.*"):
|
with self.assertRaisesOpError("Condition x <= y.*"):
|
||||||
binom.prob([7., 3, 0]).eval()
|
binom.prob([7., 3, 0]).eval()
|
||||||
|
with self.assertRaisesOpError("Condition x >= 0.*"):
|
||||||
|
binom.cdf([-1., 4, 2]).eval()
|
||||||
|
with self.assertRaisesOpError("Condition x <= y.*"):
|
||||||
|
binom.cdf([7., 3, 0]).eval()
|
||||||
|
|
||||||
def testPmfNonIntegerCounts(self):
|
def testPmfAndCdfNonIntegerCounts(self):
|
||||||
p = [[0.1, 0.2, 0.7]]
|
p = [[0.1, 0.2, 0.7]]
|
||||||
n = [[5.]]
|
n = [[5.]]
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
@ -89,50 +95,72 @@ class BinomialTest(test.TestCase):
|
|||||||
binom = binomial.Binomial(total_count=n, probs=p, validate_args=True)
|
binom = binomial.Binomial(total_count=n, probs=p, validate_args=True)
|
||||||
binom.prob([2., 3, 2]).eval()
|
binom.prob([2., 3, 2]).eval()
|
||||||
binom.prob([3., 1, 2]).eval()
|
binom.prob([3., 1, 2]).eval()
|
||||||
|
binom.cdf([2., 3, 2]).eval()
|
||||||
|
binom.cdf([3., 1, 2]).eval()
|
||||||
# Both equality and integer checking fail.
|
# Both equality and integer checking fail.
|
||||||
with self.assertRaisesOpError(
|
with self.assertRaisesOpError(
|
||||||
"cannot contain fractional components."):
|
"cannot contain fractional components."):
|
||||||
binom.prob([1.0, 2.5, 1.5]).eval()
|
binom.prob([1.0, 2.5, 1.5]).eval()
|
||||||
|
with self.assertRaisesOpError(
|
||||||
|
"cannot contain fractional components."):
|
||||||
|
binom.cdf([1.0, 2.5, 1.5]).eval()
|
||||||
|
|
||||||
binom = binomial.Binomial(total_count=n, probs=p, validate_args=False)
|
binom = binomial.Binomial(total_count=n, probs=p, validate_args=False)
|
||||||
binom.prob([1., 2., 3.]).eval()
|
binom.prob([1., 2., 3.]).eval()
|
||||||
|
binom.cdf([1., 2., 3.]).eval()
|
||||||
# Non-integer arguments work.
|
# Non-integer arguments work.
|
||||||
binom.prob([1.0, 2.5, 1.5]).eval()
|
binom.prob([1.0, 2.5, 1.5]).eval()
|
||||||
|
binom.cdf([1.0, 2.5, 1.5]).eval()
|
||||||
|
|
||||||
def testPmfBothZeroBatches(self):
|
def testPmfAndCdfBothZeroBatches(self):
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
# Both zero-batches. No broadcast
|
# Both zero-batches. No broadcast
|
||||||
p = 0.5
|
p = 0.5
|
||||||
counts = 1.
|
counts = 1.
|
||||||
pmf = binomial.Binomial(total_count=1., probs=p).prob(counts)
|
binom = binomial.Binomial(total_count=1., probs=p)
|
||||||
|
pmf = binom.prob(counts)
|
||||||
|
cdf = binom.cdf(counts)
|
||||||
self.assertAllClose(0.5, pmf.eval())
|
self.assertAllClose(0.5, pmf.eval())
|
||||||
|
self.assertAllClose(stats.binom.cdf(counts, n=1, p=p), cdf.eval())
|
||||||
self.assertEqual((), pmf.get_shape())
|
self.assertEqual((), pmf.get_shape())
|
||||||
|
self.assertEqual((), cdf.get_shape())
|
||||||
|
|
||||||
def testPmfBothZeroBatchesNontrivialN(self):
|
def testPmfAndCdfBothZeroBatchesNontrivialN(self):
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
# Both zero-batches. No broadcast
|
# Both zero-batches. No broadcast
|
||||||
p = 0.1
|
p = 0.1
|
||||||
counts = 3.
|
counts = 3.
|
||||||
binom = binomial.Binomial(total_count=5., probs=p)
|
binom = binomial.Binomial(total_count=5., probs=p)
|
||||||
pmf = binom.prob(counts)
|
pmf = binom.prob(counts)
|
||||||
|
cdf = binom.cdf(counts)
|
||||||
self.assertAllClose(stats.binom.pmf(counts, n=5., p=p), pmf.eval())
|
self.assertAllClose(stats.binom.pmf(counts, n=5., p=p), pmf.eval())
|
||||||
|
self.assertAllClose(stats.binom.cdf(counts, n=5., p=p), cdf.eval())
|
||||||
self.assertEqual((), pmf.get_shape())
|
self.assertEqual((), pmf.get_shape())
|
||||||
|
self.assertEqual((), cdf.get_shape())
|
||||||
|
|
||||||
def testPmfPStretchedInBroadcastWhenSameRank(self):
|
def testPmfAndCdfPStretchedInBroadcastWhenSameRank(self):
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
p = [[0.1, 0.9]]
|
p = [[0.1, 0.9]]
|
||||||
counts = [[1., 2.]]
|
counts = [[1., 2.]]
|
||||||
pmf = binomial.Binomial(total_count=3., probs=p).prob(counts)
|
binom = binomial.Binomial(total_count=3., probs=p)
|
||||||
|
pmf = binom.prob(counts)
|
||||||
|
cdf = binom.cdf(counts)
|
||||||
self.assertAllClose(stats.binom.pmf(counts, n=3., p=p), pmf.eval())
|
self.assertAllClose(stats.binom.pmf(counts, n=3., p=p), pmf.eval())
|
||||||
|
self.assertAllClose(stats.binom.cdf(counts, n=3., p=p), cdf.eval())
|
||||||
self.assertEqual((1, 2), pmf.get_shape())
|
self.assertEqual((1, 2), pmf.get_shape())
|
||||||
|
self.assertEqual((1, 2), cdf.get_shape())
|
||||||
|
|
||||||
def testPmfPStretchedInBroadcastWhenLowerRank(self):
|
def testPmfAndCdfPStretchedInBroadcastWhenLowerRank(self):
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
p = [0.1, 0.4]
|
p = [0.1, 0.4]
|
||||||
counts = [[1.], [0.]]
|
counts = [[1.], [0.]]
|
||||||
pmf = binomial.Binomial(total_count=1., probs=p).prob(counts)
|
binom = binomial.Binomial(total_count=1., probs=p)
|
||||||
|
pmf = binom.prob(counts)
|
||||||
|
cdf = binom.cdf(counts)
|
||||||
self.assertAllClose([[0.1, 0.4], [0.9, 0.6]], pmf.eval())
|
self.assertAllClose([[0.1, 0.4], [0.9, 0.6]], pmf.eval())
|
||||||
|
self.assertAllClose([[1.0, 1.0], [0.9, 0.6]], cdf.eval())
|
||||||
self.assertEqual((2, 2), pmf.get_shape())
|
self.assertEqual((2, 2), pmf.get_shape())
|
||||||
|
self.assertEqual((2, 2), cdf.get_shape())
|
||||||
|
|
||||||
def testBinomialMean(self):
|
def testBinomialMean(self):
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
|
@ -103,6 +103,31 @@ class MultivariateNormalDiagTest(test.TestCase):
|
|||||||
self.assertAllClose(cov_mat, np.cov(samps.T),
|
self.assertAllClose(cov_mat, np.cov(samps.T),
|
||||||
atol=0.05, rtol=0.05)
|
atol=0.05, rtol=0.05)
|
||||||
|
|
||||||
|
def testSampleWithBroadcastScale(self):
|
||||||
|
# mu corresponds to a 2-batch of 3-variate normals
|
||||||
|
mu = np.zeros([2, 3])
|
||||||
|
|
||||||
|
# diag corresponds to no batches of 3-variate normals
|
||||||
|
diag = np.ones([3])
|
||||||
|
|
||||||
|
with self.test_session():
|
||||||
|
dist = ds.MultivariateNormalDiag(mu, diag, validate_args=True)
|
||||||
|
|
||||||
|
mean = dist.mean()
|
||||||
|
self.assertAllEqual([2, 3], mean.get_shape())
|
||||||
|
self.assertAllClose(mu, mean.eval())
|
||||||
|
|
||||||
|
n = int(1e3)
|
||||||
|
samps = dist.sample(n, seed=0).eval()
|
||||||
|
cov_mat = array_ops.matrix_diag(diag).eval()**2
|
||||||
|
sample_cov = np.matmul(samps.transpose([1, 2, 0]),
|
||||||
|
samps.transpose([1, 0, 2])) / n
|
||||||
|
|
||||||
|
self.assertAllClose(mu, samps.mean(axis=0),
|
||||||
|
atol=0.10, rtol=0.05)
|
||||||
|
self.assertAllClose([cov_mat, cov_mat], sample_cov,
|
||||||
|
atol=0.10, rtol=0.05)
|
||||||
|
|
||||||
def testCovariance(self):
|
def testCovariance(self):
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
mvn = ds.MultivariateNormalDiag(
|
mvn = ds.MultivariateNormalDiag(
|
||||||
|
@ -42,6 +42,28 @@ to integer values.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def _bdtr(k, n, p):
|
||||||
|
"""The binomial cumulative distribution function.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
k: floating point `Tensor`.
|
||||||
|
n: floating point `Tensor`.
|
||||||
|
p: floating point `Tensor`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`sum_{j=0}^k p^j (1 - p)^(n - j)`.
|
||||||
|
"""
|
||||||
|
# Trick for getting safe backprop/gradients into n, k when
|
||||||
|
# betainc(a = 0, ..) = nan
|
||||||
|
# Write:
|
||||||
|
# where(unsafe, safe_output, betainc(where(unsafe, safe_input, input)))
|
||||||
|
ones = array_ops.ones_like(n - k)
|
||||||
|
k_eq_n = math_ops.equal(k, n)
|
||||||
|
safe_dn = array_ops.where(k_eq_n, ones, n - k)
|
||||||
|
dk = math_ops.betainc(a=safe_dn, b=k + 1, x=1 - p)
|
||||||
|
return array_ops.where(k_eq_n, ones, dk)
|
||||||
|
|
||||||
|
|
||||||
class Binomial(distribution.Distribution):
|
class Binomial(distribution.Distribution):
|
||||||
"""Binomial distribution.
|
"""Binomial distribution.
|
||||||
|
|
||||||
@ -201,6 +223,18 @@ class Binomial(distribution.Distribution):
|
|||||||
def _prob(self, counts):
|
def _prob(self, counts):
|
||||||
return math_ops.exp(self._log_prob(counts))
|
return math_ops.exp(self._log_prob(counts))
|
||||||
|
|
||||||
|
def _cdf(self, counts):
|
||||||
|
counts = self._maybe_assert_valid_sample(counts)
|
||||||
|
probs = self.probs
|
||||||
|
if not (counts.shape.is_fully_defined()
|
||||||
|
and self.probs.shape.is_fully_defined()
|
||||||
|
and counts.shape.is_compatible_with(self.probs.shape)):
|
||||||
|
# If both shapes are well defined and equal, we skip broadcasting.
|
||||||
|
probs += array_ops.zeros_like(counts)
|
||||||
|
counts += array_ops.zeros_like(self.probs)
|
||||||
|
|
||||||
|
return _bdtr(k=counts, n=self.total_count, p=probs)
|
||||||
|
|
||||||
def _log_unnormalized_prob(self, counts):
|
def _log_unnormalized_prob(self, counts):
|
||||||
counts = self._maybe_assert_valid_sample(counts)
|
counts = self._maybe_assert_valid_sample(counts)
|
||||||
return (counts * math_ops.log(self.probs) +
|
return (counts * math_ops.log(self.probs) +
|
||||||
|
@ -25,6 +25,7 @@ from tensorflow.contrib.distributions.python.ops import kullback_leibler
|
|||||||
from tensorflow.contrib.distributions.python.ops import normal
|
from tensorflow.contrib.distributions.python.ops import normal
|
||||||
from tensorflow.contrib.distributions.python.ops import transformed_distribution
|
from tensorflow.contrib.distributions.python.ops import transformed_distribution
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.framework import tensor_shape
|
||||||
from tensorflow.python.framework import tensor_util
|
from tensorflow.python.framework import tensor_util
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import linalg_ops
|
from tensorflow.python.ops import linalg_ops
|
||||||
@ -53,6 +54,16 @@ or
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def _broadcast_shape(shape1, shape2):
|
||||||
|
"""Convenience function which statically broadcasts shape when possible."""
|
||||||
|
if (tensor_util.constant_value(shape1) is not None and
|
||||||
|
tensor_util.constant_value(shape2) is not None):
|
||||||
|
return array_ops.broadcast_static_shape(
|
||||||
|
tensor_shape.TensorShape(tensor_util.constant_value(shape1)),
|
||||||
|
tensor_shape.TensorShape(tensor_util.constant_value(shape2)))
|
||||||
|
return array_ops.broadcast_dynamic_shape(shape1, shape2)
|
||||||
|
|
||||||
|
|
||||||
# TODO(b/35290280): Import in `../../__init__.py` after adding unit-tests.
|
# TODO(b/35290280): Import in `../../__init__.py` after adding unit-tests.
|
||||||
class MultivariateNormalLinearOperator(
|
class MultivariateNormalLinearOperator(
|
||||||
transformed_distribution.TransformedDistribution):
|
transformed_distribution.TransformedDistribution):
|
||||||
@ -179,12 +190,25 @@ class MultivariateNormalLinearOperator(
|
|||||||
if not scale.dtype.is_floating:
|
if not scale.dtype.is_floating:
|
||||||
raise TypeError("`scale` parameter must have floating-point dtype.")
|
raise TypeError("`scale` parameter must have floating-point dtype.")
|
||||||
|
|
||||||
# Since expand_dims doesn't preserve constant-ness, we obtain the
|
with ops.name_scope(name, values=[loc] + scale.graph_parents):
|
||||||
# non-dynamic value if possible.
|
# Since expand_dims doesn't preserve constant-ness, we obtain the
|
||||||
event_shape = scale.domain_dimension_tensor()
|
# non-dynamic value if possible.
|
||||||
if tensor_util.constant_value(event_shape) is not None:
|
event_shape = scale.range_dimension_tensor()
|
||||||
event_shape = tensor_util.constant_value(event_shape)
|
if tensor_util.constant_value(event_shape) is not None:
|
||||||
event_shape = event_shape[array_ops.newaxis]
|
event_shape = tensor_util.constant_value(event_shape).reshape([1])
|
||||||
|
else:
|
||||||
|
event_shape = event_shape[array_ops.newaxis]
|
||||||
|
batch_shape = scale.batch_shape_tensor()
|
||||||
|
if loc is not None:
|
||||||
|
loc = ops.convert_to_tensor(loc, name="loc")
|
||||||
|
loc_batch_shape = loc.get_shape().with_rank_at_least(1)[:-1]
|
||||||
|
if (loc.get_shape().ndims is None or
|
||||||
|
not loc_batch_shape.is_fully_defined()):
|
||||||
|
loc_batch_shape = array_ops.shape(loc)[:-1]
|
||||||
|
else:
|
||||||
|
loc_batch_shape = ops.convert_to_tensor(loc_batch_shape,
|
||||||
|
name="loc_batch_shape")
|
||||||
|
batch_shape = _broadcast_shape(batch_shape, loc_batch_shape)
|
||||||
|
|
||||||
super(MultivariateNormalLinearOperator, self).__init__(
|
super(MultivariateNormalLinearOperator, self).__init__(
|
||||||
distribution=normal.Normal(
|
distribution=normal.Normal(
|
||||||
@ -192,7 +216,7 @@ class MultivariateNormalLinearOperator(
|
|||||||
scale=array_ops.ones([], dtype=scale.dtype)),
|
scale=array_ops.ones([], dtype=scale.dtype)),
|
||||||
bijector=bijectors.AffineLinearOperator(
|
bijector=bijectors.AffineLinearOperator(
|
||||||
shift=loc, scale=scale, validate_args=validate_args),
|
shift=loc, scale=scale, validate_args=validate_args),
|
||||||
batch_shape=scale.batch_shape_tensor(),
|
batch_shape=batch_shape,
|
||||||
event_shape=event_shape,
|
event_shape=event_shape,
|
||||||
validate_args=validate_args,
|
validate_args=validate_args,
|
||||||
name=name)
|
name=name)
|
||||||
|
@ -35,6 +35,7 @@ tf_custom_op_py_library(
|
|||||||
],
|
],
|
||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
deps = [
|
deps = [
|
||||||
|
":factorization_ops_test_utils_py",
|
||||||
":gen_clustering_ops",
|
":gen_clustering_ops",
|
||||||
":gen_factorization_ops",
|
":gen_factorization_ops",
|
||||||
"//tensorflow/contrib/framework:framework_py",
|
"//tensorflow/contrib/framework:framework_py",
|
||||||
@ -161,12 +162,28 @@ tf_py_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "factorization_ops_test_utils_py",
|
||||||
|
srcs = [
|
||||||
|
"python/ops/factorization_ops_test_utils.py",
|
||||||
|
],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/python:array_ops",
|
||||||
|
"//tensorflow/python:embedding_ops",
|
||||||
|
"//tensorflow/python:framework",
|
||||||
|
"//tensorflow/python:math_ops",
|
||||||
|
"//tensorflow/python:sparse_ops",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
tf_py_test(
|
tf_py_test(
|
||||||
name = "factorization_ops_test",
|
name = "factorization_ops_test",
|
||||||
srcs = ["python/ops/factorization_ops_test.py"],
|
srcs = ["python/ops/factorization_ops_test.py"],
|
||||||
additional_deps = [
|
additional_deps = [
|
||||||
":factorization_py",
|
":factorization_py",
|
||||||
":factorization_py_CYCLIC_DEPENDENCIES_THAT_NEED_TO_GO",
|
":factorization_py_CYCLIC_DEPENDENCIES_THAT_NEED_TO_GO",
|
||||||
|
":factorization_ops_test_utils_py",
|
||||||
"//third_party/py/numpy",
|
"//third_party/py/numpy",
|
||||||
"//tensorflow/python:array_ops",
|
"//tensorflow/python:array_ops",
|
||||||
"//tensorflow/python:client_testlib",
|
"//tensorflow/python:client_testlib",
|
||||||
|
@ -18,160 +18,56 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import random
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||||
|
|
||||||
from tensorflow.contrib.factorization.python.ops import factorization_ops
|
from tensorflow.contrib.factorization.python.ops import factorization_ops
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.contrib.factorization.python.ops import factorization_ops_test_utils
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import sparse_tensor
|
from tensorflow.python.framework import sparse_tensor
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import embedding_ops
|
from tensorflow.python.ops import embedding_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import sparse_ops
|
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
INPUT_MATRIX = np.array(
|
|
||||||
[[0.1, 0.0, 0.2, 0.0, 0.4, 0.5, 0.0],
|
|
||||||
[0.0, 1.1, 0.0, 1.3, 1.4, 0.0, 1.6],
|
|
||||||
[2.0, 0.0, 0.0, 2.3, 0.0, 2.5, 0.0],
|
|
||||||
[3.0, 0.0, 3.2, 3.3, 0.0, 3.5, 0.0],
|
|
||||||
[0.0, 4.1, 0.0, 0.0, 4.4, 0.0, 4.6]]).astype(np.float32)
|
|
||||||
|
|
||||||
|
INPUT_MATRIX = factorization_ops_test_utils.INPUT_MATRIX
|
||||||
def np_matrix_to_tf_sparse(np_matrix,
|
np_matrix_to_tf_sparse = factorization_ops_test_utils.np_matrix_to_tf_sparse
|
||||||
row_slices=None,
|
|
||||||
col_slices=None,
|
|
||||||
transpose=False,
|
|
||||||
shuffle=False):
|
|
||||||
"""Simple util to slice non-zero np matrix elements as tf.SparseTensor."""
|
|
||||||
indices = np.nonzero(np_matrix)
|
|
||||||
|
|
||||||
# Only allow slices of whole rows or whole columns.
|
|
||||||
assert not (row_slices is not None and col_slices is not None)
|
|
||||||
|
|
||||||
if row_slices is not None:
|
|
||||||
selected_ind = np.concatenate(
|
|
||||||
[np.where(indices[0] == r)[0] for r in row_slices], 0)
|
|
||||||
indices = (indices[0][selected_ind], indices[1][selected_ind])
|
|
||||||
|
|
||||||
if col_slices is not None:
|
|
||||||
selected_ind = np.concatenate(
|
|
||||||
[np.where(indices[1] == c)[0] for c in col_slices], 0)
|
|
||||||
indices = (indices[0][selected_ind], indices[1][selected_ind])
|
|
||||||
|
|
||||||
if shuffle:
|
|
||||||
shuffled_ind = [x for x in range(len(indices[0]))]
|
|
||||||
random.shuffle(shuffled_ind)
|
|
||||||
indices = (indices[0][shuffled_ind], indices[1][shuffled_ind])
|
|
||||||
|
|
||||||
ind = (np.concatenate((np.expand_dims(indices[1], 1),
|
|
||||||
np.expand_dims(indices[0], 1)), 1).astype(np.int64) if
|
|
||||||
transpose else np.concatenate((np.expand_dims(indices[0], 1),
|
|
||||||
np.expand_dims(indices[1], 1)),
|
|
||||||
1).astype(np.int64))
|
|
||||||
val = np_matrix[indices].astype(np.float32)
|
|
||||||
shape = (np.array([max(indices[1]) + 1, max(indices[0]) + 1]).astype(np.int64)
|
|
||||||
if transpose else np.array(
|
|
||||||
[max(indices[0]) + 1, max(indices[1]) + 1]).astype(np.int64))
|
|
||||||
return sparse_tensor.SparseTensor(ind, val, shape)
|
|
||||||
|
|
||||||
|
|
||||||
def sparse_input():
|
|
||||||
return np_matrix_to_tf_sparse(INPUT_MATRIX)
|
|
||||||
|
|
||||||
|
|
||||||
def count_rows(sp_input):
|
|
||||||
return math_ops.cast(
|
|
||||||
array_ops.shape(array_ops.unique(sp_input.indices[:, 0])[0])[0],
|
|
||||||
dtypes.float32)
|
|
||||||
|
|
||||||
|
|
||||||
def count_cols(sp_input):
|
|
||||||
return math_ops.cast(
|
|
||||||
array_ops.shape(array_ops.unique(sp_input.indices[:, 1])[0])[0],
|
|
||||||
dtypes.float32)
|
|
||||||
|
|
||||||
|
|
||||||
def calculate_loss(input_mat, row_factors, col_factors, regularization=None,
|
|
||||||
w0=1., row_weights=None, col_weights=None):
|
|
||||||
"""Calculates the loss of a given factorization.
|
|
||||||
|
|
||||||
Using a non distributed method, different than the one implemented in the
|
|
||||||
WALS model. The weight of an observed entry (i, j) (i.e. such that
|
|
||||||
input_mat[i, j] is non zero) is (w0 + row_weights[i]col_weights[j]).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
input_mat: The input matrix, a SparseTensor of rank 2.
|
|
||||||
row_factors: The row factors, a dense Tensor of rank 2.
|
|
||||||
col_factors: The col factors, a dense Tensor of rank 2.
|
|
||||||
regularization: the regularization coefficient, a scalar.
|
|
||||||
w0: the weight of unobserved entries. A scalar.
|
|
||||||
row_weights: A dense tensor of rank 1.
|
|
||||||
col_weights: A dense tensor of rank 1.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The total loss.
|
|
||||||
"""
|
|
||||||
wr = (array_ops.expand_dims(row_weights, 1) if row_weights is not None
|
|
||||||
else constant_op.constant(1.))
|
|
||||||
wc = (array_ops.expand_dims(col_weights, 0) if col_weights is not None
|
|
||||||
else constant_op.constant(1.))
|
|
||||||
reg = (regularization if regularization is not None
|
|
||||||
else constant_op.constant(0.))
|
|
||||||
|
|
||||||
row_indices, col_indices = array_ops.split(input_mat.indices,
|
|
||||||
axis=1,
|
|
||||||
num_or_size_splits=2)
|
|
||||||
gathered_row_factors = array_ops.gather(row_factors, row_indices)
|
|
||||||
gathered_col_factors = array_ops.gather(col_factors, col_indices)
|
|
||||||
sp_approx_vals = array_ops.squeeze(math_ops.matmul(
|
|
||||||
gathered_row_factors, gathered_col_factors, adjoint_b=True))
|
|
||||||
sp_approx = sparse_tensor.SparseTensor(
|
|
||||||
indices=input_mat.indices,
|
|
||||||
values=sp_approx_vals,
|
|
||||||
dense_shape=input_mat.dense_shape)
|
|
||||||
|
|
||||||
sp_approx_sq = math_ops.square(sp_approx)
|
|
||||||
row_norm = math_ops.reduce_sum(math_ops.square(row_factors))
|
|
||||||
col_norm = math_ops.reduce_sum(math_ops.square(col_factors))
|
|
||||||
row_col_norm = math_ops.reduce_sum(math_ops.square(math_ops.matmul(
|
|
||||||
row_factors, col_factors, transpose_b=True)))
|
|
||||||
|
|
||||||
resid = sparse_ops.sparse_add(input_mat, sp_approx * (-1))
|
|
||||||
resid_sq = math_ops.square(resid)
|
|
||||||
loss = w0 * (
|
|
||||||
sparse_ops.sparse_reduce_sum(resid_sq) -
|
|
||||||
sparse_ops.sparse_reduce_sum(sp_approx_sq)
|
|
||||||
)
|
|
||||||
loss += (sparse_ops.sparse_reduce_sum(wr * (resid_sq * wc)) +
|
|
||||||
w0 * row_col_norm + reg * (row_norm + col_norm))
|
|
||||||
return loss.eval()
|
|
||||||
|
|
||||||
|
|
||||||
def calculate_loss_from_wals_model(wals_model, sp_inputs):
|
|
||||||
current_rows = embedding_ops.embedding_lookup(
|
|
||||||
wals_model.row_factors, math_ops.range(wals_model._input_rows),
|
|
||||||
partition_strategy="div")
|
|
||||||
current_cols = embedding_ops.embedding_lookup(
|
|
||||||
wals_model.col_factors, math_ops.range(wals_model._input_cols),
|
|
||||||
partition_strategy="div")
|
|
||||||
row_wts = embedding_ops.embedding_lookup(
|
|
||||||
wals_model._row_weights, math_ops.range(wals_model._input_rows),
|
|
||||||
partition_strategy="div")
|
|
||||||
col_wts = embedding_ops.embedding_lookup(
|
|
||||||
wals_model._col_weights, math_ops.range(wals_model._input_cols),
|
|
||||||
partition_strategy="div")
|
|
||||||
return calculate_loss(
|
|
||||||
sp_inputs, current_rows, current_cols, wals_model._regularization,
|
|
||||||
wals_model._unobserved_weight, row_wts, col_wts)
|
|
||||||
|
|
||||||
|
|
||||||
class WalsModelTest(test.TestCase):
|
class WalsModelTest(test.TestCase):
|
||||||
|
|
||||||
|
def sparse_input(self):
|
||||||
|
return np_matrix_to_tf_sparse(INPUT_MATRIX)
|
||||||
|
|
||||||
|
def count_rows(self, sp_input):
|
||||||
|
return math_ops.cast(
|
||||||
|
array_ops.shape(array_ops.unique(sp_input.indices[:, 0])[0])[0],
|
||||||
|
dtypes.float32)
|
||||||
|
|
||||||
|
def count_cols(self, sp_input):
|
||||||
|
return math_ops.cast(
|
||||||
|
array_ops.shape(array_ops.unique(sp_input.indices[:, 1])[0])[0],
|
||||||
|
dtypes.float32)
|
||||||
|
|
||||||
|
def calculate_loss_from_wals_model(self, wals_model, sp_inputs):
|
||||||
|
current_rows = embedding_ops.embedding_lookup(
|
||||||
|
wals_model.row_factors, math_ops.range(wals_model._input_rows),
|
||||||
|
partition_strategy="div")
|
||||||
|
current_cols = embedding_ops.embedding_lookup(
|
||||||
|
wals_model.col_factors, math_ops.range(wals_model._input_cols),
|
||||||
|
partition_strategy="div")
|
||||||
|
row_wts = embedding_ops.embedding_lookup(
|
||||||
|
wals_model._row_weights, math_ops.range(wals_model._input_rows),
|
||||||
|
partition_strategy="div")
|
||||||
|
col_wts = embedding_ops.embedding_lookup(
|
||||||
|
wals_model._col_weights, math_ops.range(wals_model._input_cols),
|
||||||
|
partition_strategy="div")
|
||||||
|
return factorization_ops_test_utils.calculate_loss(
|
||||||
|
sp_inputs, current_rows, current_cols, wals_model._regularization,
|
||||||
|
wals_model._unobserved_weight, row_wts, col_wts)
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.col_init = [
|
self.col_init = [
|
||||||
# shard 0
|
# shard 0
|
||||||
@ -208,7 +104,7 @@ class WalsModelTest(test.TestCase):
|
|||||||
use_factors_weights_cache,
|
use_factors_weights_cache,
|
||||||
compute_loss=False):
|
compute_loss=False):
|
||||||
with ops.Graph().as_default(), self.test_session() as sess:
|
with ops.Graph().as_default(), self.test_session() as sess:
|
||||||
self._wals_inputs = sparse_input()
|
self._wals_inputs = self.sparse_input()
|
||||||
sp_feeder = array_ops.sparse_placeholder(dtypes.float32)
|
sp_feeder = array_ops.sparse_placeholder(dtypes.float32)
|
||||||
num_rows = 5
|
num_rows = 5
|
||||||
num_cols = 7
|
num_cols = 7
|
||||||
@ -282,10 +178,10 @@ class WalsModelTest(test.TestCase):
|
|||||||
if compute_loss:
|
if compute_loss:
|
||||||
# Test loss computation after the row update
|
# Test loss computation after the row update
|
||||||
loss = sum(
|
loss = sum(
|
||||||
sess.run(factor_loss * count_rows(inp) / num_rows,
|
sess.run(factor_loss * self.count_rows(inp) / num_rows,
|
||||||
feed_dict={sp_feeder: inp})
|
feed_dict={sp_feeder: inp})
|
||||||
for inp in input_scattered_rows)
|
for inp in input_scattered_rows)
|
||||||
true_loss = calculate_loss_from_wals_model(
|
true_loss = self.calculate_loss_from_wals_model(
|
||||||
wals_model, self._wals_inputs)
|
wals_model, self._wals_inputs)
|
||||||
self.assertNear(
|
self.assertNear(
|
||||||
loss, true_loss, err=.001,
|
loss, true_loss, err=.001,
|
||||||
@ -355,10 +251,10 @@ class WalsModelTest(test.TestCase):
|
|||||||
if compute_loss:
|
if compute_loss:
|
||||||
# Test loss computation after the column update.
|
# Test loss computation after the column update.
|
||||||
loss = sum(
|
loss = sum(
|
||||||
sess.run(factor_loss * count_cols(inp) / num_cols,
|
sess.run(factor_loss * self.count_cols(inp) / num_cols,
|
||||||
feed_dict={sp_feeder: inp})
|
feed_dict={sp_feeder: inp})
|
||||||
for inp in input_scattered_cols_non_duplicate)
|
for inp in input_scattered_cols_non_duplicate)
|
||||||
true_loss = calculate_loss_from_wals_model(
|
true_loss = self.calculate_loss_from_wals_model(
|
||||||
wals_model, self._wals_inputs)
|
wals_model, self._wals_inputs)
|
||||||
self.assertNear(
|
self.assertNear(
|
||||||
loss, true_loss, err=.001,
|
loss, true_loss, err=.001,
|
||||||
@ -368,7 +264,7 @@ class WalsModelTest(test.TestCase):
|
|||||||
def _run_test_process_input_transposed(self, use_factors_weights_cache,
|
def _run_test_process_input_transposed(self, use_factors_weights_cache,
|
||||||
compute_loss=False):
|
compute_loss=False):
|
||||||
with ops.Graph().as_default(), self.test_session() as sess:
|
with ops.Graph().as_default(), self.test_session() as sess:
|
||||||
self._wals_inputs = sparse_input()
|
self._wals_inputs = self.sparse_input()
|
||||||
sp_feeder = array_ops.sparse_placeholder(dtypes.float32)
|
sp_feeder = array_ops.sparse_placeholder(dtypes.float32)
|
||||||
num_rows = 5
|
num_rows = 5
|
||||||
num_cols = 7
|
num_cols = 7
|
||||||
@ -448,10 +344,10 @@ class WalsModelTest(test.TestCase):
|
|||||||
if compute_loss:
|
if compute_loss:
|
||||||
# Test loss computation after the row update
|
# Test loss computation after the row update
|
||||||
loss = sum(
|
loss = sum(
|
||||||
sess.run(factor_loss * count_cols(inp) / num_rows,
|
sess.run(factor_loss * self.count_cols(inp) / num_rows,
|
||||||
feed_dict={sp_feeder: inp})
|
feed_dict={sp_feeder: inp})
|
||||||
for inp in input_scattered_rows_non_duplicate)
|
for inp in input_scattered_rows_non_duplicate)
|
||||||
true_loss = calculate_loss_from_wals_model(
|
true_loss = self.calculate_loss_from_wals_model(
|
||||||
wals_model, self._wals_inputs)
|
wals_model, self._wals_inputs)
|
||||||
self.assertNear(
|
self.assertNear(
|
||||||
loss, true_loss, err=.001,
|
loss, true_loss, err=.001,
|
||||||
@ -516,10 +412,10 @@ class WalsModelTest(test.TestCase):
|
|||||||
if compute_loss:
|
if compute_loss:
|
||||||
# Test loss computation after the col update
|
# Test loss computation after the col update
|
||||||
loss = sum(
|
loss = sum(
|
||||||
sess.run(factor_loss * count_rows(inp) / num_cols,
|
sess.run(factor_loss * self.count_rows(inp) / num_cols,
|
||||||
feed_dict={sp_feeder: inp})
|
feed_dict={sp_feeder: inp})
|
||||||
for inp in input_scattered_cols_non_duplicate)
|
for inp in input_scattered_cols_non_duplicate)
|
||||||
true_loss = calculate_loss_from_wals_model(
|
true_loss = self.calculate_loss_from_wals_model(
|
||||||
wals_model, self._wals_inputs)
|
wals_model, self._wals_inputs)
|
||||||
self.assertNear(
|
self.assertNear(
|
||||||
loss, true_loss, err=.001,
|
loss, true_loss, err=.001,
|
||||||
@ -534,7 +430,7 @@ class WalsModelTest(test.TestCase):
|
|||||||
# Here we test that those two give identical results.
|
# Here we test that those two give identical results.
|
||||||
def _run_test_als(self, use_factors_weights_cache):
|
def _run_test_als(self, use_factors_weights_cache):
|
||||||
with ops.Graph().as_default(), self.test_session():
|
with ops.Graph().as_default(), self.test_session():
|
||||||
self._wals_inputs = sparse_input()
|
self._wals_inputs = self.sparse_input()
|
||||||
col_init = np.random.rand(7, 3)
|
col_init = np.random.rand(7, 3)
|
||||||
als_model = factorization_ops.WALSModel(
|
als_model = factorization_ops.WALSModel(
|
||||||
5,
|
5,
|
||||||
@ -613,7 +509,7 @@ class WalsModelTest(test.TestCase):
|
|||||||
|
|
||||||
def _run_test_als_transposed(self, use_factors_weights_cache):
|
def _run_test_als_transposed(self, use_factors_weights_cache):
|
||||||
with ops.Graph().as_default(), self.test_session():
|
with ops.Graph().as_default(), self.test_session():
|
||||||
self._wals_inputs = sparse_input()
|
self._wals_inputs = self.sparse_input()
|
||||||
col_init = np.random.rand(7, 3)
|
col_init = np.random.rand(7, 3)
|
||||||
als_model = factorization_ops.WALSModel(
|
als_model = factorization_ops.WALSModel(
|
||||||
5,
|
5,
|
||||||
|
@ -0,0 +1,131 @@
|
|||||||
|
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Test utils for factorization_ops."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import random
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from tensorflow.python.framework import constant_op
|
||||||
|
from tensorflow.python.framework import sparse_tensor
|
||||||
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.ops import math_ops
|
||||||
|
from tensorflow.python.ops import sparse_ops
|
||||||
|
|
||||||
|
|
||||||
|
INPUT_MATRIX = np.array(
|
||||||
|
[[0.1, 0.0, 0.2, 0.0, 0.4, 0.5, 0.0],
|
||||||
|
[0.0, 1.1, 0.0, 1.3, 1.4, 0.0, 1.6],
|
||||||
|
[2.0, 0.0, 0.0, 2.3, 0.0, 2.5, 0.0],
|
||||||
|
[3.0, 0.0, 3.2, 3.3, 0.0, 3.5, 0.0],
|
||||||
|
[0.0, 4.1, 0.0, 0.0, 4.4, 0.0, 4.6]]).astype(np.float32)
|
||||||
|
|
||||||
|
|
||||||
|
def np_matrix_to_tf_sparse(np_matrix,
|
||||||
|
row_slices=None,
|
||||||
|
col_slices=None,
|
||||||
|
transpose=False,
|
||||||
|
shuffle=False):
|
||||||
|
"""Simple util to slice non-zero np matrix elements as tf.SparseTensor."""
|
||||||
|
indices = np.nonzero(np_matrix)
|
||||||
|
|
||||||
|
# Only allow slices of whole rows or whole columns.
|
||||||
|
assert not (row_slices is not None and col_slices is not None)
|
||||||
|
|
||||||
|
if row_slices is not None:
|
||||||
|
selected_ind = np.concatenate(
|
||||||
|
[np.where(indices[0] == r)[0] for r in row_slices], 0)
|
||||||
|
indices = (indices[0][selected_ind], indices[1][selected_ind])
|
||||||
|
|
||||||
|
if col_slices is not None:
|
||||||
|
selected_ind = np.concatenate(
|
||||||
|
[np.where(indices[1] == c)[0] for c in col_slices], 0)
|
||||||
|
indices = (indices[0][selected_ind], indices[1][selected_ind])
|
||||||
|
|
||||||
|
if shuffle:
|
||||||
|
shuffled_ind = [x for x in range(len(indices[0]))]
|
||||||
|
random.shuffle(shuffled_ind)
|
||||||
|
indices = (indices[0][shuffled_ind], indices[1][shuffled_ind])
|
||||||
|
|
||||||
|
ind = (np.concatenate((np.expand_dims(indices[1], 1),
|
||||||
|
np.expand_dims(indices[0], 1)), 1).astype(np.int64) if
|
||||||
|
transpose else np.concatenate((np.expand_dims(indices[0], 1),
|
||||||
|
np.expand_dims(indices[1], 1)),
|
||||||
|
1).astype(np.int64))
|
||||||
|
val = np_matrix[indices].astype(np.float32)
|
||||||
|
shape = (np.array([max(indices[1]) + 1, max(indices[0]) + 1]).astype(np.int64)
|
||||||
|
if transpose else np.array(
|
||||||
|
[max(indices[0]) + 1, max(indices[1]) + 1]).astype(np.int64))
|
||||||
|
return sparse_tensor.SparseTensor(ind, val, shape)
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_loss(input_mat, row_factors, col_factors, regularization=None,
|
||||||
|
w0=1., row_weights=None, col_weights=None):
|
||||||
|
"""Calculates the loss of a given factorization.
|
||||||
|
|
||||||
|
Using a non distributed method, different than the one implemented in the
|
||||||
|
WALS model. The weight of an observed entry (i, j) (i.e. such that
|
||||||
|
input_mat[i, j] is non zero) is (w0 + row_weights[i]col_weights[j]).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_mat: The input matrix, a SparseTensor of rank 2.
|
||||||
|
row_factors: The row factors, a dense Tensor of rank 2.
|
||||||
|
col_factors: The col factors, a dense Tensor of rank 2.
|
||||||
|
regularization: the regularization coefficient, a scalar.
|
||||||
|
w0: the weight of unobserved entries. A scalar.
|
||||||
|
row_weights: A dense tensor of rank 1.
|
||||||
|
col_weights: A dense tensor of rank 1.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The total loss.
|
||||||
|
"""
|
||||||
|
wr = (array_ops.expand_dims(row_weights, 1) if row_weights is not None
|
||||||
|
else constant_op.constant(1.))
|
||||||
|
wc = (array_ops.expand_dims(col_weights, 0) if col_weights is not None
|
||||||
|
else constant_op.constant(1.))
|
||||||
|
reg = (regularization if regularization is not None
|
||||||
|
else constant_op.constant(0.))
|
||||||
|
|
||||||
|
row_indices, col_indices = array_ops.split(input_mat.indices,
|
||||||
|
axis=1,
|
||||||
|
num_or_size_splits=2)
|
||||||
|
gathered_row_factors = array_ops.gather(row_factors, row_indices)
|
||||||
|
gathered_col_factors = array_ops.gather(col_factors, col_indices)
|
||||||
|
sp_approx_vals = array_ops.squeeze(math_ops.matmul(
|
||||||
|
gathered_row_factors, gathered_col_factors, adjoint_b=True))
|
||||||
|
sp_approx = sparse_tensor.SparseTensor(
|
||||||
|
indices=input_mat.indices,
|
||||||
|
values=sp_approx_vals,
|
||||||
|
dense_shape=input_mat.dense_shape)
|
||||||
|
|
||||||
|
sp_approx_sq = math_ops.square(sp_approx)
|
||||||
|
row_norm = math_ops.reduce_sum(math_ops.square(row_factors))
|
||||||
|
col_norm = math_ops.reduce_sum(math_ops.square(col_factors))
|
||||||
|
row_col_norm = math_ops.reduce_sum(math_ops.square(math_ops.matmul(
|
||||||
|
row_factors, col_factors, transpose_b=True)))
|
||||||
|
|
||||||
|
resid = sparse_ops.sparse_add(input_mat, sp_approx * (-1))
|
||||||
|
resid_sq = math_ops.square(resid)
|
||||||
|
loss = w0 * (
|
||||||
|
sparse_ops.sparse_reduce_sum(resid_sq) -
|
||||||
|
sparse_ops.sparse_reduce_sum(sp_approx_sq)
|
||||||
|
)
|
||||||
|
loss += (sparse_ops.sparse_reduce_sum(wr * (resid_sq * wc)) +
|
||||||
|
w0 * row_col_norm + reg * (row_norm + col_norm))
|
||||||
|
return loss.eval()
|
@ -21,9 +21,14 @@ py_library(
|
|||||||
":dense_kernel_mapper_py",
|
":dense_kernel_mapper_py",
|
||||||
"//tensorflow/contrib/layers:layers_py",
|
"//tensorflow/contrib/layers:layers_py",
|
||||||
"//tensorflow/contrib/learn",
|
"//tensorflow/contrib/learn",
|
||||||
"//tensorflow/python:framework",
|
"//tensorflow/python:array_ops",
|
||||||
"//tensorflow/python:ops",
|
"//tensorflow/python:constant_op",
|
||||||
|
"//tensorflow/python:dtypes",
|
||||||
|
"//tensorflow/python:math_ops",
|
||||||
|
"//tensorflow/python:platform",
|
||||||
|
"//tensorflow/python:util",
|
||||||
"//third_party/py/numpy",
|
"//third_party/py/numpy",
|
||||||
|
"@six_archive//:six",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -31,6 +36,7 @@ py_library(
|
|||||||
name = "dense_kernel_mapper_py",
|
name = "dense_kernel_mapper_py",
|
||||||
srcs = ["python/mappers/dense_kernel_mapper.py"],
|
srcs = ["python/mappers/dense_kernel_mapper.py"],
|
||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
|
deps = ["@six_archive//:six"],
|
||||||
)
|
)
|
||||||
|
|
||||||
py_test(
|
py_test(
|
||||||
@ -40,12 +46,12 @@ py_test(
|
|||||||
deps = [
|
deps = [
|
||||||
":dense_kernel_mapper_py",
|
":dense_kernel_mapper_py",
|
||||||
":kernel_methods",
|
":kernel_methods",
|
||||||
"//tensorflow/python:client_testlib",
|
|
||||||
"//tensorflow/python:framework",
|
|
||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
"//tensorflow/python:framework_for_generated_wrappers",
|
||||||
"//tensorflow/python:framework_test_lib",
|
"//tensorflow/python:framework_test_lib",
|
||||||
|
"//tensorflow/python:math_ops",
|
||||||
"//tensorflow/python:nn",
|
"//tensorflow/python:nn",
|
||||||
"//tensorflow/python:ops",
|
"//tensorflow/python:platform_test",
|
||||||
|
"//tensorflow/python:random_ops",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -55,10 +61,12 @@ py_test(
|
|||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
deps = [
|
deps = [
|
||||||
":kernel_methods",
|
":kernel_methods",
|
||||||
"//tensorflow/python:client_testlib",
|
"//tensorflow/contrib/layers:layers_py",
|
||||||
|
"//tensorflow/contrib/learn",
|
||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
"//tensorflow/python:framework_for_generated_wrappers",
|
||||||
"//tensorflow/python:framework_test_lib",
|
"//tensorflow/python:framework_test_lib",
|
||||||
"//tensorflow/python:ops",
|
"//tensorflow/python:platform_test",
|
||||||
|
"//tensorflow/python:sparse_tensor",
|
||||||
"//third_party/py/numpy",
|
"//third_party/py/numpy",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -22,7 +22,6 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
from tensorflow.contrib.kernel_methods.python.kernel_estimators import KernelLinearClassifier
|
from tensorflow.contrib.kernel_methods.python.kernel_estimators import KernelLinearClassifier
|
||||||
from tensorflow.contrib.kernel_methods.python.mappers import dense_kernel_mapper
|
|
||||||
from tensorflow.contrib.kernel_methods.python.mappers.random_fourier_features import RandomFourierFeatureMapper
|
from tensorflow.contrib.kernel_methods.python.mappers.random_fourier_features import RandomFourierFeatureMapper
|
||||||
|
|
||||||
from tensorflow.python.util.all_util import remove_undocumented
|
from tensorflow.python.util.all_util import remove_undocumented
|
||||||
|
@ -118,6 +118,7 @@ tf_custom_op_py_library(
|
|||||||
"//tensorflow/python:array_ops",
|
"//tensorflow/python:array_ops",
|
||||||
"//tensorflow/python:check_ops",
|
"//tensorflow/python:check_ops",
|
||||||
"//tensorflow/python:clip_ops",
|
"//tensorflow/python:clip_ops",
|
||||||
|
"//tensorflow/python:common_shapes",
|
||||||
"//tensorflow/python:control_flow_ops",
|
"//tensorflow/python:control_flow_ops",
|
||||||
"//tensorflow/python:embedding_ops",
|
"//tensorflow/python:embedding_ops",
|
||||||
"//tensorflow/python:framework",
|
"//tensorflow/python:framework",
|
||||||
@ -131,9 +132,11 @@ tf_custom_op_py_library(
|
|||||||
"//tensorflow/python:platform",
|
"//tensorflow/python:platform",
|
||||||
"//tensorflow/python:random_ops",
|
"//tensorflow/python:random_ops",
|
||||||
"//tensorflow/python:sparse_ops",
|
"//tensorflow/python:sparse_ops",
|
||||||
|
"//tensorflow/python:sparse_tensor",
|
||||||
"//tensorflow/python:standard_ops",
|
"//tensorflow/python:standard_ops",
|
||||||
"//tensorflow/python:string_ops",
|
"//tensorflow/python:string_ops",
|
||||||
"//tensorflow/python:summary",
|
"//tensorflow/python:summary",
|
||||||
|
"//tensorflow/python:tensor_util",
|
||||||
"//tensorflow/python:training",
|
"//tensorflow/python:training",
|
||||||
"//tensorflow/python:util",
|
"//tensorflow/python:util",
|
||||||
"//tensorflow/python:variable_scope",
|
"//tensorflow/python:variable_scope",
|
||||||
|
@ -18,6 +18,8 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import functools
|
||||||
|
|
||||||
from tensorflow.contrib.framework.python.framework import checkpoint_utils
|
from tensorflow.contrib.framework.python.framework import checkpoint_utils
|
||||||
from tensorflow.contrib.framework.python.framework import experimental
|
from tensorflow.contrib.framework.python.framework import experimental
|
||||||
from tensorflow.contrib.framework.python.ops import variables as contrib_variables
|
from tensorflow.contrib.framework.python.ops import variables as contrib_variables
|
||||||
@ -36,6 +38,7 @@ from tensorflow.python.ops import sparse_ops
|
|||||||
from tensorflow.python.ops import variable_scope
|
from tensorflow.python.ops import variable_scope
|
||||||
from tensorflow.python.ops import variables
|
from tensorflow.python.ops import variables
|
||||||
from tensorflow.python.platform import tf_logging as logging
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
|
from tensorflow.python.util import nest
|
||||||
|
|
||||||
|
|
||||||
def _embeddings_from_arguments(column,
|
def _embeddings_from_arguments(column,
|
||||||
@ -136,6 +139,58 @@ def _embeddings_from_arguments(column,
|
|||||||
max_norm=args.max_norm)
|
max_norm=args.max_norm)
|
||||||
|
|
||||||
|
|
||||||
|
def _maybe_reshape_input_tensor(tensor, column_name, output_rank):
|
||||||
|
"""Reshape the input tensor by the following rule.
|
||||||
|
|
||||||
|
1. If `output_rank > input_rank + 1`, raise a `ValueError`.
|
||||||
|
2. If `output_rank == input_rank + 1`, expand the tensor by one dimension.
|
||||||
|
3. If `output_rank == input_rank`, do nothing.
|
||||||
|
4. If `output_rank < input_rank`, flatten the inner dimensions of the tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor: A Tensor or SparseTensor to be reshaped.
|
||||||
|
column_name: A string name of the feature column for the tensor.
|
||||||
|
output_rank: the desired rank of the tensor.
|
||||||
|
Returns:
|
||||||
|
A reshaped Tensor or SparseTensor.
|
||||||
|
Raises:
|
||||||
|
ValueError: if `output_rank > input_rank + 1` for the input tensor.
|
||||||
|
"""
|
||||||
|
input_rank = tensor.get_shape().ndims
|
||||||
|
|
||||||
|
if input_rank is None and isinstance(tensor, sparse_tensor_py.SparseTensor):
|
||||||
|
# Try to get the rank of a sparse tensor by its dense_shape's shape.
|
||||||
|
input_rank = tensor.dense_shape.get_shape().as_list()[0]
|
||||||
|
|
||||||
|
if input_rank is None:
|
||||||
|
raise ValueError('Error while processing column {}. Rank of input Tensor '
|
||||||
|
'can not be None.'.format(column_name))
|
||||||
|
|
||||||
|
if output_rank > input_rank + 1:
|
||||||
|
raise ValueError('Error while processing column {}. Rank of input Tensor '
|
||||||
|
'({}) should be the same as output_rank ({}). For '
|
||||||
|
'example, sequence data should typically be 3 '
|
||||||
|
'dimensional (rank 3) while non-sequence data is '
|
||||||
|
'typically 2 dimensional (rank 2).'.format(
|
||||||
|
column_name, input_rank, output_rank))
|
||||||
|
elif output_rank == input_rank + 1:
|
||||||
|
# Expand the tensor's shape by 1 dimension.
|
||||||
|
if isinstance(tensor, sparse_tensor_py.SparseTensor):
|
||||||
|
output_shape = array_ops.concat([tensor.dense_shape, [1]], 0)
|
||||||
|
return sparse_ops.sparse_reshape(tensor, output_shape)
|
||||||
|
else:
|
||||||
|
reshaped = array_ops.expand_dims(tensor, -1)
|
||||||
|
# Try to calculate the new shape.
|
||||||
|
static_shape = tensor.get_shape()
|
||||||
|
if static_shape is not None and static_shape.dims is not None:
|
||||||
|
reshaped.set_shape(static_shape.as_list() + [1])
|
||||||
|
return reshaped
|
||||||
|
elif output_rank < input_rank:
|
||||||
|
return layers._inner_flatten(tensor, output_rank) # pylint: disable=protected-access
|
||||||
|
else:
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
def _input_from_feature_columns(columns_to_tensors,
|
def _input_from_feature_columns(columns_to_tensors,
|
||||||
feature_columns,
|
feature_columns,
|
||||||
weight_collections,
|
weight_collections,
|
||||||
@ -160,6 +215,12 @@ def _input_from_feature_columns(columns_to_tensors,
|
|||||||
default_name=column.name,
|
default_name=column.name,
|
||||||
values=columns_to_tensors.values()):
|
values=columns_to_tensors.values()):
|
||||||
transformed_tensor = transformer.transform(column)
|
transformed_tensor = transformer.transform(column)
|
||||||
|
if output_rank == 3:
|
||||||
|
transformed_tensor = nest.map_structure(
|
||||||
|
functools.partial(
|
||||||
|
_maybe_reshape_input_tensor,
|
||||||
|
column_name=column.name,
|
||||||
|
output_rank=output_rank), transformed_tensor)
|
||||||
try:
|
try:
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
arguments = column._deep_embedding_lookup_arguments(
|
arguments = column._deep_embedding_lookup_arguments(
|
||||||
@ -548,7 +609,8 @@ def weighted_sum_from_feature_columns(columns_to_tensors,
|
|||||||
default_name=column.name,
|
default_name=column.name,
|
||||||
values=columns_to_tensors.values()):
|
values=columns_to_tensors.values()):
|
||||||
tensor = column._to_dense_tensor(transformed_tensor)
|
tensor = column._to_dense_tensor(transformed_tensor)
|
||||||
tensor = fc._reshape_real_valued_tensor(tensor, 2, column.name)
|
tensor = _maybe_reshape_input_tensor(
|
||||||
|
tensor, column.name, output_rank=2)
|
||||||
variable = [
|
variable = [
|
||||||
contrib_variables.model_variable(
|
contrib_variables.model_variable(
|
||||||
name='weight',
|
name='weight',
|
||||||
|
@ -1350,6 +1350,35 @@ class SequenceInputFromFeatureColumnTest(test.TestCase):
|
|||||||
|
|
||||||
self.assertAllEqual(expected_input_shape, model_input.shape)
|
self.assertAllEqual(expected_input_shape, model_input.shape)
|
||||||
|
|
||||||
|
def testEmbeddingColumnWithAutoReshape(self):
|
||||||
|
hash_buckets = 10
|
||||||
|
embedding_dimension = 5
|
||||||
|
ids_tensor = sparse_tensor.SparseTensor(
|
||||||
|
values=["c", "b",
|
||||||
|
"a", "c", "b",
|
||||||
|
"b"],
|
||||||
|
indices=[[0, 0], [0, 1],
|
||||||
|
[1, 0], [1, 1], [1, 2],
|
||||||
|
[3, 2]],
|
||||||
|
dense_shape=[4, 3])
|
||||||
|
|
||||||
|
expected_input_shape = np.array([4, 3, embedding_dimension])
|
||||||
|
|
||||||
|
hashed_ids_column = feature_column.sparse_column_with_hash_bucket(
|
||||||
|
"ids", hash_buckets)
|
||||||
|
embedded_column = feature_column.embedding_column(hashed_ids_column,
|
||||||
|
embedding_dimension)
|
||||||
|
columns_to_tensors = {"ids": ids_tensor}
|
||||||
|
model_input_tensor = feature_column_ops.sequence_input_from_feature_columns(
|
||||||
|
columns_to_tensors, [embedded_column])
|
||||||
|
|
||||||
|
with self.test_session() as sess:
|
||||||
|
variables_lib.global_variables_initializer().run()
|
||||||
|
data_flow_ops.tables_initializer().run()
|
||||||
|
model_input = sess.run(model_input_tensor)
|
||||||
|
|
||||||
|
self.assertAllEqual(expected_input_shape, model_input.shape)
|
||||||
|
|
||||||
def testEmbeddingColumnGradient(self):
|
def testEmbeddingColumnGradient(self):
|
||||||
hash_buckets = 1000
|
hash_buckets = 1000
|
||||||
embedding_dimension = 3
|
embedding_dimension = 3
|
||||||
|
@ -836,6 +836,19 @@ py_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "model_fn_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = ["python/learn/estimators/model_fn_test.py"],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
deps = [
|
||||||
|
":learn",
|
||||||
|
"//tensorflow/python:client_testlib",
|
||||||
|
"//tensorflow/python:framework_test_lib",
|
||||||
|
"//third_party/py/numpy",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
py_test(
|
py_test(
|
||||||
name = "multioutput_test",
|
name = "multioutput_test",
|
||||||
size = "small",
|
size = "small",
|
||||||
|
@ -42,6 +42,7 @@ from tensorflow.python.ops import logging_ops
|
|||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import nn
|
from tensorflow.python.ops import nn
|
||||||
from tensorflow.python.ops import sparse_ops
|
from tensorflow.python.ops import sparse_ops
|
||||||
|
from tensorflow.python.ops import string_ops
|
||||||
from tensorflow.python.ops import variable_scope
|
from tensorflow.python.ops import variable_scope
|
||||||
from tensorflow.python.ops import variables
|
from tensorflow.python.ops import variables
|
||||||
from tensorflow.python.summary import summary
|
from tensorflow.python.summary import summary
|
||||||
@ -816,7 +817,8 @@ class _BinaryLogisticHead(_SingleHead):
|
|||||||
loss_fn=self._loss_fn,
|
loss_fn=self._loss_fn,
|
||||||
logits_to_predictions_fn=self._logits_to_predictions,
|
logits_to_predictions_fn=self._logits_to_predictions,
|
||||||
metrics_fn=self._metrics,
|
metrics_fn=self._metrics,
|
||||||
create_output_alternatives_fn=self._create_output_alternatives,
|
create_output_alternatives_fn=_classification_output_alternatives(
|
||||||
|
self.head_name, self._problem_type),
|
||||||
labels=labels,
|
labels=labels,
|
||||||
train_op_fn=train_op_fn,
|
train_op_fn=train_op_fn,
|
||||||
logits=logits,
|
logits=logits,
|
||||||
@ -885,6 +887,8 @@ class _BinaryLogisticHead(_SingleHead):
|
|||||||
_indicator_labels_streaming_mean(labels, weights))
|
_indicator_labels_streaming_mean(labels, weights))
|
||||||
metrics[_summary_key(self.head_name, mkey.AUC)] = (
|
metrics[_summary_key(self.head_name, mkey.AUC)] = (
|
||||||
_streaming_auc(logistic, labels, weights))
|
_streaming_auc(logistic, labels, weights))
|
||||||
|
metrics[_summary_key(self.head_name, mkey.AUC_PR)] = (
|
||||||
|
_streaming_auc(logistic, labels, weights, curve="PR"))
|
||||||
|
|
||||||
for threshold in self._thresholds:
|
for threshold in self._thresholds:
|
||||||
metrics[_summary_key(
|
metrics[_summary_key(
|
||||||
@ -1009,7 +1013,8 @@ class _MultiClassHead(_SingleHead):
|
|||||||
loss_fn=self._wrapped_loss_fn,
|
loss_fn=self._wrapped_loss_fn,
|
||||||
logits_to_predictions_fn=self._logits_to_predictions,
|
logits_to_predictions_fn=self._logits_to_predictions,
|
||||||
metrics_fn=self._metrics,
|
metrics_fn=self._metrics,
|
||||||
create_output_alternatives_fn=self._create_output_alternatives,
|
create_output_alternatives_fn=_classification_output_alternatives(
|
||||||
|
self.head_name, self._problem_type, self._label_keys),
|
||||||
labels=labels,
|
labels=labels,
|
||||||
train_op_fn=train_op_fn,
|
train_op_fn=train_op_fn,
|
||||||
logits=logits,
|
logits=logits,
|
||||||
@ -1113,25 +1118,6 @@ class _MultiClassHead(_SingleHead):
|
|||||||
|
|
||||||
return metrics
|
return metrics
|
||||||
|
|
||||||
def _create_output_alternatives(self, predictions):
|
|
||||||
"""See superclass."""
|
|
||||||
probabilities = predictions[prediction_key.PredictionKey.PROBABILITIES]
|
|
||||||
batch_size = array_ops.shape(probabilities)[0]
|
|
||||||
if self._label_keys:
|
|
||||||
classes = array_ops.tile(
|
|
||||||
input=array_ops.expand_dims(input=self._label_keys, axis=0),
|
|
||||||
multiples=[batch_size, 1])
|
|
||||||
else:
|
|
||||||
classes = array_ops.tile(
|
|
||||||
input=array_ops.expand_dims(
|
|
||||||
input=math_ops.range(self.logits_dimension), axis=0),
|
|
||||||
multiples=[batch_size, 1])
|
|
||||||
predictions_for_serving = {
|
|
||||||
prediction_key.PredictionKey.CLASSES: classes,
|
|
||||||
prediction_key.PredictionKey.PROBABILITIES: probabilities,
|
|
||||||
}
|
|
||||||
return {self._head_name: (self._problem_type, predictions_for_serving)}
|
|
||||||
|
|
||||||
|
|
||||||
def _to_labels_tensor(labels, label_name):
|
def _to_labels_tensor(labels, label_name):
|
||||||
"""Returns label as a tensor.
|
"""Returns label as a tensor.
|
||||||
@ -1226,6 +1212,7 @@ class _BinarySvmHead(_SingleHead):
|
|||||||
loss_fn=self._loss_fn,
|
loss_fn=self._loss_fn,
|
||||||
logits_to_predictions_fn=self._logits_to_predictions,
|
logits_to_predictions_fn=self._logits_to_predictions,
|
||||||
metrics_fn=self._metrics,
|
metrics_fn=self._metrics,
|
||||||
|
# TODO(zakaria): Handle labels for export.
|
||||||
create_output_alternatives_fn=self._create_output_alternatives,
|
create_output_alternatives_fn=self._create_output_alternatives,
|
||||||
labels=labels,
|
labels=labels,
|
||||||
train_op_fn=train_op_fn,
|
train_op_fn=train_op_fn,
|
||||||
@ -1325,7 +1312,8 @@ class _MultiLabelHead(_SingleHead):
|
|||||||
loss_fn=self._loss_fn,
|
loss_fn=self._loss_fn,
|
||||||
logits_to_predictions_fn=self._logits_to_predictions,
|
logits_to_predictions_fn=self._logits_to_predictions,
|
||||||
metrics_fn=self._metrics,
|
metrics_fn=self._metrics,
|
||||||
create_output_alternatives_fn=self._create_output_alternatives,
|
create_output_alternatives_fn=_classification_output_alternatives(
|
||||||
|
self.head_name, self._problem_type),
|
||||||
labels=labels,
|
labels=labels,
|
||||||
train_op_fn=train_op_fn,
|
train_op_fn=train_op_fn,
|
||||||
logits=logits,
|
logits=logits,
|
||||||
@ -1374,6 +1362,8 @@ class _MultiLabelHead(_SingleHead):
|
|||||||
metrics_lib.streaming_accuracy(classes, labels, weights))
|
metrics_lib.streaming_accuracy(classes, labels, weights))
|
||||||
metrics[_summary_key(self.head_name, mkey.AUC)] = _streaming_auc(
|
metrics[_summary_key(self.head_name, mkey.AUC)] = _streaming_auc(
|
||||||
probabilities, labels, weights)
|
probabilities, labels, weights)
|
||||||
|
metrics[_summary_key(self.head_name, mkey.AUC_PR)] = _streaming_auc(
|
||||||
|
probabilities, labels, weights, curve="PR")
|
||||||
|
|
||||||
for class_id in self._metric_class_ids:
|
for class_id in self._metric_class_ids:
|
||||||
# TODO(ptucker): Add per-class accuracy, precision, recall.
|
# TODO(ptucker): Add per-class accuracy, precision, recall.
|
||||||
@ -1391,6 +1381,9 @@ class _MultiLabelHead(_SingleHead):
|
|||||||
_predictions_streaming_mean(logits, weights, class_id))
|
_predictions_streaming_mean(logits, weights, class_id))
|
||||||
metrics[_summary_key(self.head_name, mkey.CLASS_AUC % class_id)] = (
|
metrics[_summary_key(self.head_name, mkey.CLASS_AUC % class_id)] = (
|
||||||
_streaming_auc(probabilities, labels, weights, class_id))
|
_streaming_auc(probabilities, labels, weights, class_id))
|
||||||
|
metrics[_summary_key(self.head_name, mkey.CLASS_AUC_PR % class_id)] = (
|
||||||
|
_streaming_auc(probabilities, labels, weights, class_id,
|
||||||
|
curve="PR"))
|
||||||
|
|
||||||
return metrics
|
return metrics
|
||||||
|
|
||||||
@ -1857,7 +1850,8 @@ def _class_labels_streaming_mean(labels, weights, class_id):
|
|||||||
weights=weights)
|
weights=weights)
|
||||||
|
|
||||||
|
|
||||||
def _streaming_auc(predictions, labels, weights=None, class_id=None):
|
def _streaming_auc(predictions, labels, weights=None, class_id=None,
|
||||||
|
curve="ROC"):
|
||||||
predictions = ops.convert_to_tensor(predictions)
|
predictions = ops.convert_to_tensor(predictions)
|
||||||
labels = ops.convert_to_tensor(labels)
|
labels = ops.convert_to_tensor(labels)
|
||||||
if class_id is not None:
|
if class_id is not None:
|
||||||
@ -1866,7 +1860,8 @@ def _streaming_auc(predictions, labels, weights=None, class_id=None):
|
|||||||
return metrics_lib.streaming_auc(
|
return metrics_lib.streaming_auc(
|
||||||
predictions,
|
predictions,
|
||||||
math_ops.cast(labels, dtypes.bool),
|
math_ops.cast(labels, dtypes.bool),
|
||||||
weights=_float_weights_or_none(weights))
|
weights=_float_weights_or_none(weights),
|
||||||
|
curve=curve)
|
||||||
|
|
||||||
|
|
||||||
def _assert_class_id(class_id, num_classes=None):
|
def _assert_class_id(class_id, num_classes=None):
|
||||||
@ -1901,6 +1896,71 @@ def _streaming_recall_at_threshold(predictions, labels, weights, threshold):
|
|||||||
return array_ops.squeeze(precision_tensor), array_ops.squeeze(update_op)
|
return array_ops.squeeze(precision_tensor), array_ops.squeeze(update_op)
|
||||||
|
|
||||||
|
|
||||||
|
def _classification_output_alternatives(head_name, problem_type,
|
||||||
|
label_keys=None):
|
||||||
|
"""Creates a func to generate output alternatives for classification.
|
||||||
|
|
||||||
|
Servo expects classes to be a string tensor, and have the same dimensions
|
||||||
|
as the probabilities tensor. It should contain the labels of the corresponding
|
||||||
|
entries in probabilities. This function creates a new classes tensor that
|
||||||
|
satisfies these conditions and can be exported.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
head_name: Name of the head.
|
||||||
|
problem_type: `ProblemType`
|
||||||
|
label_keys: Optional label keys
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A function to generate output alternatives.
|
||||||
|
"""
|
||||||
|
def _create_output_alternatives(predictions):
|
||||||
|
"""Creates output alternative for the Head.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
predictions: a dict of {tensor_name: Tensor}, where 'tensor_name' is a
|
||||||
|
symbolic name for an output Tensor possibly but not necessarily taken
|
||||||
|
from `PredictionKey`, and 'Tensor' is the corresponding output Tensor
|
||||||
|
itself.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`dict` of {submodel_name: (problem_type, {tensor_name: Tensor})}, where
|
||||||
|
'submodel_name' is a submodel identifier that should be consistent across
|
||||||
|
the pipeline (here likely taken from the head_name),
|
||||||
|
'problem_type' is a `ProblemType`,
|
||||||
|
'tensor_name' is a symbolic name for an output Tensor possibly but not
|
||||||
|
necessarily taken from `PredictionKey`, and
|
||||||
|
'Tensor' is the corresponding output Tensor itself.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: if predictions does not have PredictionKey.PROBABILITIES key.
|
||||||
|
"""
|
||||||
|
probabilities = predictions.get(prediction_key.PredictionKey.PROBABILITIES)
|
||||||
|
if probabilities is None:
|
||||||
|
raise ValueError("%s missing in predictions" %
|
||||||
|
prediction_key.PredictionKey.PROBABILITIES)
|
||||||
|
|
||||||
|
with ops.name_scope(None, "_classification_output_alternatives",
|
||||||
|
(probabilities,)):
|
||||||
|
batch_size = array_ops.shape(probabilities)[0]
|
||||||
|
if label_keys:
|
||||||
|
classes = array_ops.tile(
|
||||||
|
input=array_ops.expand_dims(input=label_keys, axis=0),
|
||||||
|
multiples=[batch_size, 1],
|
||||||
|
name="classes_tensor")
|
||||||
|
else:
|
||||||
|
n = array_ops.shape(probabilities)[1]
|
||||||
|
classes = array_ops.tile(
|
||||||
|
input=array_ops.expand_dims(input=math_ops.range(n), axis=0),
|
||||||
|
multiples=[batch_size, 1])
|
||||||
|
classes = string_ops.as_string(classes, name="classes_tensor")
|
||||||
|
|
||||||
|
exported_predictions = {
|
||||||
|
prediction_key.PredictionKey.PROBABILITIES: probabilities,
|
||||||
|
prediction_key.PredictionKey.CLASSES: classes}
|
||||||
|
return {head_name: (problem_type, exported_predictions)}
|
||||||
|
|
||||||
|
return _create_output_alternatives
|
||||||
|
|
||||||
# Aliases
|
# Aliases
|
||||||
# TODO(zakaria): Remove these aliases, See b/34751732
|
# TODO(zakaria): Remove these aliases, See b/34751732
|
||||||
_regression_head = regression_head
|
_regression_head = regression_head
|
||||||
|
@ -297,11 +297,15 @@ class MultiLabelHeadTest(test.TestCase):
|
|||||||
def _expected_eval_metrics(self, expected_loss):
|
def _expected_eval_metrics(self, expected_loss):
|
||||||
return {
|
return {
|
||||||
"accuracy": 1. / 3,
|
"accuracy": 1. / 3,
|
||||||
"auc": 1. / 4,
|
|
||||||
"loss": expected_loss,
|
"loss": expected_loss,
|
||||||
|
"auc": 1. / 4,
|
||||||
"auc/class0": 1.,
|
"auc/class0": 1.,
|
||||||
"auc/class1": 1.,
|
"auc/class1": 1.,
|
||||||
"auc/class2": 0.,
|
"auc/class2": 0.,
|
||||||
|
"auc_precision_recall": 0.166667,
|
||||||
|
"auc_precision_recall/class0": 0,
|
||||||
|
"auc_precision_recall/class1": 0.,
|
||||||
|
"auc_precision_recall/class2": 1.,
|
||||||
"labels/actual_label_mean/class0": self._labels[0][0],
|
"labels/actual_label_mean/class0": self._labels[0][0],
|
||||||
"labels/actual_label_mean/class1": self._labels[0][1],
|
"labels/actual_label_mean/class1": self._labels[0][1],
|
||||||
"labels/actual_label_mean/class2": self._labels[0][2],
|
"labels/actual_label_mean/class2": self._labels[0][2],
|
||||||
@ -417,7 +421,7 @@ class MultiLabelHeadTest(test.TestCase):
|
|||||||
{}, model_fn.ModeKeys.TRAIN, self._labels, head_lib.no_op_train_fn,
|
{}, model_fn.ModeKeys.TRAIN, self._labels, head_lib.no_op_train_fn,
|
||||||
logits_input=((0., 0.),), logits=self._logits)
|
logits_input=((0., 0.),), logits=self._logits)
|
||||||
|
|
||||||
def testMultiLabelEvalMode(self):
|
def testMultiLabelEval(self):
|
||||||
n_classes = 3
|
n_classes = 3
|
||||||
head = head_lib.multi_label_head(
|
head = head_lib.multi_label_head(
|
||||||
n_classes=n_classes, metric_class_ids=range(n_classes))
|
n_classes=n_classes, metric_class_ids=range(n_classes))
|
||||||
@ -433,7 +437,7 @@ class MultiLabelHeadTest(test.TestCase):
|
|||||||
_assert_metrics(self, expected_loss,
|
_assert_metrics(self, expected_loss,
|
||||||
self._expected_eval_metrics(expected_loss), model_fn_ops)
|
self._expected_eval_metrics(expected_loss), model_fn_ops)
|
||||||
|
|
||||||
def testMultiClassEvalModeWithLargeLogits(self):
|
def testMultiClassEvalWithLargeLogits(self):
|
||||||
n_classes = 3
|
n_classes = 3
|
||||||
head = head_lib.multi_label_head(
|
head = head_lib.multi_label_head(
|
||||||
n_classes=n_classes, metric_class_ids=range(n_classes))
|
n_classes=n_classes, metric_class_ids=range(n_classes))
|
||||||
@ -472,6 +476,36 @@ class MultiLabelHeadTest(test.TestCase):
|
|||||||
_assert_metrics(self, expected_loss,
|
_assert_metrics(self, expected_loss,
|
||||||
expected_eval_metrics, model_fn_ops)
|
expected_eval_metrics, model_fn_ops)
|
||||||
|
|
||||||
|
def testMultiLabelInfer(self):
|
||||||
|
n_classes = 3
|
||||||
|
head = head_lib.multi_label_head(n_classes=n_classes, head_name="head_name")
|
||||||
|
with ops.Graph().as_default(), session.Session():
|
||||||
|
model_fn_ops = head.create_model_fn_ops(
|
||||||
|
{}, model_fn.ModeKeys.INFER, self._labels, head_lib.no_op_train_fn,
|
||||||
|
logits=((1., 0., 0.), (0., 0., 1)))
|
||||||
|
self.assertIsNone(model_fn_ops.train_op)
|
||||||
|
_assert_no_variables(self)
|
||||||
|
with session.Session():
|
||||||
|
self.assertListEqual(
|
||||||
|
[1, 0, 0], model_fn_ops.predictions["classes"].eval().tolist()[0])
|
||||||
|
self.assertItemsEqual(
|
||||||
|
["head_name"], six.iterkeys(model_fn_ops.output_alternatives))
|
||||||
|
self.assertEqual(
|
||||||
|
constants.ProblemType.CLASSIFICATION,
|
||||||
|
model_fn_ops.output_alternatives["head_name"][0])
|
||||||
|
|
||||||
|
predictions_for_serving = (
|
||||||
|
model_fn_ops.output_alternatives["head_name"][1])
|
||||||
|
self.assertIn("classes", six.iterkeys(predictions_for_serving))
|
||||||
|
self.assertAllEqual(
|
||||||
|
[[b"0", b"1", b"2"], [b"0", b"1", b"2"]],
|
||||||
|
predictions_for_serving["classes"].eval())
|
||||||
|
self.assertIn("probabilities", six.iterkeys(predictions_for_serving))
|
||||||
|
self.assertAllClose(
|
||||||
|
[[0.731059, 0.5, 0.5],
|
||||||
|
[0.5, 0.5, 0.731059,]],
|
||||||
|
predictions_for_serving["probabilities"].eval())
|
||||||
|
|
||||||
def testMultiLabelWithLabelName(self):
|
def testMultiLabelWithLabelName(self):
|
||||||
n_classes = 3
|
n_classes = 3
|
||||||
label_name = "my_label"
|
label_name = "my_label"
|
||||||
@ -621,6 +655,7 @@ class BinaryClassificationHeadTest(test.TestCase):
|
|||||||
"accuracy/baseline_label_mean": label_mean,
|
"accuracy/baseline_label_mean": label_mean,
|
||||||
"accuracy/threshold_0.500000_mean": 1. / 2,
|
"accuracy/threshold_0.500000_mean": 1. / 2,
|
||||||
"auc": 1. / 2,
|
"auc": 1. / 2,
|
||||||
|
"auc_precision_recall": 0.749999,
|
||||||
"labels/actual_label_mean": label_mean,
|
"labels/actual_label_mean": label_mean,
|
||||||
"labels/prediction_mean": .731059, # softmax
|
"labels/prediction_mean": .731059, # softmax
|
||||||
"loss": expected_loss,
|
"loss": expected_loss,
|
||||||
@ -691,7 +726,7 @@ class BinaryClassificationHeadTest(test.TestCase):
|
|||||||
{}, model_fn.ModeKeys.TRAIN, self._labels, head_lib.no_op_train_fn,
|
{}, model_fn.ModeKeys.TRAIN, self._labels, head_lib.no_op_train_fn,
|
||||||
logits_input=((0., 0.), (0., 0.)), logits=self._logits)
|
logits_input=((0., 0.), (0., 0.)), logits=self._logits)
|
||||||
|
|
||||||
def testBinaryClassificationEvalMode(self):
|
def testBinaryClassificationEval(self):
|
||||||
n_classes = 2
|
n_classes = 2
|
||||||
head = head_lib.multi_class_head(n_classes=n_classes)
|
head = head_lib.multi_class_head(n_classes=n_classes)
|
||||||
with ops.Graph().as_default(), session.Session():
|
with ops.Graph().as_default(), session.Session():
|
||||||
@ -708,18 +743,32 @@ class BinaryClassificationHeadTest(test.TestCase):
|
|||||||
_assert_metrics(self, expected_loss,
|
_assert_metrics(self, expected_loss,
|
||||||
self._expected_eval_metrics(expected_loss), model_fn_ops)
|
self._expected_eval_metrics(expected_loss), model_fn_ops)
|
||||||
|
|
||||||
def testBinaryClassificationInferMode(self):
|
def testBinaryClassificationInfer(self):
|
||||||
n_classes = 2
|
n_classes = 2
|
||||||
head = head_lib.multi_class_head(n_classes=n_classes)
|
head = head_lib.multi_class_head(n_classes=n_classes, head_name="head_name")
|
||||||
with ops.Graph().as_default(), session.Session():
|
with ops.Graph().as_default(), session.Session():
|
||||||
# logloss: z:label, x:logit
|
# logloss: z:label, x:logit
|
||||||
# z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
|
# z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
|
||||||
model_fn_ops = head.create_model_fn_ops(
|
model_fn_ops = head.create_model_fn_ops(
|
||||||
{}, model_fn.ModeKeys.INFER, self._labels, head_lib.no_op_train_fn,
|
{}, model_fn.ModeKeys.INFER, self._labels, head_lib.no_op_train_fn,
|
||||||
logits=self._logits)
|
logits=self._logits)
|
||||||
self._assert_output_alternatives(model_fn_ops)
|
|
||||||
self.assertIsNone(model_fn_ops.train_op)
|
self.assertIsNone(model_fn_ops.train_op)
|
||||||
_assert_no_variables(self)
|
_assert_no_variables(self)
|
||||||
|
with session.Session():
|
||||||
|
self.assertListEqual(
|
||||||
|
[1, 1], list(model_fn_ops.predictions["classes"].eval()))
|
||||||
|
self.assertItemsEqual(
|
||||||
|
["head_name"], six.iterkeys(model_fn_ops.output_alternatives))
|
||||||
|
self.assertEqual(
|
||||||
|
constants.ProblemType.LOGISTIC_REGRESSION,
|
||||||
|
model_fn_ops.output_alternatives["head_name"][0])
|
||||||
|
predictions_for_serving = (
|
||||||
|
model_fn_ops.output_alternatives["head_name"][1])
|
||||||
|
self.assertIn("classes", six.iterkeys(predictions_for_serving))
|
||||||
|
predicted_classes = predictions_for_serving["classes"].eval().tolist()
|
||||||
|
self.assertListEqual(
|
||||||
|
[b"0", b"1"], predicted_classes[0])
|
||||||
|
self.assertIn("probabilities", six.iterkeys(predictions_for_serving))
|
||||||
|
|
||||||
def testBinaryClassificationInferMode_withWightColumn(self):
|
def testBinaryClassificationInferMode_withWightColumn(self):
|
||||||
n_classes = 2
|
n_classes = 2
|
||||||
@ -1006,7 +1055,7 @@ class MultiClassHeadTest(test.TestCase):
|
|||||||
"multi_class_head/centered_bias/bias_1",
|
"multi_class_head/centered_bias/bias_1",
|
||||||
"multi_class_head/centered_bias/bias_2"])
|
"multi_class_head/centered_bias/bias_2"])
|
||||||
|
|
||||||
def testMultiClassEvalMode(self):
|
def testMultiClassEval(self):
|
||||||
n_classes = 3
|
n_classes = 3
|
||||||
head = head_lib.multi_class_head(
|
head = head_lib.multi_class_head(
|
||||||
n_classes=n_classes, metric_class_ids=range(n_classes))
|
n_classes=n_classes, metric_class_ids=range(n_classes))
|
||||||
@ -1131,7 +1180,7 @@ class MultiClassHeadTest(test.TestCase):
|
|||||||
model_fn_ops.output_alternatives["head_name"][1])
|
model_fn_ops.output_alternatives["head_name"][1])
|
||||||
self.assertIn("classes", six.iterkeys(predictions_for_serving))
|
self.assertIn("classes", six.iterkeys(predictions_for_serving))
|
||||||
self.assertAllEqual(
|
self.assertAllEqual(
|
||||||
[[0, 1, 2], [0, 1, 2]],
|
[[b"0", b"1", b"2"], [b"0", b"1", b"2"]],
|
||||||
predictions_for_serving["classes"].eval())
|
predictions_for_serving["classes"].eval())
|
||||||
self.assertIn("probabilities", six.iterkeys(predictions_for_serving))
|
self.assertIn("probabilities", six.iterkeys(predictions_for_serving))
|
||||||
self.assertAllClose(
|
self.assertAllClose(
|
||||||
|
@ -22,7 +22,9 @@ class MetricKey(object):
|
|||||||
"""Metric key strings."""
|
"""Metric key strings."""
|
||||||
LOSS = "loss"
|
LOSS = "loss"
|
||||||
AUC = "auc"
|
AUC = "auc"
|
||||||
|
AUC_PR = "auc_precision_recall"
|
||||||
CLASS_AUC = "auc/class%d"
|
CLASS_AUC = "auc/class%d"
|
||||||
|
CLASS_AUC_PR = "auc_precision_recall/class%d"
|
||||||
PREDICTION_MEAN = "labels/prediction_mean"
|
PREDICTION_MEAN = "labels/prediction_mean"
|
||||||
CLASS_PREDICTION_MEAN = "labels/prediction_mean/class%d"
|
CLASS_PREDICTION_MEAN = "labels/prediction_mean/class%d"
|
||||||
CLASS_LOGITS_MEAN = "labels/logits_mean/class%d"
|
CLASS_LOGITS_MEAN = "labels/logits_mean/class%d"
|
||||||
|
@ -25,10 +25,16 @@ import six
|
|||||||
|
|
||||||
from tensorflow.contrib import framework as contrib_framework
|
from tensorflow.contrib import framework as contrib_framework
|
||||||
from tensorflow.contrib.framework import get_graph_from_inputs
|
from tensorflow.contrib.framework import get_graph_from_inputs
|
||||||
|
from tensorflow.contrib.learn.python.learn.estimators import constants
|
||||||
|
from tensorflow.contrib.learn.python.learn.estimators import prediction_key
|
||||||
|
from tensorflow.python.estimator import model_fn as core_model_fn_lib
|
||||||
|
from tensorflow.python.estimator.export import export_output as core_export_lib
|
||||||
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import tensor_shape
|
from tensorflow.python.framework import tensor_shape
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
|
from tensorflow.python.saved_model import signature_constants
|
||||||
from tensorflow.python.training import session_run_hook
|
from tensorflow.python.training import session_run_hook
|
||||||
|
|
||||||
|
|
||||||
@ -177,3 +183,85 @@ class ModelFnOps(
|
|||||||
training_chief_hooks=training_chief_hooks,
|
training_chief_hooks=training_chief_hooks,
|
||||||
training_hooks=training_hooks,
|
training_hooks=training_hooks,
|
||||||
scaffold=scaffold)
|
scaffold=scaffold)
|
||||||
|
|
||||||
|
def estimator_spec(self, mode, default_serving_output_alternative_key=None):
|
||||||
|
"""Creates an equivalent `EstimatorSpec`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mode: One of `ModeKeys`. Specifies if this training, evaluation or
|
||||||
|
prediction.
|
||||||
|
default_serving_output_alternative_key: Required for multiple heads. If
|
||||||
|
you have multiple entries in `output_alternatives` dict (comparable to
|
||||||
|
multiple heads), `EstimatorSpec` requires a default head that will be
|
||||||
|
used if a Servo request does not explicitly mention which head to infer
|
||||||
|
on. Pass the key of the output alternative here that you want to
|
||||||
|
designate as default. A separate ExportOutpout for this default head
|
||||||
|
wil be added to the export_outputs dict with the special key
|
||||||
|
signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY, unless there is
|
||||||
|
already an enry in output_alternatives with this special key.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Instance of `EstimatorSpec` that is equivalent to this `ModelFnOps`
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If problem type is unknown.
|
||||||
|
"""
|
||||||
|
def _scores(output_tensors):
|
||||||
|
scores = output_tensors.get(prediction_key.PredictionKey.SCORES)
|
||||||
|
if scores is None:
|
||||||
|
scores = output_tensors.get(prediction_key.PredictionKey.PROBABILITIES)
|
||||||
|
return scores
|
||||||
|
|
||||||
|
def _classes(output_tensors): # pylint: disable=missing-docstring
|
||||||
|
classes = output_tensors.get(prediction_key.PredictionKey.CLASSES)
|
||||||
|
if classes is None:
|
||||||
|
logging.warning(
|
||||||
|
'classes is None, Servo inference will not have class ids.')
|
||||||
|
return None
|
||||||
|
elif classes.dtype != dtypes.string:
|
||||||
|
# Servo classification can only serve string classes
|
||||||
|
logging.warning(
|
||||||
|
'classes is not string, Servo inference will not have class ids.')
|
||||||
|
return None
|
||||||
|
|
||||||
|
return classes
|
||||||
|
|
||||||
|
def _export_output(problem_type, predictions): # pylint: disable=missing-docstring
|
||||||
|
if problem_type == constants.ProblemType.LINEAR_REGRESSION:
|
||||||
|
return core_export_lib.RegressionOutput(_scores(predictions))
|
||||||
|
|
||||||
|
if (problem_type == constants.ProblemType.CLASSIFICATION or
|
||||||
|
problem_type == constants.ProblemType.LOGISTIC_REGRESSION):
|
||||||
|
return core_export_lib.ClassificationOutput(
|
||||||
|
scores=_scores(predictions), classes=_classes(predictions))
|
||||||
|
|
||||||
|
if problem_type == constants.ProblemType.UNSPECIFIED:
|
||||||
|
return core_export_lib.PredictOutput(predictions)
|
||||||
|
|
||||||
|
raise ValueError('Unknown problem_type=%s' % problem_type)
|
||||||
|
|
||||||
|
# Converts output_alternatives
|
||||||
|
export_outputs_dict = None
|
||||||
|
if self.output_alternatives:
|
||||||
|
output_alternatives = self.output_alternatives
|
||||||
|
# Adds default output_alternative if needed.
|
||||||
|
if (len(output_alternatives) > 1 and
|
||||||
|
signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY not in
|
||||||
|
output_alternatives):
|
||||||
|
output_alternatives = output_alternatives.copy()
|
||||||
|
output_alternatives[
|
||||||
|
signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] = (
|
||||||
|
output_alternatives[default_serving_output_alternative_key])
|
||||||
|
export_outputs_dict = {key: _export_output(*val) for key, val in
|
||||||
|
output_alternatives.items()}
|
||||||
|
|
||||||
|
return core_model_fn_lib.EstimatorSpec(
|
||||||
|
mode=mode,
|
||||||
|
predictions=self.predictions,
|
||||||
|
loss=self.loss,
|
||||||
|
train_op=self.train_op,
|
||||||
|
eval_metric_ops=self.eval_metric_ops,
|
||||||
|
export_outputs=export_outputs_dict,
|
||||||
|
training_chief_hooks=self.training_chief_hooks,
|
||||||
|
training_hooks=self.training_hooks,
|
||||||
|
scaffold=self.scaffold)
|
||||||
|
@ -0,0 +1,279 @@
|
|||||||
|
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""ModelFnOps tests."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
|
||||||
|
from tensorflow.contrib.learn.python.learn.estimators import constants
|
||||||
|
from tensorflow.contrib.learn.python.learn.estimators import model_fn
|
||||||
|
from tensorflow.python.client import session
|
||||||
|
from tensorflow.python.estimator.export import export_output as core_export_lib
|
||||||
|
from tensorflow.python.framework import constant_op
|
||||||
|
from tensorflow.python.ops import control_flow_ops
|
||||||
|
from tensorflow.python.platform import test
|
||||||
|
from tensorflow.python.saved_model import signature_constants
|
||||||
|
from tensorflow.python.training import basic_session_run_hooks
|
||||||
|
from tensorflow.python.training import monitored_session
|
||||||
|
|
||||||
|
|
||||||
|
class ModelFnopsTest(test.TestCase):
|
||||||
|
"""Multi-output tests."""
|
||||||
|
|
||||||
|
def create_predictions(self):
|
||||||
|
probabilities = constant_op.constant([1., 1., 1.])
|
||||||
|
scores = constant_op.constant([1., 2., 3.])
|
||||||
|
classes = constant_op.constant([b"0", b"1", b"2"])
|
||||||
|
return {
|
||||||
|
"probabilities": probabilities,
|
||||||
|
"scores": scores,
|
||||||
|
"classes": classes}
|
||||||
|
|
||||||
|
def create_model_fn_ops(self, predictions, output_alternatives,
|
||||||
|
mode=model_fn.ModeKeys.INFER):
|
||||||
|
|
||||||
|
return model_fn.ModelFnOps(
|
||||||
|
model_fn.ModeKeys.INFER,
|
||||||
|
predictions=predictions,
|
||||||
|
loss=constant_op.constant([1]),
|
||||||
|
train_op=control_flow_ops.no_op(),
|
||||||
|
eval_metric_ops={"metric_key": (control_flow_ops.no_op(),
|
||||||
|
control_flow_ops.no_op())},
|
||||||
|
# zzz
|
||||||
|
training_chief_hooks=[basic_session_run_hooks.StepCounterHook()],
|
||||||
|
training_hooks=[basic_session_run_hooks.StepCounterHook()],
|
||||||
|
output_alternatives=output_alternatives,
|
||||||
|
scaffold=monitored_session.Scaffold())
|
||||||
|
|
||||||
|
def assertEquals_except_export(self, model_fn_ops, estimator_spec):
|
||||||
|
self.assertEqual(model_fn_ops.predictions, estimator_spec.predictions)
|
||||||
|
self.assertEqual(model_fn_ops.loss, estimator_spec.loss)
|
||||||
|
self.assertEqual(model_fn_ops.train_op, estimator_spec.train_op)
|
||||||
|
self.assertEqual(model_fn_ops.eval_metric_ops,
|
||||||
|
estimator_spec.eval_metric_ops)
|
||||||
|
self.assertEqual(model_fn_ops.training_chief_hooks,
|
||||||
|
estimator_spec.training_chief_hooks)
|
||||||
|
self.assertEqual(model_fn_ops.training_hooks, estimator_spec.training_hooks)
|
||||||
|
self.assertEqual(model_fn_ops.scaffold, estimator_spec.scaffold)
|
||||||
|
|
||||||
|
def testEstimatorSpec_except_export(self):
|
||||||
|
predictions = self.create_predictions()
|
||||||
|
model_fn_ops = self.create_model_fn_ops(predictions, None)
|
||||||
|
|
||||||
|
estimator_spec = model_fn_ops.estimator_spec(model_fn.ModeKeys.INFER)
|
||||||
|
self.assertEquals_except_export(model_fn_ops, estimator_spec)
|
||||||
|
|
||||||
|
def testEstimatorSpec_export_regression_with_scores(self):
|
||||||
|
predictions = self.create_predictions()
|
||||||
|
output_alternatives = {"regression_head": (
|
||||||
|
constants.ProblemType.LINEAR_REGRESSION, predictions)}
|
||||||
|
model_fn_ops = self.create_model_fn_ops(predictions, output_alternatives)
|
||||||
|
|
||||||
|
estimator_spec = model_fn_ops.estimator_spec(model_fn.ModeKeys.INFER)
|
||||||
|
self.assertEquals_except_export(model_fn_ops, estimator_spec)
|
||||||
|
|
||||||
|
with session.Session():
|
||||||
|
regression_output = estimator_spec.export_outputs["regression_head"]
|
||||||
|
self.assertTrue(isinstance(
|
||||||
|
regression_output, core_export_lib.RegressionOutput))
|
||||||
|
self.assertAllEqual(predictions["scores"].eval(),
|
||||||
|
regression_output.value.eval())
|
||||||
|
|
||||||
|
def testEstimatorSpec_export_regression_with_probabilities(self):
|
||||||
|
predictions = self.create_predictions()
|
||||||
|
output_alternatives_predictions = predictions.copy()
|
||||||
|
del output_alternatives_predictions["scores"]
|
||||||
|
output_alternatives = {"regression_head": (
|
||||||
|
constants.ProblemType.LINEAR_REGRESSION,
|
||||||
|
output_alternatives_predictions)}
|
||||||
|
model_fn_ops = self.create_model_fn_ops(predictions, output_alternatives)
|
||||||
|
|
||||||
|
estimator_spec = model_fn_ops.estimator_spec(model_fn.ModeKeys.INFER)
|
||||||
|
self.assertEquals_except_export(model_fn_ops, estimator_spec)
|
||||||
|
|
||||||
|
with session.Session():
|
||||||
|
regression_output = estimator_spec.export_outputs["regression_head"]
|
||||||
|
self.assertTrue(isinstance(
|
||||||
|
regression_output, core_export_lib.RegressionOutput))
|
||||||
|
self.assertAllEqual(predictions["probabilities"].eval(),
|
||||||
|
regression_output.value.eval())
|
||||||
|
|
||||||
|
def testEstimatorSpec_export_classsification(self):
|
||||||
|
predictions = self.create_predictions()
|
||||||
|
output_alternatives = {"classification_head": (
|
||||||
|
constants.ProblemType.CLASSIFICATION, predictions)}
|
||||||
|
model_fn_ops = self.create_model_fn_ops(predictions, output_alternatives)
|
||||||
|
|
||||||
|
estimator_spec = model_fn_ops.estimator_spec(model_fn.ModeKeys.INFER)
|
||||||
|
self.assertEquals_except_export(model_fn_ops, estimator_spec)
|
||||||
|
|
||||||
|
with session.Session():
|
||||||
|
classification_output = estimator_spec.export_outputs[
|
||||||
|
"classification_head"]
|
||||||
|
self.assertTrue(isinstance(classification_output,
|
||||||
|
core_export_lib.ClassificationOutput))
|
||||||
|
self.assertAllEqual(predictions["scores"].eval(),
|
||||||
|
classification_output.scores.eval())
|
||||||
|
self.assertAllEqual(predictions["classes"].eval(),
|
||||||
|
classification_output.classes.eval())
|
||||||
|
|
||||||
|
def testEstimatorSpec_export_classsification_with_missing_scores(self):
|
||||||
|
predictions = self.create_predictions()
|
||||||
|
output_alternatives_predictions = predictions.copy()
|
||||||
|
del output_alternatives_predictions["scores"]
|
||||||
|
output_alternatives = {"classification_head": (
|
||||||
|
constants.ProblemType.CLASSIFICATION, output_alternatives_predictions)}
|
||||||
|
model_fn_ops = self.create_model_fn_ops(predictions, output_alternatives)
|
||||||
|
|
||||||
|
estimator_spec = model_fn_ops.estimator_spec(model_fn.ModeKeys.INFER)
|
||||||
|
self.assertEquals_except_export(model_fn_ops, estimator_spec)
|
||||||
|
|
||||||
|
with session.Session():
|
||||||
|
classification_output = estimator_spec.export_outputs[
|
||||||
|
"classification_head"]
|
||||||
|
self.assertTrue(isinstance(classification_output,
|
||||||
|
core_export_lib.ClassificationOutput))
|
||||||
|
self.assertAllEqual(predictions["probabilities"].eval(),
|
||||||
|
classification_output.scores.eval())
|
||||||
|
self.assertAllEqual(predictions["classes"].eval(),
|
||||||
|
classification_output.classes.eval())
|
||||||
|
|
||||||
|
def testEstimatorSpec_export_classsification_with_missing_scores_proba(self):
|
||||||
|
predictions = self.create_predictions()
|
||||||
|
output_alternatives_predictions = predictions.copy()
|
||||||
|
del output_alternatives_predictions["scores"]
|
||||||
|
del output_alternatives_predictions["probabilities"]
|
||||||
|
output_alternatives = {"classification_head": (
|
||||||
|
constants.ProblemType.CLASSIFICATION, output_alternatives_predictions)}
|
||||||
|
model_fn_ops = self.create_model_fn_ops(predictions, output_alternatives)
|
||||||
|
|
||||||
|
estimator_spec = model_fn_ops.estimator_spec(model_fn.ModeKeys.INFER)
|
||||||
|
self.assertEquals_except_export(model_fn_ops, estimator_spec)
|
||||||
|
|
||||||
|
with session.Session():
|
||||||
|
classification_output = estimator_spec.export_outputs[
|
||||||
|
"classification_head"]
|
||||||
|
self.assertTrue(isinstance(classification_output,
|
||||||
|
core_export_lib.ClassificationOutput))
|
||||||
|
self.assertIsNone(classification_output.scores)
|
||||||
|
self.assertAllEqual(predictions["classes"].eval(),
|
||||||
|
classification_output.classes.eval())
|
||||||
|
|
||||||
|
def testEstimatorSpec_export_classsification_with_missing_classes(self):
|
||||||
|
predictions = self.create_predictions()
|
||||||
|
output_alternatives_predictions = predictions.copy()
|
||||||
|
del output_alternatives_predictions["classes"]
|
||||||
|
output_alternatives = {"classification_head": (
|
||||||
|
constants.ProblemType.CLASSIFICATION, output_alternatives_predictions)}
|
||||||
|
model_fn_ops = self.create_model_fn_ops(predictions, output_alternatives)
|
||||||
|
|
||||||
|
estimator_spec = model_fn_ops.estimator_spec(model_fn.ModeKeys.INFER)
|
||||||
|
self.assertEquals_except_export(model_fn_ops, estimator_spec)
|
||||||
|
|
||||||
|
with session.Session():
|
||||||
|
classification_output = estimator_spec.export_outputs[
|
||||||
|
"classification_head"]
|
||||||
|
self.assertTrue(isinstance(classification_output,
|
||||||
|
core_export_lib.ClassificationOutput))
|
||||||
|
self.assertAllEqual(predictions["scores"].eval(),
|
||||||
|
classification_output.scores.eval())
|
||||||
|
self.assertIsNone(classification_output.classes)
|
||||||
|
|
||||||
|
def testEstimatorSpec_export_classsification_with_nonstring_classes(self):
|
||||||
|
predictions = self.create_predictions()
|
||||||
|
output_alternatives_predictions = predictions.copy()
|
||||||
|
output_alternatives_predictions["classes"] = constant_op.constant(
|
||||||
|
[1, 2, 3])
|
||||||
|
output_alternatives = {"classification_head": (
|
||||||
|
constants.ProblemType.CLASSIFICATION, output_alternatives_predictions)}
|
||||||
|
model_fn_ops = self.create_model_fn_ops(predictions, output_alternatives)
|
||||||
|
|
||||||
|
estimator_spec = model_fn_ops.estimator_spec(model_fn.ModeKeys.INFER)
|
||||||
|
self.assertEquals_except_export(model_fn_ops, estimator_spec)
|
||||||
|
|
||||||
|
with session.Session():
|
||||||
|
classification_output = estimator_spec.export_outputs[
|
||||||
|
"classification_head"]
|
||||||
|
self.assertTrue(isinstance(classification_output,
|
||||||
|
core_export_lib.ClassificationOutput))
|
||||||
|
self.assertAllEqual(predictions["scores"].eval(),
|
||||||
|
classification_output.scores.eval())
|
||||||
|
self.assertIsNone(classification_output.classes)
|
||||||
|
|
||||||
|
def testEstimatorSpec_export_logistic(self):
|
||||||
|
predictions = self.create_predictions()
|
||||||
|
output_alternatives = {"logistic_head": (
|
||||||
|
constants.ProblemType.LOGISTIC_REGRESSION, predictions)}
|
||||||
|
model_fn_ops = self.create_model_fn_ops(predictions, output_alternatives)
|
||||||
|
|
||||||
|
estimator_spec = model_fn_ops.estimator_spec(model_fn.ModeKeys.INFER)
|
||||||
|
self.assertEquals_except_export(model_fn_ops, estimator_spec)
|
||||||
|
|
||||||
|
with session.Session():
|
||||||
|
logistic_output = estimator_spec.export_outputs["logistic_head"]
|
||||||
|
self.assertTrue(isinstance(logistic_output,
|
||||||
|
core_export_lib.ClassificationOutput))
|
||||||
|
self.assertAllEqual(predictions["scores"].eval(),
|
||||||
|
logistic_output.scores.eval())
|
||||||
|
self.assertAllEqual(predictions["classes"].eval(),
|
||||||
|
logistic_output.classes.eval())
|
||||||
|
|
||||||
|
def testEstimatorSpec_export_unspecified(self):
|
||||||
|
predictions = self.create_predictions()
|
||||||
|
output_alternatives = {"unspecified_head": (
|
||||||
|
constants.ProblemType.UNSPECIFIED, predictions)}
|
||||||
|
|
||||||
|
model_fn_ops = self.create_model_fn_ops(predictions, output_alternatives)
|
||||||
|
|
||||||
|
estimator_spec = model_fn_ops.estimator_spec(model_fn.ModeKeys.INFER)
|
||||||
|
self.assertEquals_except_export(model_fn_ops, estimator_spec)
|
||||||
|
|
||||||
|
with session.Session():
|
||||||
|
unspecified_output = estimator_spec.export_outputs["unspecified_head"]
|
||||||
|
self.assertTrue(isinstance(unspecified_output,
|
||||||
|
core_export_lib.PredictOutput))
|
||||||
|
self.assertEqual(predictions, unspecified_output.outputs)
|
||||||
|
|
||||||
|
def testEstimatorSpec_export_multihead(self):
|
||||||
|
predictions = self.create_predictions()
|
||||||
|
output_alternatives = {
|
||||||
|
"regression_head": (
|
||||||
|
constants.ProblemType.LINEAR_REGRESSION, predictions),
|
||||||
|
"classification_head": (
|
||||||
|
constants.ProblemType.CLASSIFICATION, predictions)}
|
||||||
|
model_fn_ops = self.create_model_fn_ops(predictions, output_alternatives)
|
||||||
|
|
||||||
|
estimator_spec = model_fn_ops.estimator_spec(model_fn.ModeKeys.INFER,
|
||||||
|
"regression_head")
|
||||||
|
self.assertEquals_except_export(model_fn_ops, estimator_spec)
|
||||||
|
|
||||||
|
with session.Session():
|
||||||
|
regression_output = estimator_spec.export_outputs["regression_head"]
|
||||||
|
self.assertTrue(isinstance(
|
||||||
|
regression_output, core_export_lib.RegressionOutput))
|
||||||
|
self.assertAllEqual(predictions["scores"].eval(),
|
||||||
|
regression_output.value.eval())
|
||||||
|
|
||||||
|
default_output = estimator_spec.export_outputs[
|
||||||
|
signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
|
||||||
|
self.assertTrue(isinstance(default_output,
|
||||||
|
core_export_lib.RegressionOutput))
|
||||||
|
self.assertAllEqual(predictions["scores"].eval(),
|
||||||
|
default_output.value.eval())
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test.main()
|
@ -66,8 +66,8 @@ def _get_single_cell(cell_type, num_units):
|
|||||||
ValueError: `cell_type` is an invalid `RNNCell` name.
|
ValueError: `cell_type` is an invalid `RNNCell` name.
|
||||||
TypeError: `cell_type` is not a string or a subclass of `RNNCell`.
|
TypeError: `cell_type` is not a string or a subclass of `RNNCell`.
|
||||||
"""
|
"""
|
||||||
cell_type = _CELL_TYPES.get(cell_type)
|
cell_type = _CELL_TYPES.get(cell_type, cell_type)
|
||||||
if cell_type is None and not issubclass(cell_type, contrib_rnn.RNNCell):
|
if not cell_type or not issubclass(cell_type, contrib_rnn.RNNCell):
|
||||||
raise ValueError('The supported cell types are {}; got {}'.format(
|
raise ValueError('The supported cell types are {}; got {}'.format(
|
||||||
list(_CELL_TYPES.keys()), cell_type))
|
list(_CELL_TYPES.keys()), cell_type))
|
||||||
return cell_type(num_units=num_units)
|
return cell_type(num_units=num_units)
|
||||||
|
@ -97,7 +97,8 @@ class Experiment(object):
|
|||||||
finite number of batches (generally, 1 epoch over the evaluation data).
|
finite number of batches (generally, 1 epoch over the evaluation data).
|
||||||
eval_metrics: `dict` of string, metric function. If `None`, default set
|
eval_metrics: `dict` of string, metric function. If `None`, default set
|
||||||
is used. This should be `None` if the `estimator` is
|
is used. This should be `None` if the `estimator` is
|
||||||
${tf.estimator.Estimator}.
|
${tf.estimator.Estimator}. If metrics are provided they will be
|
||||||
|
*appended* to the default set.
|
||||||
train_steps: Perform this many steps of training. `None`, the default,
|
train_steps: Perform this many steps of training. `None`, the default,
|
||||||
means train forever.
|
means train forever.
|
||||||
eval_steps: `evaluate` runs until input is exhausted (or another exception
|
eval_steps: `evaluate` runs until input is exhausted (or another exception
|
||||||
|
@ -45,7 +45,7 @@ class SquareLinearOperatorFullMatrixTest(
|
|||||||
# values are random and we want the same value used for both mat and
|
# values are random and we want the same value used for both mat and
|
||||||
# feed_dict.
|
# feed_dict.
|
||||||
matrix = matrix.eval()
|
matrix = matrix.eval()
|
||||||
operator = linalg.LinearOperatorFullMatrix(matrix)
|
operator = linalg.LinearOperatorFullMatrix(matrix_ph)
|
||||||
feed_dict = {matrix_ph: matrix}
|
feed_dict = {matrix_ph: matrix}
|
||||||
else:
|
else:
|
||||||
operator = linalg.LinearOperatorFullMatrix(matrix)
|
operator = linalg.LinearOperatorFullMatrix(matrix)
|
||||||
@ -105,7 +105,7 @@ class SquareLinearOperatorFullMatrixSymmetricPositiveDefiniteTest(
|
|||||||
# feed_dict.
|
# feed_dict.
|
||||||
matrix = matrix.eval()
|
matrix = matrix.eval()
|
||||||
operator = linalg.LinearOperatorFullMatrix(
|
operator = linalg.LinearOperatorFullMatrix(
|
||||||
matrix, is_self_adjoint=True, is_positive_definite=True)
|
matrix_ph, is_self_adjoint=True, is_positive_definite=True)
|
||||||
feed_dict = {matrix_ph: matrix}
|
feed_dict = {matrix_ph: matrix}
|
||||||
else:
|
else:
|
||||||
operator = linalg.LinearOperatorFullMatrix(
|
operator = linalg.LinearOperatorFullMatrix(
|
||||||
@ -144,7 +144,7 @@ class NonSquareLinearOperatorFullMatrixTest(
|
|||||||
# values are random and we want the same value used for both mat and
|
# values are random and we want the same value used for both mat and
|
||||||
# feed_dict.
|
# feed_dict.
|
||||||
matrix = matrix.eval()
|
matrix = matrix.eval()
|
||||||
operator = linalg.LinearOperatorFullMatrix(matrix)
|
operator = linalg.LinearOperatorFullMatrix(matrix_ph)
|
||||||
feed_dict = {matrix_ph: matrix}
|
feed_dict = {matrix_ph: matrix}
|
||||||
else:
|
else:
|
||||||
operator = linalg.LinearOperatorFullMatrix(matrix)
|
operator = linalg.LinearOperatorFullMatrix(matrix)
|
||||||
|
@ -12,13 +12,25 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
"""Ops for nccl AllReduce."""
|
"""Functions for using NVIDIA nccl collective ops.
|
||||||
|
|
||||||
|
@@all_max
|
||||||
|
@@all_min
|
||||||
|
@@all_prod
|
||||||
|
@@all_sum
|
||||||
|
@@broadcast
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
# go/tf-wildcard-import
|
from tensorflow.contrib.nccl.python.ops.nccl_ops import all_max
|
||||||
# pylint: disable=wildcard-import
|
from tensorflow.contrib.nccl.python.ops.nccl_ops import all_min
|
||||||
from tensorflow.contrib.nccl.python.ops.nccl_ops import *
|
from tensorflow.contrib.nccl.python.ops.nccl_ops import all_prod
|
||||||
# pylint: enable=wildcard-import
|
from tensorflow.contrib.nccl.python.ops.nccl_ops import all_sum
|
||||||
|
from tensorflow.contrib.nccl.python.ops.nccl_ops import broadcast
|
||||||
|
|
||||||
|
from tensorflow.python.util.all_util import remove_undocumented
|
||||||
|
remove_undocumented(__name__)
|
||||||
|
@ -66,7 +66,7 @@ class RNNCellTest(test.TestCase):
|
|||||||
x = array_ops.zeros([batch_size, input_size])
|
x = array_ops.zeros([batch_size, input_size])
|
||||||
m = array_ops.zeros([batch_size, state_size])
|
m = array_ops.zeros([batch_size, state_size])
|
||||||
output, state = rnn_cell.CoupledInputForgetGateLSTMCell(
|
output, state = rnn_cell.CoupledInputForgetGateLSTMCell(
|
||||||
num_units=num_units, forget_bias=1.0)(x, m)
|
num_units=num_units, forget_bias=1.0, state_is_tuple=False)(x, m)
|
||||||
sess.run([variables.global_variables_initializer()])
|
sess.run([variables.global_variables_initializer()])
|
||||||
res = sess.run([output, state], {
|
res = sess.run([output, state], {
|
||||||
x.name:
|
x.name:
|
||||||
|
@ -466,12 +466,13 @@ class OutputProjectionWrapper(RNNCell):
|
|||||||
if needed or directly feed into a softmax.
|
if needed or directly feed into a softmax.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, cell, output_size, reuse=None):
|
def __init__(self, cell, output_size, activation=None, reuse=None):
|
||||||
"""Create a cell with output projection.
|
"""Create a cell with output projection.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cell: an RNNCell, a projection to output_size is added to it.
|
cell: an RNNCell, a projection to output_size is added to it.
|
||||||
output_size: integer, the size of the output after projection.
|
output_size: integer, the size of the output after projection.
|
||||||
|
activation: (optional) an optional activation function.
|
||||||
reuse: (optional) Python boolean describing whether to reuse variables
|
reuse: (optional) Python boolean describing whether to reuse variables
|
||||||
in an existing scope. If not `True`, and the existing scope already has
|
in an existing scope. If not `True`, and the existing scope already has
|
||||||
the given variables, an error is raised.
|
the given variables, an error is raised.
|
||||||
@ -487,6 +488,7 @@ class OutputProjectionWrapper(RNNCell):
|
|||||||
self._cell = cell
|
self._cell = cell
|
||||||
self._output_size = output_size
|
self._output_size = output_size
|
||||||
self._reuse = reuse
|
self._reuse = reuse
|
||||||
|
self._activation = activation
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def state_size(self):
|
def state_size(self):
|
||||||
@ -507,6 +509,8 @@ class OutputProjectionWrapper(RNNCell):
|
|||||||
with _checked_scope(self, scope or "output_projection_wrapper",
|
with _checked_scope(self, scope or "output_projection_wrapper",
|
||||||
reuse=self._reuse):
|
reuse=self._reuse):
|
||||||
projected = _linear(output, self._output_size, True)
|
projected = _linear(output, self._output_size, True)
|
||||||
|
if self._activation:
|
||||||
|
projected = self._activation(projected)
|
||||||
return projected, res_state
|
return projected, res_state
|
||||||
|
|
||||||
|
|
||||||
@ -518,12 +522,13 @@ class InputProjectionWrapper(RNNCell):
|
|||||||
do the projection on this batch-concatenated sequence, then split it.
|
do the projection on this batch-concatenated sequence, then split it.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, cell, num_proj, input_size=None):
|
def __init__(self, cell, num_proj, activation=None, input_size=None):
|
||||||
"""Create a cell with input projection.
|
"""Create a cell with input projection.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cell: an RNNCell, a projection of inputs is added before it.
|
cell: an RNNCell, a projection of inputs is added before it.
|
||||||
num_proj: Python integer. The dimension to project to.
|
num_proj: Python integer. The dimension to project to.
|
||||||
|
activation: (optional) an optional activation function.
|
||||||
input_size: Deprecated and unused.
|
input_size: Deprecated and unused.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
@ -535,6 +540,7 @@ class InputProjectionWrapper(RNNCell):
|
|||||||
raise TypeError("The parameter cell is not RNNCell.")
|
raise TypeError("The parameter cell is not RNNCell.")
|
||||||
self._cell = cell
|
self._cell = cell
|
||||||
self._num_proj = num_proj
|
self._num_proj = num_proj
|
||||||
|
self._activation = activation
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def state_size(self):
|
def state_size(self):
|
||||||
@ -553,6 +559,8 @@ class InputProjectionWrapper(RNNCell):
|
|||||||
# Default scope: "InputProjectionWrapper"
|
# Default scope: "InputProjectionWrapper"
|
||||||
with vs.variable_scope(scope or "input_projection_wrapper"):
|
with vs.variable_scope(scope or "input_projection_wrapper"):
|
||||||
projected = _linear(inputs, self._num_proj, True)
|
projected = _linear(inputs, self._num_proj, True)
|
||||||
|
if self._activation:
|
||||||
|
projected = self._activation(projected)
|
||||||
return self._cell(projected, state)
|
return self._cell(projected, state)
|
||||||
|
|
||||||
|
|
||||||
|
@ -109,7 +109,7 @@ class CoupledInputForgetGateLSTMCell(core_rnn_cell.RNNCell):
|
|||||||
def __init__(self, num_units, use_peepholes=False,
|
def __init__(self, num_units, use_peepholes=False,
|
||||||
initializer=None, num_proj=None, proj_clip=None,
|
initializer=None, num_proj=None, proj_clip=None,
|
||||||
num_unit_shards=1, num_proj_shards=1,
|
num_unit_shards=1, num_proj_shards=1,
|
||||||
forget_bias=1.0, state_is_tuple=False,
|
forget_bias=1.0, state_is_tuple=True,
|
||||||
activation=math_ops.tanh, reuse=None):
|
activation=math_ops.tanh, reuse=None):
|
||||||
"""Initialize the parameters for an LSTM cell.
|
"""Initialize the parameters for an LSTM cell.
|
||||||
|
|
||||||
@ -457,7 +457,7 @@ class GridLSTMCell(core_rnn_cell.RNNCell):
|
|||||||
start_freqindex_list=None,
|
start_freqindex_list=None,
|
||||||
end_freqindex_list=None,
|
end_freqindex_list=None,
|
||||||
couple_input_forget_gates=False,
|
couple_input_forget_gates=False,
|
||||||
state_is_tuple=False,
|
state_is_tuple=True,
|
||||||
reuse=None):
|
reuse=None):
|
||||||
"""Initialize the parameters for an LSTM cell.
|
"""Initialize the parameters for an LSTM cell.
|
||||||
|
|
||||||
@ -571,7 +571,7 @@ class GridLSTMCell(core_rnn_cell.RNNCell):
|
|||||||
ValueError: if an input_size was specified and the provided inputs have
|
ValueError: if an input_size was specified and the provided inputs have
|
||||||
a different dimension.
|
a different dimension.
|
||||||
"""
|
"""
|
||||||
batch_size = int(inputs.get_shape()[0])
|
batch_size = inputs.shape[0].value or array_ops.shape(inputs)[0]
|
||||||
freq_inputs = self._make_tf_features(inputs)
|
freq_inputs = self._make_tf_features(inputs)
|
||||||
with _checked_scope(self, scope or "grid_lstm_cell",
|
with _checked_scope(self, scope or "grid_lstm_cell",
|
||||||
initializer=self._initializer, reuse=self._reuse):
|
initializer=self._initializer, reuse=self._reuse):
|
||||||
@ -994,7 +994,7 @@ class BidirectionalGridLSTMCell(GridLSTMCell):
|
|||||||
ValueError: if an input_size was specified and the provided inputs have
|
ValueError: if an input_size was specified and the provided inputs have
|
||||||
a different dimension.
|
a different dimension.
|
||||||
"""
|
"""
|
||||||
batch_size = int(inputs.get_shape()[0])
|
batch_size = inputs.shape[0].value or array_ops.shape(inputs)[0]
|
||||||
fwd_inputs = self._make_tf_features(inputs)
|
fwd_inputs = self._make_tf_features(inputs)
|
||||||
if self._backward_slice_offset:
|
if self._backward_slice_offset:
|
||||||
bwd_inputs = self._make_tf_features(inputs, self._backward_slice_offset)
|
bwd_inputs = self._make_tf_features(inputs, self._backward_slice_offset)
|
||||||
@ -1043,7 +1043,7 @@ class AttentionCellWrapper(core_rnn_cell.RNNCell):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, cell, attn_length, attn_size=None, attn_vec_size=None,
|
def __init__(self, cell, attn_length, attn_size=None, attn_vec_size=None,
|
||||||
input_size=None, state_is_tuple=False, reuse=None):
|
input_size=None, state_is_tuple=True, reuse=None):
|
||||||
"""Create a cell with attention.
|
"""Create a cell with attention.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -56,6 +56,13 @@ class AttentionWrapperTest(test.TestCase):
|
|||||||
return super(AttentionWrapperTest, self).assertAllClose(
|
return super(AttentionWrapperTest, self).assertAllClose(
|
||||||
*args, **kwargs)
|
*args, **kwargs)
|
||||||
|
|
||||||
|
def testAttentionWrapperState(self):
|
||||||
|
num_fields = len(wrapper.AttentionWrapperState._fields) # pylint: disable=protected-access
|
||||||
|
state = wrapper.AttentionWrapperState(*([None] * num_fields))
|
||||||
|
new_state = state.clone(time=1)
|
||||||
|
self.assertEqual(state.time, None)
|
||||||
|
self.assertEqual(new_state.time, 1)
|
||||||
|
|
||||||
def _testWithAttention(self,
|
def _testWithAttention(self,
|
||||||
create_attention_mechanism,
|
create_attention_mechanism,
|
||||||
expected_final_output,
|
expected_final_output,
|
||||||
|
@ -369,7 +369,26 @@ class AttentionWrapperState(
|
|||||||
- `attention_history`: (if enabled) a `TensorArray` containing attention
|
- `attention_history`: (if enabled) a `TensorArray` containing attention
|
||||||
matrices from all time steps. Call `stack()` to convert to a `Tensor`.
|
matrices from all time steps. Call `stack()` to convert to a `Tensor`.
|
||||||
"""
|
"""
|
||||||
pass
|
|
||||||
|
def clone(self, **kwargs):
|
||||||
|
"""Clone this object, overriding components provided by kwargs.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
initial_state = attention_wrapper.zero_state(dtype=..., batch_size=...)
|
||||||
|
initial_state = initial_state.clone(cell_state=encoder_state)
|
||||||
|
```
|
||||||
|
|
||||||
|
Args:
|
||||||
|
**kwargs: Any properties of the state object to replace in the returned
|
||||||
|
`AttentionWrapperState`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A new `AttentionWrapperState` whose properties are the same as
|
||||||
|
this one, except any overriden properties as provided in `kwargs`.
|
||||||
|
"""
|
||||||
|
return super(AttentionWrapperState, self)._replace(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
def hardmax(logits, name=None):
|
def hardmax(logits, name=None):
|
||||||
|
@ -431,7 +431,7 @@ class ScheduledOutputTrainingHelper(TrainingHelper):
|
|||||||
shape=base_shape))
|
shape=base_shape))
|
||||||
|
|
||||||
all_finished = math_ops.reduce_all(finished)
|
all_finished = math_ops.reduce_all(finished)
|
||||||
no_samples = math_ops.equal(array_ops.shape(sample_ids)[0], 0)
|
no_samples = math_ops.logical_not(math_ops.reduce_any(sample_ids))
|
||||||
next_inputs = control_flow_ops.cond(
|
next_inputs = control_flow_ops.cond(
|
||||||
math_ops.logical_or(all_finished, no_samples),
|
math_ops.logical_or(all_finished, no_samples),
|
||||||
lambda: base_next_inputs, maybe_sample)
|
lambda: base_next_inputs, maybe_sample)
|
||||||
|
@ -31,6 +31,8 @@ from tensorflow.python.ops import control_flow_ops
|
|||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import state_ops
|
from tensorflow.python.ops import state_ops
|
||||||
from tensorflow.python.platform import tf_logging as logging
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
|
from tensorflow.python.training import basic_session_run_hooks
|
||||||
|
from tensorflow.python.training import monitored_session
|
||||||
from tensorflow.python.training import session_run_hook
|
from tensorflow.python.training import session_run_hook
|
||||||
|
|
||||||
|
|
||||||
@ -95,6 +97,22 @@ class TensorForestLossHook(session_run_hook.SessionRunHook):
|
|||||||
run_context.request_stop()
|
run_context.request_stop()
|
||||||
|
|
||||||
|
|
||||||
|
class EveryCheckpointPreSaveListener(
|
||||||
|
basic_session_run_hooks.CheckpointSaverListener):
|
||||||
|
"""Runs a given op before each checkpoint save."""
|
||||||
|
|
||||||
|
def __init__(self, op):
|
||||||
|
"""Initializes the object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
op: An op to run before each checkpoint save.
|
||||||
|
"""
|
||||||
|
self._op = op
|
||||||
|
|
||||||
|
def before_save(self, session, global_step_value):
|
||||||
|
session.run(self._op)
|
||||||
|
|
||||||
|
|
||||||
def get_model_fn(params,
|
def get_model_fn(params,
|
||||||
graph_builder_class,
|
graph_builder_class,
|
||||||
device_assigner,
|
device_assigner,
|
||||||
@ -103,6 +121,7 @@ def get_model_fn(params,
|
|||||||
num_trainers=1,
|
num_trainers=1,
|
||||||
trainer_id=0,
|
trainer_id=0,
|
||||||
report_feature_importances=False,
|
report_feature_importances=False,
|
||||||
|
model_dir=None,
|
||||||
local_eval=False):
|
local_eval=False):
|
||||||
"""Return a model function given a way to construct a graph builder."""
|
"""Return a model function given a way to construct a graph builder."""
|
||||||
def _model_fn(features, labels, mode):
|
def _model_fn(features, labels, mode):
|
||||||
@ -138,6 +157,8 @@ def get_model_fn(params,
|
|||||||
# question of why we force everything to adhere to a single model_fn).
|
# question of why we force everything to adhere to a single model_fn).
|
||||||
loss_deps = []
|
loss_deps = []
|
||||||
training_graph = None
|
training_graph = None
|
||||||
|
training_hooks = []
|
||||||
|
scaffold = None
|
||||||
if labels is not None and mode == model_fn_lib.ModeKeys.TRAIN:
|
if labels is not None and mode == model_fn_lib.ModeKeys.TRAIN:
|
||||||
training_graph = control_flow_ops.group(
|
training_graph = control_flow_ops.group(
|
||||||
graph_builder.training_graph(
|
graph_builder.training_graph(
|
||||||
@ -146,6 +167,15 @@ def get_model_fn(params,
|
|||||||
trainer_id=trainer_id),
|
trainer_id=trainer_id),
|
||||||
state_ops.assign_add(contrib_framework.get_global_step(), 1))
|
state_ops.assign_add(contrib_framework.get_global_step(), 1))
|
||||||
loss_deps.append(training_graph)
|
loss_deps.append(training_graph)
|
||||||
|
if hasattr(graph_builder, 'finalize_training'):
|
||||||
|
finalize_listener = EveryCheckpointPreSaveListener(
|
||||||
|
graph_builder.finalize_training())
|
||||||
|
scaffold = monitored_session.Scaffold()
|
||||||
|
training_hooks.append(
|
||||||
|
basic_session_run_hooks.CheckpointSaverHook(
|
||||||
|
model_dir, save_secs=600, save_steps=None,
|
||||||
|
scaffold=scaffold,
|
||||||
|
listeners=[finalize_listener]))
|
||||||
|
|
||||||
training_loss = None
|
training_loss = None
|
||||||
if (mode == model_fn_lib.ModeKeys.EVAL or
|
if (mode == model_fn_lib.ModeKeys.EVAL or
|
||||||
@ -158,7 +188,6 @@ def get_model_fn(params,
|
|||||||
if weights is not None:
|
if weights is not None:
|
||||||
features[weights_name] = weights
|
features[weights_name] = weights
|
||||||
|
|
||||||
training_hooks = []
|
|
||||||
if early_stopping_rounds:
|
if early_stopping_rounds:
|
||||||
training_hooks.append(TensorForestLossHook(early_stopping_rounds))
|
training_hooks.append(TensorForestLossHook(early_stopping_rounds))
|
||||||
|
|
||||||
@ -167,7 +196,9 @@ def get_model_fn(params,
|
|||||||
predictions=inference,
|
predictions=inference,
|
||||||
loss=training_loss,
|
loss=training_loss,
|
||||||
train_op=training_graph,
|
train_op=training_graph,
|
||||||
training_hooks=training_hooks)
|
training_hooks=training_hooks,
|
||||||
|
scaffold=scaffold)
|
||||||
|
|
||||||
return _model_fn
|
return _model_fn
|
||||||
|
|
||||||
|
|
||||||
@ -257,6 +288,7 @@ class TensorForestEstimator(estimator.Estimator):
|
|||||||
num_trainers=num_trainers,
|
num_trainers=num_trainers,
|
||||||
trainer_id=trainer_id,
|
trainer_id=trainer_id,
|
||||||
report_feature_importances=report_feature_importances,
|
report_feature_importances=report_feature_importances,
|
||||||
|
model_dir=model_dir,
|
||||||
local_eval=local_eval),
|
local_eval=local_eval),
|
||||||
model_dir=model_dir,
|
model_dir=model_dir,
|
||||||
config=config,
|
config=config,
|
||||||
|
@ -43,9 +43,9 @@ py_library(
|
|||||||
srcs = ["plugins/projector/__init__.py"],
|
srcs = ["plugins/projector/__init__.py"],
|
||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
deps = [
|
deps = [
|
||||||
":protos_all_py",
|
|
||||||
"//tensorflow/python:lib",
|
"//tensorflow/python:lib",
|
||||||
"//tensorflow/tensorboard/plugins/projector:projector_plugin",
|
"//tensorflow/tensorboard/plugins/projector:projector_plugin",
|
||||||
|
"//tensorflow/tensorboard/plugins/projector:protos_all_py",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -56,10 +56,10 @@ py_test(
|
|||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
deps = [
|
deps = [
|
||||||
":projector",
|
":projector",
|
||||||
":protos_all_py",
|
|
||||||
"//tensorflow/python:client_testlib",
|
"//tensorflow/python:client_testlib",
|
||||||
"//tensorflow/python:platform",
|
"//tensorflow/python:platform",
|
||||||
"//tensorflow/python:summary",
|
"//tensorflow/python:summary",
|
||||||
|
"//tensorflow/tensorboard/plugins/projector:protos_all_py",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -28,11 +28,10 @@ from __future__ import print_function
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
from google.protobuf import text_format
|
from google.protobuf import text_format
|
||||||
from tensorflow.contrib.tensorboard.plugins.projector.projector_config_pb2 import EmbeddingInfo
|
|
||||||
from tensorflow.contrib.tensorboard.plugins.projector.projector_config_pb2 import ProjectorConfig
|
|
||||||
from tensorflow.python.lib.io import file_io
|
from tensorflow.python.lib.io import file_io
|
||||||
from tensorflow.tensorboard.plugins.projector import projector_plugin
|
from tensorflow.tensorboard.plugins.projector import projector_plugin
|
||||||
# pylint: disable=wildcard-import
|
# pylint: disable=wildcard-import
|
||||||
|
from tensorflow.tensorboard.plugins.projector.projector_config_pb2 import *
|
||||||
from tensorflow.tensorboard.plugins.projector.projector_plugin import *
|
from tensorflow.tensorboard.plugins.projector.projector_plugin import *
|
||||||
# pylint: enable=wildcard-import
|
# pylint: enable=wildcard-import
|
||||||
|
|
||||||
|
@ -24,10 +24,10 @@ import shutil
|
|||||||
from google.protobuf import text_format
|
from google.protobuf import text_format
|
||||||
|
|
||||||
from tensorflow.contrib.tensorboard.plugins import projector
|
from tensorflow.contrib.tensorboard.plugins import projector
|
||||||
from tensorflow.contrib.tensorboard.plugins.projector import projector_config_pb2
|
|
||||||
from tensorflow.python.platform import gfile
|
from tensorflow.python.platform import gfile
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
from tensorflow.python.summary.writer import writer as writer_lib
|
from tensorflow.python.summary.writer import writer as writer_lib
|
||||||
|
from tensorflow.tensorboard.plugins.projector import projector_config_pb2
|
||||||
|
|
||||||
|
|
||||||
class ProjectorApiTest(test.TestCase):
|
class ProjectorApiTest(test.TestCase):
|
||||||
|
62
tensorflow/contrib/xla_tf_graph/BUILD
Normal file
62
tensorflow/contrib/xla_tf_graph/BUILD
Normal file
@ -0,0 +1,62 @@
|
|||||||
|
# Description:
|
||||||
|
# contains parts of TensorFlow that are experimental or unstable and which are not supported.
|
||||||
|
|
||||||
|
package(
|
||||||
|
default_visibility = ["//visibility:public"],
|
||||||
|
)
|
||||||
|
|
||||||
|
licenses(["notice"]) # Apache 2.0
|
||||||
|
|
||||||
|
exports_files(["LICENSE"])
|
||||||
|
|
||||||
|
filegroup(
|
||||||
|
name = "all_files",
|
||||||
|
srcs = glob(
|
||||||
|
["**/*"],
|
||||||
|
exclude = [
|
||||||
|
"**/METADATA",
|
||||||
|
"**/OWNERS",
|
||||||
|
],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "xla_tf_graph_util",
|
||||||
|
srcs = [
|
||||||
|
"xla_tf_graph_util.cc",
|
||||||
|
],
|
||||||
|
hdrs = [
|
||||||
|
"xla_tf_graph_util.h",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||||
|
"//tensorflow/compiler/xla:status_macros",
|
||||||
|
"//tensorflow/compiler/xla/client",
|
||||||
|
"//tensorflow/compiler/xla/client:client_library",
|
||||||
|
"//tensorflow/core:core_cpu",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "xla_tf_graph_util_test",
|
||||||
|
srcs = ["xla_tf_graph_util_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":xla_tf_graph_util",
|
||||||
|
"//tensorflow/cc:cc_ops",
|
||||||
|
"//tensorflow/cc:function_ops",
|
||||||
|
"//tensorflow/cc:scope",
|
||||||
|
"//tensorflow/compiler/jit:xla_cpu_jit",
|
||||||
|
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||||
|
"//tensorflow/compiler/xla:status_macros",
|
||||||
|
"//tensorflow/compiler/xla/client:client_library",
|
||||||
|
"//tensorflow/core:core_cpu_internal",
|
||||||
|
"//tensorflow/core:framework_internal",
|
||||||
|
"//tensorflow/core:ops",
|
||||||
|
"//tensorflow/core:tensorflow",
|
||||||
|
"//tensorflow/core:test",
|
||||||
|
"//tensorflow/core:test_main",
|
||||||
|
"//tensorflow/core/kernels:cwise_op",
|
||||||
|
],
|
||||||
|
)
|
8
tensorflow/contrib/xla_tf_graph/README.md
Normal file
8
tensorflow/contrib/xla_tf_graph/README.md
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
# Xla Tf Graph
|
||||||
|
|
||||||
|
## Description
|
||||||
|
|
||||||
|
This module contains utilities to treat xla representation as tf graph to support mobile SOC experiments and leverage tf tools.
|
||||||
|
|
||||||
|
Maintainers:
|
||||||
|
- Satoshi Kataoka (satok@google.com, github.com/satok16)
|
71
tensorflow/contrib/xla_tf_graph/xla_tf_graph_util.cc
Normal file
71
tensorflow/contrib/xla_tf_graph/xla_tf_graph_util.cc
Normal file
@ -0,0 +1,71 @@
|
|||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/contrib/xla_tf_graph/xla_tf_graph_util.h"
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||||
|
#include "tensorflow/compiler/xla/client/client_library.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace xla_tf_graph {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
constexpr const char* const GRAPH_NAME = "xla_tf_graph_util";
|
||||||
|
|
||||||
|
void SetupXlaCpuClient(std::unique_ptr<FunctionLibraryDefinition>* flib_def,
|
||||||
|
std::unique_ptr<FunctionLibraryRuntime>* flr,
|
||||||
|
std::unique_ptr<XlaCompiler>* compiler) {
|
||||||
|
xla::Client* client = xla::ClientLibrary::LocalClientOrDie();
|
||||||
|
XlaOpRegistry::RegisterCompilationKernels();
|
||||||
|
|
||||||
|
FunctionDefLibrary flib;
|
||||||
|
flib_def->reset(new FunctionLibraryDefinition(OpRegistry::Global(), flib));
|
||||||
|
|
||||||
|
// Setup compiler options
|
||||||
|
XlaCompiler::Options options;
|
||||||
|
options.device_type = DeviceType(DEVICE_CPU_XLA_JIT);
|
||||||
|
options.client = client;
|
||||||
|
compiler->reset(new XlaCompiler(options));
|
||||||
|
|
||||||
|
flr->reset(NewFunctionLibraryRuntime(
|
||||||
|
compiler->get()->device_mgr(), /*env=*/nullptr, compiler->get()->device(),
|
||||||
|
TF_GRAPH_DEF_VERSION, flib_def->get(), OptimizerOptions(),
|
||||||
|
/*custom_kernel_creator=*/nullptr));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
xla::StatusOr<std::unique_ptr<xla::SessionModule>>
|
||||||
|
ConvertTfGraphToXlaSessionModule(const std::vector<XlaCompiler::Argument>& args,
|
||||||
|
std::unique_ptr<Graph> graph) {
|
||||||
|
CHECK(graph);
|
||||||
|
|
||||||
|
std::unique_ptr<FunctionLibraryDefinition> flib_def;
|
||||||
|
std::unique_ptr<FunctionLibraryRuntime> flr;
|
||||||
|
std::unique_ptr<XlaCompiler> compiler;
|
||||||
|
|
||||||
|
SetupXlaCpuClient(&flib_def, &flr, &compiler);
|
||||||
|
|
||||||
|
// Compile graph and build computation
|
||||||
|
XlaCompiler::CompilationResult result;
|
||||||
|
TF_CHECK_OK(compiler->CompileGraph(GRAPH_NAME, std::move(graph), flr.get(),
|
||||||
|
args, &result));
|
||||||
|
|
||||||
|
return result.computation.Snapshot();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace xla_tf_graph
|
||||||
|
} // namespace tensorflow
|
43
tensorflow/contrib/xla_tf_graph/xla_tf_graph_util.h
Normal file
43
tensorflow/contrib/xla_tf_graph/xla_tf_graph_util.h
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_CONTRIB_XLA_TF_GRAPH_XLA_TF_GRAPH_UTIL_H_
|
||||||
|
#define TENSORFLOW_CONTRIB_XLA_TF_GRAPH_XLA_TF_GRAPH_UTIL_H_
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
|
||||||
|
#include "tensorflow/compiler/xla/client/client.h"
|
||||||
|
#include "tensorflow/compiler/xla/status_macros.h"
|
||||||
|
#include "tensorflow/core/framework/function.h"
|
||||||
|
#include "tensorflow/core/graph/graph.h"
|
||||||
|
#include "tensorflow/core/platform/macros.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace xla_tf_graph {
|
||||||
|
|
||||||
|
// A set of utilities to handle xla computation requests.
|
||||||
|
// These utilities help developers leverage existing tools to work with
|
||||||
|
// xla computations, also provide a way to support TensorFlow ops by
|
||||||
|
// implementing xla computations so that they can do experiments on their
|
||||||
|
// specialized environments.
|
||||||
|
|
||||||
|
// Convert a tf graph to a xla session module
|
||||||
|
xla::StatusOr<std::unique_ptr<xla::SessionModule>>
|
||||||
|
ConvertTfGraphToXlaSessionModule(const std::vector<XlaCompiler::Argument>& args,
|
||||||
|
std::unique_ptr<Graph> graph);
|
||||||
|
|
||||||
|
} // namespace xla_tf_graph
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_CONTRIB_XLA_TF_GRAPH_XLA_TF_GRAPH_UTIL_H_
|
57
tensorflow/contrib/xla_tf_graph/xla_tf_graph_util_test.cc
Normal file
57
tensorflow/contrib/xla_tf_graph/xla_tf_graph_util_test.cc
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/contrib/xla_tf_graph/xla_tf_graph_util.h"
|
||||||
|
#include "tensorflow/cc/framework/scope.h"
|
||||||
|
#include "tensorflow/cc/ops/function_ops.h"
|
||||||
|
#include "tensorflow/cc/ops/standard_ops.h"
|
||||||
|
#include "tensorflow/compiler/xla/client/client_library.h"
|
||||||
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace xla_tf_graph {
|
||||||
|
|
||||||
|
static std::unique_ptr<Graph> BuildAddGraph() {
|
||||||
|
Scope scope = Scope::NewRootScope().ExitOnError();
|
||||||
|
auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
|
||||||
|
auto b = ops::_Arg(scope.WithOpName("B"), DT_INT32, 1);
|
||||||
|
auto c = ops::Add(scope.WithOpName("C"), a, b);
|
||||||
|
auto d = ops::_Retval(scope.WithOpName("D"), c, 0);
|
||||||
|
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
||||||
|
TF_CHECK_OK(scope.ToGraph(graph.get()));
|
||||||
|
return graph;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(XlaTfGraphUtil, ConvertTfGraphToHloModule) {
|
||||||
|
// Builds a description of the arguments.
|
||||||
|
std::vector<XlaCompiler::Argument> args(2);
|
||||||
|
args[0].kind = XlaCompiler::Argument::kParameter;
|
||||||
|
args[0].type = DT_INT32;
|
||||||
|
args[0].shape = TensorShape({2});
|
||||||
|
args[1].kind = XlaCompiler::Argument::kParameter;
|
||||||
|
args[1].type = DT_INT32;
|
||||||
|
args[1].shape = TensorShape({2});
|
||||||
|
|
||||||
|
std::unique_ptr<Graph> graph = BuildAddGraph();
|
||||||
|
|
||||||
|
TF_ASSIGN_OR_ASSERT_OK(
|
||||||
|
std::unique_ptr<xla::SessionModule> session_module,
|
||||||
|
ConvertTfGraphToXlaSessionModule(args, std::move(graph)));
|
||||||
|
|
||||||
|
ASSERT_EQ(5, session_module->entry().requests_size());
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace xla_tf_graph
|
||||||
|
} // namespace tensorflow
|
@ -453,8 +453,8 @@ void BFCAllocator::RemoveFreeChunkIterFromBin(
|
|||||||
void BFCAllocator::RemoveFreeChunkFromBin(BFCAllocator::ChunkHandle h) {
|
void BFCAllocator::RemoveFreeChunkFromBin(BFCAllocator::ChunkHandle h) {
|
||||||
Chunk* c = ChunkFromHandle(h);
|
Chunk* c = ChunkFromHandle(h);
|
||||||
CHECK(!c->in_use() && (c->bin_num != kInvalidBinNum));
|
CHECK(!c->in_use() && (c->bin_num != kInvalidBinNum));
|
||||||
int count = BinFromIndex(c->bin_num)->free_chunks.erase(h);
|
CHECK_GT(BinFromIndex(c->bin_num)->free_chunks.erase(h), 0)
|
||||||
CHECK(count > 0) << "Could not find chunk in bin";
|
<< "Could not find chunk in bin";
|
||||||
c->bin_num = kInvalidBinNum;
|
c->bin_num = kInvalidBinNum;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -78,7 +78,7 @@ class BFCAllocator : public VisitableAllocator {
|
|||||||
|
|
||||||
// A ChunkHandle is an index into the chunks_ vector in BFCAllocator
|
// A ChunkHandle is an index into the chunks_ vector in BFCAllocator
|
||||||
// kInvalidChunkHandle means an invalid chunk
|
// kInvalidChunkHandle means an invalid chunk
|
||||||
typedef int ChunkHandle;
|
typedef size_t ChunkHandle;
|
||||||
static const int kInvalidChunkHandle = -1;
|
static const int kInvalidChunkHandle = -1;
|
||||||
|
|
||||||
typedef int BinNum;
|
typedef int BinNum;
|
||||||
|
@ -34,6 +34,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/graph/node_builder.h"
|
#include "tensorflow/core/graph/node_builder.h"
|
||||||
#include "tensorflow/core/graph/subgraph.h"
|
#include "tensorflow/core/graph/subgraph.h"
|
||||||
#include "tensorflow/core/lib/core/threadpool.h"
|
#include "tensorflow/core/lib/core/threadpool.h"
|
||||||
|
#include "tensorflow/core/lib/gtl/cleanup.h"
|
||||||
#include "tensorflow/core/lib/strings/strcat.h"
|
#include "tensorflow/core/lib/strings/strcat.h"
|
||||||
#include "tensorflow/core/public/session_options.h"
|
#include "tensorflow/core/public/session_options.h"
|
||||||
|
|
||||||
@ -304,10 +305,18 @@ Status DoConstantFoldingWithStatus(const ConstantFoldingOptions& opts,
|
|||||||
tensors_to_replace.push_back({n.second, n.first.second});
|
tensors_to_replace.push_back({n.second, n.first.second});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auto graph_runner = std::unique_ptr<GraphRunner>(new GraphRunner(env));
|
||||||
// Evaluate the constant foldable nodes.
|
// Evaluate the constant foldable nodes.
|
||||||
std::vector<Tensor> outputs;
|
std::vector<Tensor> outputs;
|
||||||
Status s = GraphRunner::Run(constant_graph.get(), function_library, env,
|
auto delete_tensors = gtl::MakeCleanup([&graph_runner, &outputs] {
|
||||||
{} /* inputs*/, tensors_to_fetch_names, &outputs);
|
// Output tensors need to be cleared before the GraphRunner is deleted.
|
||||||
|
outputs.clear();
|
||||||
|
graph_runner.reset(nullptr);
|
||||||
|
});
|
||||||
|
|
||||||
|
Status s =
|
||||||
|
graph_runner->Run(constant_graph.get(), function_library, {} /* inputs*/,
|
||||||
|
tensors_to_fetch_names, &outputs);
|
||||||
if (!s.ok()) {
|
if (!s.ok()) {
|
||||||
VLOG(1) << "Could not fetch constants: " << s;
|
VLOG(1) << "Could not fetch constants: " << s;
|
||||||
*was_mutated = false;
|
*was_mutated = false;
|
||||||
|
@ -44,7 +44,7 @@ DeviceMgr::~DeviceMgr() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
StringPiece DeviceMgr::CopyToBackingStore(StringPiece s) {
|
StringPiece DeviceMgr::CopyToBackingStore(StringPiece s) {
|
||||||
int n = s.size();
|
size_t n = s.size();
|
||||||
char* space = name_backing_store_.Alloc(n);
|
char* space = name_backing_store_.Alloc(n);
|
||||||
memcpy(space, s.data(), n);
|
memcpy(space, s.data(), n);
|
||||||
return StringPiece(space, n);
|
return StringPiece(space, n);
|
||||||
|
@ -427,7 +427,7 @@ Status DirectSession::Run(const RunOptions& run_options,
|
|||||||
TF_RETURN_IF_ERROR(SendInputs(inputs, executors_and_keys, run_state.rendez));
|
TF_RETURN_IF_ERROR(SendInputs(inputs, executors_and_keys, run_state.rendez));
|
||||||
|
|
||||||
// Start parallel Executors.
|
// Start parallel Executors.
|
||||||
const int num_executors = executors_and_keys->items.size();
|
const size_t num_executors = executors_and_keys->items.size();
|
||||||
ExecutorBarrier* barrier = new ExecutorBarrier(
|
ExecutorBarrier* barrier = new ExecutorBarrier(
|
||||||
num_executors, run_state.rendez, [&run_state](const Status& ret) {
|
num_executors, run_state.rendez, [&run_state](const Status& ret) {
|
||||||
{
|
{
|
||||||
@ -458,7 +458,7 @@ Status DirectSession::Run(const RunOptions& run_options,
|
|||||||
options_.config.graph_options().build_cost_model();
|
options_.config.graph_options().build_cost_model();
|
||||||
const int64 build_cost_model_after =
|
const int64 build_cost_model_after =
|
||||||
options_.config.graph_options().build_cost_model_after();
|
options_.config.graph_options().build_cost_model_after();
|
||||||
int measure_step_count = executor_step_count - build_cost_model_after;
|
int64 measure_step_count = executor_step_count - build_cost_model_after;
|
||||||
if (measure_step_count >= 0) {
|
if (measure_step_count >= 0) {
|
||||||
update_cost_model =
|
update_cost_model =
|
||||||
((measure_step_count + 1) % build_cost_model_every == 0);
|
((measure_step_count + 1) % build_cost_model_every == 0);
|
||||||
@ -611,7 +611,7 @@ Status DirectSession::PRunSetup(const std::vector<string>& input_names,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Start parallel Executors.
|
// Start parallel Executors.
|
||||||
const int num_executors = executors_and_keys->items.size();
|
const size_t num_executors = executors_and_keys->items.size();
|
||||||
ExecutorBarrier* barrier = new ExecutorBarrier(
|
ExecutorBarrier* barrier = new ExecutorBarrier(
|
||||||
num_executors, run_state->rendez, [run_state](const Status& ret) {
|
num_executors, run_state->rendez, [run_state](const Status& ret) {
|
||||||
if (!ret.ok()) {
|
if (!ret.ok()) {
|
||||||
|
@ -232,7 +232,7 @@ struct NodeItem {
|
|||||||
int input_start = 0;
|
int input_start = 0;
|
||||||
|
|
||||||
// Number of output edges.
|
// Number of output edges.
|
||||||
int num_output_edges;
|
size_t num_output_edges;
|
||||||
|
|
||||||
PendingCounts::Handle pending_id;
|
PendingCounts::Handle pending_id;
|
||||||
|
|
||||||
@ -307,7 +307,7 @@ class GraphView {
|
|||||||
void Initialize(const Graph* g);
|
void Initialize(const Graph* g);
|
||||||
Status SetAllocAttrs(const Graph* g, const Device* device);
|
Status SetAllocAttrs(const Graph* g, const Device* device);
|
||||||
|
|
||||||
NodeItem* node(int id) const {
|
NodeItem* node(size_t id) const {
|
||||||
DCHECK_GE(id, 0);
|
DCHECK_GE(id, 0);
|
||||||
DCHECK_LT(id, num_nodes_);
|
DCHECK_LT(id, num_nodes_);
|
||||||
uint32 offset = node_offsets_[id];
|
uint32 offset = node_offsets_[id];
|
||||||
@ -454,7 +454,7 @@ GraphView::~GraphView() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
size_t GraphView::NodeItemBytes(const Node* n) {
|
size_t GraphView::NodeItemBytes(const Node* n) {
|
||||||
const int num_output_edges = n->out_edges().size();
|
const size_t num_output_edges = n->out_edges().size();
|
||||||
const int num_inputs = n->num_inputs();
|
const int num_inputs = n->num_inputs();
|
||||||
const int num_outputs = n->num_outputs();
|
const int num_outputs = n->num_outputs();
|
||||||
|
|
||||||
@ -500,11 +500,11 @@ char* GraphView::InitializeNode(char* ptr, const Node* n) {
|
|||||||
// pointers). Casting to int64 is needed on 32bit CPU to avoid comparing
|
// pointers). Casting to int64 is needed on 32bit CPU to avoid comparing
|
||||||
// values as "int" vs "size_t" in CHECK_LE.
|
// values as "int" vs "size_t" in CHECK_LE.
|
||||||
CHECK_LE(static_cast<int64>(ptr - space_), kuint32max);
|
CHECK_LE(static_cast<int64>(ptr - space_), kuint32max);
|
||||||
const uint32 offset = ptr - space_;
|
const uint32 offset = static_cast<uint32>(ptr - space_);
|
||||||
node_offsets_[id] = offset;
|
node_offsets_[id] = offset;
|
||||||
ptr += bytes;
|
ptr += bytes;
|
||||||
|
|
||||||
const int num_output_edges = n->out_edges().size();
|
const size_t num_output_edges = n->out_edges().size();
|
||||||
const int num_inputs = n->num_inputs();
|
const int num_inputs = n->num_inputs();
|
||||||
const int num_outputs = n->num_outputs();
|
const int num_outputs = n->num_outputs();
|
||||||
|
|
||||||
@ -580,9 +580,10 @@ void GraphView::Initialize(const Graph* g) {
|
|||||||
CHECK_EQ(ptr, space_ + total_bytes);
|
CHECK_EQ(ptr, space_ + total_bytes);
|
||||||
}
|
}
|
||||||
|
|
||||||
void GetMaxPendingCounts(const Node* n, int* max_pending, int* max_dead_count) {
|
void GetMaxPendingCounts(const Node* n, size_t* max_pending,
|
||||||
const int num_in_edges = n->in_edges().size();
|
size_t* max_dead_count) {
|
||||||
int initial_count;
|
const size_t num_in_edges = n->in_edges().size();
|
||||||
|
size_t initial_count;
|
||||||
if (IsMerge(n)) {
|
if (IsMerge(n)) {
|
||||||
// merge waits all control inputs so we initialize the pending
|
// merge waits all control inputs so we initialize the pending
|
||||||
// count to be the number of control edges.
|
// count to be the number of control edges.
|
||||||
@ -626,8 +627,7 @@ Status ExecutorImpl::Initialize() {
|
|||||||
FrameInfo* frame_info = EnsureFrameInfo(frame_name);
|
FrameInfo* frame_info = EnsureFrameInfo(frame_name);
|
||||||
|
|
||||||
// See if this node is a root node, and if so, add to root_nodes_.
|
// See if this node is a root node, and if so, add to root_nodes_.
|
||||||
const int num_in_edges = n->in_edges().size();
|
if (n->in_edges().empty()) {
|
||||||
if (num_in_edges == 0) {
|
|
||||||
root_nodes_.push_back(n);
|
root_nodes_.push_back(n);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -659,7 +659,7 @@ Status ExecutorImpl::Initialize() {
|
|||||||
// pending counts data structure, and allocate a handle in
|
// pending counts data structure, and allocate a handle in
|
||||||
// that frame's pending counts data structure that has enough
|
// that frame's pending counts data structure that has enough
|
||||||
// space to store these maximal count values.
|
// space to store these maximal count values.
|
||||||
int max_pending, max_dead;
|
size_t max_pending, max_dead;
|
||||||
GetMaxPendingCounts(n, &max_pending, &max_dead);
|
GetMaxPendingCounts(n, &max_pending, &max_dead);
|
||||||
item->pending_id =
|
item->pending_id =
|
||||||
frame_info->pending_counts_layout.CreateHandle(max_pending, max_dead);
|
frame_info->pending_counts_layout.CreateHandle(max_pending, max_dead);
|
||||||
@ -896,7 +896,7 @@ class ExecutorState {
|
|||||||
Entry* input_tensors;
|
Entry* input_tensors;
|
||||||
|
|
||||||
// The number of outstanding ops for each iteration.
|
// The number of outstanding ops for each iteration.
|
||||||
int outstanding_ops;
|
size_t outstanding_ops;
|
||||||
|
|
||||||
// The number of outstanding frames for each iteration.
|
// The number of outstanding frames for each iteration.
|
||||||
int outstanding_frame_count;
|
int outstanding_frame_count;
|
||||||
@ -1037,13 +1037,13 @@ class ExecutorState {
|
|||||||
|
|
||||||
inline IterationState* GetIteration(int64 iter)
|
inline IterationState* GetIteration(int64 iter)
|
||||||
EXCLUSIVE_LOCKS_REQUIRED(mu) {
|
EXCLUSIVE_LOCKS_REQUIRED(mu) {
|
||||||
int index = iter % iterations.size();
|
size_t index = iter % iterations.size();
|
||||||
return iterations[index];
|
return iterations[index];
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void SetIteration(int64 iter, IterationState* state)
|
inline void SetIteration(int64 iter, IterationState* state)
|
||||||
EXCLUSIVE_LOCKS_REQUIRED(mu) {
|
EXCLUSIVE_LOCKS_REQUIRED(mu) {
|
||||||
int index = iter % iterations.size();
|
size_t index = iter % iterations.size();
|
||||||
DCHECK(state == nullptr || iterations[index] == nullptr);
|
DCHECK(state == nullptr || iterations[index] == nullptr);
|
||||||
iterations[index] = state;
|
iterations[index] = state;
|
||||||
}
|
}
|
||||||
@ -1404,7 +1404,7 @@ void ExecutorImpl::InitializePending(const Graph* graph,
|
|||||||
for (const Node* n : graph->nodes()) {
|
for (const Node* n : graph->nodes()) {
|
||||||
const int id = n->id();
|
const int id = n->id();
|
||||||
const string& name = cf_info.frame_names[id];
|
const string& name = cf_info.frame_names[id];
|
||||||
int max_pending, max_dead;
|
size_t max_pending, max_dead;
|
||||||
GetMaxPendingCounts(n, &max_pending, &max_dead);
|
GetMaxPendingCounts(n, &max_pending, &max_dead);
|
||||||
const NodeItem* item = gview_.node(id);
|
const NodeItem* item = gview_.node(id);
|
||||||
PendingCounts* counts = EnsureFrameInfo(name)->pending_counts;
|
PendingCounts* counts = EnsureFrameInfo(name)->pending_counts;
|
||||||
@ -2027,7 +2027,7 @@ bool ExecutorState::NodeDone(const Status& s, const Node* node,
|
|||||||
}
|
}
|
||||||
|
|
||||||
bool completed = false;
|
bool completed = false;
|
||||||
int ready_size = ready.size();
|
size_t ready_size = ready.size();
|
||||||
if (ready_size == 0 || !s.ok()) {
|
if (ready_size == 0 || !s.ok()) {
|
||||||
completed = (num_outstanding_ops_.fetch_sub(1) == 1);
|
completed = (num_outstanding_ops_.fetch_sub(1) == 1);
|
||||||
} else if (ready_size > 1) {
|
} else if (ready_size > 1) {
|
||||||
@ -2375,10 +2375,10 @@ void ExecutorState::FrameState::ActivateNodes(const NodeItem* item,
|
|||||||
TaggedNodeSeq* ready) {
|
TaggedNodeSeq* ready) {
|
||||||
const GraphView& gview = executor->gview_;
|
const GraphView& gview = executor->gview_;
|
||||||
IterationState* iter_state = GetIteration(iter);
|
IterationState* iter_state = GetIteration(iter);
|
||||||
const int num_output_edges = item->num_output_edges;
|
const size_t num_output_edges = item->num_output_edges;
|
||||||
const EdgeInfo* edges = item->output_edge_list();
|
const EdgeInfo* edges = item->output_edge_list();
|
||||||
Entry* input_tensors = iter_state->input_tensors;
|
Entry* input_tensors = iter_state->input_tensors;
|
||||||
for (int out_index = 0; out_index < num_output_edges; out_index++) {
|
for (size_t out_index = 0; out_index < num_output_edges; out_index++) {
|
||||||
const EdgeInfo& e = edges[out_index];
|
const EdgeInfo& e = edges[out_index];
|
||||||
const int dst_id = e.dst_id;
|
const int dst_id = e.dst_id;
|
||||||
const NodeItem* dst_item = gview.node(dst_id);
|
const NodeItem* dst_item = gview.node(dst_id);
|
||||||
|
@ -162,7 +162,7 @@ class ExecutorBarrier {
|
|||||||
//
|
//
|
||||||
// 'done' is called after the last executor completes, and
|
// 'done' is called after the last executor completes, and
|
||||||
// ExecutorBarrier is deleted.
|
// ExecutorBarrier is deleted.
|
||||||
ExecutorBarrier(int num, Rendezvous* r, StatusCallback done)
|
ExecutorBarrier(size_t num, Rendezvous* r, StatusCallback done)
|
||||||
: rendez_(r), done_cb_(done), pending_(num) {}
|
: rendez_(r), done_cb_(done), pending_(num) {}
|
||||||
|
|
||||||
~ExecutorBarrier() {}
|
~ExecutorBarrier() {}
|
||||||
|
@ -274,8 +274,9 @@ class CallOp : public AsyncOpKernel {
|
|||||||
if (!status.ok()) {
|
if (!status.ok()) {
|
||||||
ctx->SetStatus(status);
|
ctx->SetStatus(status);
|
||||||
} else {
|
} else {
|
||||||
CHECK_EQ(rets->size(), ctx->num_outputs());
|
const int ret_size = static_cast<int>(rets->size());
|
||||||
for (size_t i = 0; i < rets->size(); ++i) {
|
CHECK_EQ(ret_size, ctx->num_outputs());
|
||||||
|
for (int i = 0; i < ret_size; ++i) {
|
||||||
ctx->set_output(i, (*rets)[i]);
|
ctx->set_output(i, (*rets)[i]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1000,7 +1001,7 @@ string NewName(const Node* n, bool pretty) {
|
|||||||
void ToGraphDef(const Graph* g, GraphDef* gdef, bool pretty) {
|
void ToGraphDef(const Graph* g, GraphDef* gdef, bool pretty) {
|
||||||
// We visit nodes in forward topological sort order, which is a
|
// We visit nodes in forward topological sort order, which is a
|
||||||
// possible execution order of the graph.
|
// possible execution order of the graph.
|
||||||
std::vector<int> pending(g->num_node_ids());
|
std::vector<size_t> pending(g->num_node_ids());
|
||||||
std::deque<const Node*> ready;
|
std::deque<const Node*> ready;
|
||||||
for (const Node* n : g->nodes()) {
|
for (const Node* n : g->nodes()) {
|
||||||
pending[n->id()] = n->in_edges().size();
|
pending[n->id()] = n->in_edges().size();
|
||||||
@ -1154,7 +1155,7 @@ FunctionBody* SymbolicGradientHelper::Compute() {
|
|||||||
|
|
||||||
Graph* g = gbody_->graph;
|
Graph* g = gbody_->graph;
|
||||||
|
|
||||||
const int num_y = gbody_->ret_nodes.size();
|
const int num_y = static_cast<int>(gbody_->ret_nodes.size());
|
||||||
|
|
||||||
// Populate 'y_node_outputs_' with node function body outputs.
|
// Populate 'y_node_outputs_' with node function body outputs.
|
||||||
// Populate 'y_grad_nodes' with initial gradient nodes for each return node of
|
// Populate 'y_grad_nodes' with initial gradient nodes for each return node of
|
||||||
@ -1169,7 +1170,7 @@ FunctionBody* SymbolicGradientHelper::Compute() {
|
|||||||
y_node_outputs.push_back({y, 0});
|
y_node_outputs.push_back({y, 0});
|
||||||
DCHECK_EQ(y->type_string(), kRetOp);
|
DCHECK_EQ(y->type_string(), kRetOp);
|
||||||
const DataType dtype = y->input_type(0);
|
const DataType dtype = y->input_type(0);
|
||||||
const int index = gbody_->arg_nodes.size();
|
const int index = static_cast<int>(gbody_->arg_nodes.size());
|
||||||
Node* dy = AddArg(g, dtype, index);
|
Node* dy = AddArg(g, dtype, index);
|
||||||
gbody_->arg_types.push_back(dtype);
|
gbody_->arg_types.push_back(dtype);
|
||||||
gbody_->arg_nodes.push_back(dy);
|
gbody_->arg_nodes.push_back(dy);
|
||||||
@ -1177,7 +1178,7 @@ FunctionBody* SymbolicGradientHelper::Compute() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Populate 'x_nodes' with function args (excluding 'y_grad_node_outputs').
|
// Populate 'x_nodes' with function args (excluding 'y_grad_node_outputs').
|
||||||
const int num_x = fbody_->arg_nodes.size();
|
const size_t num_x = fbody_->arg_nodes.size();
|
||||||
std::vector<NodeOut> x_node_outputs;
|
std::vector<NodeOut> x_node_outputs;
|
||||||
x_node_outputs.reserve(num_x);
|
x_node_outputs.reserve(num_x);
|
||||||
for (size_t i = 0; i < fbody_->arg_nodes.size(); ++i) {
|
for (size_t i = 0; i < fbody_->arg_nodes.size(); ++i) {
|
||||||
@ -1200,7 +1201,8 @@ FunctionBody* SymbolicGradientHelper::Compute() {
|
|||||||
gbody_->ret_nodes.clear();
|
gbody_->ret_nodes.clear();
|
||||||
// Add new return nodes to the function gradient body for each node
|
// Add new return nodes to the function gradient body for each node
|
||||||
// in 'x_grad_nodes'.
|
// in 'x_grad_nodes'.
|
||||||
for (size_t i = 0; i < fbody_->arg_types.size(); ++i) {
|
const int arg_types_size = static_cast<int>(fbody_->arg_types.size());
|
||||||
|
for (int i = 0; i < arg_types_size; ++i) {
|
||||||
Endpoint grad = {x_grad_node_outputs[i].node, x_grad_node_outputs[i].index};
|
Endpoint grad = {x_grad_node_outputs[i].node, x_grad_node_outputs[i].index};
|
||||||
Node* ret = AddRet(g, grad, i);
|
Node* ret = AddRet(g, grad, i);
|
||||||
gbody_->ret_nodes.push_back(ret);
|
gbody_->ret_nodes.push_back(ret);
|
||||||
|
@ -82,7 +82,7 @@ Status AssignStreams(const Graph* graph, const AssignStreamsOpts& opts,
|
|||||||
// Determine a suitable stream to use.
|
// Determine a suitable stream to use.
|
||||||
int stream_id = highest_stream_id + 1;
|
int stream_id = highest_stream_id + 1;
|
||||||
for (const Edge* e : n->in_edges()) {
|
for (const Edge* e : n->in_edges()) {
|
||||||
const int fanout = e->src()->out_edges().size();
|
const size_t fanout = e->src()->out_edges().size();
|
||||||
if (fanout == 1) {
|
if (fanout == 1) {
|
||||||
stream_id = (*node_to_stream_id)[e->src()->id()];
|
stream_id = (*node_to_stream_id)[e->src()->id()];
|
||||||
break;
|
break;
|
||||||
|
@ -191,7 +191,7 @@ Allocator* ProcessState::GetCUDAHostAllocator(int numa_node) {
|
|||||||
// example, process_state could maybe save the first stream executor
|
// example, process_state could maybe save the first stream executor
|
||||||
// it knows is valid.
|
// it knows is valid.
|
||||||
gpu::StreamExecutor* se = nullptr;
|
gpu::StreamExecutor* se = nullptr;
|
||||||
for (size_t i = 0; i < gpu_allocators_.size(); ++i) {
|
for (int i = 0; i < static_cast<int>(gpu_allocators_.size()); ++i) {
|
||||||
if (gpu_allocators_[i] != nullptr) {
|
if (gpu_allocators_[i] != nullptr) {
|
||||||
se = GPUMachineManager()->ExecutorForDevice(i).ValueOrDie();
|
se = GPUMachineManager()->ExecutorForDevice(i).ValueOrDie();
|
||||||
break;
|
break;
|
||||||
|
@ -15,7 +15,6 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/core/common_runtime/graph_runner.h"
|
#include "tensorflow/core/common_runtime/graph_runner.h"
|
||||||
|
|
||||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
|
||||||
#include "tensorflow/core/common_runtime/executor.h"
|
#include "tensorflow/core/common_runtime/executor.h"
|
||||||
#include "tensorflow/core/common_runtime/function.h"
|
#include "tensorflow/core/common_runtime/function.h"
|
||||||
#include "tensorflow/core/common_runtime/memory_types.h"
|
#include "tensorflow/core/common_runtime/memory_types.h"
|
||||||
@ -95,22 +94,24 @@ class SimpleRendezvous : public Rendezvous {
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
// static
|
GraphRunner::GraphRunner(Env* env) : cpu_device_(GetCPUDevice(env)) {}
|
||||||
|
|
||||||
|
GraphRunner::~GraphRunner() {}
|
||||||
|
|
||||||
Status GraphRunner::Run(Graph* graph, FunctionLibraryRuntime* function_library,
|
Status GraphRunner::Run(Graph* graph, FunctionLibraryRuntime* function_library,
|
||||||
Env* env, const NamedTensorList& inputs,
|
const NamedTensorList& inputs,
|
||||||
const std::vector<string>& output_names,
|
const std::vector<string>& output_names,
|
||||||
std::vector<Tensor>* outputs) {
|
std::vector<Tensor>* outputs) {
|
||||||
|
if (cpu_device_ == nullptr) {
|
||||||
|
return errors::NotFound("Cannot find a device for GraphRunner.");
|
||||||
|
}
|
||||||
|
|
||||||
// TODO(vrv): Instead of copying the entire graph, consider modifying
|
// TODO(vrv): Instead of copying the entire graph, consider modifying
|
||||||
// the existing graph, and then removing those removed edges.
|
// the existing graph, and then removing those removed edges.
|
||||||
// prior to returning.
|
// prior to returning.
|
||||||
std::unique_ptr<Graph> graph_to_run(new Graph(graph->op_registry()));
|
std::unique_ptr<Graph> graph_to_run(new Graph(graph->op_registry()));
|
||||||
CopyGraph(*graph, graph_to_run.get());
|
CopyGraph(*graph, graph_to_run.get());
|
||||||
|
|
||||||
std::unique_ptr<Device> device = GetCPUDevice(env);
|
|
||||||
if (!device) {
|
|
||||||
return errors::NotFound("Cannot find a device for GraphRunner.");
|
|
||||||
}
|
|
||||||
|
|
||||||
SimpleRendezvous* rendez = new SimpleRendezvous;
|
SimpleRendezvous* rendez = new SimpleRendezvous;
|
||||||
core::ScopedUnref rendez_unref(rendez);
|
core::ScopedUnref rendez_unref(rendez);
|
||||||
|
|
||||||
@ -130,7 +131,7 @@ Status GraphRunner::Run(Graph* graph, FunctionLibraryRuntime* function_library,
|
|||||||
// Call RewriteGraphForExecution
|
// Call RewriteGraphForExecution
|
||||||
TF_RETURN_IF_ERROR(subgraph::RewriteGraphForExecution(
|
TF_RETURN_IF_ERROR(subgraph::RewriteGraphForExecution(
|
||||||
graph_to_run.get(), input_names, output_names, {} /* target nodes */,
|
graph_to_run.get(), input_names, output_names, {} /* target nodes */,
|
||||||
device->attributes()));
|
cpu_device_->attributes()));
|
||||||
|
|
||||||
// Create the local executor and the Rendezvous for fetching back the
|
// Create the local executor and the Rendezvous for fetching back the
|
||||||
// constants.
|
// constants.
|
||||||
@ -143,10 +144,11 @@ Status GraphRunner::Run(Graph* graph, FunctionLibraryRuntime* function_library,
|
|||||||
Graph* g = graph_to_run.release();
|
Graph* g = graph_to_run.release();
|
||||||
|
|
||||||
LocalExecutorParams params;
|
LocalExecutorParams params;
|
||||||
params.device = device.get();
|
// The ownership of the output tensors are bound to this device's lifetime.
|
||||||
|
params.device = cpu_device_.get();
|
||||||
params.function_library = function_library;
|
params.function_library = function_library;
|
||||||
params.create_kernel = [&device, g](const NodeDef& ndef, OpKernel** kernel) {
|
params.create_kernel = [this, g](const NodeDef& ndef, OpKernel** kernel) {
|
||||||
return CreateNonCachedKernel(device.get(), nullptr, ndef,
|
return CreateNonCachedKernel(cpu_device_.get(), nullptr, ndef,
|
||||||
g->versions().producer(), kernel);
|
g->versions().producer(), kernel);
|
||||||
};
|
};
|
||||||
params.delete_kernel = [](OpKernel* kernel) { delete kernel; };
|
params.delete_kernel = [](OpKernel* kernel) { delete kernel; };
|
||||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
|||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||||
#include "tensorflow/core/framework/function.h"
|
#include "tensorflow/core/framework/function.h"
|
||||||
#include "tensorflow/core/framework/tensor.h"
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
#include "tensorflow/core/graph/graph.h"
|
#include "tensorflow/core/graph/graph.h"
|
||||||
@ -44,16 +45,26 @@ namespace tensorflow {
|
|||||||
// to be particularly lightweight, fast, or efficient.
|
// to be particularly lightweight, fast, or efficient.
|
||||||
class GraphRunner {
|
class GraphRunner {
|
||||||
public:
|
public:
|
||||||
|
// REQUIRES: `env` is not nullptr.
|
||||||
|
GraphRunner(Env* env);
|
||||||
|
~GraphRunner();
|
||||||
|
|
||||||
// Function semantics for `inputs`, `output_names` and `outputs`
|
// Function semantics for `inputs`, `output_names` and `outputs`
|
||||||
// matches those from Session::Run().
|
// matches those from Session::Run().
|
||||||
//
|
//
|
||||||
|
// NOTE: The output tensors share lifetime with the GraphRunner, and could
|
||||||
|
// be destroyed once the GraphRunner is destroyed.
|
||||||
|
//
|
||||||
// REQUIRES: `graph`, `env`, and `outputs` are not nullptr.
|
// REQUIRES: `graph`, `env`, and `outputs` are not nullptr.
|
||||||
// `function_library` may be nullptr.
|
// `function_library` may be nullptr.
|
||||||
typedef std::vector<std::pair<string, Tensor>> NamedTensorList;
|
typedef std::vector<std::pair<string, Tensor>> NamedTensorList;
|
||||||
static Status Run(Graph* graph, FunctionLibraryRuntime* function_library,
|
Status Run(Graph* graph, FunctionLibraryRuntime* function_library,
|
||||||
Env* env, const NamedTensorList& inputs,
|
const NamedTensorList& inputs,
|
||||||
const std::vector<string>& output_names,
|
const std::vector<string>& output_names,
|
||||||
std::vector<Tensor>* outputs);
|
std::vector<Tensor>* outputs);
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::unique_ptr<Device> cpu_device_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -46,9 +46,9 @@ using test::internal::ExpectEqual;
|
|||||||
TEST(GraphRunnerTest, SingleConst) {
|
TEST(GraphRunnerTest, SingleConst) {
|
||||||
Scope root = Scope::NewRootScope();
|
Scope root = Scope::NewRootScope();
|
||||||
auto c = ops::Const(root, 42.0f);
|
auto c = ops::Const(root, 42.0f);
|
||||||
|
GraphRunner graph_runner(Env::Default());
|
||||||
std::vector<Tensor> outputs;
|
std::vector<Tensor> outputs;
|
||||||
Status s = GraphRunner::Run(root.graph(), nullptr, Env::Default(), {},
|
Status s = graph_runner.Run(root.graph(), nullptr, {}, {c.name()}, &outputs);
|
||||||
{c.name()}, &outputs);
|
|
||||||
TF_ASSERT_OK(s);
|
TF_ASSERT_OK(s);
|
||||||
ExpectEqual(42.0f, outputs[0].scalar<float>()());
|
ExpectEqual(42.0f, outputs[0].scalar<float>()());
|
||||||
}
|
}
|
||||||
@ -57,9 +57,10 @@ TEST(GraphRunnerTest, MultiFetchConst) {
|
|||||||
Scope root = Scope::NewRootScope();
|
Scope root = Scope::NewRootScope();
|
||||||
auto c = ops::Const(root, 42.0f);
|
auto c = ops::Const(root, 42.0f);
|
||||||
auto pi = ops::Const(root, 3.14f);
|
auto pi = ops::Const(root, 3.14f);
|
||||||
|
GraphRunner graph_runner(Env::Default());
|
||||||
std::vector<Tensor> outputs;
|
std::vector<Tensor> outputs;
|
||||||
Status s = GraphRunner::Run(root.graph(), nullptr, Env::Default(), {},
|
Status s = graph_runner.Run(root.graph(), nullptr, {}, {c.name(), pi.name()},
|
||||||
{c.name(), pi.name()}, &outputs);
|
&outputs);
|
||||||
TF_ASSERT_OK(s);
|
TF_ASSERT_OK(s);
|
||||||
ExpectEqual(42.0f, outputs[0].scalar<float>()());
|
ExpectEqual(42.0f, outputs[0].scalar<float>()());
|
||||||
ExpectEqual(3.14f, outputs[1].scalar<float>()());
|
ExpectEqual(3.14f, outputs[1].scalar<float>()());
|
||||||
@ -78,9 +79,10 @@ TEST(GraphRunnerTest, FeedAndFetch) {
|
|||||||
std::vector<std::pair<string, Tensor>> inputs = {{"p1:0", p1_data},
|
std::vector<std::pair<string, Tensor>> inputs = {{"p1:0", p1_data},
|
||||||
{"p2:0", p2_data}};
|
{"p2:0", p2_data}};
|
||||||
|
|
||||||
|
GraphRunner graph_runner(Env::Default());
|
||||||
std::vector<Tensor> outputs;
|
std::vector<Tensor> outputs;
|
||||||
Status s = GraphRunner::Run(root.graph(), nullptr, Env::Default(), inputs,
|
Status s =
|
||||||
{"add:0"}, &outputs);
|
graph_runner.Run(root.graph(), nullptr, inputs, {"add:0"}, &outputs);
|
||||||
TF_ASSERT_OK(s);
|
TF_ASSERT_OK(s);
|
||||||
ExpectEqual(3.0f, outputs[0].scalar<float>()());
|
ExpectEqual(3.0f, outputs[0].scalar<float>()());
|
||||||
}
|
}
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user