diff --git a/tensorflow/BUILD b/tensorflow/BUILD index 6a70c0e4057..e437987112b 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -202,6 +202,7 @@ filegroup( "//tensorflow/contrib/boosted_trees:all_files", "//tensorflow/contrib/boosted_trees/lib:all_files", "//tensorflow/contrib/boosted_trees/proto:all_files", + "//tensorflow/contrib/boosted_trees/resources:all_files", "//tensorflow/contrib/cloud:all_files", "//tensorflow/contrib/cloud/kernels:all_files", "//tensorflow/contrib/compiler:all_files", @@ -256,6 +257,7 @@ filegroup( "//tensorflow/contrib/tfprof/python/tools/tfprof:all_files", "//tensorflow/contrib/training:all_files", "//tensorflow/contrib/util:all_files", + "//tensorflow/contrib/xla_tf_graph:all_files", "//tensorflow/core:all_files", "//tensorflow/core/debug:all_files", "//tensorflow/core/distributed_runtime:all_files", diff --git a/tensorflow/compiler/aot/tests/BUILD b/tensorflow/compiler/aot/tests/BUILD index ecb071a416c..59d13e53934 100644 --- a/tensorflow/compiler/aot/tests/BUILD +++ b/tensorflow/compiler/aot/tests/BUILD @@ -51,6 +51,7 @@ genrule( "test_graph_tfgather.pb", "test_graph_tfmatmul.pb", "test_graph_tfmatmulandadd.pb", + "test_graph_tffunction.pb", ], cmd = "$(location :make_test_graphs) --out_dir $(@D)", tags = ["manual"], @@ -114,6 +115,15 @@ tf_library( tags = ["manual"], ) +tf_library( + name = "test_graph_tffunction", + testonly = 1, + config = "test_graph_tffunction.config.pbtxt", + cpp_class = "FunctionComp", + graph = "test_graph_tffunction.pb", + tags = ["manual"], +) + cc_test( name = "tfcompile_test", srcs = ["tfcompile_test.cc"], @@ -122,6 +132,7 @@ cc_test( ":test_graph_tfadd", ":test_graph_tfadd_with_ckpt", ":test_graph_tfadd_with_ckpt_saver", + ":test_graph_tffunction", ":test_graph_tfgather", ":test_graph_tfmatmul", ":test_graph_tfmatmulandadd", diff --git a/tensorflow/compiler/aot/tests/make_test_graphs.py b/tensorflow/compiler/aot/tests/make_test_graphs.py index 9279c45f373..98c13958d37 100644 --- a/tensorflow/compiler/aot/tests/make_test_graphs.py +++ b/tensorflow/compiler/aot/tests/make_test_graphs.py @@ -25,6 +25,7 @@ from tensorflow.core.protobuf import saver_pb2 from tensorflow.python.client import session from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import function from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops @@ -95,6 +96,17 @@ def tfmatmulandadd(_): math_ops.add(x, y, name='x_y_sum') +def tffunction(_): + + @function.Defun(dtypes.int32, dtypes.int32) + def test_func(a, b): + return a + b + + x = constant_op.constant([1], name='x_const') + y = constant_op.constant([2], name='y_const') + test_func(x, y, name='func_call') # pylint: disable=unexpected-keyword-arg + + def write_graph(build_graph, out_dir): """Build a graph using build_graph and write it out.""" g = ops.Graph() @@ -112,6 +124,7 @@ def main(_): write_graph(tfgather, FLAGS.out_dir) write_graph(tfmatmul, FLAGS.out_dir) write_graph(tfmatmulandadd, FLAGS.out_dir) + write_graph(tffunction, FLAGS.out_dir) if __name__ == '__main__': @@ -121,7 +134,6 @@ if __name__ == '__main__': '--out_dir', type=str, default='', - help='Output directory for graphs, checkpoints and savers.' - ) + help='Output directory for graphs, checkpoints and savers.') FLAGS, unparsed = parser.parse_known_args() app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/tensorflow/compiler/aot/tests/test_graph_tffunction.config.pbtxt b/tensorflow/compiler/aot/tests/test_graph_tffunction.config.pbtxt new file mode 100644 index 00000000000..eb9c1cacb7f --- /dev/null +++ b/tensorflow/compiler/aot/tests/test_graph_tffunction.config.pbtxt @@ -0,0 +1,16 @@ +# Text form of tensorflow.tfcompile.Config proto. +feed { + id { node_name: "x_const" } + shape { + dim { size: 1 } + } +} +feed { + id { node_name: "y_const" } + shape { + dim { size: 1 } + } +} +fetch { + id { node_name: "func_call" } +} diff --git a/tensorflow/compiler/aot/tests/tfcompile_test.cc b/tensorflow/compiler/aot/tests/tfcompile_test.cc index f57d2859dfa..76343b97521 100644 --- a/tensorflow/compiler/aot/tests/tfcompile_test.cc +++ b/tensorflow/compiler/aot/tests/tfcompile_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/aot/tests/test_graph_tfadd.h" #include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt.h" #include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt_saver.h" +#include "tensorflow/compiler/aot/tests/test_graph_tffunction.h" #include "tensorflow/compiler/aot/tests/test_graph_tfgather.h" #include "tensorflow/compiler/aot/tests/test_graph_tfmatmul.h" #include "tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd.h" @@ -376,6 +377,21 @@ TEST(TFCompileTest, MatMulAndAdd1) { } } +TEST(TFCompileTest, Function) { + // The function is equivalent to an addition + FunctionComp add_fn; + EXPECT_EQ(add_fn.arg0_data(), add_fn.args()[0]); + EXPECT_EQ(add_fn.arg1_data(), add_fn.args()[1]); + + add_fn.arg0() = 1; + add_fn.arg1() = 2; + EXPECT_TRUE(add_fn.Run()); + EXPECT_EQ(add_fn.error_msg(), ""); + EXPECT_EQ(add_fn.result0(), 3); + EXPECT_EQ(add_fn.result0_data()[0], 3); + EXPECT_EQ(add_fn.result0_data(), add_fn.results()[0]); +} + } // namespace } // namespace tfcompile } // namespace tensorflow diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 22dbf7ec99f..b27c07d0d98 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -50,7 +50,7 @@ bool HasXLAKernel(const Node& node, const DeviceType& jit_device_type) { } // Make sure we don't recurse infinitely on recursive functions. -const int kMaxRecursionDepth = 5; +const int kMaxRecursionDepth = 10; bool IsCompilableCall(const NodeDef& call_def, DeviceType jit_device_type, int depth, FunctionLibraryRuntime* lib_runtime); diff --git a/tensorflow/compiler/tests/randomized_tests.cc b/tensorflow/compiler/tests/randomized_tests.cc index c3e8ff724c1..a0cd905f173 100644 --- a/tensorflow/compiler/tests/randomized_tests.cc +++ b/tensorflow/compiler/tests/randomized_tests.cc @@ -2339,6 +2339,14 @@ TEST_F(OpTest, ZerosLike) { }); } +TEST_F(OpTest, OnesLike) { + Repeatedly([this]() { + DataType type = Choose({DT_INT32, DT_FLOAT}); + ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("OnesLike").Input(RandomTensor(type)).Attr("T", type)); + }); +} + } // anonymous namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py index c96826fd0a6..1e85d3a2c8b 100644 --- a/tensorflow/compiler/tests/unary_ops_test.py +++ b/tensorflow/compiler/tests/unary_ops_test.py @@ -257,6 +257,11 @@ class UnaryOpsTest(XLATestCase): np.array([[4, 3], [2, 1]], dtype=dtype), expected=np.array([[0, 0], [0, 0]], dtype=dtype)) + self._assertOpOutputMatchesExpected( + array_ops.ones_like, + np.array([[4, 3], [2, 1]], dtype=dtype), + expected=np.array([[1, 1], [1, 1]], dtype=dtype)) + def testLogicalOps(self): self._assertOpOutputMatchesExpected( math_ops.logical_not, diff --git a/tensorflow/compiler/tf2xla/kernels/shape_op.cc b/tensorflow/compiler/tf2xla/kernels/shape_op.cc index 74e3297dc33..24a99f253d6 100644 --- a/tensorflow/compiler/tf2xla/kernels/shape_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/shape_op.cc @@ -241,5 +241,19 @@ class ZerosLikeOp : public XlaOpKernel { REGISTER_XLA_OP(Name("ZerosLike"), ZerosLikeOp); +class OnesLikeOp : public XlaOpKernel { + public: + explicit OnesLikeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + + void Compile(XlaOpKernelContext* ctx) override { + const TensorShape input_shape = ctx->InputShape(0); + + auto one = XlaHelpers::One(ctx->builder(), input_type(0)); + ctx->SetOutput(0, ctx->builder()->Broadcast(one, input_shape.dim_sizes())); + } +}; + +REGISTER_XLA_OP(Name("OnesLike"), OnesLikeOp); + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index ba975d617dc..33b4a43aa15 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -59,9 +59,15 @@ Status CheckSignature(const DataTypeVector& types, XlaCompiler::XlaCompiler(XlaCompiler::Options options) : options_(std::move(options)), + initialization_status_(Status::OK()), next_step_id_(1), device_(new XlaCompilationDevice(SessionOptions(), options_.device_type)), - device_mgr_({device_}) {} + device_mgr_({device_}) { + if (options_.populate_resource_manager) { + initialization_status_ = + (*options_.populate_resource_manager)(device_->resource_manager()); + } +} XlaCompiler::~XlaCompiler() = default; @@ -379,6 +385,9 @@ Status XlaCompiler::CompileGraph(string const& name, CompilationResult* result) { VLOG(1) << "Executing graph symbolically to populate ComputationBuilder."; + // Report the error here if initialization failed. + TF_RETURN_IF_ERROR(initialization_status_); + xla::ComputationBuilder builder(client(), name); XlaContext* context = new XlaContext(this, &builder, options_.allow_cpu_custom_calls, diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h index 3ed920521b2..3d28ca37460 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.h +++ b/tensorflow/compiler/tf2xla/xla_compiler.h @@ -214,6 +214,12 @@ class XlaCompiler { // This is useful to prune stateful operators that should not be executed // from a function body. bool prune_unreachable_nodes = false; + + // If not nullptr, populate_resource_manager is called with the + // compilation device's resource manager when the compilation + // device is created, and can be used to create metadata objects + // that can be accessed by XLA op kernels. + std::function* populate_resource_manager = nullptr; }; explicit XlaCompiler(Options options); @@ -247,6 +253,7 @@ class XlaCompiler { Status BuildExecutable(const CompilationResult& result, std::unique_ptr* executable); + const Options& options() const { return options_; } xla::Client* client() const { return options_.client; } XlaCompilationDevice* device() const { return device_; } const DeviceMgr* device_mgr() const { return &device_mgr_; } @@ -260,6 +267,9 @@ class XlaCompiler { private: Options options_; + // Status set to non-OK in the constructor if initialization fails. + Status initialization_status_; + // Returns the next step sequence number. int64 NextStepId(); diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc index aa809f85a15..1cc7f4abd15 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc @@ -17,12 +17,14 @@ limitations under the License. #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/ops/function_ops.h" #include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/graph_constructor.h" @@ -33,6 +35,65 @@ limitations under the License. namespace tensorflow { namespace { +// Helper class to test the ability to pass resources through to XLA +// compiled kernels. +class DummyResourceForTest : public ResourceBase { + public: + string DebugString() override { return "dummy"; } + void Increment() { ++value_; } + int Get() { return value_; } + + private: + int value_ = 0; +}; + +class DummyReadResourceOp : public XlaOpKernel { + public: + explicit DummyReadResourceOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + void Compile(XlaOpKernelContext* ctx) override { + ResourceMgr* rm = ctx->op_kernel_context()->resource_manager(); + OP_REQUIRES(ctx, rm, errors::Internal("No resource manager.")); + DummyResourceForTest* dummy; + OP_REQUIRES_OK(ctx, rm->Lookup( + rm->default_container(), "dummy", &dummy)); + dummy->Increment(); + dummy->Unref(); + + ctx->SetOutput(0, ctx->Input(0)); + } +}; + +class DummyReadResourceCC { + public: + DummyReadResourceCC(const Scope& scope, const Input& value) { + if (!scope.ok()) return; + auto _value = ops::AsNodeOut(scope, value); + if (!scope.ok()) return; + Node* ret; + const auto unique_name = scope.GetUniqueNameForOp("DummyReadResource"); + auto builder = NodeBuilder(unique_name, "DummyReadResource").Input(_value); + scope.UpdateBuilder(&builder); + scope.UpdateStatus(builder.Finalize(scope.graph(), &ret)); + if (!scope.ok()) return; + this->output_ = Output(ret, 0); + } + Node* node() const { return output_.node(); } + + Output output_; +}; + +REGISTER_OP("DummyReadResource") + .Input("input: int32") + .Output("output: int32") + .Doc(R"doc( +A dummy Op. + +input: dummy input. +output: dummy output. +)doc"); + +REGISTER_XLA_OP(Name("DummyReadResource"), DummyReadResourceOp); + class XlaCompilerTest : public ::testing::Test { protected: void SetUp() override { @@ -224,5 +285,45 @@ TEST_F(XlaCompilerTest, ConstantOutputs) { } } +// Tests compilation and execution of a graph that adds two tensors. +TEST_F(XlaCompilerTest, ResourceManager) { + // Builds a graph that calls the dummy resource Op. + Scope scope = Scope::NewRootScope().ExitOnError(); + auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0); + auto b = DummyReadResourceCC(scope.WithOpName("B"), a); + auto c = ops::_Retval(scope.WithOpName("C"), b.output_, 0); + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(scope.ToGraph(graph.get())); + + // Builds a description of the argument. + std::vector args(1); + args[0].kind = XlaCompiler::Argument::kParameter; + args[0].type = DT_INT32; + args[0].shape = TensorShape({2}); + + DummyResourceForTest* resource = new DummyResourceForTest(); + + // Compiles the graph. + auto options = DefaultOptions(); + std::function populate_function = + [resource](ResourceMgr* rm) { + resource->Ref(); + return rm->Create(rm->default_container(), "dummy", resource); + }; + options.populate_resource_manager = &populate_function; + XlaCompiler compiler(options); + auto flr = BuildFunctionLibraryRuntime(compiler); + + EXPECT_EQ(0, resource->Get()); + + XlaCompiler::CompilationResult result; + TF_ASSERT_OK(compiler.CompileGraph("dummy", std::move(graph), flr.get(), args, + &result)); + + EXPECT_EQ(1, resource->Get()); + + resource->Unref(); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc index dc5a342bcdd..53dcdec7a25 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc @@ -354,6 +354,10 @@ void XlaOpKernelContext::SetOpHasSideEffects() { XlaContext::Get(context_).AddSideEffects(); } +const XlaCompiler::Options& XlaOpKernelContext::GetCompilerOptions() const { + return XlaContext::Get(context_).compiler()->options(); +} + void XlaOpKernelContext::CtxFailure(Status s) { context_->CtxFailure(s); } void XlaOpKernelContext::CtxFailureWithWarning(Status s) { context_->CtxFailureWithWarning(s); diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h index d214879e3cc..60e3b59d32a 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.h +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_OP_KERNEL_H_ #define TENSORFLOW_COMPILER_TF2XLA_XLA_OP_KERNEL_H_ +#include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/platform/macros.h" @@ -182,6 +183,11 @@ class XlaOpKernelContext { // Returns the underlying OpKernelContext. Use rarely. OpKernelContext* op_kernel_context() const { return context_; } + // Returns the options passed to the XlaCompiler that is being + // run. Used for, e.g., While to inherit options needed for nested + // computation. + const XlaCompiler::Options& GetCompilerOptions() const; + // TODO(phawkins): find a better home for these helpers. // Get an XLA lambda to compute Max. This is cached in the diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.cc b/tensorflow/compiler/tf2xla/xla_op_registry.cc index 5b895bfdf60..13fdfc3b0c8 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.cc +++ b/tensorflow/compiler/tf2xla/xla_op_registry.cc @@ -167,6 +167,8 @@ void XlaOpRegistry::RegisterCompilationKernels() { !backend.second.op_filter(kdef.get())) { continue; } + VLOG(2) << "XLA op registration: device: " << backend.first + << " op: " << op.first; registry.kernel_registrars_.emplace_back( new kernel_factory::OpKernelRegistrar( new KernelDef(*kdef), "XlaJitOp", op.second->factory)); diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index e73a29ddee1..35c0efb8f0f 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -6,6 +6,7 @@ package_group( name = "friends", packages = [ "//tensorflow/compiler/...", + "//tensorflow/contrib/xla_tf_graph/...", ], ) diff --git a/tensorflow/compiler/xla/client/computation_builder.cc b/tensorflow/compiler/xla/client/computation_builder.cc index 88efd87d1cc..22a70681468 100644 --- a/tensorflow/compiler/xla/client/computation_builder.cc +++ b/tensorflow/compiler/xla/client/computation_builder.cc @@ -1229,8 +1229,7 @@ StatusOr ComputationBuilder::IsConstant( VLOG(2) << "done with request"; if (!s.ok()) { - NoteError(s); - return first_error_; + return s; } return response.is_constant(); } @@ -1255,8 +1254,7 @@ StatusOr> ComputationBuilder::ComputeConstant( VLOG(2) << "done with request"; if (!s.ok()) { - NoteError(s); - return first_error_; + return s; } TF_RET_CHECK(response.output().handle() != 0); diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h index ef3cba6fa08..dddcf519749 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.h +++ b/tensorflow/compiler/xla/service/hlo_computation.h @@ -120,6 +120,7 @@ class HloComputation { } const string& name() const { return name_; } + void set_name(const string& name) { name_ = name; } // Return a string representation of the computation. string ToString() const; @@ -257,7 +258,7 @@ class HloComputation { // Internal helper to collect unreachable roots. std::vector CollectUnreachableRoots() const; - const string name_; + string name_; HloInstruction* root_instruction_; // Module containing this computation. diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc index 8fe1897e75c..5fc7d6b22e9 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc @@ -357,7 +357,9 @@ Status HloCostAnalysis::HandleRng(HloInstruction* random, Status HloCostAnalysis::HandleFusion(HloInstruction* fusion) { // Compute the cost of the fused expression. HloInstruction* fused_expression_root = fusion->fused_expression_root(); - HloCostAnalysis visitor(shape_size_); + // Don't compute sizes inside of fused ops. We don't use the size here and the + // operations inside might not have a layout. + HloCostAnalysis visitor([](const Shape&) { return 0; }); TF_RETURN_IF_ERROR(fused_expression_root->Accept(&visitor)); // Attribute the cost of the fused expression to the fusion node. diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc index 9f1c91d41c6..22e782da27d 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc @@ -375,6 +375,33 @@ TEST_F(FusionCostAnalysis, LoopFusion) { EXPECT_EQ(fusion_analysis.transcendental_count(), 4); } +TEST_F(FusionCostAnalysis, NoLayout) { + Shape shape_with_layout = ShapeUtil::MakeShape(F32, {2, 3, 4, 5}); + // Instructions within a fused op may have no layout. + Shape shape_without_layout = shape_with_layout; + shape_without_layout.clear_layout(); + + auto c1 = HloInstruction::CreateConstant( + LiteralUtil::CreateR4FromArray4D(Array4D(2, 3, 4, 5))); + auto c2 = + HloInstruction::CreateConstant(LiteralUtil::CreateR1({1, 2, 3})); + + auto broadcast = + HloInstruction::CreateBroadcast(shape_without_layout, c2.get(), {1}); + auto add = HloInstruction::CreateBinary(shape_with_layout, HloOpcode::kAdd, + c1.get(), broadcast.get()); + + auto fusion = HloInstruction::CreateFusion( + shape_with_layout, HloInstruction::FusionKind::kLoop, add.get()); + fusion->FuseInstruction(broadcast.get()); + + HloCostAnalysis fusion_analysis(ShapeSize); + ASSERT_IS_OK(fusion->Accept(&fusion_analysis)); + + EXPECT_EQ(fusion_analysis.flop_count(), 120); + EXPECT_EQ(fusion_analysis.transcendental_count(), 0); +} + TEST_F(HloCostAnalysisTest, TupleCost) { HloCostAnalysis analysis(ShapeSize); { diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc index 36064e93fe8..8ed672aa9b8 100644 --- a/tensorflow/compiler/xla/service/hlo_module.cc +++ b/tensorflow/compiler/xla/service/hlo_module.cc @@ -31,20 +31,38 @@ limitations under the License. namespace xla { -HloComputation* HloModule::AddEntryComputation( +HloModule::HloModule(const string& name, + const VersionedComputationHandle& entry_computation_handle) + : name_(name), + entry_computation_(nullptr), + has_entry_computation_handle_(true), + entry_computation_handle_(entry_computation_handle), + computation_name_uniquer_(/*separator=*/".") {} + +HloModule::HloModule(const string& name) + : name_(name), + entry_computation_(nullptr), + computation_name_uniquer_(/*separator=*/".") {} + +HloComputation* HloModule::AddComputationInternal( std::unique_ptr computation) { - CHECK_EQ(nullptr, entry_computation_); - entry_computation_ = computation.get(); + computation->set_name( + computation_name_uniquer_.GetUniqueName(computation->name())); computation->set_parent(this); computations_.push_back(std::move(computation)); return computations_.back().get(); } +HloComputation* HloModule::AddEntryComputation( + std::unique_ptr computation) { + CHECK_EQ(nullptr, entry_computation_); + entry_computation_ = computation.get(); + return AddComputationInternal(std::move(computation)); +} + HloComputation* HloModule::AddEmbeddedComputation( std::unique_ptr computation) { - computation->set_parent(this); - computations_.push_back(std::move(computation)); - return computations_.back().get(); + return AddComputationInternal(std::move(computation)); } void HloModule::ReplaceComputations( diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h index d598750da65..1ff5c5dacb8 100644 --- a/tensorflow/compiler/xla/service/hlo_module.h +++ b/tensorflow/compiler/xla/service/hlo_module.h @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/name_uniquer.h" #include "tensorflow/compiler/xla/service/versioned_computation_handle.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/gtl/array_slice.h" @@ -41,19 +42,14 @@ namespace xla { // computations are owned by the module. class HloModule { public: - explicit HloModule(const string& name, - const VersionedComputationHandle& entry_computation_handle) - : name_(name), - entry_computation_(nullptr), - has_entry_computation_handle_(true), - entry_computation_handle_(entry_computation_handle) {} + HloModule(const string& name, + const VersionedComputationHandle& entry_computation_handle); // Constructor without a versioned computation handle. This constructor should // only be used for HloModules used outside of the XLA service (eg // tests). The versioned handle is used by the service in the compilation // cache. - explicit HloModule(const string& name) - : name_(name), entry_computation_(nullptr) {} + explicit HloModule(const string& name); // Adds an entry computation to the module. A module can only have one entry // computation. Returns a pointer to the newly added computation. @@ -111,6 +107,9 @@ class HloModule { uint64 RandomNew64() const; private: + HloComputation* AddComputationInternal( + std::unique_ptr computation); + const string name_; HloComputation* entry_computation_; std::vector> computations_; @@ -125,6 +124,9 @@ class HloModule { // Versioned handle of the entry computation of the module. bool has_entry_computation_handle_ = false; VersionedComputationHandle entry_computation_handle_; + + // Unique name generator for computation names, which are unique per module. + NameUniquer computation_name_uniquer_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_module_test.cc b/tensorflow/compiler/xla/service/hlo_module_test.cc index 0f4252522d3..dba9731e2a0 100644 --- a/tensorflow/compiler/xla/service/hlo_module_test.cc +++ b/tensorflow/compiler/xla/service/hlo_module_test.cc @@ -74,6 +74,11 @@ TEST_F(HloModuleTest, TwoComputationsPostOrder) { EXPECT_MATCH( testing::ListToVec(module->MakeComputationPostOrder()), testing::UnorderedMatcher(computation1, computation2)); + + // We specified the same name for both computations, but the HloModule should + // have made the names unique. + EXPECT_EQ(computation1->name(), "Constant"); + EXPECT_EQ(computation2->name(), "Constant.1"); } TEST_F(HloModuleTest, DiamondComputationsPostOrder) { diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index c05cf8c37d8..9472086e2b4 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -633,26 +633,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( TF_DCHECK_OK(ShapeUtil::ValidateShape(ehs)); switch (operation) { case TRIOP_CLAMP: - TF_RETURN_IF_ERROR( - ExpectNotTupleOrOpaque(lhs, "lhs of ternary operation")); - TF_RETURN_IF_ERROR( - ExpectNotTupleOrOpaque(rhs, "rhs of ternary operation")); - TF_RETURN_IF_ERROR( - ExpectNotTupleOrOpaque(ehs, "ehs of ternary operation")); - if (((ShapeUtil::Compatible(lhs, rhs) || ShapeUtil::Rank(lhs) == 0) && - (ShapeUtil::Compatible(rhs, ehs) || ShapeUtil::Rank(ehs) == 0))) { - return rhs; - } - if (ShapeUtil::Rank(rhs) == 0) { - if (ShapeUtil::Compatible(lhs, ehs)) { - return lhs; - } - return ShapeUtil::Rank(ehs) == 0 ? lhs : ehs; - } - return Unimplemented("not yet implemented: %s, %s %s", - lhs.ShortDebugString().c_str(), - ehs.ShortDebugString().c_str(), - rhs.ShortDebugString().c_str()); + return InferClampShape(lhs, rhs, ehs); case TRIOP_SELECT: return InferSelectShape(lhs, rhs, ehs); case TRIOP_UPDATE: @@ -1332,6 +1313,41 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( return ShapeUtil::PermuteDimensions(InversePermutation(dimensions), operand); } +// TODO(b/36794510): Make broadcast semantics more consistent, by supporting +// "degenerate" cases, as with binary elementwise ops. +/* static */ StatusOr ShapeInference::InferClampShape( + const Shape& min, const Shape& operand, const Shape& max) { + TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(min, "clamp min")); + TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(operand, "clamp operand")); + TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(max, "clamp max")); + if (!ShapeUtil::SameElementType(min, operand) || + !ShapeUtil::SameElementType(max, operand)) { + return InvalidArgument("clamp op with different operand types: %s, %s, %s", + ShapeUtil::HumanString(min).c_str(), + ShapeUtil::HumanString(operand).c_str(), + ShapeUtil::HumanString(max).c_str()); + } + if (((ShapeUtil::Compatible(min, operand) || ShapeUtil::IsScalar(min)) && + (ShapeUtil::Compatible(max, operand) || ShapeUtil::IsScalar(max)))) { + return operand; + } + if (ShapeUtil::IsScalar(operand)) { + if (ShapeUtil::Compatible(min, max)) { + return min; + } else if (ShapeUtil::IsScalar(min)) { + return max; + } else if (ShapeUtil::IsScalar(max)) { + return min; + } + } + return Unimplemented( + "not yet implemented: %s, %s %s", min.ShortDebugString().c_str(), + max.ShortDebugString().c_str(), operand.ShortDebugString().c_str()); +} + +// TODO(b/36794510): Make broadcast semantics more consistent, by supporting +// "degenerate" cases, as with binary elementwise ops, as well as scalar +// broadcast from all operands, not just the predicate. /* static */ StatusOr ShapeInference::InferSelectShape( const Shape& pred, const Shape& on_true, const Shape& on_false) { if (!ShapeUtil::Compatible(on_true, on_false)) { diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h index ced2f4d0017..c2223423e92 100644 --- a/tensorflow/compiler/xla/service/shape_inference.h +++ b/tensorflow/compiler/xla/service/shape_inference.h @@ -190,6 +190,10 @@ class ShapeInference { BinaryOperation operation, const Shape& lhs, const Shape& rhs, tensorflow::gtl::ArraySlice broadcast_dimensions); + // Helper for inferring the shape of Clamp ops. + static StatusOr InferClampShape(const Shape& min, const Shape& operand, + const Shape& max); + // Helper for inferring the shape of Select ops. static StatusOr InferSelectShape(const Shape& pred, const Shape& on_true, diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc index 5a1ae6b0024..6f968ded568 100644 --- a/tensorflow/compiler/xla/service/shape_inference_test.cc +++ b/tensorflow/compiler/xla/service/shape_inference_test.cc @@ -157,6 +157,99 @@ TEST_F(ShapeInferenceTest, SelectBadShapes) { testing::ContainsRegex("pred operand must have PRED element type")); } +TEST_F(ShapeInferenceTest, ClampAllMatrix) { + auto inferred_status = ShapeInference::InferTernaryOpShape( + TernaryOperation::TRIOP_CLAMP, matrix_64_48_, matrix_64_48_, + matrix_64_48_); + ASSERT_IS_OK(inferred_status.status()); + ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie())); +} + +TEST_F(ShapeInferenceTest, ClampAllScalar) { + auto inferred_status = ShapeInference::InferTernaryOpShape( + TernaryOperation::TRIOP_CLAMP, f32_, f32_, f32_); + ASSERT_IS_OK(inferred_status.status()); + ASSERT_TRUE(ShapeUtil::Equal(f32_, inferred_status.ValueOrDie())); +} + +TEST_F(ShapeInferenceTest, ClampMinScalar) { + auto inferred_status = ShapeInference::InferTernaryOpShape( + TernaryOperation::TRIOP_CLAMP, f32_, matrix_64_48_, matrix_64_48_); + ASSERT_IS_OK(inferred_status.status()); + ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie())); +} + +TEST_F(ShapeInferenceTest, ClampMaxScalar) { + auto inferred_status = ShapeInference::InferTernaryOpShape( + TernaryOperation::TRIOP_CLAMP, matrix_64_48_, matrix_64_48_, f32_); + ASSERT_IS_OK(inferred_status.status()); + ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie())); +} + +TEST_F(ShapeInferenceTest, ClampOperandScalar) { + auto inferred_status = ShapeInference::InferTernaryOpShape( + TernaryOperation::TRIOP_CLAMP, matrix_64_48_, f32_, matrix_64_48_); + ASSERT_IS_OK(inferred_status.status()); + ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie())); +} + +TEST_F(ShapeInferenceTest, ClampMinMatrix) { + auto inferred_status = ShapeInference::InferTernaryOpShape( + TernaryOperation::TRIOP_CLAMP, matrix_64_48_, f32_, f32_); + ASSERT_IS_OK(inferred_status.status()); + ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie())); +} + +TEST_F(ShapeInferenceTest, ClampMaxMatrix) { + auto inferred_status = ShapeInference::InferTernaryOpShape( + TernaryOperation::TRIOP_CLAMP, f32_, f32_, matrix_64_48_); + ASSERT_IS_OK(inferred_status.status()); + ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie())); +} + +TEST_F(ShapeInferenceTest, ClampOperandMatrix) { + auto inferred_status = ShapeInference::InferTernaryOpShape( + TernaryOperation::TRIOP_CLAMP, f32_, matrix_64_48_, f32_); + ASSERT_IS_OK(inferred_status.status()); + ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie())); +} + +TEST_F(ShapeInferenceTest, ClampBadShapes) { + // Type mismatch + ASSERT_FALSE(ShapeInference::InferTernaryOpShape( + TernaryOperation::TRIOP_CLAMP, s32_, f32_, f32_) + .ok()); + ASSERT_FALSE(ShapeInference::InferTernaryOpShape( + TernaryOperation::TRIOP_CLAMP, f32_, s32_, f32_) + .ok()); + ASSERT_FALSE(ShapeInference::InferTernaryOpShape( + TernaryOperation::TRIOP_CLAMP, f32_, f32_, s32_) + .ok()); + // Dimension mismatch + ASSERT_FALSE( + ShapeInference::InferTernaryOpShape(TernaryOperation::TRIOP_CLAMP, + vector_64_, vector_32_, vector_32_) + .ok()); + ASSERT_FALSE( + ShapeInference::InferTernaryOpShape(TernaryOperation::TRIOP_CLAMP, + vector_32_, vector_64_, vector_32_) + .ok()); + ASSERT_FALSE( + ShapeInference::InferTernaryOpShape(TernaryOperation::TRIOP_CLAMP, + vector_32_, vector_32_, vector_64_) + .ok()); + // Dimension mismatch, where one operand is a scalar + ASSERT_FALSE(ShapeInference::InferTernaryOpShape( + TernaryOperation::TRIOP_CLAMP, vector_64_, vector_32_, f32_) + .ok()); + ASSERT_FALSE(ShapeInference::InferTernaryOpShape( + TernaryOperation::TRIOP_CLAMP, vector_64_, f32_, vector_32_) + .ok()); + ASSERT_FALSE(ShapeInference::InferTernaryOpShape( + TernaryOperation::TRIOP_CLAMP, f32_, vector_64_, vector_32_) + .ok()); +} + TEST_F(ShapeInferenceTest, VariadicOpTuplify) { StatusOr result = ShapeInference::InferVariadicOpShape( VariadicOperation::VAROP_TUPLE, {&s32_, &f32_}); diff --git a/tensorflow/compiler/xla/tests/scalar_computations_test.cc b/tensorflow/compiler/xla/tests/scalar_computations_test.cc index 134eb91a1fe..0005c0c9e23 100644 --- a/tensorflow/compiler/xla/tests/scalar_computations_test.cc +++ b/tensorflow/compiler/xla/tests/scalar_computations_test.cc @@ -30,6 +30,7 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -245,37 +246,69 @@ XLA_TEST_F(ScalarComputationsTest, RemTwoScalarsF32) { ComputeAndCompareR0(&builder, 2.5f, {}, error_spec_); } -XLA_TEST_F(ScalarComputationsTest, DivideTwoScalarsS32) { - ComputationBuilder builder(client_, TestName()); - builder.Div(builder.ConstantR0(-5), builder.ConstantR0(2)); +struct DivS32Params { + int32 dividend; + int32 divisor; + int32 quotient; + int32 remainder; +}; - ComputeAndCompareR0(&builder, -2, {}); +void PrintTo(const DivS32Params& p, std::ostream* os) { + *os << "{" << p.dividend << ", " << p.divisor << ", " << p.quotient << ", " + << p.remainder << "}"; } -TEST_F(ScalarComputationsTest, RemainderTwoScalarsNegativeResultS32) { - ComputationBuilder builder(client_, TestName()); - builder.Rem(builder.ConstantR0(-5), builder.ConstantR0(2)); +class DivS32Test : public ClientLibraryTestBase, + public ::testing::WithParamInterface {}; - ComputeAndCompareR0(&builder, -1, {}); +XLA_TEST_P(DivS32Test, DivideTwoScalarsS32) { + DivS32Params p = GetParam(); + ComputationBuilder builder(client_, TestName()); + builder.Div(builder.ConstantR0(p.dividend), + builder.ConstantR0(p.divisor)); + + ComputeAndCompareR0(&builder, p.quotient, {}); } -TEST_F(ScalarComputationsTest, RemainderTwoScalarsIntMinS32) { +XLA_TEST_P(DivS32Test, RemainderTwoScalarsS32) { + DivS32Params p = GetParam(); ComputationBuilder builder(client_, TestName()); - builder.Rem(builder.ConstantR0(INT_MIN), - builder.ConstantR0(7919)); + builder.Rem(builder.ConstantR0(p.dividend), + builder.ConstantR0(p.divisor)); - ComputeAndCompareR0(&builder, -1309, {}); + ComputeAndCompareR0(&builder, p.remainder, {}); } -TEST_F(ScalarComputationsTest, RemainderTwoScalarsIntMinVsIntMaxS32) { - ComputationBuilder builder(client_, TestName()); - builder.Rem(builder.ConstantR0(INT_MIN), - builder.ConstantR0(INT_MAX)); +INSTANTIATE_TEST_CASE_P( + DivS32Test_Instantiation, DivS32Test, + ::testing::Values( + // Positive divisors. + DivS32Params{5, 2, 2, 1}, // + DivS32Params{-5, 2, -2, -1}, // + DivS32Params{17, 3, 5, 2}, // + DivS32Params{-17, 3, -5, -2}, // + // Negative divisors. + DivS32Params{5, -2, -2, 1}, // + DivS32Params{-5, -2, 2, -1}, // + DivS32Params{17, -3, -5, 2}, // + DivS32Params{-17, -3, 5, -2}, // + // Large positive divisors. + DivS32Params{INT32_MIN, 7919, -271181, -1309}, // + DivS32Params{INT32_MIN, INT32_MAX, -1, -1}, // + DivS32Params{INT32_MIN + 1, INT32_MAX, -1, 0}, // + DivS32Params{INT32_MIN + 2, INT32_MAX, 0, INT32_MIN + 2}, // + DivS32Params{INT32_MIN, 0x40000000, -2, 0}, // + DivS32Params{INT32_MIN + 1, 0x40000000, -1, -0x3fffffff}, // + // Large negative divisors. + DivS32Params{INT32_MIN, INT32_MIN, 1, 0}, // + DivS32Params{INT32_MIN, INT32_MIN + 1, 1, -1}, // + DivS32Params{INT32_MIN + 1, INT32_MIN, 0, INT32_MIN + 1}, // + DivS32Params{INT32_MAX, INT32_MIN, 0, INT32_MAX}, // + DivS32Params{INT32_MAX, INT32_MIN + 1, -1, 0}, // + DivS32Params{INT32_MIN, -0x40000000, 2, 0}, // + DivS32Params{INT32_MIN + 1, -0x40000000, 1, -0x3fffffff})); - ComputeAndCompareR0(&builder, -1, {}); -} - -TEST_F(ScalarComputationsTest, RemainderTwoScalarsPositiveResultS32) { +TEST_F(ScalarComputationsTest, RemainderTwoScalarsNonConstDividendS32) { ComputationBuilder builder(client_, TestName()); auto x = builder.Parameter(0, ShapeUtil::MakeShape(S32, {}), "x"); builder.Rem(x, builder.ConstantR0(80000)); diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD index 595d8997388..a726471d0fb 100644 --- a/tensorflow/contrib/BUILD +++ b/tensorflow/contrib/BUILD @@ -7,8 +7,6 @@ exports_files(["LICENSE"]) package(default_visibility = ["//tensorflow:__subpackages__"]) -load("//tensorflow:tensorflow.bzl", "if_not_windows") - py_library( name = "contrib_py", srcs = glob(["**/*.py"]), @@ -46,6 +44,7 @@ py_library( "//tensorflow/contrib/losses:losses_py", "//tensorflow/contrib/memory_stats:memory_stats_py", "//tensorflow/contrib/metrics:metrics_py", + "//tensorflow/contrib/nccl:nccl_py", "//tensorflow/contrib/ndlstm", "//tensorflow/contrib/nn:nn_py", "//tensorflow/contrib/opt:opt_py", @@ -65,9 +64,7 @@ py_library( "//tensorflow/contrib/tfprof", "//tensorflow/contrib/training:training_py", "//tensorflow/contrib/util:util_py", - ] + if_not_windows([ - "//tensorflow/contrib/nccl:nccl_py", - ]), + ], ) cc_library( diff --git a/tensorflow/contrib/__init__.py b/tensorflow/contrib/__init__.py index d4ddd1cf6a6..1ce1b4da090 100644 --- a/tensorflow/contrib/__init__.py +++ b/tensorflow/contrib/__init__.py @@ -35,6 +35,7 @@ from tensorflow.contrib import image from tensorflow.contrib import input_pipeline from tensorflow.contrib import integrate from tensorflow.contrib import keras +from tensorflow.contrib import kernel_methods from tensorflow.contrib import labeled_tensor from tensorflow.contrib import layers from tensorflow.contrib import learn @@ -45,6 +46,7 @@ from tensorflow.contrib import lookup from tensorflow.contrib import losses from tensorflow.contrib import memory_stats from tensorflow.contrib import metrics +from tensorflow.contrib import nccl from tensorflow.contrib import nn from tensorflow.contrib import opt from tensorflow.contrib import quantization diff --git a/tensorflow/contrib/boosted_trees/lib/BUILD b/tensorflow/contrib/boosted_trees/lib/BUILD index 714bd324c2a..011c02d720f 100644 --- a/tensorflow/contrib/boosted_trees/lib/BUILD +++ b/tensorflow/contrib/boosted_trees/lib/BUILD @@ -160,3 +160,90 @@ cc_test( "//tensorflow/core:test_main", ], ) + +cc_library( + name = "models", + srcs = ["models/multiple_additive_trees.cc"], + hdrs = ["models/multiple_additive_trees.h"], + deps = [ + ":trees", + ":utils", + "//tensorflow/contrib/boosted_trees/proto:tree_config_proto_cc", + "//tensorflow/core:framework_headers_lib", + ], +) + +cc_test( + name = "multiple_additive_trees_test", + size = "small", + srcs = ["models/multiple_additive_trees_test.cc"], + deps = [ + ":batch_features_testutil", + ":models", + ":random_tree_gen", + "//tensorflow/contrib/boosted_trees/resources:decision_tree_ensemble_resource", + "//tensorflow/core:framework_headers_lib", + "//tensorflow/core:lib", + "//tensorflow/core:tensor_testutil", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +cc_library( + name = "trees", + srcs = ["trees/decision_tree.cc"], + hdrs = ["trees/decision_tree.h"], + deps = [ + ":utils", + "//tensorflow/contrib/boosted_trees/proto:tree_config_proto_cc", + "//tensorflow/core:framework_headers_lib", + ], +) + +cc_test( + name = "trees_test", + size = "small", + srcs = ["trees/decision_tree_test.cc"], + deps = [ + ":trees", + ":utils", + "//tensorflow/core:tensor_testutil", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +cc_library( + name = "batch_features_testutil", + testonly = 1, + srcs = ["testutil/batch_features_testutil.cc"], + hdrs = ["testutil/batch_features_testutil.h"], + deps = [ + ":utils", + "//tensorflow/core:framework_headers_lib", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:testlib", + ], +) + +cc_library( + name = "random_tree_gen", + srcs = ["testutil/random_tree_gen.cc"], + hdrs = ["testutil/random_tree_gen.h"], + deps = [ + "//tensorflow/contrib/boosted_trees/proto:tree_config_proto_cc", + "//tensorflow/core:lib", + ], +) + +cc_binary( + name = "random_tree_gen_main", + srcs = ["testutil/random_tree_gen_main.cc"], + deps = [ + ":random_tree_gen", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + ], +) diff --git a/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.cc b/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.cc new file mode 100644 index 00000000000..16bffd9becc --- /dev/null +++ b/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.cc @@ -0,0 +1,140 @@ +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +#include "tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.h" +#include "tensorflow/contrib/boosted_trees/lib/trees/decision_tree.h" +#include "tensorflow/contrib/boosted_trees/lib/utils/batch_features.h" +#include "tensorflow/contrib/boosted_trees/lib/utils/parallel_for.h" + +namespace tensorflow { +namespace boosted_trees { +namespace models { + +namespace { +void CalculateTreesToKeep( + const boosted_trees::trees::DecisionTreeEnsembleConfig& config, + const std::vector& trees_to_drop, const int32 num_trees, + const bool only_finalized, std::vector* trees_to_keep) { + trees_to_keep->reserve(num_trees - trees_to_drop.size()); + + int32 index = 0; + // This assumes that trees_to_drop is a sorted list of tree ids. + for (int32 tree = 0; tree < num_trees; ++tree) { + if ((!trees_to_drop.empty() && index < trees_to_drop.size() && + trees_to_drop[index] == tree) || + (only_finalized && config.tree_metadata_size() > 0 && + !config.tree_metadata(tree).is_finalized())) { + ++index; + continue; + } + trees_to_keep->push_back(tree); + } +} + +void UpdatePredictions( + const int32 index_1, const int32 index_2, const float value, + tensorflow::TTypes::Matrix* output_predictions, + tensorflow::TTypes::Matrix* additional_output_predictions) { + (*output_predictions)(index_1, index_2) += value; + + if (additional_output_predictions != nullptr) { + (*additional_output_predictions)(index_1, index_2) += value; + } +} + +void UpdatePredictionsBasedOnTree( + const boosted_trees::trees::DecisionTreeEnsembleConfig& config, + const int32 tree_idx, const boosted_trees::utils::Example& example, + tensorflow::TTypes::Matrix* output_predictions, + tensorflow::TTypes::Matrix* additional_output_predictions) { + const boosted_trees::trees::DecisionTreeConfig& tree = config.trees(tree_idx); + const float tree_weight = config.tree_weights(tree_idx); + const int leaf_idx = trees::DecisionTree::Traverse(tree, 0, example); + QCHECK(leaf_idx >= 0) << "Invalid tree: " << tree.DebugString(); + const auto& leaf_node = tree.nodes(leaf_idx); + QCHECK(leaf_node.has_leaf()) + << "Invalid leaf node: " << leaf_node.DebugString(); + if (leaf_node.leaf().has_sparse_vector()) { + const auto& leaf = leaf_node.leaf().sparse_vector(); + QCHECK_EQ(leaf.index_size(), leaf.value_size()); + for (size_t class_idx = 0; class_idx < leaf.index_size(); ++class_idx) { + const float value = tree_weight * leaf.value(class_idx); + + UpdatePredictions(example.example_idx, leaf.index(class_idx), value, + output_predictions, additional_output_predictions); + } + } else { + QCHECK(leaf_node.leaf().has_vector()) << "Unknown leaf type"; + const auto& leaf = leaf_node.leaf().vector(); + for (size_t i = 0; i < leaf.value_size(); ++i) { + const float value = tree_weight * leaf.value(i); + UpdatePredictions(example.example_idx, i, value, output_predictions, + additional_output_predictions); + } + } +} + +} // namespace + +void MultipleAdditiveTrees::Predict( + const boosted_trees::trees::DecisionTreeEnsembleConfig& config, + const bool only_finalized_trees, const std::vector& trees_to_drop, + const boosted_trees::utils::BatchFeatures& features, + tensorflow::thread::ThreadPool* worker_threads, + tensorflow::TTypes::Matrix output_predictions, + tensorflow::TTypes::Matrix no_dropout_predictions) { + // Zero out predictions as the model is additive. + output_predictions.setZero(); + no_dropout_predictions.setZero(); + + // Get batch size. + const int64 batch_size = features.batch_size(); + if (batch_size <= 0) { + return; + } + + // Prepare the list of trees to keep. + std::vector trees_to_keep; + CalculateTreesToKeep(config, trees_to_drop, config.trees_size(), + only_finalized_trees, &trees_to_keep); + + // Lambda for doing a block of work. + auto update_predictions = [&config, &features, &trees_to_keep, &trees_to_drop, + &output_predictions, + &no_dropout_predictions](int64 start, int64 end) { + auto examples_iterable = features.examples_iterable(start, end); + for (const auto& example : examples_iterable) { + for (const int32 tree_idx : trees_to_keep) { + UpdatePredictionsBasedOnTree(config, tree_idx, example, + &output_predictions, + &no_dropout_predictions); + } + + // Now do predictions for dropped trees + for (const int32 tree_idx : trees_to_drop) { + UpdatePredictionsBasedOnTree(config, tree_idx, example, + &no_dropout_predictions, nullptr); + } + } + }; + + // TODO(salehay): parallelize this for low latency in serving path where + // batch size tends to be small but ensemble size tends to be large. + boosted_trees::utils::ParallelFor(batch_size, worker_threads->NumThreads(), + worker_threads, update_predictions); +} + +} // namespace models +} // namespace boosted_trees +} // namespace tensorflow diff --git a/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.h b/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.h new file mode 100644 index 00000000000..fedade20261 --- /dev/null +++ b/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.h @@ -0,0 +1,50 @@ +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_MODELS_MULTIPLE_ADDITIVE_TREES_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_MODELS_MULTIPLE_ADDITIVE_TREES_H_ + +#include + +#include "tensorflow/contrib/boosted_trees/lib/utils/batch_features.h" +#include "tensorflow/contrib/boosted_trees/proto/tree_config.pb.h" // NOLINT +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace boosted_trees { +namespace models { + +// Multiple additive trees prediction model. +// This class does not hold state and is thread safe. +class MultipleAdditiveTrees { + public: + // Predict runs tree ensemble on the given batch and updates + // output predictions accordingly. The method also returns predictions that + // we would get if no dropout was applied. + static void Predict( + const boosted_trees::trees::DecisionTreeEnsembleConfig& config, + const bool only_finalized_trees, const std::vector& trees_to_drop, + const boosted_trees::utils::BatchFeatures& features, + thread::ThreadPool* const thread_pool, + TTypes::Matrix output_predictions, + TTypes::Matrix no_dropout_predictions); +}; + +} // namespace models +} // namespace boosted_trees +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_MODELS_MULTIPLE_ADDITIVE_TREES_H_ diff --git a/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees_test.cc b/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees_test.cc new file mode 100644 index 00000000000..5f0924b48f2 --- /dev/null +++ b/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees_test.cc @@ -0,0 +1,381 @@ +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +#include "tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.h" + +#include "tensorflow/contrib/boosted_trees/lib/testutil/batch_features_testutil.h" +#include "tensorflow/contrib/boosted_trees/lib/testutil/random_tree_gen.h" +#include "tensorflow/contrib/boosted_trees/resources/decision_tree_ensemble_resource.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/random/philox_random.h" +#include "tensorflow/core/lib/random/simple_philox.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" + +namespace tensorflow { +using boosted_trees::trees::DecisionTreeEnsembleConfig; +using test::AsTensor; + +namespace boosted_trees { +namespace models { +namespace { + +const int32 kNumThreadsMultiThreaded = 6; +const int32 kNumThreadsSingleThreaded = 1; + +class MultipleAdditiveTreesTest : public ::testing::Test { + protected: + MultipleAdditiveTreesTest() : batch_features_(2) { + // Create a batch of two examples having one dense feature each. + // The shape of the dense matrix is therefore 2x1 as in one row per example + // and one column per feature per example. + auto dense_matrix = test::AsTensor({7.0f, -2.0f}, {2, 1}); + TF_EXPECT_OK( + batch_features_.Initialize({dense_matrix}, {}, {}, {}, {}, {}, {})); + } + + boosted_trees::utils::BatchFeatures batch_features_; +}; + +TEST_F(MultipleAdditiveTreesTest, Empty) { + // Create empty tree ensemble. + DecisionTreeEnsembleConfig tree_ensemble_config; + auto output_tensor = AsTensor({9.0f, 23.0f}, {2, 1}); + auto output_matrix = output_tensor.matrix(); + auto no_dropout_output_matrix = output_tensor.matrix(); + + // Predict for both instances. + tensorflow::thread::ThreadPool threads(tensorflow::Env::Default(), "test", + kNumThreadsSingleThreaded); + MultipleAdditiveTrees::Predict(tree_ensemble_config, + false, // include non-finalized trees + {}, batch_features_, &threads, output_matrix, + no_dropout_output_matrix); + EXPECT_EQ(0, output_matrix(0, 0)); + EXPECT_EQ(0, output_matrix(1, 0)); + + // There was no dropout + for (int i = 0; i < 2; ++i) { + EXPECT_EQ(output_matrix(i, 0), no_dropout_output_matrix(i, 0)); + } +} + +TEST_F(MultipleAdditiveTreesTest, SingleClass) { + // Add one bias and one stump to ensemble for a single class. + DecisionTreeEnsembleConfig tree_ensemble_config; + auto* tree1 = tree_ensemble_config.add_trees(); + auto* bias_leaf = tree1->add_nodes()->mutable_leaf()->mutable_sparse_vector(); + bias_leaf->add_index(0); + bias_leaf->add_value(-0.4f); + auto* tree2 = tree_ensemble_config.add_trees(); + auto* dense_split = tree2->add_nodes()->mutable_dense_float_binary_split(); + dense_split->set_feature_column(0); + dense_split->set_threshold(5.0f); + dense_split->set_left_id(1); + dense_split->set_right_id(2); + auto* leaf1 = tree2->add_nodes()->mutable_leaf()->mutable_sparse_vector(); + leaf1->add_index(0); + leaf1->add_value(0.9f); + auto* leaf2 = tree2->add_nodes()->mutable_leaf()->mutable_sparse_vector(); + leaf2->add_index(0); + leaf2->add_value(0.2f); + + tree_ensemble_config.add_tree_weights(1.0); + tree_ensemble_config.add_tree_weights(1.0); + + auto output_tensor = AsTensor({0.0f, 0.0f}, {2, 1}); + auto output_matrix = output_tensor.matrix(); + + auto no_dropout_output_tensor = AsTensor({0.0f, 0.0f}, {2, 1}); + auto no_dropout_output_matrix = no_dropout_output_tensor.matrix(); + + tensorflow::thread::ThreadPool threads(tensorflow::Env::Default(), "test", + kNumThreadsSingleThreaded); + + // Normal case. + { + MultipleAdditiveTrees::Predict(tree_ensemble_config, + false, // include non-finalized trees + {}, batch_features_, &threads, output_matrix, + no_dropout_output_matrix); + EXPECT_FLOAT_EQ(-0.2f, output_matrix(0, 0)); // -0.4 (bias) + 0.2 (leaf 2). + EXPECT_FLOAT_EQ(0.5f, output_matrix(1, 0)); // -0.4 (bias) + 0.9 (leaf 1). + + // No dropout predictions are the same. + for (int i = 0; i < 2; ++i) { + EXPECT_EQ(output_matrix(i, 0), no_dropout_output_matrix(i, 0)); + } + } + // Weighted case + { + DecisionTreeEnsembleConfig weighted = tree_ensemble_config; + weighted.set_tree_weights(0, 6.0); + weighted.set_tree_weights(1, 3.2); + MultipleAdditiveTrees::Predict(weighted, + false, // include non-finalized trees + {}, batch_features_, &threads, output_matrix, + no_dropout_output_matrix); + // -0.4 (bias) + 0.2 (leaf 2). + EXPECT_FLOAT_EQ(-0.4f * 6 + 0.2 * 3.2, output_matrix(0, 0)); + // -0.4 (bias) + 0.9 (leaf 1). + EXPECT_FLOAT_EQ(-0.4f * 6 + 0.9 * 3.2, output_matrix(1, 0)); + + // No dropout predictions are the same. + for (int i = 0; i < 2; ++i) { + EXPECT_EQ(output_matrix(i, 0), no_dropout_output_matrix(i, 0)); + } + } + // Drop first tree. + { + MultipleAdditiveTrees::Predict(tree_ensemble_config, + false, // include non-finalized trees + {0}, batch_features_, &threads, + output_matrix, no_dropout_output_matrix); + EXPECT_FLOAT_EQ(0.2f, output_matrix(0, 0)); // 0.2 (leaf 2). + EXPECT_FLOAT_EQ(0.9f, output_matrix(1, 0)); // 0.9 (leaf 1). + + // No dropout predictions + EXPECT_FLOAT_EQ( + -0.2f, no_dropout_output_matrix(0, 0)); // -0.4 (bias) + 0.2 (leaf 2). + EXPECT_FLOAT_EQ( + 0.5f, no_dropout_output_matrix(1, 0)); // -0.4 (bias) + 0.9 (leaf 1). + } + // Drop second tree. + { + MultipleAdditiveTrees::Predict(tree_ensemble_config, + false, // include non-finalized trees + {1}, batch_features_, &threads, + output_matrix, no_dropout_output_matrix); + EXPECT_FLOAT_EQ(-0.4f, output_matrix(0, 0)); // -0.4 (bias). + EXPECT_FLOAT_EQ(-0.4f, output_matrix(1, 0)); // -0.4 (bias). + + // No dropout predictions + EXPECT_FLOAT_EQ( + -0.2f, no_dropout_output_matrix(0, 0)); // -0.4 (bias) + 0.2 (leaf 2). + EXPECT_FLOAT_EQ( + 0.5f, no_dropout_output_matrix(1, 0)); // -0.4 (bias) + 0.9 (leaf 1). + } + // Drop all trees. + { + MultipleAdditiveTrees::Predict(tree_ensemble_config, + false, // include non-finalized trees + {0, 1}, batch_features_, &threads, + output_matrix, no_dropout_output_matrix); + EXPECT_FLOAT_EQ(0.0, output_matrix(0, 0)); + EXPECT_FLOAT_EQ(0.0, output_matrix(1, 0)); + + // No dropout predictions + EXPECT_FLOAT_EQ( + -0.2f, no_dropout_output_matrix(0, 0)); // -0.4 (bias) + 0.2 (leaf 2). + EXPECT_FLOAT_EQ( + 0.5f, no_dropout_output_matrix(1, 0)); // -0.4 (bias) + 0.9 (leaf 1). + } +} + +TEST_F(MultipleAdditiveTreesTest, MultiClass) { + // Add one bias and one stump to ensemble for two classes. + DecisionTreeEnsembleConfig tree_ensemble_config; + auto* tree1 = tree_ensemble_config.add_trees(); + auto* bias_leaf = tree1->add_nodes()->mutable_leaf()->mutable_sparse_vector(); + bias_leaf->add_index(0); + bias_leaf->add_value(-0.4f); + bias_leaf->add_index(1); + bias_leaf->add_value(-0.7f); + auto* tree2 = tree_ensemble_config.add_trees(); + auto* dense_split = tree2->add_nodes()->mutable_dense_float_binary_split(); + dense_split->set_feature_column(0); + dense_split->set_threshold(5.0f); + dense_split->set_left_id(1); + dense_split->set_right_id(2); + auto* leaf1 = tree2->add_nodes()->mutable_leaf()->mutable_sparse_vector(); + leaf1->add_index(0); + leaf1->add_value(0.9f); + auto* leaf2 = tree2->add_nodes()->mutable_leaf()->mutable_sparse_vector(); + leaf2->add_index(1); + leaf2->add_value(0.2f); + + tree_ensemble_config.add_tree_weights(1.0); + tree_ensemble_config.add_tree_weights(1.0); + + // Predict for both instances. + tensorflow::thread::ThreadPool threads(tensorflow::Env::Default(), "test", + kNumThreadsSingleThreaded); + auto output_tensor = AsTensor({0.0f, 0.0f, 0.0f, 0.0f}, {2, 2}); + auto output_matrix = output_tensor.matrix(); + + auto no_dropout_output_tensor = + AsTensor({0.0f, 0.0f, 0.0f, 0.0f}, {2, 2}); + auto no_dropout_output_matrix = no_dropout_output_tensor.matrix(); + + // Normal case. + { + MultipleAdditiveTrees::Predict(tree_ensemble_config, + false, // include non-finalized trees + {}, batch_features_, &threads, output_matrix, + no_dropout_output_matrix); + EXPECT_FLOAT_EQ(-0.4f, output_matrix(0, 0)); // -0.4 (bias) + EXPECT_FLOAT_EQ(-0.5f, output_matrix(0, 1)); // -0.7 (bias) + 0.2 (leaf 2) + EXPECT_FLOAT_EQ(0.5f, output_matrix(1, 0)); // -0.4 (bias) + 0.9 (leaf 1) + EXPECT_FLOAT_EQ(-0.7f, output_matrix(1, 1)); // -0.7 (bias) + + // No dropout predictions are the same. + for (int i = 0; i < 2; ++i) { + for (int j = 0; j < 2; ++j) { + EXPECT_EQ(output_matrix(i, j), no_dropout_output_matrix(i, j)); + } + } + } + // Weighted case. + { + DecisionTreeEnsembleConfig weighted = tree_ensemble_config; + weighted.set_tree_weights(0, 6.0); + weighted.set_tree_weights(1, 3.2); + MultipleAdditiveTrees::Predict(weighted, + false, // include non-finalized trees + {}, batch_features_, &threads, output_matrix, + no_dropout_output_matrix); + // bias + EXPECT_FLOAT_EQ(-0.4f * 6, output_matrix(0, 0)); + // bias + leaf 2 + EXPECT_FLOAT_EQ(-0.7f * 6 + 0.2f * 3.2, output_matrix(0, 1)); + // bias + leaf 2 + EXPECT_FLOAT_EQ(-0.4f * 6 + 0.9f * 3.2f, output_matrix(1, 0)); + // bias + EXPECT_FLOAT_EQ(-0.7f * 6, output_matrix(1, 1)); + } + // Dropout first tree. + { + MultipleAdditiveTrees::Predict(tree_ensemble_config, + false, // include non-finalized trees + {0}, batch_features_, &threads, + output_matrix, no_dropout_output_matrix); + EXPECT_FLOAT_EQ(0.0, output_matrix(0, 0)); + EXPECT_FLOAT_EQ(0.2f, output_matrix(0, 1)); // 0.2 (leaf 2) + EXPECT_FLOAT_EQ(0.9f, output_matrix(1, 0)); // 0.9 (leaf 2) + EXPECT_FLOAT_EQ(0.0f, output_matrix(1, 1)); + + // No dropout predictions + EXPECT_FLOAT_EQ(-0.4f, no_dropout_output_matrix(0, 0)); // -0.4 (bias) + EXPECT_FLOAT_EQ( + -0.5f, no_dropout_output_matrix(0, 1)); // -0.7 (bias) + 0.2 (leaf 2) + EXPECT_FLOAT_EQ( + 0.5f, no_dropout_output_matrix(1, 0)); // -0.4 (bias) + 0.9 (leaf 2) + EXPECT_FLOAT_EQ(-0.7f, no_dropout_output_matrix(1, 1)); // -0.7 (bias) + } + // Dropout second tree. + { + MultipleAdditiveTrees::Predict(tree_ensemble_config, + false, // include non-finalized trees + {1}, batch_features_, &threads, + output_matrix, no_dropout_output_matrix); + EXPECT_FLOAT_EQ(-0.4f, output_matrix(0, 0)); // -0.4 (bias) + EXPECT_FLOAT_EQ(-0.7f, output_matrix(0, 1)); // -0.7 (bias) + EXPECT_FLOAT_EQ(-0.4f, output_matrix(1, 0)); // -0.4 (bias) + EXPECT_FLOAT_EQ(-0.7f, output_matrix(1, 1)); // -0.7 (bias) + + // No dropout predictions + EXPECT_FLOAT_EQ(-0.4f, no_dropout_output_matrix(0, 0)); // -0.4 (bias) + EXPECT_FLOAT_EQ( + -0.5f, no_dropout_output_matrix(0, 1)); // -0.7 (bias) + 0.2 (leaf 2) + EXPECT_FLOAT_EQ( + 0.5f, no_dropout_output_matrix(1, 0)); // -0.4 (bias) + 0.9 (leaf 2) + EXPECT_FLOAT_EQ(-0.7f, no_dropout_output_matrix(1, 1)); // -0.7 (bias) + } + // Drop both trees. + { + MultipleAdditiveTrees::Predict(tree_ensemble_config, + false, // include non-finalized trees + {0, 1}, batch_features_, &threads, + output_matrix, no_dropout_output_matrix); + EXPECT_FLOAT_EQ(0.0f, output_matrix(0, 0)); + EXPECT_FLOAT_EQ(0.0f, output_matrix(0, 1)); + EXPECT_FLOAT_EQ(0.0f, output_matrix(1, 0)); + EXPECT_FLOAT_EQ(0.0f, output_matrix(1, 1)); + + // No dropout predictions + EXPECT_FLOAT_EQ(-0.4f, no_dropout_output_matrix(0, 0)); // -0.4 (bias) + EXPECT_FLOAT_EQ( + -0.5f, no_dropout_output_matrix(0, 1)); // -0.7 (bias) + 0.2 (leaf 2) + EXPECT_FLOAT_EQ( + 0.5f, no_dropout_output_matrix(1, 0)); // -0.4 (bias) + 0.9 (leaf 2) + EXPECT_FLOAT_EQ(-0.7f, no_dropout_output_matrix(1, 1)); // -0.7 (bias) + } +} + +TEST_F(MultipleAdditiveTreesTest, DenseLeaves) { + DecisionTreeEnsembleConfig tree_ensemble_config; + auto* tree1 = tree_ensemble_config.add_trees(); + auto* bias_leaf = tree1->add_nodes()->mutable_leaf()->mutable_vector(); + bias_leaf->add_value(-0.4f); + bias_leaf->add_value(-0.7f); + bias_leaf->add_value(3.0f); + auto* tree2 = tree_ensemble_config.add_trees(); + auto* dense_split = tree2->add_nodes()->mutable_dense_float_binary_split(); + dense_split->set_feature_column(0); + dense_split->set_threshold(5.0f); + dense_split->set_left_id(1); + dense_split->set_right_id(2); + auto* leaf1 = tree2->add_nodes()->mutable_leaf()->mutable_vector(); + leaf1->add_value(0.9f); + leaf1->add_value(0.8f); + leaf1->add_value(0.7f); + auto* leaf2 = tree2->add_nodes()->mutable_leaf()->mutable_vector(); + leaf2->add_value(0.2f); + leaf2->add_value(0.3f); + leaf2->add_value(0.4f); + + tree_ensemble_config.add_tree_weights(1.0); + tree_ensemble_config.add_tree_weights(1.0); + + // Predict for both instances. + tensorflow::thread::ThreadPool threads(tensorflow::Env::Default(), "test", + kNumThreadsSingleThreaded); + auto output_tensor = + AsTensor({0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f}, {2, 3}); + auto output_matrix = output_tensor.matrix(); + + auto no_dropout_output_tensor = + AsTensor({0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f}, {2, 3}); + auto no_dropout_output_matrix = no_dropout_output_tensor.matrix(); + + // Normal case. + { + MultipleAdditiveTrees::Predict(tree_ensemble_config, + false, // include non-finalized trees + {}, batch_features_, &threads, output_matrix, + no_dropout_output_matrix); + EXPECT_FLOAT_EQ(-0.2f, output_matrix(0, 0)); // -0.4 (tree1) + 0.2 (leaf 2) + EXPECT_FLOAT_EQ(-0.4f, output_matrix(0, 1)); // -0.7 (tree1) + 0.3 (leaf 2) + EXPECT_FLOAT_EQ(3.4f, output_matrix(0, 2)); // 3.0 -(tree1) + 0.4 (leaf 2) + EXPECT_FLOAT_EQ(0.5f, output_matrix(1, 0)); // -0.4 (tree1) + 0.9 (leaf 1) + EXPECT_FLOAT_EQ(0.1f, output_matrix(1, 1)); // -0.7 (tree1) + 0.8 (leaf 1) + EXPECT_FLOAT_EQ(3.7f, output_matrix(1, 2)); // 3.0 (tree1) + 0.7 (leaf 1) + + // No dropout predictions are the same. + for (int i = 0; i < 2; ++i) { + for (int j = 0; j < 3; ++j) { + EXPECT_EQ(output_matrix(i, j), no_dropout_output_matrix(i, j)); + } + } + } +} + +} // namespace +} // namespace models +} // namespace boosted_trees +} // namespace tensorflow diff --git a/tensorflow/contrib/boosted_trees/lib/testutil/batch_features_testutil.cc b/tensorflow/contrib/boosted_trees/lib/testutil/batch_features_testutil.cc new file mode 100644 index 00000000000..39c2fbe9c99 --- /dev/null +++ b/tensorflow/contrib/boosted_trees/lib/testutil/batch_features_testutil.cc @@ -0,0 +1,88 @@ +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +#include "tensorflow/contrib/boosted_trees/lib/testutil/batch_features_testutil.h" + +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace tensorflow { +namespace boosted_trees { +namespace testutil { + +using tensorflow::Tensor; + +void RandomlyInitializeBatchFeatures( + tensorflow::random::SimplePhilox* rng, uint32 num_dense_float_features, + uint32 num_sparse_float_features, double sparsity_lo, double sparsity_hi, + boosted_trees::utils::BatchFeatures* batch_features) { + const int64 batch_size = static_cast(batch_features->batch_size()); + + // Populate dense features. + std::vector dense_float_features_list; + for (int i = 0; i < num_dense_float_features; ++i) { + std::vector values; + for (int64 j = 0; j < batch_size; ++j) { + values.push_back(rng->RandFloat()); + } + auto dense_tensor = Tensor(tensorflow::DT_FLOAT, {batch_size, 1}); + tensorflow::test::FillValues(&dense_tensor, values); + dense_float_features_list.push_back(dense_tensor); + } + + // Populate sparse features. + std::vector sparse_float_feature_indices_list; + std::vector sparse_float_feature_values_list; + std::vector sparse_float_feature_shapes_list; + for (int i = 0; i < num_sparse_float_features; ++i) { + std::set indices; + const double sparsity = + sparsity_lo + rng->RandDouble() * (sparsity_hi - sparsity_lo); + const double density = 1 - sparsity; + for (int64 k = 0; k < static_cast(density * batch_size) + 1; ++k) { + indices.insert(rng->Uniform64(batch_size)); + } + const int64 sparse_values_size = indices.size(); + std::vector indices_vector; + for (auto idx : indices) { + indices_vector.push_back(idx); + indices_vector.push_back(0); + } + auto indices_tensor = Tensor(tensorflow::DT_INT64, {sparse_values_size, 2}); + tensorflow::test::FillValues(&indices_tensor, indices_vector); + sparse_float_feature_indices_list.push_back(indices_tensor); + + std::vector values; + for (int64 j = 0; j < sparse_values_size; ++j) { + values.push_back(rng->RandFloat()); + } + auto values_tensor = Tensor(tensorflow::DT_FLOAT, {sparse_values_size}); + tensorflow::test::FillValues(&values_tensor, values); + sparse_float_feature_values_list.push_back(values_tensor); + + auto shape_tensor = Tensor(tensorflow::DT_INT64, {2}); + tensorflow::test::FillValues(&shape_tensor, {batch_size, 1}); + sparse_float_feature_shapes_list.push_back(shape_tensor); + } + + // TODO(salehay): Add categorical feature generation support. + TF_EXPECT_OK(batch_features->Initialize( + dense_float_features_list, sparse_float_feature_indices_list, + sparse_float_feature_values_list, sparse_float_feature_shapes_list, {}, + {}, {})); +} + +} // namespace testutil +} // namespace boosted_trees +} // namespace tensorflow diff --git a/tensorflow/contrib/boosted_trees/lib/testutil/batch_features_testutil.h b/tensorflow/contrib/boosted_trees/lib/testutil/batch_features_testutil.h new file mode 100644 index 00000000000..d95878ec87b --- /dev/null +++ b/tensorflow/contrib/boosted_trees/lib/testutil/batch_features_testutil.h @@ -0,0 +1,45 @@ +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_TESTUTIL_BATCH_FEATURES_TESTUTIL_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_TESTUTIL_BATCH_FEATURES_TESTUTIL_H_ + +#include "tensorflow/contrib/boosted_trees/lib/utils/batch_features.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/random/simple_philox.h" + +namespace tensorflow { +namespace boosted_trees { +namespace testutil { + +// This method calls Initialize on the given 'batch_features', which will be +// populated with randomly generated feature values when the call returns. +// 'tensors' returns a vector of all tensors used in the initialization, +// because they must outlive 'batch_features'. +// +// All float features will be either missing or uniformly randomly chosen +// from [0, 1). For sparse (float) features, a sparsity is uniformly randomly +// chosen from ['sparsity_lo', 'sparsity_hi') per feature, and each instance +// will have a probability of sparsity of missing that feature, in other words, +// sparsity = 1 - density. +void RandomlyInitializeBatchFeatures( + tensorflow::random::SimplePhilox* rng, uint32 num_dense_float_features, + uint32 num_sparse_float_features, double sparsity_lo, double sparsity_hi, + boosted_trees::utils::BatchFeatures* batch_features); + +} // namespace testutil +} // namespace boosted_trees +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_TESTUTIL_BATCH_FEATURES_TESTUTIL_H_ diff --git a/tensorflow/contrib/boosted_trees/lib/testutil/random_tree_gen.cc b/tensorflow/contrib/boosted_trees/lib/testutil/random_tree_gen.cc new file mode 100644 index 00000000000..cbe26ba918d --- /dev/null +++ b/tensorflow/contrib/boosted_trees/lib/testutil/random_tree_gen.cc @@ -0,0 +1,211 @@ +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +#include "tensorflow/contrib/boosted_trees/lib/testutil/random_tree_gen.h" + +#include "tensorflow/core/lib/random/philox_random.h" +#include "tensorflow/core/lib/random/simple_philox.h" +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { +namespace boosted_trees { +namespace testutil { + +using tensorflow::boosted_trees::trees::DecisionTreeConfig; +using tensorflow::boosted_trees::trees::TreeNode; +using boosted_trees::trees::DenseFloatBinarySplit; + +namespace { + +// Append the given nodes to tree with transfer of pointer ownership. +// nodes will not be usable upon return. +template +void AppendNodes(DecisionTreeConfig* tree, T* nodes) { + std::reverse(nodes->pointer_begin(), nodes->pointer_end()); + while (!nodes->empty()) { + tree->mutable_nodes()->AddAllocated(nodes->ReleaseLast()); + } +} + +DenseFloatBinarySplit* GetSplit(TreeNode* node) { + switch (node->node_case()) { + case TreeNode::kSparseFloatBinarySplitDefaultLeft: + return node->mutable_sparse_float_binary_split_default_left() + ->mutable_split(); + case TreeNode::kSparseFloatBinarySplitDefaultRight: + return node->mutable_sparse_float_binary_split_default_right() + ->mutable_split(); + case TreeNode::kDenseFloatBinarySplit: + return node->mutable_dense_float_binary_split(); + default: + LOG(FATAL) << "Unknown node type encountered."; + } + return nullptr; +} + +} // namespace + +RandomTreeGen::RandomTreeGen(tensorflow::random::SimplePhilox* rng, + int dense_feature_size, int sparse_feature_size) + : rng_(rng), + dense_feature_size_(dense_feature_size), + sparse_feature_size_(sparse_feature_size) {} + +namespace { +void AddWeightAndMetadata( + boosted_trees::trees::DecisionTreeEnsembleConfig* ret) { + // Assign the weight of the tree to 1 and say that this weight was updated + // only once. + ret->add_tree_weights(1.0); + auto* meta = ret->add_tree_metadata(); + meta->set_num_tree_weight_updates(1); +} + +} // namespace + +boosted_trees::trees::DecisionTreeEnsembleConfig +RandomTreeGen::GenerateEnsemble(int depth, int tree_count) { + boosted_trees::trees::DecisionTreeEnsembleConfig ret; + *(ret.add_trees()) = Generate(depth); + AddWeightAndMetadata(&ret); + for (int i = 1; i < tree_count; ++i) { + *(ret.add_trees()) = Generate(ret.trees(0)); + AddWeightAndMetadata(&ret); + } + return ret; +} + +DecisionTreeConfig RandomTreeGen::Generate(const DecisionTreeConfig& tree) { + DecisionTreeConfig ret = tree; + for (auto& node : *ret.mutable_nodes()) { + if (node.node_case() == TreeNode::kLeaf) { + node.mutable_leaf()->mutable_sparse_vector()->set_value( + 0, rng_->RandFloat()); + continue; + } + // Original node is a split. Re-generate it's type but retain the split node + // indices. + DenseFloatBinarySplit* split = GetSplit(&node); + const int left_id = split->left_id(); + const int right_id = split->right_id(); + GenerateSplit(&node, left_id, right_id); + } + return ret; +} + +DecisionTreeConfig RandomTreeGen::Generate(int depth) { + DecisionTreeConfig ret; + // Add root, + TreeNode* node = ret.add_nodes(); + GenerateSplit(node, 1, 2); + if (depth == 1) { + // Add left and right leaves. + TreeNode* left = ret.add_nodes(); + left->mutable_leaf()->mutable_sparse_vector()->add_index(0); + left->mutable_leaf()->mutable_sparse_vector()->add_value(rng_->RandFloat()); + TreeNode* right = ret.add_nodes(); + right->mutable_leaf()->mutable_sparse_vector()->add_index(0); + right->mutable_leaf()->mutable_sparse_vector()->add_value( + rng_->RandFloat()); + return ret; + } else { + DecisionTreeConfig left_branch = Generate(depth - 1); + DecisionTreeConfig right_branch = Generate(depth - 1); + Combine(&ret, &left_branch, &right_branch); + return ret; + } +} + +void RandomTreeGen::Combine(DecisionTreeConfig* root, + DecisionTreeConfig* left_branch, + DecisionTreeConfig* right_branch) { + const int left_branch_size = left_branch->nodes_size(); + CHECK_EQ(1, root->nodes_size()); + // left_branch starts its index at 1. right_branch starts its index at + // (left_branch_size + 1). + auto* root_node = root->mutable_nodes(0); + DenseFloatBinarySplit* root_split = GetSplit(root_node); + root_split->set_left_id(1); + root_split->set_right_id(left_branch_size + 1); + // Shift left/right branch's indices internally so that everything is + // consistent. + ShiftNodeIndex(left_branch, 1); + ShiftNodeIndex(right_branch, left_branch_size + 1); + + // Complexity O(branch node size). No proto copying though. + AppendNodes(root, left_branch->mutable_nodes()); + AppendNodes(root, right_branch->mutable_nodes()); +} + +void RandomTreeGen::ShiftNodeIndex(DecisionTreeConfig* tree, int shift) { + for (TreeNode& node : *(tree->mutable_nodes())) { + DenseFloatBinarySplit* split = nullptr; + switch (node.node_case()) { + case TreeNode::kLeaf: + break; + case TreeNode::kSparseFloatBinarySplitDefaultLeft: + split = node.mutable_sparse_float_binary_split_default_left() + ->mutable_split(); + break; + case TreeNode::kSparseFloatBinarySplitDefaultRight: + split = node.mutable_sparse_float_binary_split_default_right() + ->mutable_split(); + break; + case TreeNode::kDenseFloatBinarySplit: + split = node.mutable_dense_float_binary_split(); + break; + default: + LOG(FATAL) << "Unknown node type encountered."; + } + if (split) { + split->set_left_id(shift + split->left_id()); + split->set_right_id(shift + split->right_id()); + } + } +} + +void RandomTreeGen::GenerateSplit(TreeNode* node, int left_id, int right_id) { + const double denseSplitProb = + sparse_feature_size_ == 0 + ? 1.0 + : static_cast(dense_feature_size_) / + (dense_feature_size_ + sparse_feature_size_); + // Generate the tree such that it has equal probability of going left and + // right when the feature is missing. + static constexpr float kLeftProb = 0.5; + + DenseFloatBinarySplit* split; + int feature_size; + if (rng_->RandFloat() < denseSplitProb) { + feature_size = dense_feature_size_; + split = node->mutable_dense_float_binary_split(); + } else { + feature_size = sparse_feature_size_; + if (rng_->RandFloat() < kLeftProb) { + split = node->mutable_sparse_float_binary_split_default_left() + ->mutable_split(); + } else { + split = node->mutable_sparse_float_binary_split_default_right() + ->mutable_split(); + } + } + split->set_threshold(rng_->RandFloat()); + split->set_feature_column(rng_->Uniform(feature_size)); + split->set_left_id(left_id); + split->set_right_id(right_id); +} + +} // namespace testutil +} // namespace boosted_trees +} // namespace tensorflow diff --git a/tensorflow/contrib/boosted_trees/lib/testutil/random_tree_gen.h b/tensorflow/contrib/boosted_trees/lib/testutil/random_tree_gen.h new file mode 100644 index 00000000000..dc584bbd3cf --- /dev/null +++ b/tensorflow/contrib/boosted_trees/lib/testutil/random_tree_gen.h @@ -0,0 +1,75 @@ +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_TESTUTIL_RANDOM_TREE_GEN_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_TESTUTIL_RANDOM_TREE_GEN_H_ + +#include + +#include "tensorflow/contrib/boosted_trees/proto/tree_config.pb.h" // NOLINT +#include "tensorflow/core/lib/random/simple_philox.h" +#include "tensorflow/core/platform/macros.h" + +namespace tensorflow { +namespace boosted_trees { +namespace testutil { + +// Randomly generate a balanced tree, for performance benchmarking purposes, +// that assume all features are sparse float features, for now. +class RandomTreeGen { + public: + RandomTreeGen(tensorflow::random::SimplePhilox* rng, int dense_feature_size, + int sparse_feature_size); + + // Required: depth must be >= 1. + // If one wants to generate multiple trees with the same depth, see also the + // overload below. + boosted_trees::trees::DecisionTreeConfig Generate(int depth); + + // Randomly generate a new tree with the same depth (and tree structure) + // as the given tree. This is faster. + boosted_trees::trees::DecisionTreeConfig Generate( + const boosted_trees::trees::DecisionTreeConfig& tree); + + // Requried: depth >= 1; tree_count >= 1. + boosted_trees::trees::DecisionTreeEnsembleConfig GenerateEnsemble( + int dept, int tree_count); + + private: + tensorflow::random::SimplePhilox* rng_; + const int dense_feature_size_; + const int sparse_feature_size_; + + // Put together a deeper tree by combining two trees. + void Combine(boosted_trees::trees::DecisionTreeConfig* root, + boosted_trees::trees::DecisionTreeConfig* left_branch, + boosted_trees::trees::DecisionTreeConfig* right_branch); + + // For each node in the provided tree, shift its referenced left/right index + // by shift. + void ShiftNodeIndex(boosted_trees::trees::DecisionTreeConfig* tree, + int shift); + + // Generate a sparse split in the node. + void GenerateSplit(boosted_trees::trees::TreeNode* node, int left_id, + int right_id); + + TF_DISALLOW_COPY_AND_ASSIGN(RandomTreeGen); +}; + +} // namespace testutil +} // namespace boosted_trees +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_TESTUTIL_RANDOM_TREE_GEN_H_ diff --git a/tensorflow/contrib/boosted_trees/lib/testutil/random_tree_gen_main.cc b/tensorflow/contrib/boosted_trees/lib/testutil/random_tree_gen_main.cc new file mode 100644 index 00000000000..5ea81e8d9a4 --- /dev/null +++ b/tensorflow/contrib/boosted_trees/lib/testutil/random_tree_gen_main.cc @@ -0,0 +1,67 @@ +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// Randomly generate a tree ensemble and write to file. + +#include "tensorflow/contrib/boosted_trees/lib/testutil/random_tree_gen.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/random/philox_random.h" +#include "tensorflow/core/lib/random/simple_philox.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/command_line_flags.h" + +using tensorflow::Flag; +using tensorflow::Flags; +using tensorflow::int32; +using tensorflow::string; + +int main(int argc, char* argv[]) { + int32 dense_feature_size = 100; + int32 sparse_feature_size = 100; + int32 depth = 8; + int32 tree_count = 10; + string filename = "/tmp/trees.pb"; + std::vector flag_list = { + Flag("dense_feature_size", &dense_feature_size, "dense feature size"), + Flag("sparse_feature_size", &sparse_feature_size, "sparse_feature_size"), + Flag("depth", &depth, "tree depth"), + Flag("tree_count", &tree_count, "tree count"), + Flag("filename", &filename, "Output filename."), + }; + string usage = Flags::Usage(argv[0], flag_list); + const bool parse_result = Flags::Parse(&argc, argv, flag_list); + // We need to call this to set up global state for TensorFlow. + tensorflow::port::InitMain(usage.c_str(), &argc, &argv); + if (!parse_result) { + LOG(ERROR) << "\n" << usage; + return -1; + } + + tensorflow::random::PhiloxRandom philox(1); + tensorflow::random::SimplePhilox rng(&philox); + tensorflow::boosted_trees::testutil::RandomTreeGen tree_gen( + &rng, dense_feature_size, sparse_feature_size); + const auto& trees = tree_gen.GenerateEnsemble(depth, tree_count); + tensorflow::Status status = + tensorflow::WriteBinaryProto(tensorflow::Env::Default(), filename, trees); + if (!status.ok()) { + LOG(WARNING) << "Failed to write: " << filename << " : " << status; + } else { + LOG(INFO) << "Tree ensemble written to: " << filename; + } + return 0; +} diff --git a/tensorflow/contrib/boosted_trees/lib/trees/decision_tree.cc b/tensorflow/contrib/boosted_trees/lib/trees/decision_tree.cc new file mode 100644 index 00000000000..318d8a5296e --- /dev/null +++ b/tensorflow/contrib/boosted_trees/lib/trees/decision_tree.cc @@ -0,0 +1,170 @@ +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +#include "tensorflow/contrib/boosted_trees/lib/trees/decision_tree.h" +#include "tensorflow/core/platform/macros.h" + +namespace tensorflow { +namespace boosted_trees { +namespace trees { + +constexpr int kInvalidLeaf = -1; +int DecisionTree::Traverse(const DecisionTreeConfig& config, + const int32 sub_root_id, + const utils::Example& example) { + if (TF_PREDICT_FALSE(config.nodes_size() <= sub_root_id)) { + return kInvalidLeaf; + } + + // Traverse tree starting at the provided sub-root. + int32 node_id = sub_root_id; + while (true) { + const auto& current_node = config.nodes(node_id); + switch (current_node.node_case()) { + case TreeNode::kLeaf: { + return node_id; + } + case TreeNode::kDenseFloatBinarySplit: { + const auto& split = current_node.dense_float_binary_split(); + node_id = example.dense_float_features[split.feature_column()] <= + split.threshold() + ? split.left_id() + : split.right_id(); + break; + } + case TreeNode::kSparseFloatBinarySplitDefaultLeft: { + const auto& split = + current_node.sparse_float_binary_split_default_left().split(); + auto sparse_feature = + example.sparse_float_features[split.feature_column()]; + node_id = !sparse_feature.has_value() || + sparse_feature.get_value() <= split.threshold() + ? split.left_id() + : split.right_id(); + break; + } + case TreeNode::kSparseFloatBinarySplitDefaultRight: { + const auto& split = + current_node.sparse_float_binary_split_default_right().split(); + auto sparse_feature = + example.sparse_float_features[split.feature_column()]; + node_id = sparse_feature.has_value() && + sparse_feature.get_value() <= split.threshold() + ? split.left_id() + : split.right_id(); + break; + } + case TreeNode::kCategoricalIdBinarySplit: { + const auto& split = current_node.categorical_id_binary_split(); + node_id = example.sparse_int_features[split.feature_column()].count( + split.feature_id()) > 0 + ? split.left_id() + : split.right_id(); + break; + } + case TreeNode::NODE_NOT_SET: { + QCHECK(false) << "Invalid node in tree: " << current_node.DebugString(); + break; + } + } + DCHECK_NE(node_id, 0) << "Malformed tree, cycles found to root:" + << current_node.DebugString(); + } +} + +void DecisionTree::LinkChildren(const std::vector& children, + TreeNode* parent_node) { + // Decide how to link children depending on the parent node's type. + auto children_it = children.begin(); + switch (parent_node->node_case()) { + case TreeNode::kLeaf: { + // Essentially no-op. + QCHECK(children.empty()) << "A leaf node cannot have children."; + break; + } + case TreeNode::kDenseFloatBinarySplit: { + QCHECK(children.size() == 2) + << "A binary split node must have exactly two children."; + auto* split = parent_node->mutable_dense_float_binary_split(); + split->set_left_id(*children_it); + split->set_right_id(*++children_it); + break; + } + case TreeNode::kSparseFloatBinarySplitDefaultLeft: { + QCHECK(children.size() == 2) + << "A binary split node must have exactly two children."; + auto* split = + parent_node->mutable_sparse_float_binary_split_default_left() + ->mutable_split(); + split->set_left_id(*children_it); + split->set_right_id(*++children_it); + break; + } + case TreeNode::kSparseFloatBinarySplitDefaultRight: { + QCHECK(children.size() == 2) + << "A binary split node must have exactly two children."; + auto* split = + parent_node->mutable_sparse_float_binary_split_default_right() + ->mutable_split(); + split->set_left_id(*children_it); + split->set_right_id(*++children_it); + break; + } + case TreeNode::kCategoricalIdBinarySplit: { + QCHECK(children.size() == 2) + << "A binary split node must have exactly two children."; + auto* split = parent_node->mutable_categorical_id_binary_split(); + split->set_left_id(*children_it); + split->set_right_id(*++children_it); + break; + } + case TreeNode::NODE_NOT_SET: { + QCHECK(false) << "A non-set node cannot have children."; + break; + } + } +} + +std::vector DecisionTree::GetChildren(const TreeNode& node) { + // A node's children depend on its type. + switch (node.node_case()) { + case TreeNode::kLeaf: { + return {}; + } + case TreeNode::kDenseFloatBinarySplit: { + const auto& split = node.dense_float_binary_split(); + return {split.left_id(), split.right_id()}; + } + case TreeNode::kSparseFloatBinarySplitDefaultLeft: { + const auto& split = node.sparse_float_binary_split_default_left().split(); + return {split.left_id(), split.right_id()}; + } + case TreeNode::kSparseFloatBinarySplitDefaultRight: { + const auto& split = + node.sparse_float_binary_split_default_right().split(); + return {split.left_id(), split.right_id()}; + } + case TreeNode::kCategoricalIdBinarySplit: { + const auto& split = node.categorical_id_binary_split(); + return {split.left_id(), split.right_id()}; + } + case TreeNode::NODE_NOT_SET: { + return {}; + } + } +} + +} // namespace trees +} // namespace boosted_trees +} // namespace tensorflow diff --git a/tensorflow/contrib/boosted_trees/lib/trees/decision_tree.h b/tensorflow/contrib/boosted_trees/lib/trees/decision_tree.h new file mode 100644 index 00000000000..604ff02744b --- /dev/null +++ b/tensorflow/contrib/boosted_trees/lib/trees/decision_tree.h @@ -0,0 +1,49 @@ +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_TREES_DECISION_TREE_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_TREES_DECISION_TREE_H_ + +#include "tensorflow/contrib/boosted_trees/lib/utils/example.h" +#include "tensorflow/contrib/boosted_trees/proto/tree_config.pb.h" // NOLINT + +namespace tensorflow { +namespace boosted_trees { +namespace trees { + +// Decision tree class to encapsulate tree traversal and mutation logic. +// This class does not hold state and is thread safe. +class DecisionTree { + public: + // Traverse given an instance, a sub-root and its set of features + // and return the leaf index or -1 if the tree is empty or + // the sub-root is invalid. + static int Traverse(const DecisionTreeConfig& config, int32 sub_root_id, + const utils::Example& example); + + // Links the specified children to the parent, the children must + // already be added to the decision tree config so this method + // just ensures nodes are re-linked. + static void LinkChildren(const std::vector& children, + TreeNode* parent_node); + + // Retrieves node children indices if any. + static std::vector GetChildren(const TreeNode& node); +}; + +} // namespace trees +} // namespace boosted_trees +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_TREES_DECISION_TREE_H_ diff --git a/tensorflow/contrib/boosted_trees/lib/trees/decision_tree_test.cc b/tensorflow/contrib/boosted_trees/lib/trees/decision_tree_test.cc new file mode 100644 index 00000000000..0f082d7fd54 --- /dev/null +++ b/tensorflow/contrib/boosted_trees/lib/trees/decision_tree_test.cc @@ -0,0 +1,326 @@ +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +#include "tensorflow/contrib/boosted_trees/lib/trees/decision_tree.h" +#include "tensorflow/contrib/boosted_trees/lib/utils/batch_features.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace boosted_trees { +namespace trees { +namespace { + +class DecisionTreeTest : public ::testing::Test { + protected: + DecisionTreeTest() : batch_features_(2) { + // Create a batch of two examples having one dense float, two sparse float + // and one sparse int features. + // The first example is missing the second sparse feature column and the + // second example is missing the first sparse feature column. + // This looks like the following: + // Instance | DenseF1 | SparseF1 | SparseF2 | SparseI1 | + // 0 | 7 | -3 | | 3 | + // 1 | -2 | | 4 | | + auto dense_float_matrix = test::AsTensor({7.0f, -2.0f}, {2, 1}); + auto sparse_float_indices1 = test::AsTensor({0, 0}, {1, 2}); + auto sparse_float_values1 = test::AsTensor({-3.0f}); + auto sparse_float_shape1 = test::AsTensor({2, 1}); + auto sparse_float_indices2 = test::AsTensor({1, 0}, {1, 2}); + auto sparse_float_values2 = test::AsTensor({4.0f}); + auto sparse_float_shape2 = test::AsTensor({2, 1}); + auto sparse_int_indices1 = test::AsTensor({0, 0}, {1, 2}); + auto sparse_int_values1 = test::AsTensor({3}); + auto sparse_int_shape1 = test::AsTensor({2, 1}); + TF_EXPECT_OK(batch_features_.Initialize( + {dense_float_matrix}, {sparse_float_indices1, sparse_float_indices2}, + {sparse_float_values1, sparse_float_values2}, + {sparse_float_shape1, sparse_float_shape2}, {sparse_int_indices1}, + {sparse_int_values1}, {sparse_int_shape1})); + } + + template + void TestLinkChildrenBinary(TreeNode* node, SplitType* split) { + // Verify children were linked. + DecisionTree::LinkChildren({3, 8}, node); + EXPECT_EQ(3, split->left_id()); + EXPECT_EQ(8, split->right_id()); + + // Invalid cases. + EXPECT_DEATH(DecisionTree::LinkChildren({}, node), + "A binary split node must have exactly two children."); + EXPECT_DEATH(DecisionTree::LinkChildren({3}, node), + "A binary split node must have exactly two children."); + EXPECT_DEATH(DecisionTree::LinkChildren({1, 2, 3}, node), + "A binary split node must have exactly two children."); + } + + void TestGetChildren(const TreeNode& node, + const std::vector& expected_children) { + // Verify children were linked. + auto children = DecisionTree::GetChildren(node); + EXPECT_EQ(children.size(), expected_children.size()); + for (size_t idx = 0; idx < children.size(); ++idx) { + EXPECT_EQ(children[idx], expected_children[idx]); + } + } + + utils::BatchFeatures batch_features_; +}; + +TEST_F(DecisionTreeTest, TraverseEmpty) { + DecisionTreeConfig tree_config; + auto example = (*batch_features_.examples_iterable(0, 1).begin()); + EXPECT_EQ(-1, DecisionTree::Traverse(tree_config, 0, example)); +} + +TEST_F(DecisionTreeTest, TraverseBias) { + DecisionTreeConfig tree_config; + tree_config.add_nodes()->mutable_leaf(); + auto example = (*batch_features_.examples_iterable(0, 1).begin()); + EXPECT_EQ(0, DecisionTree::Traverse(tree_config, 0, example)); +} + +TEST_F(DecisionTreeTest, TraverseInvalidSubRoot) { + DecisionTreeConfig tree_config; + tree_config.add_nodes()->mutable_leaf(); + auto example = (*batch_features_.examples_iterable(0, 1).begin()); + EXPECT_EQ(-1, DecisionTree::Traverse(tree_config, 10, example)); +} + +TEST_F(DecisionTreeTest, TraverseDenseBinarySplit) { + DecisionTreeConfig tree_config; + auto* split_node = + tree_config.add_nodes()->mutable_dense_float_binary_split(); + split_node->set_feature_column(0); + split_node->set_threshold(0.0f); + split_node->set_left_id(1); + split_node->set_right_id(2); + tree_config.add_nodes()->mutable_leaf(); + tree_config.add_nodes()->mutable_leaf(); + auto example_iterable = batch_features_.examples_iterable(0, 2); + + // Expect right child to be picked as !(7 <= 0); + auto example_it = example_iterable.begin(); + EXPECT_EQ(2, DecisionTree::Traverse(tree_config, 0, *example_it)); + + // Expect left child to be picked as (-2 <= 0); + EXPECT_EQ(1, DecisionTree::Traverse(tree_config, 0, *++example_it)); +} + +TEST_F(DecisionTreeTest, TraverseSparseBinarySplit) { + // Test first sparse feature which is missing for the second example. + DecisionTreeConfig tree_config1; + auto* split_node1 = tree_config1.add_nodes() + ->mutable_sparse_float_binary_split_default_left() + ->mutable_split(); + split_node1->set_feature_column(0); + split_node1->set_threshold(-20.0f); + split_node1->set_left_id(1); + split_node1->set_right_id(2); + tree_config1.add_nodes()->mutable_leaf(); + tree_config1.add_nodes()->mutable_leaf(); + auto example_iterable = batch_features_.examples_iterable(0, 2); + + // Expect right child to be picked as !(-3 <= -20). + auto example_it = example_iterable.begin(); + EXPECT_EQ(2, DecisionTree::Traverse(tree_config1, 0, *example_it)); + + // Expect left child to be picked as default direction. + EXPECT_EQ(1, DecisionTree::Traverse(tree_config1, 0, *++example_it)); + + // Test second sparse feature which is missing for the first example. + DecisionTreeConfig tree_config2; + auto* split_node2 = tree_config2.add_nodes() + ->mutable_sparse_float_binary_split_default_right() + ->mutable_split(); + split_node2->set_feature_column(1); + split_node2->set_threshold(4.0f); + split_node2->set_left_id(1); + split_node2->set_right_id(2); + tree_config2.add_nodes()->mutable_leaf(); + tree_config2.add_nodes()->mutable_leaf(); + + // Expect right child to be picked as default direction. + example_it = example_iterable.begin(); + EXPECT_EQ(2, DecisionTree::Traverse(tree_config2, 0, *example_it)); + + // Expect left child to be picked as (4 <= 4). + EXPECT_EQ(1, DecisionTree::Traverse(tree_config2, 0, *++example_it)); +} + +TEST_F(DecisionTreeTest, TraverseCategoricalIdBinarySplit) { + DecisionTreeConfig tree_config; + auto* split_node = + tree_config.add_nodes()->mutable_categorical_id_binary_split(); + split_node->set_feature_column(0); + split_node->set_feature_id(3); + split_node->set_left_id(1); + split_node->set_right_id(2); + tree_config.add_nodes()->mutable_leaf(); + tree_config.add_nodes()->mutable_leaf(); + auto example_iterable = batch_features_.examples_iterable(0, 2); + + // Expect left child to be picked as 3 == 3; + auto example_it = example_iterable.begin(); + EXPECT_EQ(1, DecisionTree::Traverse(tree_config, 0, *example_it)); + + // Expect right child to be picked as the feature is missing; + EXPECT_EQ(2, DecisionTree::Traverse(tree_config, 0, *++example_it)); +} + +TEST_F(DecisionTreeTest, TraverseHybridSplits) { + DecisionTreeConfig tree_config; + auto* split_node1 = + tree_config.add_nodes()->mutable_dense_float_binary_split(); + split_node1->set_feature_column(0); + split_node1->set_threshold(9.0f); + split_node1->set_left_id(1); // sparse split. + split_node1->set_right_id(2); // leaf + auto* split_node2 = tree_config.add_nodes() + ->mutable_sparse_float_binary_split_default_left() + ->mutable_split(); + tree_config.add_nodes()->mutable_leaf(); + split_node2->set_feature_column(0); + split_node2->set_threshold(-20.0f); + split_node2->set_left_id(3); + split_node2->set_right_id(4); + auto* split_node3 = + tree_config.add_nodes()->mutable_categorical_id_binary_split(); + split_node3->set_feature_column(0); + split_node3->set_feature_id(2); + split_node3->set_left_id(5); + split_node3->set_right_id(6); + tree_config.add_nodes()->mutable_leaf(); + tree_config.add_nodes()->mutable_leaf(); + tree_config.add_nodes()->mutable_leaf(); + auto example_iterable = batch_features_.examples_iterable(0, 2); + + // Expect will go left through the first dense split as (7.0f <= 9.0f), + // then will go right through the sparse split as !(-3 <= -20). + auto example_it = example_iterable.begin(); + EXPECT_EQ(4, DecisionTree::Traverse(tree_config, 0, *example_it)); + + // Expect will go left through the first dense split as (-2.0f <= 9.0f), + // then will go left the default direction as the sparse feature is missing, + // then will go right as 2 != 3 on the categorical split. + EXPECT_EQ(6, DecisionTree::Traverse(tree_config, 0, *++example_it)); +} + +TEST_F(DecisionTreeTest, LinkChildrenLeaf) { + // Create leaf node. + TreeNode node; + node.mutable_leaf(); + + // No-op. + DecisionTree::LinkChildren({}, &node); + + // Invalid case. + EXPECT_DEATH(DecisionTree::LinkChildren({1}, &node), + "A leaf node cannot have children."); +} + +TEST_F(DecisionTreeTest, LinkChildrenDenseFloatBinarySplit) { + TreeNode node; + auto* split = node.mutable_dense_float_binary_split(); + split->set_left_id(-1); + split->set_right_id(-1); + TestLinkChildrenBinary(&node, split); +} + +TEST_F(DecisionTreeTest, LinkChildrenSparseFloatBinarySplitDefaultLeft) { + TreeNode node; + auto* split = + node.mutable_sparse_float_binary_split_default_left()->mutable_split(); + split->set_left_id(-1); + split->set_right_id(-1); + TestLinkChildrenBinary(&node, split); +} + +TEST_F(DecisionTreeTest, LinkChildrenSparseFloatBinarySplitDefaultRight) { + TreeNode node; + auto* split = + node.mutable_sparse_float_binary_split_default_right()->mutable_split(); + split->set_left_id(-1); + split->set_right_id(-1); + TestLinkChildrenBinary(&node, split); +} + +TEST_F(DecisionTreeTest, LinkChildrenCategoricalSingleIdBinarySplit) { + TreeNode node; + auto* split = node.mutable_categorical_id_binary_split(); + split->set_left_id(-1); + split->set_right_id(-1); + TestLinkChildrenBinary(&node, split); +} + +TEST_F(DecisionTreeTest, LinkChildrenNodeNotSet) { + // Create unset node. + TreeNode node; + + // Invalid case. + EXPECT_DEATH(DecisionTree::LinkChildren({1}, &node), + "A non-set node cannot have children."); +} + +TEST_F(DecisionTreeTest, GetChildrenLeaf) { + TreeNode node; + node.mutable_leaf(); + TestGetChildren(node, {}); +} + +TEST_F(DecisionTreeTest, GetChildrenDenseFloatBinarySplit) { + TreeNode node; + auto* split = node.mutable_dense_float_binary_split(); + split->set_left_id(23); + split->set_right_id(24); + TestGetChildren(node, {23, 24}); +} + +TEST_F(DecisionTreeTest, GetChildrenSparseFloatBinarySplitDefaultLeft) { + TreeNode node; + auto* split = + node.mutable_sparse_float_binary_split_default_left()->mutable_split(); + split->set_left_id(12); + split->set_right_id(13); + TestGetChildren(node, {12, 13}); +} + +TEST_F(DecisionTreeTest, GetChildrenSparseFloatBinarySplitDefaultRight) { + TreeNode node; + auto* split = + node.mutable_sparse_float_binary_split_default_right()->mutable_split(); + split->set_left_id(1); + split->set_right_id(2); + TestGetChildren(node, {1, 2}); +} + +TEST_F(DecisionTreeTest, GetChildrenCategoricalSingleIdBinarySplit) { + TreeNode node; + auto* split = node.mutable_categorical_id_binary_split(); + split->set_left_id(7); + split->set_right_id(8); + TestGetChildren(node, {7, 8}); +} + +TEST_F(DecisionTreeTest, GetChildrenNodeNotSet) { + TreeNode node; + TestGetChildren(node, {}); +} + +} // namespace +} // namespace trees +} // namespace boosted_trees +} // namespace tensorflow diff --git a/tensorflow/contrib/boosted_trees/proto/BUILD b/tensorflow/contrib/boosted_trees/proto/BUILD index 3b6b0339d2e..c99d8849bd5 100644 --- a/tensorflow/contrib/boosted_trees/proto/BUILD +++ b/tensorflow/contrib/boosted_trees/proto/BUILD @@ -24,6 +24,15 @@ tf_proto_library( visibility = ["//visibility:public"], ) +tf_proto_library( + name = "quantiles_proto", + srcs = [ + "quantiles.proto", + ], + cc_api_version = 2, + visibility = ["//visibility:public"], +) + tf_proto_library( name = "tree_config_proto", srcs = ["tree_config.proto"], diff --git a/tensorflow/contrib/boosted_trees/proto/quantiles.proto b/tensorflow/contrib/boosted_trees/proto/quantiles.proto new file mode 100644 index 00000000000..7f872d2aa71 --- /dev/null +++ b/tensorflow/contrib/boosted_trees/proto/quantiles.proto @@ -0,0 +1,32 @@ +syntax = "proto3"; + +option cc_enable_arenas = true; + +package boosted_trees; + +message QuantileConfig { + // Maximum eps error when computing quantile summaries. + double eps = 1; + // Number of quantiles to generate. + int64 num_quantiles = 2; +} + +message QuantileEntry { + // Value for the entry. + float value = 1; + // Weight for the entry. + float weight = 2; + // We need the minimum and maximum rank possible for this entry. + // Rank is 0.0 for the absolute minimum and sum of the weights for the maximum + // value in the input. + float min_rank = 3; + float max_rank = 4; +} + +message QuantileSummaryState { + repeated QuantileEntry entries = 1; +} + +message QuantileStreamState { + repeated QuantileSummaryState summaries = 1; +} diff --git a/tensorflow/contrib/boosted_trees/resources/BUILD b/tensorflow/contrib/boosted_trees/resources/BUILD new file mode 100644 index 00000000000..5dfdf8f4896 --- /dev/null +++ b/tensorflow/contrib/boosted_trees/resources/BUILD @@ -0,0 +1,53 @@ +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +package( + default_visibility = [ + "//tensorflow/contrib/boosted_trees:__subpackages__", + "//tensorflow/contrib/boosted_trees:friends", + ], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) + +cc_library( + name = "stamped_resource", + hdrs = ["stamped_resource.h"], + deps = [ + "//tensorflow/core:framework_headers_lib", + "//third_party/eigen3", + ], +) + +cc_library( + name = "quantile_stream_resource", + hdrs = ["quantile_stream_resource.h"], + deps = [ + ":stamped_resource", + "//tensorflow/contrib/boosted_trees/lib:weighted_quantiles", + "//tensorflow/contrib/boosted_trees/proto:quantiles_proto_cc", + "//tensorflow/core:framework_headers_lib", + "//third_party/eigen3", + ], +) + +cc_library( + name = "decision_tree_ensemble_resource", + hdrs = ["decision_tree_ensemble_resource.h"], + deps = [ + ":stamped_resource", + "//tensorflow/contrib/boosted_trees/lib:trees", + "//tensorflow/core:framework_headers_lib", + ], + alwayslink = 1, +) diff --git a/tensorflow/contrib/boosted_trees/resources/decision_tree_ensemble_resource.h b/tensorflow/contrib/boosted_trees/resources/decision_tree_ensemble_resource.h new file mode 100644 index 00000000000..45c3bbadfc8 --- /dev/null +++ b/tensorflow/contrib/boosted_trees/resources/decision_tree_ensemble_resource.h @@ -0,0 +1,77 @@ +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_RESOURCES_DECISION_TREE_ENSEMBLE_RESOURCE_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_RESOURCES_DECISION_TREE_ENSEMBLE_RESOURCE_H_ + +#include "tensorflow/contrib/boosted_trees/lib/trees/decision_tree.h" +#include "tensorflow/contrib/boosted_trees/resources/stamped_resource.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/protobuf.h" + +namespace tensorflow { +namespace boosted_trees { +namespace models { + +// Keep a tree ensemble in memory for efficient evaluation and mutation. +class DecisionTreeEnsembleResource : public StampedResource { + public: + // Constructor. + explicit DecisionTreeEnsembleResource() + : decision_tree_ensemble_( + protobuf::Arena::CreateMessage< + boosted_trees::trees::DecisionTreeEnsembleConfig>(&arena_)) {} + + string DebugString() override { + return strings::StrCat("GTFlowDecisionTreeEnsemble[size=", + decision_tree_ensemble_->trees_size(), "]"); + } + + const boosted_trees::trees::DecisionTreeEnsembleConfig& + decision_tree_ensemble() const { + return *decision_tree_ensemble_; + } + + boosted_trees::trees::DecisionTreeEnsembleConfig* + mutable_decision_tree_ensemble() { + return decision_tree_ensemble_; + } + + // Resets the resource and frees the protos in arena. + // Caller needs to hold the mutex lock while calling this. + void Reset() { + // Reset stamp. + set_stamp(-1); + + // Clear tree ensemle. + arena_.Reset(); + CHECK_EQ(0, arena_.SpaceAllocated()); + decision_tree_ensemble_ = protobuf::Arena::CreateMessage< + boosted_trees::trees::DecisionTreeEnsembleConfig>(&arena_); + } + + mutex* get_mutex() { return &mu_; } + + private: + protobuf::Arena arena_; + mutex mu_; + boosted_trees::trees::DecisionTreeEnsembleConfig* decision_tree_ensemble_; +}; + +} // namespace models +} // namespace boosted_trees +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_RESOURCES_DECISION_TREE_ENSEMBLE_RESOURCE_H_ diff --git a/tensorflow/contrib/boosted_trees/resources/quantile_stream_resource.h b/tensorflow/contrib/boosted_trees/resources/quantile_stream_resource.h new file mode 100644 index 00000000000..fb29f79e578 --- /dev/null +++ b/tensorflow/contrib/boosted_trees/resources/quantile_stream_resource.h @@ -0,0 +1,104 @@ +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_RESOURCES_QUANTILE_STREAM_RESOURCE_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_RESOURCES_QUANTILE_STREAM_RESOURCE_H_ + +#include "tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_stream.h" +#include "tensorflow/contrib/boosted_trees/proto/quantiles.pb.h" // NOLINT +#include "tensorflow/contrib/boosted_trees/resources/stamped_resource.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/mutex.h" + +namespace tensorflow { +namespace boosted_trees { + +using QuantileStream = + boosted_trees::quantiles::WeightedQuantilesStream; + +// Resource for accumulating summaries for multiple columns. +class QuantileStreamResource : public StampedResource { + public: + QuantileStreamResource(const float epsilon, const int32 num_quantiles, + const int64 max_elements, int64 stamp_token) + : stream_(epsilon, max_elements), + are_buckets_ready_(false), + epsilon_(epsilon), + num_quantiles_(num_quantiles), + max_elements_(max_elements) { + set_stamp(stamp_token); + } + + string DebugString() override { return "QuantileStreamResource"; } + + tensorflow::mutex* mutex() { return &mu_; } + + QuantileStream* stream(int64 stamp) { + CHECK(is_stamp_valid(stamp)); + return &stream_; + } + + const std::vector& boundaries(int64 stamp) { + CHECK(is_stamp_valid(stamp)); + return boundaries_; + } + + void set_boundaries(int64 stamp, const std::vector& boundaries) { + CHECK(is_stamp_valid(stamp)); + are_buckets_ready_ = true; + boundaries_ = boundaries; + } + + float epsilon() const { return epsilon_; } + int32 num_quantiles() const { return num_quantiles_; } + + void Reset(int64 stamp) { + set_stamp(stamp); + stream_ = QuantileStream(epsilon_, max_elements_); + } + + bool are_buckets_ready() const { return are_buckets_ready_; } + void set_buckets_ready(bool are_buckets_ready) { + are_buckets_ready_ = are_buckets_ready; + } + + private: + ~QuantileStreamResource() override {} + + // Mutex for the whole resource. + tensorflow::mutex mu_; + + // Quantile stream. + QuantileStream stream_; + + // Stores the boundaries from the previous iteration. Empty during the first + // iteration. + std::vector boundaries_; + + // Whether boundaries are created. Initially boundaries are empty until + // set_boundaries are called. + bool are_buckets_ready_; + + const float epsilon_; + const int32 num_quantiles_; + // An upper-bound for the number of elements. + int64 max_elements_; + TF_DISALLOW_COPY_AND_ASSIGN(QuantileStreamResource); +}; + +} // namespace boosted_trees +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_RESOURCES_QUANTILE_STREAM_RESOURCE_H_ diff --git a/tensorflow/contrib/boosted_trees/resources/stamped_resource.h b/tensorflow/contrib/boosted_trees/resources/stamped_resource.h new file mode 100644 index 00000000000..aabeeb98516 --- /dev/null +++ b/tensorflow/contrib/boosted_trees/resources/stamped_resource.h @@ -0,0 +1,42 @@ +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_RESOURCES_STAMPED_RESOURCE_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_RESOURCES_STAMPED_RESOURCE_H_ + +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/platform/mutex.h" + +namespace tensorflow { +namespace boosted_trees { + +// A StampedResource is a resource that has a stamp token associated with it. +// Before reading from or applying updates to the resource, the stamp should +// be checked to verify that the update is not stale. +class StampedResource : public ResourceBase { + public: + StampedResource() : stamp_(-1) {} + + bool is_stamp_valid(int64 stamp) const { return stamp_ == stamp; } + + int64 stamp() const { return stamp_; } + void set_stamp(int64 stamp) { stamp_ = stamp; } + + private: + int64 stamp_; +}; + +} // namespace boosted_trees +} // namespace tensorflow +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_RESOURCES_STAMPED_RESOURCE_H_ diff --git a/tensorflow/contrib/cmake/tf_core_kernels.cmake b/tensorflow/contrib/cmake/tf_core_kernels.cmake index 33384eed480..0663ba16379 100644 --- a/tensorflow/contrib/cmake/tf_core_kernels.cmake +++ b/tensorflow/contrib/cmake/tf_core_kernels.cmake @@ -37,6 +37,9 @@ if(tensorflow_BUILD_CONTRIB_KERNELS) "${tensorflow_source_dir}/tensorflow/contrib/layers/kernels/sparse_feature_cross_kernel.cc" "${tensorflow_source_dir}/tensorflow/contrib/layers/ops/bucketization_op.cc" "${tensorflow_source_dir}/tensorflow/contrib/layers/ops/sparse_feature_cross_op.cc" + "${tensorflow_source_dir}/tensorflow/contrib/nccl/kernels/nccl_manager.cc" + "${tensorflow_source_dir}/tensorflow/contrib/nccl/kernels/nccl_ops.cc" + "${tensorflow_source_dir}/tensorflow/contrib/nccl/ops/nccl_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/rnn/kernels/blas_gemm.cc" "${tensorflow_source_dir}/tensorflow/contrib/rnn/kernels/gru_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/rnn/kernels/lstm_ops.cc" diff --git a/tensorflow/contrib/cmake/tf_core_ops.cmake b/tensorflow/contrib/cmake/tf_core_ops.cmake index 4e300056295..126ef6c00c2 100644 --- a/tensorflow/contrib/cmake/tf_core_ops.cmake +++ b/tensorflow/contrib/cmake/tf_core_ops.cmake @@ -58,6 +58,7 @@ GENERATE_CONTRIB_OP_LIBRARY(image "${tensorflow_source_dir}/tensorflow/contrib/i GENERATE_CONTRIB_OP_LIBRARY(layers_bucketization "${tensorflow_source_dir}/tensorflow/contrib/layers/ops/bucketization_op.cc") GENERATE_CONTRIB_OP_LIBRARY(layers_sparse_feature_cross "${tensorflow_source_dir}/tensorflow/contrib/layers/ops/sparse_feature_cross_op.cc") GENERATE_CONTRIB_OP_LIBRARY(memory_stats "${tensorflow_source_dir}/tensorflow/contrib/memory_stats/ops/memory_stats_ops.cc") +GENERATE_CONTRIB_OP_LIBRARY(nccl "${tensorflow_source_dir}/tensorflow/contrib/nccl/ops/nccl_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(rnn_gru "${tensorflow_source_dir}/tensorflow/contrib/rnn/ops/gru_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(rnn_lstm "${tensorflow_source_dir}/tensorflow/contrib/rnn/ops/lstm_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(tensor_forest "${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/ops/tensor_forest_ops.cc") diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake index 02038da7f85..39fbf603f0e 100755 --- a/tensorflow/contrib/cmake/tf_python.cmake +++ b/tensorflow/contrib/cmake/tf_python.cmake @@ -111,6 +111,7 @@ file(GLOB_RECURSE tf_protos_python_srcs RELATIVE ${tensorflow_source_dir} "${tensorflow_source_dir}/tensorflow/core/*.proto" "${tensorflow_source_dir}/tensorflow/python/*.proto" "${tensorflow_source_dir}/tensorflow/contrib/session_bundle/*.proto" + "${tensorflow_source_dir}/tensorflow/tensorboard/*.proto" "${tensorflow_source_dir}/tensorflow/contrib/tensorboard/*.proto" "${tensorflow_source_dir}/tensorflow/contrib/training/*.proto" ) @@ -124,6 +125,7 @@ RELATIVE_PROTOBUF_GENERATE_PYTHON( file(GLOB_RECURSE tf_python_protos_cc_srcs RELATIVE ${tensorflow_source_dir} "${tensorflow_source_dir}/tensorflow/python/*.proto" "${tensorflow_source_dir}/tensorflow/contrib/session_bundle/*.proto" + "${tensorflow_source_dir}/tensorflow/tensorboard/*.proto" "${tensorflow_source_dir}/tensorflow/contrib/tensorboard/*.proto" "${tensorflow_source_dir}/tensorflow/contrib/training/*.proto" ) @@ -342,6 +344,9 @@ add_python_module("tensorflow/contrib/keras/python/keras/layers") add_python_module("tensorflow/contrib/keras/python/keras/preprocessing") add_python_module("tensorflow/contrib/keras/python/keras/utils") add_python_module("tensorflow/contrib/keras/python/keras/wrappers") +add_python_module("tensorflow/contrib/kernel_methods") +add_python_module("tensorflow/contrib/kernel_methods/python") +add_python_module("tensorflow/contrib/kernel_methods/python/mappers") add_python_module("tensorflow/contrib/labeled_tensor") add_python_module("tensorflow/contrib/labeled_tensor/python") add_python_module("tensorflow/contrib/labeled_tensor/python/ops") @@ -405,6 +410,11 @@ add_python_module("tensorflow/contrib/ndlstm/python") add_python_module("tensorflow/contrib/nn") add_python_module("tensorflow/contrib/nn/python") add_python_module("tensorflow/contrib/nn/python/ops") +add_python_module("tensorflow/contrib/nccl") +add_python_module("tensorflow/contrib/nccl/kernels") +add_python_module("tensorflow/contrib/nccl/ops") +add_python_module("tensorflow/contrib/nccl/python") +add_python_module("tensorflow/contrib/nccl/python/ops") add_python_module("tensorflow/contrib/opt") add_python_module("tensorflow/contrib/opt/python") add_python_module("tensorflow/contrib/opt/python/training") @@ -599,6 +609,8 @@ GENERATE_PYTHON_OP_LIB("contrib_layers_sparse_feature_cross_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/layers/ops/gen_sparse_feature_cross_op.py) GENERATE_PYTHON_OP_LIB("contrib_memory_stats_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/memory_stats/ops/gen_memory_stats_ops.py) +GENERATE_PYTHON_OP_LIB("contrib_nccl_ops" + DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/nccl/ops/gen_nccl_ops.py) GENERATE_PYTHON_OP_LIB("contrib_rnn_gru_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/rnn/ops/gen_gru_ops.py) GENERATE_PYTHON_OP_LIB("contrib_rnn_lstm_ops" diff --git a/tensorflow/contrib/distributions/python/kernel_tests/binomial_test.py b/tensorflow/contrib/distributions/python/kernel_tests/binomial_test.py index f34d7a1fd4d..c473d54f47a 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/binomial_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/binomial_test.py @@ -69,19 +69,25 @@ class BinomialTest(test.TestCase): self.assertEqual((1, 3), binom.logits.get_shape()) self.assertAllClose(logits, binom.logits.eval()) - def testPmfNandCountsAgree(self): + def testPmfAndCdfNandCountsAgree(self): p = [[0.1, 0.2, 0.7]] n = [[5.]] with self.test_session(): binom = binomial.Binomial(total_count=n, probs=p, validate_args=True) binom.prob([2., 3, 2]).eval() binom.prob([3., 1, 2]).eval() + binom.cdf([2., 3, 2]).eval() + binom.cdf([3., 1, 2]).eval() with self.assertRaisesOpError("Condition x >= 0.*"): binom.prob([-1., 4, 2]).eval() with self.assertRaisesOpError("Condition x <= y.*"): binom.prob([7., 3, 0]).eval() + with self.assertRaisesOpError("Condition x >= 0.*"): + binom.cdf([-1., 4, 2]).eval() + with self.assertRaisesOpError("Condition x <= y.*"): + binom.cdf([7., 3, 0]).eval() - def testPmfNonIntegerCounts(self): + def testPmfAndCdfNonIntegerCounts(self): p = [[0.1, 0.2, 0.7]] n = [[5.]] with self.test_session(): @@ -89,50 +95,72 @@ class BinomialTest(test.TestCase): binom = binomial.Binomial(total_count=n, probs=p, validate_args=True) binom.prob([2., 3, 2]).eval() binom.prob([3., 1, 2]).eval() + binom.cdf([2., 3, 2]).eval() + binom.cdf([3., 1, 2]).eval() # Both equality and integer checking fail. with self.assertRaisesOpError( "cannot contain fractional components."): binom.prob([1.0, 2.5, 1.5]).eval() + with self.assertRaisesOpError( + "cannot contain fractional components."): + binom.cdf([1.0, 2.5, 1.5]).eval() binom = binomial.Binomial(total_count=n, probs=p, validate_args=False) binom.prob([1., 2., 3.]).eval() + binom.cdf([1., 2., 3.]).eval() # Non-integer arguments work. binom.prob([1.0, 2.5, 1.5]).eval() + binom.cdf([1.0, 2.5, 1.5]).eval() - def testPmfBothZeroBatches(self): + def testPmfAndCdfBothZeroBatches(self): with self.test_session(): # Both zero-batches. No broadcast p = 0.5 counts = 1. - pmf = binomial.Binomial(total_count=1., probs=p).prob(counts) + binom = binomial.Binomial(total_count=1., probs=p) + pmf = binom.prob(counts) + cdf = binom.cdf(counts) self.assertAllClose(0.5, pmf.eval()) + self.assertAllClose(stats.binom.cdf(counts, n=1, p=p), cdf.eval()) self.assertEqual((), pmf.get_shape()) + self.assertEqual((), cdf.get_shape()) - def testPmfBothZeroBatchesNontrivialN(self): + def testPmfAndCdfBothZeroBatchesNontrivialN(self): with self.test_session(): # Both zero-batches. No broadcast p = 0.1 counts = 3. binom = binomial.Binomial(total_count=5., probs=p) pmf = binom.prob(counts) + cdf = binom.cdf(counts) self.assertAllClose(stats.binom.pmf(counts, n=5., p=p), pmf.eval()) + self.assertAllClose(stats.binom.cdf(counts, n=5., p=p), cdf.eval()) self.assertEqual((), pmf.get_shape()) + self.assertEqual((), cdf.get_shape()) - def testPmfPStretchedInBroadcastWhenSameRank(self): + def testPmfAndCdfPStretchedInBroadcastWhenSameRank(self): with self.test_session(): p = [[0.1, 0.9]] counts = [[1., 2.]] - pmf = binomial.Binomial(total_count=3., probs=p).prob(counts) + binom = binomial.Binomial(total_count=3., probs=p) + pmf = binom.prob(counts) + cdf = binom.cdf(counts) self.assertAllClose(stats.binom.pmf(counts, n=3., p=p), pmf.eval()) + self.assertAllClose(stats.binom.cdf(counts, n=3., p=p), cdf.eval()) self.assertEqual((1, 2), pmf.get_shape()) + self.assertEqual((1, 2), cdf.get_shape()) - def testPmfPStretchedInBroadcastWhenLowerRank(self): + def testPmfAndCdfPStretchedInBroadcastWhenLowerRank(self): with self.test_session(): p = [0.1, 0.4] counts = [[1.], [0.]] - pmf = binomial.Binomial(total_count=1., probs=p).prob(counts) + binom = binomial.Binomial(total_count=1., probs=p) + pmf = binom.prob(counts) + cdf = binom.cdf(counts) self.assertAllClose([[0.1, 0.4], [0.9, 0.6]], pmf.eval()) + self.assertAllClose([[1.0, 1.0], [0.9, 0.6]], cdf.eval()) self.assertEqual((2, 2), pmf.get_shape()) + self.assertEqual((2, 2), cdf.get_shape()) def testBinomialMean(self): with self.test_session(): diff --git a/tensorflow/contrib/distributions/python/kernel_tests/mvn_diag_test.py b/tensorflow/contrib/distributions/python/kernel_tests/mvn_diag_test.py index aa9d45f151d..406cd4ebbea 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/mvn_diag_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/mvn_diag_test.py @@ -103,6 +103,31 @@ class MultivariateNormalDiagTest(test.TestCase): self.assertAllClose(cov_mat, np.cov(samps.T), atol=0.05, rtol=0.05) + def testSampleWithBroadcastScale(self): + # mu corresponds to a 2-batch of 3-variate normals + mu = np.zeros([2, 3]) + + # diag corresponds to no batches of 3-variate normals + diag = np.ones([3]) + + with self.test_session(): + dist = ds.MultivariateNormalDiag(mu, diag, validate_args=True) + + mean = dist.mean() + self.assertAllEqual([2, 3], mean.get_shape()) + self.assertAllClose(mu, mean.eval()) + + n = int(1e3) + samps = dist.sample(n, seed=0).eval() + cov_mat = array_ops.matrix_diag(diag).eval()**2 + sample_cov = np.matmul(samps.transpose([1, 2, 0]), + samps.transpose([1, 0, 2])) / n + + self.assertAllClose(mu, samps.mean(axis=0), + atol=0.10, rtol=0.05) + self.assertAllClose([cov_mat, cov_mat], sample_cov, + atol=0.10, rtol=0.05) + def testCovariance(self): with self.test_session(): mvn = ds.MultivariateNormalDiag( diff --git a/tensorflow/contrib/distributions/python/ops/binomial.py b/tensorflow/contrib/distributions/python/ops/binomial.py index ee7751c9969..78eb44acb2c 100644 --- a/tensorflow/contrib/distributions/python/ops/binomial.py +++ b/tensorflow/contrib/distributions/python/ops/binomial.py @@ -42,6 +42,28 @@ to integer values. """ +def _bdtr(k, n, p): + """The binomial cumulative distribution function. + + Args: + k: floating point `Tensor`. + n: floating point `Tensor`. + p: floating point `Tensor`. + + Returns: + `sum_{j=0}^k p^j (1 - p)^(n - j)`. + """ + # Trick for getting safe backprop/gradients into n, k when + # betainc(a = 0, ..) = nan + # Write: + # where(unsafe, safe_output, betainc(where(unsafe, safe_input, input))) + ones = array_ops.ones_like(n - k) + k_eq_n = math_ops.equal(k, n) + safe_dn = array_ops.where(k_eq_n, ones, n - k) + dk = math_ops.betainc(a=safe_dn, b=k + 1, x=1 - p) + return array_ops.where(k_eq_n, ones, dk) + + class Binomial(distribution.Distribution): """Binomial distribution. @@ -201,6 +223,18 @@ class Binomial(distribution.Distribution): def _prob(self, counts): return math_ops.exp(self._log_prob(counts)) + def _cdf(self, counts): + counts = self._maybe_assert_valid_sample(counts) + probs = self.probs + if not (counts.shape.is_fully_defined() + and self.probs.shape.is_fully_defined() + and counts.shape.is_compatible_with(self.probs.shape)): + # If both shapes are well defined and equal, we skip broadcasting. + probs += array_ops.zeros_like(counts) + counts += array_ops.zeros_like(self.probs) + + return _bdtr(k=counts, n=self.total_count, p=probs) + def _log_unnormalized_prob(self, counts): counts = self._maybe_assert_valid_sample(counts) return (counts * math_ops.log(self.probs) + diff --git a/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py b/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py index 734648d2d64..3bb6bb4af2a 100644 --- a/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py +++ b/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py @@ -25,6 +25,7 @@ from tensorflow.contrib.distributions.python.ops import kullback_leibler from tensorflow.contrib.distributions.python.ops import normal from tensorflow.contrib.distributions.python.ops import transformed_distribution from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import linalg_ops @@ -53,6 +54,16 @@ or """ +def _broadcast_shape(shape1, shape2): + """Convenience function which statically broadcasts shape when possible.""" + if (tensor_util.constant_value(shape1) is not None and + tensor_util.constant_value(shape2) is not None): + return array_ops.broadcast_static_shape( + tensor_shape.TensorShape(tensor_util.constant_value(shape1)), + tensor_shape.TensorShape(tensor_util.constant_value(shape2))) + return array_ops.broadcast_dynamic_shape(shape1, shape2) + + # TODO(b/35290280): Import in `../../__init__.py` after adding unit-tests. class MultivariateNormalLinearOperator( transformed_distribution.TransformedDistribution): @@ -179,12 +190,25 @@ class MultivariateNormalLinearOperator( if not scale.dtype.is_floating: raise TypeError("`scale` parameter must have floating-point dtype.") - # Since expand_dims doesn't preserve constant-ness, we obtain the - # non-dynamic value if possible. - event_shape = scale.domain_dimension_tensor() - if tensor_util.constant_value(event_shape) is not None: - event_shape = tensor_util.constant_value(event_shape) - event_shape = event_shape[array_ops.newaxis] + with ops.name_scope(name, values=[loc] + scale.graph_parents): + # Since expand_dims doesn't preserve constant-ness, we obtain the + # non-dynamic value if possible. + event_shape = scale.range_dimension_tensor() + if tensor_util.constant_value(event_shape) is not None: + event_shape = tensor_util.constant_value(event_shape).reshape([1]) + else: + event_shape = event_shape[array_ops.newaxis] + batch_shape = scale.batch_shape_tensor() + if loc is not None: + loc = ops.convert_to_tensor(loc, name="loc") + loc_batch_shape = loc.get_shape().with_rank_at_least(1)[:-1] + if (loc.get_shape().ndims is None or + not loc_batch_shape.is_fully_defined()): + loc_batch_shape = array_ops.shape(loc)[:-1] + else: + loc_batch_shape = ops.convert_to_tensor(loc_batch_shape, + name="loc_batch_shape") + batch_shape = _broadcast_shape(batch_shape, loc_batch_shape) super(MultivariateNormalLinearOperator, self).__init__( distribution=normal.Normal( @@ -192,7 +216,7 @@ class MultivariateNormalLinearOperator( scale=array_ops.ones([], dtype=scale.dtype)), bijector=bijectors.AffineLinearOperator( shift=loc, scale=scale, validate_args=validate_args), - batch_shape=scale.batch_shape_tensor(), + batch_shape=batch_shape, event_shape=event_shape, validate_args=validate_args, name=name) diff --git a/tensorflow/contrib/factorization/BUILD b/tensorflow/contrib/factorization/BUILD index aa26c6060fb..7f359bea51c 100644 --- a/tensorflow/contrib/factorization/BUILD +++ b/tensorflow/contrib/factorization/BUILD @@ -35,6 +35,7 @@ tf_custom_op_py_library( ], srcs_version = "PY2AND3", deps = [ + ":factorization_ops_test_utils_py", ":gen_clustering_ops", ":gen_factorization_ops", "//tensorflow/contrib/framework:framework_py", @@ -161,12 +162,28 @@ tf_py_test( ], ) +py_library( + name = "factorization_ops_test_utils_py", + srcs = [ + "python/ops/factorization_ops_test_utils.py", + ], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:array_ops", + "//tensorflow/python:embedding_ops", + "//tensorflow/python:framework", + "//tensorflow/python:math_ops", + "//tensorflow/python:sparse_ops", + ], +) + tf_py_test( name = "factorization_ops_test", srcs = ["python/ops/factorization_ops_test.py"], additional_deps = [ ":factorization_py", ":factorization_py_CYCLIC_DEPENDENCIES_THAT_NEED_TO_GO", + ":factorization_ops_test_utils_py", "//third_party/py/numpy", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", diff --git a/tensorflow/contrib/factorization/python/ops/factorization_ops_test.py b/tensorflow/contrib/factorization/python/ops/factorization_ops_test.py index 40b8550ac83..bcee8818545 100644 --- a/tensorflow/contrib/factorization/python/ops/factorization_ops_test.py +++ b/tensorflow/contrib/factorization/python/ops/factorization_ops_test.py @@ -18,160 +18,56 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import random - import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.contrib.factorization.python.ops import factorization_ops -from tensorflow.python.framework import constant_op +from tensorflow.contrib.factorization.python.ops import factorization_ops_test_utils from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops from tensorflow.python.ops import embedding_ops from tensorflow.python.ops import math_ops -from tensorflow.python.ops import sparse_ops from tensorflow.python.platform import test -INPUT_MATRIX = np.array( - [[0.1, 0.0, 0.2, 0.0, 0.4, 0.5, 0.0], - [0.0, 1.1, 0.0, 1.3, 1.4, 0.0, 1.6], - [2.0, 0.0, 0.0, 2.3, 0.0, 2.5, 0.0], - [3.0, 0.0, 3.2, 3.3, 0.0, 3.5, 0.0], - [0.0, 4.1, 0.0, 0.0, 4.4, 0.0, 4.6]]).astype(np.float32) - -def np_matrix_to_tf_sparse(np_matrix, - row_slices=None, - col_slices=None, - transpose=False, - shuffle=False): - """Simple util to slice non-zero np matrix elements as tf.SparseTensor.""" - indices = np.nonzero(np_matrix) - - # Only allow slices of whole rows or whole columns. - assert not (row_slices is not None and col_slices is not None) - - if row_slices is not None: - selected_ind = np.concatenate( - [np.where(indices[0] == r)[0] for r in row_slices], 0) - indices = (indices[0][selected_ind], indices[1][selected_ind]) - - if col_slices is not None: - selected_ind = np.concatenate( - [np.where(indices[1] == c)[0] for c in col_slices], 0) - indices = (indices[0][selected_ind], indices[1][selected_ind]) - - if shuffle: - shuffled_ind = [x for x in range(len(indices[0]))] - random.shuffle(shuffled_ind) - indices = (indices[0][shuffled_ind], indices[1][shuffled_ind]) - - ind = (np.concatenate((np.expand_dims(indices[1], 1), - np.expand_dims(indices[0], 1)), 1).astype(np.int64) if - transpose else np.concatenate((np.expand_dims(indices[0], 1), - np.expand_dims(indices[1], 1)), - 1).astype(np.int64)) - val = np_matrix[indices].astype(np.float32) - shape = (np.array([max(indices[1]) + 1, max(indices[0]) + 1]).astype(np.int64) - if transpose else np.array( - [max(indices[0]) + 1, max(indices[1]) + 1]).astype(np.int64)) - return sparse_tensor.SparseTensor(ind, val, shape) - - -def sparse_input(): - return np_matrix_to_tf_sparse(INPUT_MATRIX) - - -def count_rows(sp_input): - return math_ops.cast( - array_ops.shape(array_ops.unique(sp_input.indices[:, 0])[0])[0], - dtypes.float32) - - -def count_cols(sp_input): - return math_ops.cast( - array_ops.shape(array_ops.unique(sp_input.indices[:, 1])[0])[0], - dtypes.float32) - - -def calculate_loss(input_mat, row_factors, col_factors, regularization=None, - w0=1., row_weights=None, col_weights=None): - """Calculates the loss of a given factorization. - - Using a non distributed method, different than the one implemented in the - WALS model. The weight of an observed entry (i, j) (i.e. such that - input_mat[i, j] is non zero) is (w0 + row_weights[i]col_weights[j]). - - Args: - input_mat: The input matrix, a SparseTensor of rank 2. - row_factors: The row factors, a dense Tensor of rank 2. - col_factors: The col factors, a dense Tensor of rank 2. - regularization: the regularization coefficient, a scalar. - w0: the weight of unobserved entries. A scalar. - row_weights: A dense tensor of rank 1. - col_weights: A dense tensor of rank 1. - - Returns: - The total loss. - """ - wr = (array_ops.expand_dims(row_weights, 1) if row_weights is not None - else constant_op.constant(1.)) - wc = (array_ops.expand_dims(col_weights, 0) if col_weights is not None - else constant_op.constant(1.)) - reg = (regularization if regularization is not None - else constant_op.constant(0.)) - - row_indices, col_indices = array_ops.split(input_mat.indices, - axis=1, - num_or_size_splits=2) - gathered_row_factors = array_ops.gather(row_factors, row_indices) - gathered_col_factors = array_ops.gather(col_factors, col_indices) - sp_approx_vals = array_ops.squeeze(math_ops.matmul( - gathered_row_factors, gathered_col_factors, adjoint_b=True)) - sp_approx = sparse_tensor.SparseTensor( - indices=input_mat.indices, - values=sp_approx_vals, - dense_shape=input_mat.dense_shape) - - sp_approx_sq = math_ops.square(sp_approx) - row_norm = math_ops.reduce_sum(math_ops.square(row_factors)) - col_norm = math_ops.reduce_sum(math_ops.square(col_factors)) - row_col_norm = math_ops.reduce_sum(math_ops.square(math_ops.matmul( - row_factors, col_factors, transpose_b=True))) - - resid = sparse_ops.sparse_add(input_mat, sp_approx * (-1)) - resid_sq = math_ops.square(resid) - loss = w0 * ( - sparse_ops.sparse_reduce_sum(resid_sq) - - sparse_ops.sparse_reduce_sum(sp_approx_sq) - ) - loss += (sparse_ops.sparse_reduce_sum(wr * (resid_sq * wc)) + - w0 * row_col_norm + reg * (row_norm + col_norm)) - return loss.eval() - - -def calculate_loss_from_wals_model(wals_model, sp_inputs): - current_rows = embedding_ops.embedding_lookup( - wals_model.row_factors, math_ops.range(wals_model._input_rows), - partition_strategy="div") - current_cols = embedding_ops.embedding_lookup( - wals_model.col_factors, math_ops.range(wals_model._input_cols), - partition_strategy="div") - row_wts = embedding_ops.embedding_lookup( - wals_model._row_weights, math_ops.range(wals_model._input_rows), - partition_strategy="div") - col_wts = embedding_ops.embedding_lookup( - wals_model._col_weights, math_ops.range(wals_model._input_cols), - partition_strategy="div") - return calculate_loss( - sp_inputs, current_rows, current_cols, wals_model._regularization, - wals_model._unobserved_weight, row_wts, col_wts) +INPUT_MATRIX = factorization_ops_test_utils.INPUT_MATRIX +np_matrix_to_tf_sparse = factorization_ops_test_utils.np_matrix_to_tf_sparse class WalsModelTest(test.TestCase): + def sparse_input(self): + return np_matrix_to_tf_sparse(INPUT_MATRIX) + + def count_rows(self, sp_input): + return math_ops.cast( + array_ops.shape(array_ops.unique(sp_input.indices[:, 0])[0])[0], + dtypes.float32) + + def count_cols(self, sp_input): + return math_ops.cast( + array_ops.shape(array_ops.unique(sp_input.indices[:, 1])[0])[0], + dtypes.float32) + + def calculate_loss_from_wals_model(self, wals_model, sp_inputs): + current_rows = embedding_ops.embedding_lookup( + wals_model.row_factors, math_ops.range(wals_model._input_rows), + partition_strategy="div") + current_cols = embedding_ops.embedding_lookup( + wals_model.col_factors, math_ops.range(wals_model._input_cols), + partition_strategy="div") + row_wts = embedding_ops.embedding_lookup( + wals_model._row_weights, math_ops.range(wals_model._input_rows), + partition_strategy="div") + col_wts = embedding_ops.embedding_lookup( + wals_model._col_weights, math_ops.range(wals_model._input_cols), + partition_strategy="div") + return factorization_ops_test_utils.calculate_loss( + sp_inputs, current_rows, current_cols, wals_model._regularization, + wals_model._unobserved_weight, row_wts, col_wts) + def setUp(self): self.col_init = [ # shard 0 @@ -208,7 +104,7 @@ class WalsModelTest(test.TestCase): use_factors_weights_cache, compute_loss=False): with ops.Graph().as_default(), self.test_session() as sess: - self._wals_inputs = sparse_input() + self._wals_inputs = self.sparse_input() sp_feeder = array_ops.sparse_placeholder(dtypes.float32) num_rows = 5 num_cols = 7 @@ -282,10 +178,10 @@ class WalsModelTest(test.TestCase): if compute_loss: # Test loss computation after the row update loss = sum( - sess.run(factor_loss * count_rows(inp) / num_rows, + sess.run(factor_loss * self.count_rows(inp) / num_rows, feed_dict={sp_feeder: inp}) for inp in input_scattered_rows) - true_loss = calculate_loss_from_wals_model( + true_loss = self.calculate_loss_from_wals_model( wals_model, self._wals_inputs) self.assertNear( loss, true_loss, err=.001, @@ -355,10 +251,10 @@ class WalsModelTest(test.TestCase): if compute_loss: # Test loss computation after the column update. loss = sum( - sess.run(factor_loss * count_cols(inp) / num_cols, + sess.run(factor_loss * self.count_cols(inp) / num_cols, feed_dict={sp_feeder: inp}) for inp in input_scattered_cols_non_duplicate) - true_loss = calculate_loss_from_wals_model( + true_loss = self.calculate_loss_from_wals_model( wals_model, self._wals_inputs) self.assertNear( loss, true_loss, err=.001, @@ -368,7 +264,7 @@ class WalsModelTest(test.TestCase): def _run_test_process_input_transposed(self, use_factors_weights_cache, compute_loss=False): with ops.Graph().as_default(), self.test_session() as sess: - self._wals_inputs = sparse_input() + self._wals_inputs = self.sparse_input() sp_feeder = array_ops.sparse_placeholder(dtypes.float32) num_rows = 5 num_cols = 7 @@ -448,10 +344,10 @@ class WalsModelTest(test.TestCase): if compute_loss: # Test loss computation after the row update loss = sum( - sess.run(factor_loss * count_cols(inp) / num_rows, + sess.run(factor_loss * self.count_cols(inp) / num_rows, feed_dict={sp_feeder: inp}) for inp in input_scattered_rows_non_duplicate) - true_loss = calculate_loss_from_wals_model( + true_loss = self.calculate_loss_from_wals_model( wals_model, self._wals_inputs) self.assertNear( loss, true_loss, err=.001, @@ -516,10 +412,10 @@ class WalsModelTest(test.TestCase): if compute_loss: # Test loss computation after the col update loss = sum( - sess.run(factor_loss * count_rows(inp) / num_cols, + sess.run(factor_loss * self.count_rows(inp) / num_cols, feed_dict={sp_feeder: inp}) for inp in input_scattered_cols_non_duplicate) - true_loss = calculate_loss_from_wals_model( + true_loss = self.calculate_loss_from_wals_model( wals_model, self._wals_inputs) self.assertNear( loss, true_loss, err=.001, @@ -534,7 +430,7 @@ class WalsModelTest(test.TestCase): # Here we test that those two give identical results. def _run_test_als(self, use_factors_weights_cache): with ops.Graph().as_default(), self.test_session(): - self._wals_inputs = sparse_input() + self._wals_inputs = self.sparse_input() col_init = np.random.rand(7, 3) als_model = factorization_ops.WALSModel( 5, @@ -613,7 +509,7 @@ class WalsModelTest(test.TestCase): def _run_test_als_transposed(self, use_factors_weights_cache): with ops.Graph().as_default(), self.test_session(): - self._wals_inputs = sparse_input() + self._wals_inputs = self.sparse_input() col_init = np.random.rand(7, 3) als_model = factorization_ops.WALSModel( 5, diff --git a/tensorflow/contrib/factorization/python/ops/factorization_ops_test_utils.py b/tensorflow/contrib/factorization/python/ops/factorization_ops_test_utils.py new file mode 100644 index 00000000000..69572119732 --- /dev/null +++ b/tensorflow/contrib/factorization/python/ops/factorization_ops_test_utils.py @@ -0,0 +1,131 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Test utils for factorization_ops.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import random + +import numpy as np + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import sparse_ops + + +INPUT_MATRIX = np.array( + [[0.1, 0.0, 0.2, 0.0, 0.4, 0.5, 0.0], + [0.0, 1.1, 0.0, 1.3, 1.4, 0.0, 1.6], + [2.0, 0.0, 0.0, 2.3, 0.0, 2.5, 0.0], + [3.0, 0.0, 3.2, 3.3, 0.0, 3.5, 0.0], + [0.0, 4.1, 0.0, 0.0, 4.4, 0.0, 4.6]]).astype(np.float32) + + +def np_matrix_to_tf_sparse(np_matrix, + row_slices=None, + col_slices=None, + transpose=False, + shuffle=False): + """Simple util to slice non-zero np matrix elements as tf.SparseTensor.""" + indices = np.nonzero(np_matrix) + + # Only allow slices of whole rows or whole columns. + assert not (row_slices is not None and col_slices is not None) + + if row_slices is not None: + selected_ind = np.concatenate( + [np.where(indices[0] == r)[0] for r in row_slices], 0) + indices = (indices[0][selected_ind], indices[1][selected_ind]) + + if col_slices is not None: + selected_ind = np.concatenate( + [np.where(indices[1] == c)[0] for c in col_slices], 0) + indices = (indices[0][selected_ind], indices[1][selected_ind]) + + if shuffle: + shuffled_ind = [x for x in range(len(indices[0]))] + random.shuffle(shuffled_ind) + indices = (indices[0][shuffled_ind], indices[1][shuffled_ind]) + + ind = (np.concatenate((np.expand_dims(indices[1], 1), + np.expand_dims(indices[0], 1)), 1).astype(np.int64) if + transpose else np.concatenate((np.expand_dims(indices[0], 1), + np.expand_dims(indices[1], 1)), + 1).astype(np.int64)) + val = np_matrix[indices].astype(np.float32) + shape = (np.array([max(indices[1]) + 1, max(indices[0]) + 1]).astype(np.int64) + if transpose else np.array( + [max(indices[0]) + 1, max(indices[1]) + 1]).astype(np.int64)) + return sparse_tensor.SparseTensor(ind, val, shape) + + +def calculate_loss(input_mat, row_factors, col_factors, regularization=None, + w0=1., row_weights=None, col_weights=None): + """Calculates the loss of a given factorization. + + Using a non distributed method, different than the one implemented in the + WALS model. The weight of an observed entry (i, j) (i.e. such that + input_mat[i, j] is non zero) is (w0 + row_weights[i]col_weights[j]). + + Args: + input_mat: The input matrix, a SparseTensor of rank 2. + row_factors: The row factors, a dense Tensor of rank 2. + col_factors: The col factors, a dense Tensor of rank 2. + regularization: the regularization coefficient, a scalar. + w0: the weight of unobserved entries. A scalar. + row_weights: A dense tensor of rank 1. + col_weights: A dense tensor of rank 1. + + Returns: + The total loss. + """ + wr = (array_ops.expand_dims(row_weights, 1) if row_weights is not None + else constant_op.constant(1.)) + wc = (array_ops.expand_dims(col_weights, 0) if col_weights is not None + else constant_op.constant(1.)) + reg = (regularization if regularization is not None + else constant_op.constant(0.)) + + row_indices, col_indices = array_ops.split(input_mat.indices, + axis=1, + num_or_size_splits=2) + gathered_row_factors = array_ops.gather(row_factors, row_indices) + gathered_col_factors = array_ops.gather(col_factors, col_indices) + sp_approx_vals = array_ops.squeeze(math_ops.matmul( + gathered_row_factors, gathered_col_factors, adjoint_b=True)) + sp_approx = sparse_tensor.SparseTensor( + indices=input_mat.indices, + values=sp_approx_vals, + dense_shape=input_mat.dense_shape) + + sp_approx_sq = math_ops.square(sp_approx) + row_norm = math_ops.reduce_sum(math_ops.square(row_factors)) + col_norm = math_ops.reduce_sum(math_ops.square(col_factors)) + row_col_norm = math_ops.reduce_sum(math_ops.square(math_ops.matmul( + row_factors, col_factors, transpose_b=True))) + + resid = sparse_ops.sparse_add(input_mat, sp_approx * (-1)) + resid_sq = math_ops.square(resid) + loss = w0 * ( + sparse_ops.sparse_reduce_sum(resid_sq) - + sparse_ops.sparse_reduce_sum(sp_approx_sq) + ) + loss += (sparse_ops.sparse_reduce_sum(wr * (resid_sq * wc)) + + w0 * row_col_norm + reg * (row_norm + col_norm)) + return loss.eval() diff --git a/tensorflow/contrib/kernel_methods/BUILD b/tensorflow/contrib/kernel_methods/BUILD index b37cbc119f4..fccaa3abd4d 100644 --- a/tensorflow/contrib/kernel_methods/BUILD +++ b/tensorflow/contrib/kernel_methods/BUILD @@ -21,9 +21,14 @@ py_library( ":dense_kernel_mapper_py", "//tensorflow/contrib/layers:layers_py", "//tensorflow/contrib/learn", - "//tensorflow/python:framework", - "//tensorflow/python:ops", + "//tensorflow/python:array_ops", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform", + "//tensorflow/python:util", "//third_party/py/numpy", + "@six_archive//:six", ], ) @@ -31,6 +36,7 @@ py_library( name = "dense_kernel_mapper_py", srcs = ["python/mappers/dense_kernel_mapper.py"], srcs_version = "PY2AND3", + deps = ["@six_archive//:six"], ) py_test( @@ -40,12 +46,12 @@ py_test( deps = [ ":dense_kernel_mapper_py", ":kernel_methods", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework_test_lib", + "//tensorflow/python:math_ops", "//tensorflow/python:nn", - "//tensorflow/python:ops", + "//tensorflow/python:platform_test", + "//tensorflow/python:random_ops", ], ) @@ -55,10 +61,12 @@ py_test( srcs_version = "PY2AND3", deps = [ ":kernel_methods", - "//tensorflow/python:client_testlib", + "//tensorflow/contrib/layers:layers_py", + "//tensorflow/contrib/learn", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework_test_lib", - "//tensorflow/python:ops", + "//tensorflow/python:platform_test", + "//tensorflow/python:sparse_tensor", "//third_party/py/numpy", ], ) diff --git a/tensorflow/contrib/kernel_methods/__init__.py b/tensorflow/contrib/kernel_methods/__init__.py index 1a3a0ab77a6..a715df1bf12 100644 --- a/tensorflow/contrib/kernel_methods/__init__.py +++ b/tensorflow/contrib/kernel_methods/__init__.py @@ -22,7 +22,6 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.kernel_methods.python.kernel_estimators import KernelLinearClassifier -from tensorflow.contrib.kernel_methods.python.mappers import dense_kernel_mapper from tensorflow.contrib.kernel_methods.python.mappers.random_fourier_features import RandomFourierFeatureMapper from tensorflow.python.util.all_util import remove_undocumented diff --git a/tensorflow/contrib/layers/BUILD b/tensorflow/contrib/layers/BUILD index f5eb723e2db..fc33e4b49e0 100644 --- a/tensorflow/contrib/layers/BUILD +++ b/tensorflow/contrib/layers/BUILD @@ -118,6 +118,7 @@ tf_custom_op_py_library( "//tensorflow/python:array_ops", "//tensorflow/python:check_ops", "//tensorflow/python:clip_ops", + "//tensorflow/python:common_shapes", "//tensorflow/python:control_flow_ops", "//tensorflow/python:embedding_ops", "//tensorflow/python:framework", @@ -131,9 +132,11 @@ tf_custom_op_py_library( "//tensorflow/python:platform", "//tensorflow/python:random_ops", "//tensorflow/python:sparse_ops", + "//tensorflow/python:sparse_tensor", "//tensorflow/python:standard_ops", "//tensorflow/python:string_ops", "//tensorflow/python:summary", + "//tensorflow/python:tensor_util", "//tensorflow/python:training", "//tensorflow/python:util", "//tensorflow/python:variable_scope", diff --git a/tensorflow/contrib/layers/python/layers/feature_column_ops.py b/tensorflow/contrib/layers/python/layers/feature_column_ops.py index 7f1bfc9605b..2662d0ae3e7 100644 --- a/tensorflow/contrib/layers/python/layers/feature_column_ops.py +++ b/tensorflow/contrib/layers/python/layers/feature_column_ops.py @@ -18,6 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import functools + from tensorflow.contrib.framework.python.framework import checkpoint_utils from tensorflow.contrib.framework.python.framework import experimental from tensorflow.contrib.framework.python.ops import variables as contrib_variables @@ -36,6 +38,7 @@ from tensorflow.python.ops import sparse_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.util import nest def _embeddings_from_arguments(column, @@ -136,6 +139,58 @@ def _embeddings_from_arguments(column, max_norm=args.max_norm) +def _maybe_reshape_input_tensor(tensor, column_name, output_rank): + """Reshape the input tensor by the following rule. + + 1. If `output_rank > input_rank + 1`, raise a `ValueError`. + 2. If `output_rank == input_rank + 1`, expand the tensor by one dimension. + 3. If `output_rank == input_rank`, do nothing. + 4. If `output_rank < input_rank`, flatten the inner dimensions of the tensor. + + Args: + tensor: A Tensor or SparseTensor to be reshaped. + column_name: A string name of the feature column for the tensor. + output_rank: the desired rank of the tensor. + Returns: + A reshaped Tensor or SparseTensor. + Raises: + ValueError: if `output_rank > input_rank + 1` for the input tensor. + """ + input_rank = tensor.get_shape().ndims + + if input_rank is None and isinstance(tensor, sparse_tensor_py.SparseTensor): + # Try to get the rank of a sparse tensor by its dense_shape's shape. + input_rank = tensor.dense_shape.get_shape().as_list()[0] + + if input_rank is None: + raise ValueError('Error while processing column {}. Rank of input Tensor ' + 'can not be None.'.format(column_name)) + + if output_rank > input_rank + 1: + raise ValueError('Error while processing column {}. Rank of input Tensor ' + '({}) should be the same as output_rank ({}). For ' + 'example, sequence data should typically be 3 ' + 'dimensional (rank 3) while non-sequence data is ' + 'typically 2 dimensional (rank 2).'.format( + column_name, input_rank, output_rank)) + elif output_rank == input_rank + 1: + # Expand the tensor's shape by 1 dimension. + if isinstance(tensor, sparse_tensor_py.SparseTensor): + output_shape = array_ops.concat([tensor.dense_shape, [1]], 0) + return sparse_ops.sparse_reshape(tensor, output_shape) + else: + reshaped = array_ops.expand_dims(tensor, -1) + # Try to calculate the new shape. + static_shape = tensor.get_shape() + if static_shape is not None and static_shape.dims is not None: + reshaped.set_shape(static_shape.as_list() + [1]) + return reshaped + elif output_rank < input_rank: + return layers._inner_flatten(tensor, output_rank) # pylint: disable=protected-access + else: + return tensor + + def _input_from_feature_columns(columns_to_tensors, feature_columns, weight_collections, @@ -160,6 +215,12 @@ def _input_from_feature_columns(columns_to_tensors, default_name=column.name, values=columns_to_tensors.values()): transformed_tensor = transformer.transform(column) + if output_rank == 3: + transformed_tensor = nest.map_structure( + functools.partial( + _maybe_reshape_input_tensor, + column_name=column.name, + output_rank=output_rank), transformed_tensor) try: # pylint: disable=protected-access arguments = column._deep_embedding_lookup_arguments( @@ -548,7 +609,8 @@ def weighted_sum_from_feature_columns(columns_to_tensors, default_name=column.name, values=columns_to_tensors.values()): tensor = column._to_dense_tensor(transformed_tensor) - tensor = fc._reshape_real_valued_tensor(tensor, 2, column.name) + tensor = _maybe_reshape_input_tensor( + tensor, column.name, output_rank=2) variable = [ contrib_variables.model_variable( name='weight', diff --git a/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py b/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py index 35daef9cc6e..01c54f77d62 100644 --- a/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py +++ b/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py @@ -1350,6 +1350,35 @@ class SequenceInputFromFeatureColumnTest(test.TestCase): self.assertAllEqual(expected_input_shape, model_input.shape) + def testEmbeddingColumnWithAutoReshape(self): + hash_buckets = 10 + embedding_dimension = 5 + ids_tensor = sparse_tensor.SparseTensor( + values=["c", "b", + "a", "c", "b", + "b"], + indices=[[0, 0], [0, 1], + [1, 0], [1, 1], [1, 2], + [3, 2]], + dense_shape=[4, 3]) + + expected_input_shape = np.array([4, 3, embedding_dimension]) + + hashed_ids_column = feature_column.sparse_column_with_hash_bucket( + "ids", hash_buckets) + embedded_column = feature_column.embedding_column(hashed_ids_column, + embedding_dimension) + columns_to_tensors = {"ids": ids_tensor} + model_input_tensor = feature_column_ops.sequence_input_from_feature_columns( + columns_to_tensors, [embedded_column]) + + with self.test_session() as sess: + variables_lib.global_variables_initializer().run() + data_flow_ops.tables_initializer().run() + model_input = sess.run(model_input_tensor) + + self.assertAllEqual(expected_input_shape, model_input.shape) + def testEmbeddingColumnGradient(self): hash_buckets = 1000 embedding_dimension = 3 diff --git a/tensorflow/contrib/learn/BUILD b/tensorflow/contrib/learn/BUILD index 90d4d7e3a8f..148f2708c85 100644 --- a/tensorflow/contrib/learn/BUILD +++ b/tensorflow/contrib/learn/BUILD @@ -836,6 +836,19 @@ py_test( ], ) +py_test( + name = "model_fn_test", + size = "small", + srcs = ["python/learn/estimators/model_fn_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":learn", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_test_lib", + "//third_party/py/numpy", + ], +) + py_test( name = "multioutput_test", size = "small", diff --git a/tensorflow/contrib/learn/python/learn/estimators/head.py b/tensorflow/contrib/learn/python/learn/estimators/head.py index 028a13ca20a..14f7666b3df 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/head.py +++ b/tensorflow/contrib/learn/python/learn/estimators/head.py @@ -42,6 +42,7 @@ from tensorflow.python.ops import logging_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn from tensorflow.python.ops import sparse_ops +from tensorflow.python.ops import string_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.summary import summary @@ -816,7 +817,8 @@ class _BinaryLogisticHead(_SingleHead): loss_fn=self._loss_fn, logits_to_predictions_fn=self._logits_to_predictions, metrics_fn=self._metrics, - create_output_alternatives_fn=self._create_output_alternatives, + create_output_alternatives_fn=_classification_output_alternatives( + self.head_name, self._problem_type), labels=labels, train_op_fn=train_op_fn, logits=logits, @@ -885,6 +887,8 @@ class _BinaryLogisticHead(_SingleHead): _indicator_labels_streaming_mean(labels, weights)) metrics[_summary_key(self.head_name, mkey.AUC)] = ( _streaming_auc(logistic, labels, weights)) + metrics[_summary_key(self.head_name, mkey.AUC_PR)] = ( + _streaming_auc(logistic, labels, weights, curve="PR")) for threshold in self._thresholds: metrics[_summary_key( @@ -1009,7 +1013,8 @@ class _MultiClassHead(_SingleHead): loss_fn=self._wrapped_loss_fn, logits_to_predictions_fn=self._logits_to_predictions, metrics_fn=self._metrics, - create_output_alternatives_fn=self._create_output_alternatives, + create_output_alternatives_fn=_classification_output_alternatives( + self.head_name, self._problem_type, self._label_keys), labels=labels, train_op_fn=train_op_fn, logits=logits, @@ -1113,25 +1118,6 @@ class _MultiClassHead(_SingleHead): return metrics - def _create_output_alternatives(self, predictions): - """See superclass.""" - probabilities = predictions[prediction_key.PredictionKey.PROBABILITIES] - batch_size = array_ops.shape(probabilities)[0] - if self._label_keys: - classes = array_ops.tile( - input=array_ops.expand_dims(input=self._label_keys, axis=0), - multiples=[batch_size, 1]) - else: - classes = array_ops.tile( - input=array_ops.expand_dims( - input=math_ops.range(self.logits_dimension), axis=0), - multiples=[batch_size, 1]) - predictions_for_serving = { - prediction_key.PredictionKey.CLASSES: classes, - prediction_key.PredictionKey.PROBABILITIES: probabilities, - } - return {self._head_name: (self._problem_type, predictions_for_serving)} - def _to_labels_tensor(labels, label_name): """Returns label as a tensor. @@ -1226,6 +1212,7 @@ class _BinarySvmHead(_SingleHead): loss_fn=self._loss_fn, logits_to_predictions_fn=self._logits_to_predictions, metrics_fn=self._metrics, + # TODO(zakaria): Handle labels for export. create_output_alternatives_fn=self._create_output_alternatives, labels=labels, train_op_fn=train_op_fn, @@ -1325,7 +1312,8 @@ class _MultiLabelHead(_SingleHead): loss_fn=self._loss_fn, logits_to_predictions_fn=self._logits_to_predictions, metrics_fn=self._metrics, - create_output_alternatives_fn=self._create_output_alternatives, + create_output_alternatives_fn=_classification_output_alternatives( + self.head_name, self._problem_type), labels=labels, train_op_fn=train_op_fn, logits=logits, @@ -1374,6 +1362,8 @@ class _MultiLabelHead(_SingleHead): metrics_lib.streaming_accuracy(classes, labels, weights)) metrics[_summary_key(self.head_name, mkey.AUC)] = _streaming_auc( probabilities, labels, weights) + metrics[_summary_key(self.head_name, mkey.AUC_PR)] = _streaming_auc( + probabilities, labels, weights, curve="PR") for class_id in self._metric_class_ids: # TODO(ptucker): Add per-class accuracy, precision, recall. @@ -1391,6 +1381,9 @@ class _MultiLabelHead(_SingleHead): _predictions_streaming_mean(logits, weights, class_id)) metrics[_summary_key(self.head_name, mkey.CLASS_AUC % class_id)] = ( _streaming_auc(probabilities, labels, weights, class_id)) + metrics[_summary_key(self.head_name, mkey.CLASS_AUC_PR % class_id)] = ( + _streaming_auc(probabilities, labels, weights, class_id, + curve="PR")) return metrics @@ -1857,7 +1850,8 @@ def _class_labels_streaming_mean(labels, weights, class_id): weights=weights) -def _streaming_auc(predictions, labels, weights=None, class_id=None): +def _streaming_auc(predictions, labels, weights=None, class_id=None, + curve="ROC"): predictions = ops.convert_to_tensor(predictions) labels = ops.convert_to_tensor(labels) if class_id is not None: @@ -1866,7 +1860,8 @@ def _streaming_auc(predictions, labels, weights=None, class_id=None): return metrics_lib.streaming_auc( predictions, math_ops.cast(labels, dtypes.bool), - weights=_float_weights_or_none(weights)) + weights=_float_weights_or_none(weights), + curve=curve) def _assert_class_id(class_id, num_classes=None): @@ -1901,6 +1896,71 @@ def _streaming_recall_at_threshold(predictions, labels, weights, threshold): return array_ops.squeeze(precision_tensor), array_ops.squeeze(update_op) +def _classification_output_alternatives(head_name, problem_type, + label_keys=None): + """Creates a func to generate output alternatives for classification. + + Servo expects classes to be a string tensor, and have the same dimensions + as the probabilities tensor. It should contain the labels of the corresponding + entries in probabilities. This function creates a new classes tensor that + satisfies these conditions and can be exported. + + Args: + head_name: Name of the head. + problem_type: `ProblemType` + label_keys: Optional label keys + + Returns: + A function to generate output alternatives. + """ + def _create_output_alternatives(predictions): + """Creates output alternative for the Head. + + Args: + predictions: a dict of {tensor_name: Tensor}, where 'tensor_name' is a + symbolic name for an output Tensor possibly but not necessarily taken + from `PredictionKey`, and 'Tensor' is the corresponding output Tensor + itself. + + Returns: + `dict` of {submodel_name: (problem_type, {tensor_name: Tensor})}, where + 'submodel_name' is a submodel identifier that should be consistent across + the pipeline (here likely taken from the head_name), + 'problem_type' is a `ProblemType`, + 'tensor_name' is a symbolic name for an output Tensor possibly but not + necessarily taken from `PredictionKey`, and + 'Tensor' is the corresponding output Tensor itself. + + Raises: + ValueError: if predictions does not have PredictionKey.PROBABILITIES key. + """ + probabilities = predictions.get(prediction_key.PredictionKey.PROBABILITIES) + if probabilities is None: + raise ValueError("%s missing in predictions" % + prediction_key.PredictionKey.PROBABILITIES) + + with ops.name_scope(None, "_classification_output_alternatives", + (probabilities,)): + batch_size = array_ops.shape(probabilities)[0] + if label_keys: + classes = array_ops.tile( + input=array_ops.expand_dims(input=label_keys, axis=0), + multiples=[batch_size, 1], + name="classes_tensor") + else: + n = array_ops.shape(probabilities)[1] + classes = array_ops.tile( + input=array_ops.expand_dims(input=math_ops.range(n), axis=0), + multiples=[batch_size, 1]) + classes = string_ops.as_string(classes, name="classes_tensor") + + exported_predictions = { + prediction_key.PredictionKey.PROBABILITIES: probabilities, + prediction_key.PredictionKey.CLASSES: classes} + return {head_name: (problem_type, exported_predictions)} + + return _create_output_alternatives + # Aliases # TODO(zakaria): Remove these aliases, See b/34751732 _regression_head = regression_head diff --git a/tensorflow/contrib/learn/python/learn/estimators/head_test.py b/tensorflow/contrib/learn/python/learn/estimators/head_test.py index ecc1d9ff9e1..9b8cba15263 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/head_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/head_test.py @@ -297,11 +297,15 @@ class MultiLabelHeadTest(test.TestCase): def _expected_eval_metrics(self, expected_loss): return { "accuracy": 1. / 3, - "auc": 1. / 4, "loss": expected_loss, + "auc": 1. / 4, "auc/class0": 1., "auc/class1": 1., "auc/class2": 0., + "auc_precision_recall": 0.166667, + "auc_precision_recall/class0": 0, + "auc_precision_recall/class1": 0., + "auc_precision_recall/class2": 1., "labels/actual_label_mean/class0": self._labels[0][0], "labels/actual_label_mean/class1": self._labels[0][1], "labels/actual_label_mean/class2": self._labels[0][2], @@ -417,7 +421,7 @@ class MultiLabelHeadTest(test.TestCase): {}, model_fn.ModeKeys.TRAIN, self._labels, head_lib.no_op_train_fn, logits_input=((0., 0.),), logits=self._logits) - def testMultiLabelEvalMode(self): + def testMultiLabelEval(self): n_classes = 3 head = head_lib.multi_label_head( n_classes=n_classes, metric_class_ids=range(n_classes)) @@ -433,7 +437,7 @@ class MultiLabelHeadTest(test.TestCase): _assert_metrics(self, expected_loss, self._expected_eval_metrics(expected_loss), model_fn_ops) - def testMultiClassEvalModeWithLargeLogits(self): + def testMultiClassEvalWithLargeLogits(self): n_classes = 3 head = head_lib.multi_label_head( n_classes=n_classes, metric_class_ids=range(n_classes)) @@ -472,6 +476,36 @@ class MultiLabelHeadTest(test.TestCase): _assert_metrics(self, expected_loss, expected_eval_metrics, model_fn_ops) + def testMultiLabelInfer(self): + n_classes = 3 + head = head_lib.multi_label_head(n_classes=n_classes, head_name="head_name") + with ops.Graph().as_default(), session.Session(): + model_fn_ops = head.create_model_fn_ops( + {}, model_fn.ModeKeys.INFER, self._labels, head_lib.no_op_train_fn, + logits=((1., 0., 0.), (0., 0., 1))) + self.assertIsNone(model_fn_ops.train_op) + _assert_no_variables(self) + with session.Session(): + self.assertListEqual( + [1, 0, 0], model_fn_ops.predictions["classes"].eval().tolist()[0]) + self.assertItemsEqual( + ["head_name"], six.iterkeys(model_fn_ops.output_alternatives)) + self.assertEqual( + constants.ProblemType.CLASSIFICATION, + model_fn_ops.output_alternatives["head_name"][0]) + + predictions_for_serving = ( + model_fn_ops.output_alternatives["head_name"][1]) + self.assertIn("classes", six.iterkeys(predictions_for_serving)) + self.assertAllEqual( + [[b"0", b"1", b"2"], [b"0", b"1", b"2"]], + predictions_for_serving["classes"].eval()) + self.assertIn("probabilities", six.iterkeys(predictions_for_serving)) + self.assertAllClose( + [[0.731059, 0.5, 0.5], + [0.5, 0.5, 0.731059,]], + predictions_for_serving["probabilities"].eval()) + def testMultiLabelWithLabelName(self): n_classes = 3 label_name = "my_label" @@ -621,6 +655,7 @@ class BinaryClassificationHeadTest(test.TestCase): "accuracy/baseline_label_mean": label_mean, "accuracy/threshold_0.500000_mean": 1. / 2, "auc": 1. / 2, + "auc_precision_recall": 0.749999, "labels/actual_label_mean": label_mean, "labels/prediction_mean": .731059, # softmax "loss": expected_loss, @@ -691,7 +726,7 @@ class BinaryClassificationHeadTest(test.TestCase): {}, model_fn.ModeKeys.TRAIN, self._labels, head_lib.no_op_train_fn, logits_input=((0., 0.), (0., 0.)), logits=self._logits) - def testBinaryClassificationEvalMode(self): + def testBinaryClassificationEval(self): n_classes = 2 head = head_lib.multi_class_head(n_classes=n_classes) with ops.Graph().as_default(), session.Session(): @@ -708,18 +743,32 @@ class BinaryClassificationHeadTest(test.TestCase): _assert_metrics(self, expected_loss, self._expected_eval_metrics(expected_loss), model_fn_ops) - def testBinaryClassificationInferMode(self): + def testBinaryClassificationInfer(self): n_classes = 2 - head = head_lib.multi_class_head(n_classes=n_classes) + head = head_lib.multi_class_head(n_classes=n_classes, head_name="head_name") with ops.Graph().as_default(), session.Session(): # logloss: z:label, x:logit # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x)) model_fn_ops = head.create_model_fn_ops( {}, model_fn.ModeKeys.INFER, self._labels, head_lib.no_op_train_fn, logits=self._logits) - self._assert_output_alternatives(model_fn_ops) self.assertIsNone(model_fn_ops.train_op) _assert_no_variables(self) + with session.Session(): + self.assertListEqual( + [1, 1], list(model_fn_ops.predictions["classes"].eval())) + self.assertItemsEqual( + ["head_name"], six.iterkeys(model_fn_ops.output_alternatives)) + self.assertEqual( + constants.ProblemType.LOGISTIC_REGRESSION, + model_fn_ops.output_alternatives["head_name"][0]) + predictions_for_serving = ( + model_fn_ops.output_alternatives["head_name"][1]) + self.assertIn("classes", six.iterkeys(predictions_for_serving)) + predicted_classes = predictions_for_serving["classes"].eval().tolist() + self.assertListEqual( + [b"0", b"1"], predicted_classes[0]) + self.assertIn("probabilities", six.iterkeys(predictions_for_serving)) def testBinaryClassificationInferMode_withWightColumn(self): n_classes = 2 @@ -1006,7 +1055,7 @@ class MultiClassHeadTest(test.TestCase): "multi_class_head/centered_bias/bias_1", "multi_class_head/centered_bias/bias_2"]) - def testMultiClassEvalMode(self): + def testMultiClassEval(self): n_classes = 3 head = head_lib.multi_class_head( n_classes=n_classes, metric_class_ids=range(n_classes)) @@ -1131,7 +1180,7 @@ class MultiClassHeadTest(test.TestCase): model_fn_ops.output_alternatives["head_name"][1]) self.assertIn("classes", six.iterkeys(predictions_for_serving)) self.assertAllEqual( - [[0, 1, 2], [0, 1, 2]], + [[b"0", b"1", b"2"], [b"0", b"1", b"2"]], predictions_for_serving["classes"].eval()) self.assertIn("probabilities", six.iterkeys(predictions_for_serving)) self.assertAllClose( diff --git a/tensorflow/contrib/learn/python/learn/estimators/metric_key.py b/tensorflow/contrib/learn/python/learn/estimators/metric_key.py index 10ac888eca7..99388f116b3 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/metric_key.py +++ b/tensorflow/contrib/learn/python/learn/estimators/metric_key.py @@ -22,7 +22,9 @@ class MetricKey(object): """Metric key strings.""" LOSS = "loss" AUC = "auc" + AUC_PR = "auc_precision_recall" CLASS_AUC = "auc/class%d" + CLASS_AUC_PR = "auc_precision_recall/class%d" PREDICTION_MEAN = "labels/prediction_mean" CLASS_PREDICTION_MEAN = "labels/prediction_mean/class%d" CLASS_LOGITS_MEAN = "labels/logits_mean/class%d" diff --git a/tensorflow/contrib/learn/python/learn/estimators/model_fn.py b/tensorflow/contrib/learn/python/learn/estimators/model_fn.py index 3c812a46597..6d15f83ef53 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/model_fn.py +++ b/tensorflow/contrib/learn/python/learn/estimators/model_fn.py @@ -25,10 +25,16 @@ import six from tensorflow.contrib import framework as contrib_framework from tensorflow.contrib.framework import get_graph_from_inputs - +from tensorflow.contrib.learn.python.learn.estimators import constants +from tensorflow.contrib.learn.python.learn.estimators import prediction_key +from tensorflow.python.estimator import model_fn as core_model_fn_lib +from tensorflow.python.estimator.export import export_output as core_export_lib +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops +from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.saved_model import signature_constants from tensorflow.python.training import session_run_hook @@ -177,3 +183,85 @@ class ModelFnOps( training_chief_hooks=training_chief_hooks, training_hooks=training_hooks, scaffold=scaffold) + + def estimator_spec(self, mode, default_serving_output_alternative_key=None): + """Creates an equivalent `EstimatorSpec`. + + Args: + mode: One of `ModeKeys`. Specifies if this training, evaluation or + prediction. + default_serving_output_alternative_key: Required for multiple heads. If + you have multiple entries in `output_alternatives` dict (comparable to + multiple heads), `EstimatorSpec` requires a default head that will be + used if a Servo request does not explicitly mention which head to infer + on. Pass the key of the output alternative here that you want to + designate as default. A separate ExportOutpout for this default head + wil be added to the export_outputs dict with the special key + signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY, unless there is + already an enry in output_alternatives with this special key. + + Returns: + Instance of `EstimatorSpec` that is equivalent to this `ModelFnOps` + + Raises: + ValueError: If problem type is unknown. + """ + def _scores(output_tensors): + scores = output_tensors.get(prediction_key.PredictionKey.SCORES) + if scores is None: + scores = output_tensors.get(prediction_key.PredictionKey.PROBABILITIES) + return scores + + def _classes(output_tensors): # pylint: disable=missing-docstring + classes = output_tensors.get(prediction_key.PredictionKey.CLASSES) + if classes is None: + logging.warning( + 'classes is None, Servo inference will not have class ids.') + return None + elif classes.dtype != dtypes.string: + # Servo classification can only serve string classes + logging.warning( + 'classes is not string, Servo inference will not have class ids.') + return None + + return classes + + def _export_output(problem_type, predictions): # pylint: disable=missing-docstring + if problem_type == constants.ProblemType.LINEAR_REGRESSION: + return core_export_lib.RegressionOutput(_scores(predictions)) + + if (problem_type == constants.ProblemType.CLASSIFICATION or + problem_type == constants.ProblemType.LOGISTIC_REGRESSION): + return core_export_lib.ClassificationOutput( + scores=_scores(predictions), classes=_classes(predictions)) + + if problem_type == constants.ProblemType.UNSPECIFIED: + return core_export_lib.PredictOutput(predictions) + + raise ValueError('Unknown problem_type=%s' % problem_type) + + # Converts output_alternatives + export_outputs_dict = None + if self.output_alternatives: + output_alternatives = self.output_alternatives + # Adds default output_alternative if needed. + if (len(output_alternatives) > 1 and + signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY not in + output_alternatives): + output_alternatives = output_alternatives.copy() + output_alternatives[ + signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] = ( + output_alternatives[default_serving_output_alternative_key]) + export_outputs_dict = {key: _export_output(*val) for key, val in + output_alternatives.items()} + + return core_model_fn_lib.EstimatorSpec( + mode=mode, + predictions=self.predictions, + loss=self.loss, + train_op=self.train_op, + eval_metric_ops=self.eval_metric_ops, + export_outputs=export_outputs_dict, + training_chief_hooks=self.training_chief_hooks, + training_hooks=self.training_hooks, + scaffold=self.scaffold) diff --git a/tensorflow/contrib/learn/python/learn/estimators/model_fn_test.py b/tensorflow/contrib/learn/python/learn/estimators/model_fn_test.py new file mode 100644 index 00000000000..fe8b3a1b346 --- /dev/null +++ b/tensorflow/contrib/learn/python/learn/estimators/model_fn_test.py @@ -0,0 +1,279 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""ModelFnOps tests.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +from tensorflow.contrib.learn.python.learn.estimators import constants +from tensorflow.contrib.learn.python.learn.estimators import model_fn +from tensorflow.python.client import session +from tensorflow.python.estimator.export import export_output as core_export_lib +from tensorflow.python.framework import constant_op +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.platform import test +from tensorflow.python.saved_model import signature_constants +from tensorflow.python.training import basic_session_run_hooks +from tensorflow.python.training import monitored_session + + +class ModelFnopsTest(test.TestCase): + """Multi-output tests.""" + + def create_predictions(self): + probabilities = constant_op.constant([1., 1., 1.]) + scores = constant_op.constant([1., 2., 3.]) + classes = constant_op.constant([b"0", b"1", b"2"]) + return { + "probabilities": probabilities, + "scores": scores, + "classes": classes} + + def create_model_fn_ops(self, predictions, output_alternatives, + mode=model_fn.ModeKeys.INFER): + + return model_fn.ModelFnOps( + model_fn.ModeKeys.INFER, + predictions=predictions, + loss=constant_op.constant([1]), + train_op=control_flow_ops.no_op(), + eval_metric_ops={"metric_key": (control_flow_ops.no_op(), + control_flow_ops.no_op())}, + # zzz + training_chief_hooks=[basic_session_run_hooks.StepCounterHook()], + training_hooks=[basic_session_run_hooks.StepCounterHook()], + output_alternatives=output_alternatives, + scaffold=monitored_session.Scaffold()) + + def assertEquals_except_export(self, model_fn_ops, estimator_spec): + self.assertEqual(model_fn_ops.predictions, estimator_spec.predictions) + self.assertEqual(model_fn_ops.loss, estimator_spec.loss) + self.assertEqual(model_fn_ops.train_op, estimator_spec.train_op) + self.assertEqual(model_fn_ops.eval_metric_ops, + estimator_spec.eval_metric_ops) + self.assertEqual(model_fn_ops.training_chief_hooks, + estimator_spec.training_chief_hooks) + self.assertEqual(model_fn_ops.training_hooks, estimator_spec.training_hooks) + self.assertEqual(model_fn_ops.scaffold, estimator_spec.scaffold) + + def testEstimatorSpec_except_export(self): + predictions = self.create_predictions() + model_fn_ops = self.create_model_fn_ops(predictions, None) + + estimator_spec = model_fn_ops.estimator_spec(model_fn.ModeKeys.INFER) + self.assertEquals_except_export(model_fn_ops, estimator_spec) + + def testEstimatorSpec_export_regression_with_scores(self): + predictions = self.create_predictions() + output_alternatives = {"regression_head": ( + constants.ProblemType.LINEAR_REGRESSION, predictions)} + model_fn_ops = self.create_model_fn_ops(predictions, output_alternatives) + + estimator_spec = model_fn_ops.estimator_spec(model_fn.ModeKeys.INFER) + self.assertEquals_except_export(model_fn_ops, estimator_spec) + + with session.Session(): + regression_output = estimator_spec.export_outputs["regression_head"] + self.assertTrue(isinstance( + regression_output, core_export_lib.RegressionOutput)) + self.assertAllEqual(predictions["scores"].eval(), + regression_output.value.eval()) + + def testEstimatorSpec_export_regression_with_probabilities(self): + predictions = self.create_predictions() + output_alternatives_predictions = predictions.copy() + del output_alternatives_predictions["scores"] + output_alternatives = {"regression_head": ( + constants.ProblemType.LINEAR_REGRESSION, + output_alternatives_predictions)} + model_fn_ops = self.create_model_fn_ops(predictions, output_alternatives) + + estimator_spec = model_fn_ops.estimator_spec(model_fn.ModeKeys.INFER) + self.assertEquals_except_export(model_fn_ops, estimator_spec) + + with session.Session(): + regression_output = estimator_spec.export_outputs["regression_head"] + self.assertTrue(isinstance( + regression_output, core_export_lib.RegressionOutput)) + self.assertAllEqual(predictions["probabilities"].eval(), + regression_output.value.eval()) + + def testEstimatorSpec_export_classsification(self): + predictions = self.create_predictions() + output_alternatives = {"classification_head": ( + constants.ProblemType.CLASSIFICATION, predictions)} + model_fn_ops = self.create_model_fn_ops(predictions, output_alternatives) + + estimator_spec = model_fn_ops.estimator_spec(model_fn.ModeKeys.INFER) + self.assertEquals_except_export(model_fn_ops, estimator_spec) + + with session.Session(): + classification_output = estimator_spec.export_outputs[ + "classification_head"] + self.assertTrue(isinstance(classification_output, + core_export_lib.ClassificationOutput)) + self.assertAllEqual(predictions["scores"].eval(), + classification_output.scores.eval()) + self.assertAllEqual(predictions["classes"].eval(), + classification_output.classes.eval()) + + def testEstimatorSpec_export_classsification_with_missing_scores(self): + predictions = self.create_predictions() + output_alternatives_predictions = predictions.copy() + del output_alternatives_predictions["scores"] + output_alternatives = {"classification_head": ( + constants.ProblemType.CLASSIFICATION, output_alternatives_predictions)} + model_fn_ops = self.create_model_fn_ops(predictions, output_alternatives) + + estimator_spec = model_fn_ops.estimator_spec(model_fn.ModeKeys.INFER) + self.assertEquals_except_export(model_fn_ops, estimator_spec) + + with session.Session(): + classification_output = estimator_spec.export_outputs[ + "classification_head"] + self.assertTrue(isinstance(classification_output, + core_export_lib.ClassificationOutput)) + self.assertAllEqual(predictions["probabilities"].eval(), + classification_output.scores.eval()) + self.assertAllEqual(predictions["classes"].eval(), + classification_output.classes.eval()) + + def testEstimatorSpec_export_classsification_with_missing_scores_proba(self): + predictions = self.create_predictions() + output_alternatives_predictions = predictions.copy() + del output_alternatives_predictions["scores"] + del output_alternatives_predictions["probabilities"] + output_alternatives = {"classification_head": ( + constants.ProblemType.CLASSIFICATION, output_alternatives_predictions)} + model_fn_ops = self.create_model_fn_ops(predictions, output_alternatives) + + estimator_spec = model_fn_ops.estimator_spec(model_fn.ModeKeys.INFER) + self.assertEquals_except_export(model_fn_ops, estimator_spec) + + with session.Session(): + classification_output = estimator_spec.export_outputs[ + "classification_head"] + self.assertTrue(isinstance(classification_output, + core_export_lib.ClassificationOutput)) + self.assertIsNone(classification_output.scores) + self.assertAllEqual(predictions["classes"].eval(), + classification_output.classes.eval()) + + def testEstimatorSpec_export_classsification_with_missing_classes(self): + predictions = self.create_predictions() + output_alternatives_predictions = predictions.copy() + del output_alternatives_predictions["classes"] + output_alternatives = {"classification_head": ( + constants.ProblemType.CLASSIFICATION, output_alternatives_predictions)} + model_fn_ops = self.create_model_fn_ops(predictions, output_alternatives) + + estimator_spec = model_fn_ops.estimator_spec(model_fn.ModeKeys.INFER) + self.assertEquals_except_export(model_fn_ops, estimator_spec) + + with session.Session(): + classification_output = estimator_spec.export_outputs[ + "classification_head"] + self.assertTrue(isinstance(classification_output, + core_export_lib.ClassificationOutput)) + self.assertAllEqual(predictions["scores"].eval(), + classification_output.scores.eval()) + self.assertIsNone(classification_output.classes) + + def testEstimatorSpec_export_classsification_with_nonstring_classes(self): + predictions = self.create_predictions() + output_alternatives_predictions = predictions.copy() + output_alternatives_predictions["classes"] = constant_op.constant( + [1, 2, 3]) + output_alternatives = {"classification_head": ( + constants.ProblemType.CLASSIFICATION, output_alternatives_predictions)} + model_fn_ops = self.create_model_fn_ops(predictions, output_alternatives) + + estimator_spec = model_fn_ops.estimator_spec(model_fn.ModeKeys.INFER) + self.assertEquals_except_export(model_fn_ops, estimator_spec) + + with session.Session(): + classification_output = estimator_spec.export_outputs[ + "classification_head"] + self.assertTrue(isinstance(classification_output, + core_export_lib.ClassificationOutput)) + self.assertAllEqual(predictions["scores"].eval(), + classification_output.scores.eval()) + self.assertIsNone(classification_output.classes) + + def testEstimatorSpec_export_logistic(self): + predictions = self.create_predictions() + output_alternatives = {"logistic_head": ( + constants.ProblemType.LOGISTIC_REGRESSION, predictions)} + model_fn_ops = self.create_model_fn_ops(predictions, output_alternatives) + + estimator_spec = model_fn_ops.estimator_spec(model_fn.ModeKeys.INFER) + self.assertEquals_except_export(model_fn_ops, estimator_spec) + + with session.Session(): + logistic_output = estimator_spec.export_outputs["logistic_head"] + self.assertTrue(isinstance(logistic_output, + core_export_lib.ClassificationOutput)) + self.assertAllEqual(predictions["scores"].eval(), + logistic_output.scores.eval()) + self.assertAllEqual(predictions["classes"].eval(), + logistic_output.classes.eval()) + + def testEstimatorSpec_export_unspecified(self): + predictions = self.create_predictions() + output_alternatives = {"unspecified_head": ( + constants.ProblemType.UNSPECIFIED, predictions)} + + model_fn_ops = self.create_model_fn_ops(predictions, output_alternatives) + + estimator_spec = model_fn_ops.estimator_spec(model_fn.ModeKeys.INFER) + self.assertEquals_except_export(model_fn_ops, estimator_spec) + + with session.Session(): + unspecified_output = estimator_spec.export_outputs["unspecified_head"] + self.assertTrue(isinstance(unspecified_output, + core_export_lib.PredictOutput)) + self.assertEqual(predictions, unspecified_output.outputs) + + def testEstimatorSpec_export_multihead(self): + predictions = self.create_predictions() + output_alternatives = { + "regression_head": ( + constants.ProblemType.LINEAR_REGRESSION, predictions), + "classification_head": ( + constants.ProblemType.CLASSIFICATION, predictions)} + model_fn_ops = self.create_model_fn_ops(predictions, output_alternatives) + + estimator_spec = model_fn_ops.estimator_spec(model_fn.ModeKeys.INFER, + "regression_head") + self.assertEquals_except_export(model_fn_ops, estimator_spec) + + with session.Session(): + regression_output = estimator_spec.export_outputs["regression_head"] + self.assertTrue(isinstance( + regression_output, core_export_lib.RegressionOutput)) + self.assertAllEqual(predictions["scores"].eval(), + regression_output.value.eval()) + + default_output = estimator_spec.export_outputs[ + signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] + self.assertTrue(isinstance(default_output, + core_export_lib.RegressionOutput)) + self.assertAllEqual(predictions["scores"].eval(), + default_output.value.eval()) + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/learn/python/learn/estimators/rnn_common.py b/tensorflow/contrib/learn/python/learn/estimators/rnn_common.py index f20dc788349..6bb2b8b2aad 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/rnn_common.py +++ b/tensorflow/contrib/learn/python/learn/estimators/rnn_common.py @@ -66,8 +66,8 @@ def _get_single_cell(cell_type, num_units): ValueError: `cell_type` is an invalid `RNNCell` name. TypeError: `cell_type` is not a string or a subclass of `RNNCell`. """ - cell_type = _CELL_TYPES.get(cell_type) - if cell_type is None and not issubclass(cell_type, contrib_rnn.RNNCell): + cell_type = _CELL_TYPES.get(cell_type, cell_type) + if not cell_type or not issubclass(cell_type, contrib_rnn.RNNCell): raise ValueError('The supported cell types are {}; got {}'.format( list(_CELL_TYPES.keys()), cell_type)) return cell_type(num_units=num_units) diff --git a/tensorflow/contrib/learn/python/learn/experiment.py b/tensorflow/contrib/learn/python/learn/experiment.py index a8f8d995fe6..cecc24c17d8 100644 --- a/tensorflow/contrib/learn/python/learn/experiment.py +++ b/tensorflow/contrib/learn/python/learn/experiment.py @@ -97,7 +97,8 @@ class Experiment(object): finite number of batches (generally, 1 epoch over the evaluation data). eval_metrics: `dict` of string, metric function. If `None`, default set is used. This should be `None` if the `estimator` is - ${tf.estimator.Estimator}. + ${tf.estimator.Estimator}. If metrics are provided they will be + *appended* to the default set. train_steps: Perform this many steps of training. `None`, the default, means train forever. eval_steps: `evaluate` runs until input is exhausted (or another exception diff --git a/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_full_matrix_test.py b/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_full_matrix_test.py index 93cbb48e1b2..d4a9e97ce7a 100644 --- a/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_full_matrix_test.py +++ b/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_full_matrix_test.py @@ -45,7 +45,7 @@ class SquareLinearOperatorFullMatrixTest( # values are random and we want the same value used for both mat and # feed_dict. matrix = matrix.eval() - operator = linalg.LinearOperatorFullMatrix(matrix) + operator = linalg.LinearOperatorFullMatrix(matrix_ph) feed_dict = {matrix_ph: matrix} else: operator = linalg.LinearOperatorFullMatrix(matrix) @@ -105,7 +105,7 @@ class SquareLinearOperatorFullMatrixSymmetricPositiveDefiniteTest( # feed_dict. matrix = matrix.eval() operator = linalg.LinearOperatorFullMatrix( - matrix, is_self_adjoint=True, is_positive_definite=True) + matrix_ph, is_self_adjoint=True, is_positive_definite=True) feed_dict = {matrix_ph: matrix} else: operator = linalg.LinearOperatorFullMatrix( @@ -144,7 +144,7 @@ class NonSquareLinearOperatorFullMatrixTest( # values are random and we want the same value used for both mat and # feed_dict. matrix = matrix.eval() - operator = linalg.LinearOperatorFullMatrix(matrix) + operator = linalg.LinearOperatorFullMatrix(matrix_ph) feed_dict = {matrix_ph: matrix} else: operator = linalg.LinearOperatorFullMatrix(matrix) diff --git a/tensorflow/contrib/nccl/__init__.py b/tensorflow/contrib/nccl/__init__.py index 0275ed60798..d851c522c03 100644 --- a/tensorflow/contrib/nccl/__init__.py +++ b/tensorflow/contrib/nccl/__init__.py @@ -12,13 +12,25 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Ops for nccl AllReduce.""" +"""Functions for using NVIDIA nccl collective ops. + +@@all_max +@@all_min +@@all_prod +@@all_sum +@@broadcast + +""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.nccl.python.ops.nccl_ops import * -# pylint: enable=wildcard-import +from tensorflow.contrib.nccl.python.ops.nccl_ops import all_max +from tensorflow.contrib.nccl.python.ops.nccl_ops import all_min +from tensorflow.contrib.nccl.python.ops.nccl_ops import all_prod +from tensorflow.contrib.nccl.python.ops.nccl_ops import all_sum +from tensorflow.contrib.nccl.python.ops.nccl_ops import broadcast + +from tensorflow.python.util.all_util import remove_undocumented +remove_undocumented(__name__) diff --git a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py index d3af4de7211..3fc78d42531 100644 --- a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py +++ b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py @@ -66,7 +66,7 @@ class RNNCellTest(test.TestCase): x = array_ops.zeros([batch_size, input_size]) m = array_ops.zeros([batch_size, state_size]) output, state = rnn_cell.CoupledInputForgetGateLSTMCell( - num_units=num_units, forget_bias=1.0)(x, m) + num_units=num_units, forget_bias=1.0, state_is_tuple=False)(x, m) sess.run([variables.global_variables_initializer()]) res = sess.run([output, state], { x.name: diff --git a/tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py b/tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py index f44302638eb..d828a337f31 100644 --- a/tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py +++ b/tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py @@ -466,12 +466,13 @@ class OutputProjectionWrapper(RNNCell): if needed or directly feed into a softmax. """ - def __init__(self, cell, output_size, reuse=None): + def __init__(self, cell, output_size, activation=None, reuse=None): """Create a cell with output projection. Args: cell: an RNNCell, a projection to output_size is added to it. output_size: integer, the size of the output after projection. + activation: (optional) an optional activation function. reuse: (optional) Python boolean describing whether to reuse variables in an existing scope. If not `True`, and the existing scope already has the given variables, an error is raised. @@ -487,6 +488,7 @@ class OutputProjectionWrapper(RNNCell): self._cell = cell self._output_size = output_size self._reuse = reuse + self._activation = activation @property def state_size(self): @@ -507,6 +509,8 @@ class OutputProjectionWrapper(RNNCell): with _checked_scope(self, scope or "output_projection_wrapper", reuse=self._reuse): projected = _linear(output, self._output_size, True) + if self._activation: + projected = self._activation(projected) return projected, res_state @@ -518,12 +522,13 @@ class InputProjectionWrapper(RNNCell): do the projection on this batch-concatenated sequence, then split it. """ - def __init__(self, cell, num_proj, input_size=None): + def __init__(self, cell, num_proj, activation=None, input_size=None): """Create a cell with input projection. Args: cell: an RNNCell, a projection of inputs is added before it. num_proj: Python integer. The dimension to project to. + activation: (optional) an optional activation function. input_size: Deprecated and unused. Raises: @@ -535,6 +540,7 @@ class InputProjectionWrapper(RNNCell): raise TypeError("The parameter cell is not RNNCell.") self._cell = cell self._num_proj = num_proj + self._activation = activation @property def state_size(self): @@ -553,6 +559,8 @@ class InputProjectionWrapper(RNNCell): # Default scope: "InputProjectionWrapper" with vs.variable_scope(scope or "input_projection_wrapper"): projected = _linear(inputs, self._num_proj, True) + if self._activation: + projected = self._activation(projected) return self._cell(projected, state) diff --git a/tensorflow/contrib/rnn/python/ops/rnn_cell.py b/tensorflow/contrib/rnn/python/ops/rnn_cell.py index 2cd18142131..4eb2966ef28 100644 --- a/tensorflow/contrib/rnn/python/ops/rnn_cell.py +++ b/tensorflow/contrib/rnn/python/ops/rnn_cell.py @@ -109,7 +109,7 @@ class CoupledInputForgetGateLSTMCell(core_rnn_cell.RNNCell): def __init__(self, num_units, use_peepholes=False, initializer=None, num_proj=None, proj_clip=None, num_unit_shards=1, num_proj_shards=1, - forget_bias=1.0, state_is_tuple=False, + forget_bias=1.0, state_is_tuple=True, activation=math_ops.tanh, reuse=None): """Initialize the parameters for an LSTM cell. @@ -457,7 +457,7 @@ class GridLSTMCell(core_rnn_cell.RNNCell): start_freqindex_list=None, end_freqindex_list=None, couple_input_forget_gates=False, - state_is_tuple=False, + state_is_tuple=True, reuse=None): """Initialize the parameters for an LSTM cell. @@ -571,7 +571,7 @@ class GridLSTMCell(core_rnn_cell.RNNCell): ValueError: if an input_size was specified and the provided inputs have a different dimension. """ - batch_size = int(inputs.get_shape()[0]) + batch_size = inputs.shape[0].value or array_ops.shape(inputs)[0] freq_inputs = self._make_tf_features(inputs) with _checked_scope(self, scope or "grid_lstm_cell", initializer=self._initializer, reuse=self._reuse): @@ -994,7 +994,7 @@ class BidirectionalGridLSTMCell(GridLSTMCell): ValueError: if an input_size was specified and the provided inputs have a different dimension. """ - batch_size = int(inputs.get_shape()[0]) + batch_size = inputs.shape[0].value or array_ops.shape(inputs)[0] fwd_inputs = self._make_tf_features(inputs) if self._backward_slice_offset: bwd_inputs = self._make_tf_features(inputs, self._backward_slice_offset) @@ -1043,7 +1043,7 @@ class AttentionCellWrapper(core_rnn_cell.RNNCell): """ def __init__(self, cell, attn_length, attn_size=None, attn_vec_size=None, - input_size=None, state_is_tuple=False, reuse=None): + input_size=None, state_is_tuple=True, reuse=None): """Create a cell with attention. Args: diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py index 9c3015ff250..606215b656f 100644 --- a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py @@ -56,6 +56,13 @@ class AttentionWrapperTest(test.TestCase): return super(AttentionWrapperTest, self).assertAllClose( *args, **kwargs) + def testAttentionWrapperState(self): + num_fields = len(wrapper.AttentionWrapperState._fields) # pylint: disable=protected-access + state = wrapper.AttentionWrapperState(*([None] * num_fields)) + new_state = state.clone(time=1) + self.assertEqual(state.time, None) + self.assertEqual(new_state.time, 1) + def _testWithAttention(self, create_attention_mechanism, expected_final_output, diff --git a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py index ac79dbecefc..d01d3751195 100644 --- a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py +++ b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py @@ -369,7 +369,26 @@ class AttentionWrapperState( - `attention_history`: (if enabled) a `TensorArray` containing attention matrices from all time steps. Call `stack()` to convert to a `Tensor`. """ - pass + + def clone(self, **kwargs): + """Clone this object, overriding components provided by kwargs. + + Example: + + ```python + initial_state = attention_wrapper.zero_state(dtype=..., batch_size=...) + initial_state = initial_state.clone(cell_state=encoder_state) + ``` + + Args: + **kwargs: Any properties of the state object to replace in the returned + `AttentionWrapperState`. + + Returns: + A new `AttentionWrapperState` whose properties are the same as + this one, except any overriden properties as provided in `kwargs`. + """ + return super(AttentionWrapperState, self)._replace(**kwargs) def hardmax(logits, name=None): diff --git a/tensorflow/contrib/seq2seq/python/ops/helper.py b/tensorflow/contrib/seq2seq/python/ops/helper.py index 258e74b8194..7a24261b644 100644 --- a/tensorflow/contrib/seq2seq/python/ops/helper.py +++ b/tensorflow/contrib/seq2seq/python/ops/helper.py @@ -431,7 +431,7 @@ class ScheduledOutputTrainingHelper(TrainingHelper): shape=base_shape)) all_finished = math_ops.reduce_all(finished) - no_samples = math_ops.equal(array_ops.shape(sample_ids)[0], 0) + no_samples = math_ops.logical_not(math_ops.reduce_any(sample_ids)) next_inputs = control_flow_ops.cond( math_ops.logical_or(all_finished, no_samples), lambda: base_next_inputs, maybe_sample) diff --git a/tensorflow/contrib/tensor_forest/client/random_forest.py b/tensorflow/contrib/tensor_forest/client/random_forest.py index 6cbe7b6d49a..01697df086d 100644 --- a/tensorflow/contrib/tensor_forest/client/random_forest.py +++ b/tensorflow/contrib/tensor_forest/client/random_forest.py @@ -31,6 +31,8 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import state_ops from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.training import basic_session_run_hooks +from tensorflow.python.training import monitored_session from tensorflow.python.training import session_run_hook @@ -95,6 +97,22 @@ class TensorForestLossHook(session_run_hook.SessionRunHook): run_context.request_stop() +class EveryCheckpointPreSaveListener( + basic_session_run_hooks.CheckpointSaverListener): + """Runs a given op before each checkpoint save.""" + + def __init__(self, op): + """Initializes the object. + + Args: + op: An op to run before each checkpoint save. + """ + self._op = op + + def before_save(self, session, global_step_value): + session.run(self._op) + + def get_model_fn(params, graph_builder_class, device_assigner, @@ -103,6 +121,7 @@ def get_model_fn(params, num_trainers=1, trainer_id=0, report_feature_importances=False, + model_dir=None, local_eval=False): """Return a model function given a way to construct a graph builder.""" def _model_fn(features, labels, mode): @@ -138,6 +157,8 @@ def get_model_fn(params, # question of why we force everything to adhere to a single model_fn). loss_deps = [] training_graph = None + training_hooks = [] + scaffold = None if labels is not None and mode == model_fn_lib.ModeKeys.TRAIN: training_graph = control_flow_ops.group( graph_builder.training_graph( @@ -146,6 +167,15 @@ def get_model_fn(params, trainer_id=trainer_id), state_ops.assign_add(contrib_framework.get_global_step(), 1)) loss_deps.append(training_graph) + if hasattr(graph_builder, 'finalize_training'): + finalize_listener = EveryCheckpointPreSaveListener( + graph_builder.finalize_training()) + scaffold = monitored_session.Scaffold() + training_hooks.append( + basic_session_run_hooks.CheckpointSaverHook( + model_dir, save_secs=600, save_steps=None, + scaffold=scaffold, + listeners=[finalize_listener])) training_loss = None if (mode == model_fn_lib.ModeKeys.EVAL or @@ -158,7 +188,6 @@ def get_model_fn(params, if weights is not None: features[weights_name] = weights - training_hooks = [] if early_stopping_rounds: training_hooks.append(TensorForestLossHook(early_stopping_rounds)) @@ -167,7 +196,9 @@ def get_model_fn(params, predictions=inference, loss=training_loss, train_op=training_graph, - training_hooks=training_hooks) + training_hooks=training_hooks, + scaffold=scaffold) + return _model_fn @@ -257,6 +288,7 @@ class TensorForestEstimator(estimator.Estimator): num_trainers=num_trainers, trainer_id=trainer_id, report_feature_importances=report_feature_importances, + model_dir=model_dir, local_eval=local_eval), model_dir=model_dir, config=config, diff --git a/tensorflow/contrib/tensorboard/BUILD b/tensorflow/contrib/tensorboard/BUILD index 06f8c9e18f7..db6b3131383 100644 --- a/tensorflow/contrib/tensorboard/BUILD +++ b/tensorflow/contrib/tensorboard/BUILD @@ -43,9 +43,9 @@ py_library( srcs = ["plugins/projector/__init__.py"], srcs_version = "PY2AND3", deps = [ - ":protos_all_py", "//tensorflow/python:lib", "//tensorflow/tensorboard/plugins/projector:projector_plugin", + "//tensorflow/tensorboard/plugins/projector:protos_all_py", ], ) @@ -56,10 +56,10 @@ py_test( srcs_version = "PY2AND3", deps = [ ":projector", - ":protos_all_py", "//tensorflow/python:client_testlib", "//tensorflow/python:platform", "//tensorflow/python:summary", + "//tensorflow/tensorboard/plugins/projector:protos_all_py", ], ) diff --git a/tensorflow/contrib/tensorboard/plugins/projector/__init__.py b/tensorflow/contrib/tensorboard/plugins/projector/__init__.py index c11f5d065c2..635c569d734 100644 --- a/tensorflow/contrib/tensorboard/plugins/projector/__init__.py +++ b/tensorflow/contrib/tensorboard/plugins/projector/__init__.py @@ -28,11 +28,10 @@ from __future__ import print_function import os from google.protobuf import text_format -from tensorflow.contrib.tensorboard.plugins.projector.projector_config_pb2 import EmbeddingInfo -from tensorflow.contrib.tensorboard.plugins.projector.projector_config_pb2 import ProjectorConfig from tensorflow.python.lib.io import file_io from tensorflow.tensorboard.plugins.projector import projector_plugin # pylint: disable=wildcard-import +from tensorflow.tensorboard.plugins.projector.projector_config_pb2 import * from tensorflow.tensorboard.plugins.projector.projector_plugin import * # pylint: enable=wildcard-import diff --git a/tensorflow/contrib/tensorboard/plugins/projector/projector_api_test.py b/tensorflow/contrib/tensorboard/plugins/projector/projector_api_test.py index 96e084fa736..91ea6bc7531 100644 --- a/tensorflow/contrib/tensorboard/plugins/projector/projector_api_test.py +++ b/tensorflow/contrib/tensorboard/plugins/projector/projector_api_test.py @@ -24,10 +24,10 @@ import shutil from google.protobuf import text_format from tensorflow.contrib.tensorboard.plugins import projector -from tensorflow.contrib.tensorboard.plugins.projector import projector_config_pb2 from tensorflow.python.platform import gfile from tensorflow.python.platform import test from tensorflow.python.summary.writer import writer as writer_lib +from tensorflow.tensorboard.plugins.projector import projector_config_pb2 class ProjectorApiTest(test.TestCase): diff --git a/tensorflow/contrib/xla_tf_graph/BUILD b/tensorflow/contrib/xla_tf_graph/BUILD new file mode 100644 index 00000000000..aef277bc6e7 --- /dev/null +++ b/tensorflow/contrib/xla_tf_graph/BUILD @@ -0,0 +1,62 @@ +# Description: +# contains parts of TensorFlow that are experimental or unstable and which are not supported. + +package( + default_visibility = ["//visibility:public"], +) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), +) + +cc_library( + name = "xla_tf_graph_util", + srcs = [ + "xla_tf_graph_util.cc", + ], + hdrs = [ + "xla_tf_graph_util.h", + ], + deps = [ + "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla/client", + "//tensorflow/compiler/xla/client:client_library", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + ], +) + +cc_test( + name = "xla_tf_graph_util_test", + srcs = ["xla_tf_graph_util_test.cc"], + deps = [ + ":xla_tf_graph_util", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:function_ops", + "//tensorflow/cc:scope", + "//tensorflow/compiler/jit:xla_cpu_jit", + "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla/client:client_library", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework_internal", + "//tensorflow/core:ops", + "//tensorflow/core:tensorflow", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core/kernels:cwise_op", + ], +) diff --git a/tensorflow/contrib/xla_tf_graph/README.md b/tensorflow/contrib/xla_tf_graph/README.md new file mode 100644 index 00000000000..a374189e813 --- /dev/null +++ b/tensorflow/contrib/xla_tf_graph/README.md @@ -0,0 +1,8 @@ +# Xla Tf Graph + +## Description + +This module contains utilities to treat xla representation as tf graph to support mobile SOC experiments and leverage tf tools. + +Maintainers: +- Satoshi Kataoka (satok@google.com, github.com/satok16) diff --git a/tensorflow/contrib/xla_tf_graph/xla_tf_graph_util.cc b/tensorflow/contrib/xla_tf_graph/xla_tf_graph_util.cc new file mode 100644 index 00000000000..3bad9b80675 --- /dev/null +++ b/tensorflow/contrib/xla_tf_graph/xla_tf_graph_util.cc @@ -0,0 +1,71 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/xla_tf_graph/xla_tf_graph_util.h" + +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/client_library.h" + +namespace tensorflow { +namespace xla_tf_graph { + +namespace { + +constexpr const char* const GRAPH_NAME = "xla_tf_graph_util"; + +void SetupXlaCpuClient(std::unique_ptr* flib_def, + std::unique_ptr* flr, + std::unique_ptr* compiler) { + xla::Client* client = xla::ClientLibrary::LocalClientOrDie(); + XlaOpRegistry::RegisterCompilationKernels(); + + FunctionDefLibrary flib; + flib_def->reset(new FunctionLibraryDefinition(OpRegistry::Global(), flib)); + + // Setup compiler options + XlaCompiler::Options options; + options.device_type = DeviceType(DEVICE_CPU_XLA_JIT); + options.client = client; + compiler->reset(new XlaCompiler(options)); + + flr->reset(NewFunctionLibraryRuntime( + compiler->get()->device_mgr(), /*env=*/nullptr, compiler->get()->device(), + TF_GRAPH_DEF_VERSION, flib_def->get(), OptimizerOptions(), + /*custom_kernel_creator=*/nullptr)); +} + +} // namespace + +xla::StatusOr> +ConvertTfGraphToXlaSessionModule(const std::vector& args, + std::unique_ptr graph) { + CHECK(graph); + + std::unique_ptr flib_def; + std::unique_ptr flr; + std::unique_ptr compiler; + + SetupXlaCpuClient(&flib_def, &flr, &compiler); + + // Compile graph and build computation + XlaCompiler::CompilationResult result; + TF_CHECK_OK(compiler->CompileGraph(GRAPH_NAME, std::move(graph), flr.get(), + args, &result)); + + return result.computation.Snapshot(); +} + +} // namespace xla_tf_graph +} // namespace tensorflow diff --git a/tensorflow/contrib/xla_tf_graph/xla_tf_graph_util.h b/tensorflow/contrib/xla_tf_graph/xla_tf_graph_util.h new file mode 100644 index 00000000000..89dca876b08 --- /dev/null +++ b/tensorflow/contrib/xla_tf_graph/xla_tf_graph_util.h @@ -0,0 +1,43 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_XLA_TF_GRAPH_XLA_TF_GRAPH_UTIL_H_ +#define TENSORFLOW_CONTRIB_XLA_TF_GRAPH_XLA_TF_GRAPH_UTIL_H_ + +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/compiler/xla/client/client.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/platform/macros.h" + +namespace tensorflow { +namespace xla_tf_graph { + +// A set of utilities to handle xla computation requests. +// These utilities help developers leverage existing tools to work with +// xla computations, also provide a way to support TensorFlow ops by +// implementing xla computations so that they can do experiments on their +// specialized environments. + +// Convert a tf graph to a xla session module +xla::StatusOr> +ConvertTfGraphToXlaSessionModule(const std::vector& args, + std::unique_ptr graph); + +} // namespace xla_tf_graph +} // namespace tensorflow + +#endif // TENSORFLOW_CONTRIB_XLA_TF_GRAPH_XLA_TF_GRAPH_UTIL_H_ diff --git a/tensorflow/contrib/xla_tf_graph/xla_tf_graph_util_test.cc b/tensorflow/contrib/xla_tf_graph/xla_tf_graph_util_test.cc new file mode 100644 index 00000000000..bab42561871 --- /dev/null +++ b/tensorflow/contrib/xla_tf_graph/xla_tf_graph_util_test.cc @@ -0,0 +1,57 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/xla_tf_graph/xla_tf_graph_util.h" +#include "tensorflow/cc/framework/scope.h" +#include "tensorflow/cc/ops/function_ops.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace xla_tf_graph { + +static std::unique_ptr BuildAddGraph() { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0); + auto b = ops::_Arg(scope.WithOpName("B"), DT_INT32, 1); + auto c = ops::Add(scope.WithOpName("C"), a, b); + auto d = ops::_Retval(scope.WithOpName("D"), c, 0); + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_CHECK_OK(scope.ToGraph(graph.get())); + return graph; +} + +TEST(XlaTfGraphUtil, ConvertTfGraphToHloModule) { + // Builds a description of the arguments. + std::vector args(2); + args[0].kind = XlaCompiler::Argument::kParameter; + args[0].type = DT_INT32; + args[0].shape = TensorShape({2}); + args[1].kind = XlaCompiler::Argument::kParameter; + args[1].type = DT_INT32; + args[1].shape = TensorShape({2}); + + std::unique_ptr graph = BuildAddGraph(); + + TF_ASSIGN_OR_ASSERT_OK( + std::unique_ptr session_module, + ConvertTfGraphToXlaSessionModule(args, std::move(graph))); + + ASSERT_EQ(5, session_module->entry().requests_size()); +} + +} // namespace xla_tf_graph +} // namespace tensorflow diff --git a/tensorflow/core/common_runtime/bfc_allocator.cc b/tensorflow/core/common_runtime/bfc_allocator.cc index b18209cb605..2cf668400e6 100644 --- a/tensorflow/core/common_runtime/bfc_allocator.cc +++ b/tensorflow/core/common_runtime/bfc_allocator.cc @@ -453,8 +453,8 @@ void BFCAllocator::RemoveFreeChunkIterFromBin( void BFCAllocator::RemoveFreeChunkFromBin(BFCAllocator::ChunkHandle h) { Chunk* c = ChunkFromHandle(h); CHECK(!c->in_use() && (c->bin_num != kInvalidBinNum)); - int count = BinFromIndex(c->bin_num)->free_chunks.erase(h); - CHECK(count > 0) << "Could not find chunk in bin"; + CHECK_GT(BinFromIndex(c->bin_num)->free_chunks.erase(h), 0) + << "Could not find chunk in bin"; c->bin_num = kInvalidBinNum; } diff --git a/tensorflow/core/common_runtime/bfc_allocator.h b/tensorflow/core/common_runtime/bfc_allocator.h index 0b528cb0c27..b74c161dcec 100644 --- a/tensorflow/core/common_runtime/bfc_allocator.h +++ b/tensorflow/core/common_runtime/bfc_allocator.h @@ -78,7 +78,7 @@ class BFCAllocator : public VisitableAllocator { // A ChunkHandle is an index into the chunks_ vector in BFCAllocator // kInvalidChunkHandle means an invalid chunk - typedef int ChunkHandle; + typedef size_t ChunkHandle; static const int kInvalidChunkHandle = -1; typedef int BinNum; diff --git a/tensorflow/core/common_runtime/constant_folding.cc b/tensorflow/core/common_runtime/constant_folding.cc index 5db49aa498c..8c4085425a1 100644 --- a/tensorflow/core/common_runtime/constant_folding.cc +++ b/tensorflow/core/common_runtime/constant_folding.cc @@ -34,6 +34,7 @@ limitations under the License. #include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/graph/subgraph.h" #include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/public/session_options.h" @@ -304,10 +305,18 @@ Status DoConstantFoldingWithStatus(const ConstantFoldingOptions& opts, tensors_to_replace.push_back({n.second, n.first.second}); } + auto graph_runner = std::unique_ptr(new GraphRunner(env)); // Evaluate the constant foldable nodes. std::vector outputs; - Status s = GraphRunner::Run(constant_graph.get(), function_library, env, - {} /* inputs*/, tensors_to_fetch_names, &outputs); + auto delete_tensors = gtl::MakeCleanup([&graph_runner, &outputs] { + // Output tensors need to be cleared before the GraphRunner is deleted. + outputs.clear(); + graph_runner.reset(nullptr); + }); + + Status s = + graph_runner->Run(constant_graph.get(), function_library, {} /* inputs*/, + tensors_to_fetch_names, &outputs); if (!s.ok()) { VLOG(1) << "Could not fetch constants: " << s; *was_mutated = false; diff --git a/tensorflow/core/common_runtime/device_mgr.cc b/tensorflow/core/common_runtime/device_mgr.cc index 820c4370e21..7807656cb25 100644 --- a/tensorflow/core/common_runtime/device_mgr.cc +++ b/tensorflow/core/common_runtime/device_mgr.cc @@ -44,7 +44,7 @@ DeviceMgr::~DeviceMgr() { } StringPiece DeviceMgr::CopyToBackingStore(StringPiece s) { - int n = s.size(); + size_t n = s.size(); char* space = name_backing_store_.Alloc(n); memcpy(space, s.data(), n); return StringPiece(space, n); diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc index c4b2b6c12a5..eda2be3e70f 100644 --- a/tensorflow/core/common_runtime/direct_session.cc +++ b/tensorflow/core/common_runtime/direct_session.cc @@ -427,7 +427,7 @@ Status DirectSession::Run(const RunOptions& run_options, TF_RETURN_IF_ERROR(SendInputs(inputs, executors_and_keys, run_state.rendez)); // Start parallel Executors. - const int num_executors = executors_and_keys->items.size(); + const size_t num_executors = executors_and_keys->items.size(); ExecutorBarrier* barrier = new ExecutorBarrier( num_executors, run_state.rendez, [&run_state](const Status& ret) { { @@ -458,7 +458,7 @@ Status DirectSession::Run(const RunOptions& run_options, options_.config.graph_options().build_cost_model(); const int64 build_cost_model_after = options_.config.graph_options().build_cost_model_after(); - int measure_step_count = executor_step_count - build_cost_model_after; + int64 measure_step_count = executor_step_count - build_cost_model_after; if (measure_step_count >= 0) { update_cost_model = ((measure_step_count + 1) % build_cost_model_every == 0); @@ -611,7 +611,7 @@ Status DirectSession::PRunSetup(const std::vector& input_names, } // Start parallel Executors. - const int num_executors = executors_and_keys->items.size(); + const size_t num_executors = executors_and_keys->items.size(); ExecutorBarrier* barrier = new ExecutorBarrier( num_executors, run_state->rendez, [run_state](const Status& ret) { if (!ret.ok()) { diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc index eaa54c2c48a..561e185ac4e 100644 --- a/tensorflow/core/common_runtime/executor.cc +++ b/tensorflow/core/common_runtime/executor.cc @@ -232,7 +232,7 @@ struct NodeItem { int input_start = 0; // Number of output edges. - int num_output_edges; + size_t num_output_edges; PendingCounts::Handle pending_id; @@ -307,7 +307,7 @@ class GraphView { void Initialize(const Graph* g); Status SetAllocAttrs(const Graph* g, const Device* device); - NodeItem* node(int id) const { + NodeItem* node(size_t id) const { DCHECK_GE(id, 0); DCHECK_LT(id, num_nodes_); uint32 offset = node_offsets_[id]; @@ -454,7 +454,7 @@ GraphView::~GraphView() { } size_t GraphView::NodeItemBytes(const Node* n) { - const int num_output_edges = n->out_edges().size(); + const size_t num_output_edges = n->out_edges().size(); const int num_inputs = n->num_inputs(); const int num_outputs = n->num_outputs(); @@ -500,11 +500,11 @@ char* GraphView::InitializeNode(char* ptr, const Node* n) { // pointers). Casting to int64 is needed on 32bit CPU to avoid comparing // values as "int" vs "size_t" in CHECK_LE. CHECK_LE(static_cast(ptr - space_), kuint32max); - const uint32 offset = ptr - space_; + const uint32 offset = static_cast(ptr - space_); node_offsets_[id] = offset; ptr += bytes; - const int num_output_edges = n->out_edges().size(); + const size_t num_output_edges = n->out_edges().size(); const int num_inputs = n->num_inputs(); const int num_outputs = n->num_outputs(); @@ -580,9 +580,10 @@ void GraphView::Initialize(const Graph* g) { CHECK_EQ(ptr, space_ + total_bytes); } -void GetMaxPendingCounts(const Node* n, int* max_pending, int* max_dead_count) { - const int num_in_edges = n->in_edges().size(); - int initial_count; +void GetMaxPendingCounts(const Node* n, size_t* max_pending, + size_t* max_dead_count) { + const size_t num_in_edges = n->in_edges().size(); + size_t initial_count; if (IsMerge(n)) { // merge waits all control inputs so we initialize the pending // count to be the number of control edges. @@ -626,8 +627,7 @@ Status ExecutorImpl::Initialize() { FrameInfo* frame_info = EnsureFrameInfo(frame_name); // See if this node is a root node, and if so, add to root_nodes_. - const int num_in_edges = n->in_edges().size(); - if (num_in_edges == 0) { + if (n->in_edges().empty()) { root_nodes_.push_back(n); } @@ -659,7 +659,7 @@ Status ExecutorImpl::Initialize() { // pending counts data structure, and allocate a handle in // that frame's pending counts data structure that has enough // space to store these maximal count values. - int max_pending, max_dead; + size_t max_pending, max_dead; GetMaxPendingCounts(n, &max_pending, &max_dead); item->pending_id = frame_info->pending_counts_layout.CreateHandle(max_pending, max_dead); @@ -896,7 +896,7 @@ class ExecutorState { Entry* input_tensors; // The number of outstanding ops for each iteration. - int outstanding_ops; + size_t outstanding_ops; // The number of outstanding frames for each iteration. int outstanding_frame_count; @@ -1037,13 +1037,13 @@ class ExecutorState { inline IterationState* GetIteration(int64 iter) EXCLUSIVE_LOCKS_REQUIRED(mu) { - int index = iter % iterations.size(); + size_t index = iter % iterations.size(); return iterations[index]; } inline void SetIteration(int64 iter, IterationState* state) EXCLUSIVE_LOCKS_REQUIRED(mu) { - int index = iter % iterations.size(); + size_t index = iter % iterations.size(); DCHECK(state == nullptr || iterations[index] == nullptr); iterations[index] = state; } @@ -1404,7 +1404,7 @@ void ExecutorImpl::InitializePending(const Graph* graph, for (const Node* n : graph->nodes()) { const int id = n->id(); const string& name = cf_info.frame_names[id]; - int max_pending, max_dead; + size_t max_pending, max_dead; GetMaxPendingCounts(n, &max_pending, &max_dead); const NodeItem* item = gview_.node(id); PendingCounts* counts = EnsureFrameInfo(name)->pending_counts; @@ -2027,7 +2027,7 @@ bool ExecutorState::NodeDone(const Status& s, const Node* node, } bool completed = false; - int ready_size = ready.size(); + size_t ready_size = ready.size(); if (ready_size == 0 || !s.ok()) { completed = (num_outstanding_ops_.fetch_sub(1) == 1); } else if (ready_size > 1) { @@ -2375,10 +2375,10 @@ void ExecutorState::FrameState::ActivateNodes(const NodeItem* item, TaggedNodeSeq* ready) { const GraphView& gview = executor->gview_; IterationState* iter_state = GetIteration(iter); - const int num_output_edges = item->num_output_edges; + const size_t num_output_edges = item->num_output_edges; const EdgeInfo* edges = item->output_edge_list(); Entry* input_tensors = iter_state->input_tensors; - for (int out_index = 0; out_index < num_output_edges; out_index++) { + for (size_t out_index = 0; out_index < num_output_edges; out_index++) { const EdgeInfo& e = edges[out_index]; const int dst_id = e.dst_id; const NodeItem* dst_item = gview.node(dst_id); diff --git a/tensorflow/core/common_runtime/executor.h b/tensorflow/core/common_runtime/executor.h index 239c9666e33..93b58906dda 100644 --- a/tensorflow/core/common_runtime/executor.h +++ b/tensorflow/core/common_runtime/executor.h @@ -162,7 +162,7 @@ class ExecutorBarrier { // // 'done' is called after the last executor completes, and // ExecutorBarrier is deleted. - ExecutorBarrier(int num, Rendezvous* r, StatusCallback done) + ExecutorBarrier(size_t num, Rendezvous* r, StatusCallback done) : rendez_(r), done_cb_(done), pending_(num) {} ~ExecutorBarrier() {} diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc index cb7e1a40ceb..5f011c2ce94 100644 --- a/tensorflow/core/common_runtime/function.cc +++ b/tensorflow/core/common_runtime/function.cc @@ -274,8 +274,9 @@ class CallOp : public AsyncOpKernel { if (!status.ok()) { ctx->SetStatus(status); } else { - CHECK_EQ(rets->size(), ctx->num_outputs()); - for (size_t i = 0; i < rets->size(); ++i) { + const int ret_size = static_cast(rets->size()); + CHECK_EQ(ret_size, ctx->num_outputs()); + for (int i = 0; i < ret_size; ++i) { ctx->set_output(i, (*rets)[i]); } } @@ -1000,7 +1001,7 @@ string NewName(const Node* n, bool pretty) { void ToGraphDef(const Graph* g, GraphDef* gdef, bool pretty) { // We visit nodes in forward topological sort order, which is a // possible execution order of the graph. - std::vector pending(g->num_node_ids()); + std::vector pending(g->num_node_ids()); std::deque ready; for (const Node* n : g->nodes()) { pending[n->id()] = n->in_edges().size(); @@ -1154,7 +1155,7 @@ FunctionBody* SymbolicGradientHelper::Compute() { Graph* g = gbody_->graph; - const int num_y = gbody_->ret_nodes.size(); + const int num_y = static_cast(gbody_->ret_nodes.size()); // Populate 'y_node_outputs_' with node function body outputs. // Populate 'y_grad_nodes' with initial gradient nodes for each return node of @@ -1169,7 +1170,7 @@ FunctionBody* SymbolicGradientHelper::Compute() { y_node_outputs.push_back({y, 0}); DCHECK_EQ(y->type_string(), kRetOp); const DataType dtype = y->input_type(0); - const int index = gbody_->arg_nodes.size(); + const int index = static_cast(gbody_->arg_nodes.size()); Node* dy = AddArg(g, dtype, index); gbody_->arg_types.push_back(dtype); gbody_->arg_nodes.push_back(dy); @@ -1177,7 +1178,7 @@ FunctionBody* SymbolicGradientHelper::Compute() { } // Populate 'x_nodes' with function args (excluding 'y_grad_node_outputs'). - const int num_x = fbody_->arg_nodes.size(); + const size_t num_x = fbody_->arg_nodes.size(); std::vector x_node_outputs; x_node_outputs.reserve(num_x); for (size_t i = 0; i < fbody_->arg_nodes.size(); ++i) { @@ -1200,7 +1201,8 @@ FunctionBody* SymbolicGradientHelper::Compute() { gbody_->ret_nodes.clear(); // Add new return nodes to the function gradient body for each node // in 'x_grad_nodes'. - for (size_t i = 0; i < fbody_->arg_types.size(); ++i) { + const int arg_types_size = static_cast(fbody_->arg_types.size()); + for (int i = 0; i < arg_types_size; ++i) { Endpoint grad = {x_grad_node_outputs[i].node, x_grad_node_outputs[i].index}; Node* ret = AddRet(g, grad, i); gbody_->ret_nodes.push_back(ret); diff --git a/tensorflow/core/common_runtime/gpu/gpu_stream_util.cc b/tensorflow/core/common_runtime/gpu/gpu_stream_util.cc index eae917a4395..de715d140a1 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_stream_util.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_stream_util.cc @@ -82,7 +82,7 @@ Status AssignStreams(const Graph* graph, const AssignStreamsOpts& opts, // Determine a suitable stream to use. int stream_id = highest_stream_id + 1; for (const Edge* e : n->in_edges()) { - const int fanout = e->src()->out_edges().size(); + const size_t fanout = e->src()->out_edges().size(); if (fanout == 1) { stream_id = (*node_to_stream_id)[e->src()->id()]; break; diff --git a/tensorflow/core/common_runtime/gpu/process_state.cc b/tensorflow/core/common_runtime/gpu/process_state.cc index f9975ef0a08..cee7b6d78ad 100644 --- a/tensorflow/core/common_runtime/gpu/process_state.cc +++ b/tensorflow/core/common_runtime/gpu/process_state.cc @@ -191,7 +191,7 @@ Allocator* ProcessState::GetCUDAHostAllocator(int numa_node) { // example, process_state could maybe save the first stream executor // it knows is valid. gpu::StreamExecutor* se = nullptr; - for (size_t i = 0; i < gpu_allocators_.size(); ++i) { + for (int i = 0; i < static_cast(gpu_allocators_.size()); ++i) { if (gpu_allocators_[i] != nullptr) { se = GPUMachineManager()->ExecutorForDevice(i).ValueOrDie(); break; diff --git a/tensorflow/core/common_runtime/graph_runner.cc b/tensorflow/core/common_runtime/graph_runner.cc index c93ff1cdde8..d4dc8f0057e 100644 --- a/tensorflow/core/common_runtime/graph_runner.cc +++ b/tensorflow/core/common_runtime/graph_runner.cc @@ -15,7 +15,6 @@ limitations under the License. #include "tensorflow/core/common_runtime/graph_runner.h" -#include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/executor.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/memory_types.h" @@ -95,22 +94,24 @@ class SimpleRendezvous : public Rendezvous { } // namespace -// static +GraphRunner::GraphRunner(Env* env) : cpu_device_(GetCPUDevice(env)) {} + +GraphRunner::~GraphRunner() {} + Status GraphRunner::Run(Graph* graph, FunctionLibraryRuntime* function_library, - Env* env, const NamedTensorList& inputs, + const NamedTensorList& inputs, const std::vector& output_names, std::vector* outputs) { + if (cpu_device_ == nullptr) { + return errors::NotFound("Cannot find a device for GraphRunner."); + } + // TODO(vrv): Instead of copying the entire graph, consider modifying // the existing graph, and then removing those removed edges. // prior to returning. std::unique_ptr graph_to_run(new Graph(graph->op_registry())); CopyGraph(*graph, graph_to_run.get()); - std::unique_ptr device = GetCPUDevice(env); - if (!device) { - return errors::NotFound("Cannot find a device for GraphRunner."); - } - SimpleRendezvous* rendez = new SimpleRendezvous; core::ScopedUnref rendez_unref(rendez); @@ -130,7 +131,7 @@ Status GraphRunner::Run(Graph* graph, FunctionLibraryRuntime* function_library, // Call RewriteGraphForExecution TF_RETURN_IF_ERROR(subgraph::RewriteGraphForExecution( graph_to_run.get(), input_names, output_names, {} /* target nodes */, - device->attributes())); + cpu_device_->attributes())); // Create the local executor and the Rendezvous for fetching back the // constants. @@ -143,10 +144,11 @@ Status GraphRunner::Run(Graph* graph, FunctionLibraryRuntime* function_library, Graph* g = graph_to_run.release(); LocalExecutorParams params; - params.device = device.get(); + // The ownership of the output tensors are bound to this device's lifetime. + params.device = cpu_device_.get(); params.function_library = function_library; - params.create_kernel = [&device, g](const NodeDef& ndef, OpKernel** kernel) { - return CreateNonCachedKernel(device.get(), nullptr, ndef, + params.create_kernel = [this, g](const NodeDef& ndef, OpKernel** kernel) { + return CreateNonCachedKernel(cpu_device_.get(), nullptr, ndef, g->versions().producer(), kernel); }; params.delete_kernel = [](OpKernel* kernel) { delete kernel; }; diff --git a/tensorflow/core/common_runtime/graph_runner.h b/tensorflow/core/common_runtime/graph_runner.h index e078c7ffc8c..24e8b04c463 100644 --- a/tensorflow/core/common_runtime/graph_runner.h +++ b/tensorflow/core/common_runtime/graph_runner.h @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/graph/graph.h" @@ -44,16 +45,26 @@ namespace tensorflow { // to be particularly lightweight, fast, or efficient. class GraphRunner { public: + // REQUIRES: `env` is not nullptr. + GraphRunner(Env* env); + ~GraphRunner(); + // Function semantics for `inputs`, `output_names` and `outputs` // matches those from Session::Run(). // + // NOTE: The output tensors share lifetime with the GraphRunner, and could + // be destroyed once the GraphRunner is destroyed. + // // REQUIRES: `graph`, `env`, and `outputs` are not nullptr. // `function_library` may be nullptr. typedef std::vector> NamedTensorList; - static Status Run(Graph* graph, FunctionLibraryRuntime* function_library, - Env* env, const NamedTensorList& inputs, - const std::vector& output_names, - std::vector* outputs); + Status Run(Graph* graph, FunctionLibraryRuntime* function_library, + const NamedTensorList& inputs, + const std::vector& output_names, + std::vector* outputs); + + private: + std::unique_ptr cpu_device_; }; } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/graph_runner_test.cc b/tensorflow/core/common_runtime/graph_runner_test.cc index 5918ba9a22d..ccb44af0ec2 100644 --- a/tensorflow/core/common_runtime/graph_runner_test.cc +++ b/tensorflow/core/common_runtime/graph_runner_test.cc @@ -46,9 +46,9 @@ using test::internal::ExpectEqual; TEST(GraphRunnerTest, SingleConst) { Scope root = Scope::NewRootScope(); auto c = ops::Const(root, 42.0f); + GraphRunner graph_runner(Env::Default()); std::vector outputs; - Status s = GraphRunner::Run(root.graph(), nullptr, Env::Default(), {}, - {c.name()}, &outputs); + Status s = graph_runner.Run(root.graph(), nullptr, {}, {c.name()}, &outputs); TF_ASSERT_OK(s); ExpectEqual(42.0f, outputs[0].scalar()()); } @@ -57,9 +57,10 @@ TEST(GraphRunnerTest, MultiFetchConst) { Scope root = Scope::NewRootScope(); auto c = ops::Const(root, 42.0f); auto pi = ops::Const(root, 3.14f); + GraphRunner graph_runner(Env::Default()); std::vector outputs; - Status s = GraphRunner::Run(root.graph(), nullptr, Env::Default(), {}, - {c.name(), pi.name()}, &outputs); + Status s = graph_runner.Run(root.graph(), nullptr, {}, {c.name(), pi.name()}, + &outputs); TF_ASSERT_OK(s); ExpectEqual(42.0f, outputs[0].scalar()()); ExpectEqual(3.14f, outputs[1].scalar()()); @@ -78,9 +79,10 @@ TEST(GraphRunnerTest, FeedAndFetch) { std::vector> inputs = {{"p1:0", p1_data}, {"p2:0", p2_data}}; + GraphRunner graph_runner(Env::Default()); std::vector outputs; - Status s = GraphRunner::Run(root.graph(), nullptr, Env::Default(), inputs, - {"add:0"}, &outputs); + Status s = + graph_runner.Run(root.graph(), nullptr, inputs, {"add:0"}, &outputs); TF_ASSERT_OK(s); ExpectEqual(3.0f, outputs[0].scalar()()); } diff --git a/tensorflow/core/common_runtime/pending_counts.h b/tensorflow/core/common_runtime/pending_counts.h index f0c79ad601c..198eb896afc 100644 --- a/tensorflow/core/common_runtime/pending_counts.h +++ b/tensorflow/core/common_runtime/pending_counts.h @@ -69,7 +69,7 @@ class PendingCounts { // to retrieve the count data for this node. class Layout { public: - Handle CreateHandle(int max_pending_count, int max_dead_count); + Handle CreateHandle(size_t max_pending_count, size_t max_dead_count); private: friend class PendingCounts; @@ -91,7 +91,7 @@ class PendingCounts { ~PendingCounts() { delete[] bytes_; } - void set_initial_count(Handle h, int pending_count) { + void set_initial_count(Handle h, size_t pending_count) { if (h.is_large_) { LargeCounts* c = Large(h); c->pending = pending_count; @@ -306,7 +306,7 @@ class PendingCounts { }; inline PendingCounts::Handle PendingCounts::Layout::CreateHandle( - int max_pending_count, int max_dead_count) { + size_t max_pending_count, size_t max_dead_count) { Handle result; if ((max_pending_count > kMaxCountForPackedCounts) || (max_dead_count > kMaxCountForPackedCounts)) { diff --git a/tensorflow/core/common_runtime/shape_refiner.cc b/tensorflow/core/common_runtime/shape_refiner.cc index 2f65abde0af..f58faefa9fb 100644 --- a/tensorflow/core/common_runtime/shape_refiner.cc +++ b/tensorflow/core/common_runtime/shape_refiner.cc @@ -19,7 +19,6 @@ limitations under the License. #include #include -#include "tensorflow/core/common_runtime/graph_runner.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/stl_util.h" @@ -33,7 +32,15 @@ using shape_inference::ShapeHandle; ShapeRefiner::ShapeRefiner(int graph_def_version, const OpRegistryInterface* ops) - : graph_def_version_(graph_def_version), ops_registry_(ops) {} + : graph_def_version_(graph_def_version), + ops_registry_(ops), + graph_runner_(Env::Default()) {} + +ShapeRefiner::~ShapeRefiner() { + // The lifetime of the tensors are bound to the GraphRunner, so the tensors + // should be deleted before it. + const_tensor_map_.clear(); +} Status ShapeRefiner::AddNode(const Node* node) { // For each 'input' of this node, fetch the corresponding shape @@ -223,9 +230,8 @@ Status ShapeRefiner::EvaluateConstantTensorForEdge(const Node* node, std::vector outputs; // NOTE; we should pass in a function library runtime if we want // to support constant-expression evaluation on functions. - Status s = GraphRunner::Run(&subgraph, nullptr /* function_library */, - Env::Default(), const_inputs, - {output_tensor_name}, &outputs); + Status s = graph_runner_.Run(&subgraph, nullptr /* function_library */, + const_inputs, {output_tensor_name}, &outputs); // If all kernels in the constant graph are not registered // in the process, GraphRunner::Run may fail, in which case diff --git a/tensorflow/core/common_runtime/shape_refiner.h b/tensorflow/core/common_runtime/shape_refiner.h index b8d69fc05b8..bbde0924c7f 100644 --- a/tensorflow/core/common_runtime/shape_refiner.h +++ b/tensorflow/core/common_runtime/shape_refiner.h @@ -17,6 +17,7 @@ limitations under the License. #include +#include "tensorflow/core/common_runtime/graph_runner.h" #include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/lib/core/status.h" @@ -32,6 +33,7 @@ namespace tensorflow { class ShapeRefiner { public: ShapeRefiner(int graph_def_version, const OpRegistryInterface* ops); + ~ShapeRefiner(); // Performs validation of 'node' and runs 'node's shape function, // storing its shape outputs. @@ -101,6 +103,10 @@ class ShapeRefiner { const int graph_def_version_; const OpRegistryInterface* const ops_registry_; + // The lifetime of the tensors are bound to the runner, so it should be the + // deleted after the tensors. + GraphRunner graph_runner_; + // Stores a map from a node to its InferenceContext. // // Owns values. @@ -118,6 +124,7 @@ class ShapeRefiner { // Only tensors less than 1KiB are currently stored in the cache. static constexpr int64 kMaxTensorSize = 1024; std::unordered_map const_tensor_map_; + TF_DISALLOW_COPY_AND_ASSIGN(ShapeRefiner); }; diff --git a/tensorflow/core/distributed_runtime/graph_mgr.cc b/tensorflow/core/distributed_runtime/graph_mgr.cc index 171e75c0d0a..537d489aae9 100644 --- a/tensorflow/core/distributed_runtime/graph_mgr.cc +++ b/tensorflow/core/distributed_runtime/graph_mgr.cc @@ -375,9 +375,7 @@ void GraphMgr::RecvOutputsFromRendezvousAsync(Rendezvous* rendezvous, } } call_state->mu.lock(); - if (status.ok()) { - call_state->shared_status = status; - } + call_state->shared_status.Update(status); call_state->done_counter--; // If we are the last async call to return, call the done callback. if (call_state->done_counter == 0) { diff --git a/tensorflow/core/distributed_runtime/master.cc b/tensorflow/core/distributed_runtime/master.cc index facbfbb2643..583776b7e33 100644 --- a/tensorflow/core/distributed_runtime/master.cc +++ b/tensorflow/core/distributed_runtime/master.cc @@ -198,7 +198,7 @@ class DeviceFinder { while (num_pending_ != 0) { pending_zero_.wait_for(l, std::chrono::milliseconds(kLoggingPeriodMs)); if (num_pending_ != 0) { - for (int i = 0; i < targets_.size(); ++i) { + for (int i = 0; static_cast(i) < targets_.size(); ++i) { if (!seen_targets_[i]) { LOG(INFO) << "CreateSession still waiting for response from worker: " diff --git a/tensorflow/core/distributed_runtime/master_session.cc b/tensorflow/core/distributed_runtime/master_session.cc index 870df353cb6..bd848849b5d 100644 --- a/tensorflow/core/distributed_runtime/master_session.cc +++ b/tensorflow/core/distributed_runtime/master_session.cc @@ -440,19 +440,6 @@ Status MasterSession::ReffedClientGraph::DoRegisterPartitions( return s; } -static bool CopyIfNeeded(TensorProto* in, TensorProto* out) { - if (in->tensor_content().empty()) { - // If the tensor is not encoded in tensor_content or contains 0 - // elements, we can return it to the client directly. - out->Swap(in); - } else { - Tensor t(in->dtype()); - if (!t.FromProto(cpu_allocator(), *in)) return false; - t.AsProtoTensorContent(out); - } - return true; -} - // Helper class to manage "num" parallel RunGraph calls. class RunManyGraphs { public: @@ -565,7 +552,7 @@ Status MasterSession::ReffedClientGraph::RunPartitions( // We keep these as separate paths for now, to ensure we aren't // inadvertently slowing down the normal run path. if (is_partial_) { - for (int i = 0; i < req.num_feeds(); ++i) { + for (int i = 0; static_cast(i) < req.num_feeds(); ++i) { const string& name = req.feed_name(i); auto iter = part.feed_key.find(name); if (iter == part.feed_key.end()) { @@ -577,7 +564,7 @@ Status MasterSession::ReffedClientGraph::RunPartitions( if (feeds_iter == feeds.end()) { return errors::InvalidArgument("No feed is provided for feed=", name, ", key=", key); - } else if (feeds_iter->second != i) { + } else if (feeds_iter->second != static_cast(i)) { return errors::Internal("Cannot find feed named \"", name, " in request."); } @@ -585,7 +572,7 @@ Status MasterSession::ReffedClientGraph::RunPartitions( } // TODO(suharshs): Make a map from feed to fetch_key to make this faster. // For now, we just iterate through partitions to find the matching key. - for (int i = 0; i < req.num_fetches(); ++i) { + for (int i = 0; static_cast(i) < req.num_fetches(); ++i) { const string& req_fetch = req.fetch_name(i); for (const auto& key_fetch : part.key_fetch) { if (key_fetch.second == req_fetch) { diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.cc b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.cc index 7854949033f..80a2f89337c 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.cc @@ -49,6 +49,9 @@ const char* GrpcWorkerMethodName(GrpcWorkerMethod id) { case GrpcWorkerMethod::kTracing: return "/tensorflow.WorkerService/Tracing"; } + // Shouldn't be reached. + LOG(FATAL) << "Invalid id: this line shouldn't be reached."; + return "invalid id"; } namespace grpc { diff --git a/tensorflow/core/distributed_runtime/worker.cc b/tensorflow/core/distributed_runtime/worker.cc index 16be15fe662..1aced4443f8 100644 --- a/tensorflow/core/distributed_runtime/worker.cc +++ b/tensorflow/core/distributed_runtime/worker.cc @@ -117,9 +117,7 @@ void Worker::MaybeCallFinalCallback(const string& graph_handle, int step_id, } } if (done != nullptr) { - if (s.ok()) { - s = executor_status; - } + s.Update(executor_status); done(s); } } diff --git a/tensorflow/core/graph/quantize_training.cc b/tensorflow/core/graph/quantize_training.cc index 63294c695e5..b241ab6ab32 100644 --- a/tensorflow/core/graph/quantize_training.cc +++ b/tensorflow/core/graph/quantize_training.cc @@ -247,7 +247,7 @@ Status ConnectVariablesToSaveOp(Graph* graph, Node* save_op, // Add a restore subgraph for each variable and connect to the restore_all op. // For each variable we add the following subgraph: // Assign----restore_all -// / \ +// | | // RestoreV2 Variable Status AddRestoreVariableSubgraphs(Graph* graph, Node* save_op, const std::vector& in_edges, @@ -311,7 +311,7 @@ Status AddRestoreVariableSubgraphs(Graph* graph, Node* save_op, // Adds new variables to save and restore ops matching the Save and Restore // graphs created in tensorflow/python/training/saver.py. Status AddSaveAndRestore(Graph* graph, const std::vector& variables) { - Node* save_op; + Node* save_op = nullptr; std::vector in_edges; bool found = false; TF_RETURN_IF_ERROR(FindSaveOp(graph, &save_op, &in_edges, &found)); diff --git a/tensorflow/core/grappler/costs/utils.cc b/tensorflow/core/grappler/costs/utils.cc index 19266208fe1..9eab8495ce6 100644 --- a/tensorflow/core/grappler/costs/utils.cc +++ b/tensorflow/core/grappler/costs/utils.cc @@ -101,6 +101,11 @@ OpInfo::DeviceProperties GetLocalCPUInfo() { OpInfo::DeviceProperties device; device.set_type("CPU"); + device.set_vendor(port::CPUVendorIDString()); + // Combine cpu family and model into the model string. + device.set_model( + strings::StrCat((port::CPUFamily() << 4) + port::CPUModelNum())); + device.set_frequency(port::NominalCPUFrequency()); device.set_num_cores(port::NumSchedulableCPUs()); device.set_l1_cache_size(Eigen::l1CacheSize()); device.set_l2_cache_size(Eigen::l2CacheSize()); diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc index 49891e2a780..7cddedef2e2 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding.cc @@ -130,7 +130,6 @@ bool ConstantFolding::IsFoldable(const NodeDef& node) const { if (!status.ok()) { return false; } - if (op_def->is_stateful()) { return false; } @@ -144,6 +143,15 @@ bool ConstantFolding::IsFoldable(const NodeDef& node) const { return false; } + // No need to (and don't) fold nodes that have no outgoing edges. Such nodes + // could be introduced by an earlier constant folding pass and are preserved + // in case users want to fetch their values; re-processing them would + // lead to an error of adding a duplicated node to graph. + auto outputs = node_map_->GetOutputs(node.name()); + if (outputs.empty()) { + return false; + } + for (const auto& input : node.input()) { bool is_const = IsConst(*node_map_->GetNode(input)); if (!is_const) { @@ -224,8 +232,7 @@ Status ConstantFolding::EvaluateOneFoldable(const NodeDef& node, Status(error::INVALID_ARGUMENT, "Expected at least one output."); } for (int i = 0; i < output_tensors.size(); i++) { - string node_name = strings::StrCat( - AddPrefixToNodeName(node.name(), kConstantFoldingConst)); + string node_name = AddPrefixToNodeName(node.name(), kConstantFoldingConst); if (output_tensors.size() > 1) { node_name = strings::StrCat(node_name, "-", i); } @@ -299,6 +306,7 @@ Status ConstantFolding::Optimize(Cluster* cluster, const GrapplerItem& item, nodes_to_preserve_.insert(NodeName(node)); } device_.reset(new DeviceSimple()); + *output = GraphDef(); TF_RETURN_IF_ERROR(FoldGraph(output)); LOG(INFO) << "Optimized graph size: " << output->node_size(); return Status::OK(); diff --git a/tensorflow/core/grappler/optimizers/graph_rewriter.cc b/tensorflow/core/grappler/optimizers/graph_rewriter.cc index fbb7e849ba2..d1ab5a1d9b4 100644 --- a/tensorflow/core/grappler/optimizers/graph_rewriter.cc +++ b/tensorflow/core/grappler/optimizers/graph_rewriter.cc @@ -64,5 +64,15 @@ bool GraphRewriter::DrivesControlDependency(const NodeDef& node) const { control_dependency_drivers_.end(); } +bool GraphRewriter::IsDrivenByControlDependency(const NodeDef& node) const { + for (const auto& input : node.input()) { + CHECK(!input.empty()); + if (input[0] == '^') { + return true; + } + } + return false; +} + } // end namespace grappler } // end namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/graph_rewriter.h b/tensorflow/core/grappler/optimizers/graph_rewriter.h index a9cc777809f..adbe5a24c86 100644 --- a/tensorflow/core/grappler/optimizers/graph_rewriter.h +++ b/tensorflow/core/grappler/optimizers/graph_rewriter.h @@ -43,6 +43,10 @@ class GraphRewriter { // a control dependency edge. bool DrivesControlDependency(const NodeDef& node) const; + // Returns true if at least one of the incident edges is a control dependency + // edge. + bool IsDrivenByControlDependency(const NodeDef& node) const; + private: std::unordered_map nodes_; std::unordered_set control_dependency_drivers_; diff --git a/tensorflow/core/grappler/optimizers/layout_optimizer.cc b/tensorflow/core/grappler/optimizers/layout_optimizer.cc index b0988b8a891..5438bff0c1e 100644 --- a/tensorflow/core/grappler/optimizers/layout_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/layout_optimizer.cc @@ -508,7 +508,7 @@ class BinaryOpProcessor : public AgnosticNodeProcessor { AttrValue attr_tensor; Tensor tensor(DT_INT32, TensorShape({4})); std::vector shape = {1, num_channels, 1, 1}; - for (int i = 0; i < shape.size(); i++) { + for (int i = 0; static_cast(i) < shape.size(); i++) { tensor.flat()(i) = shape[i]; } tensor.AsProtoTensorContent(attr_tensor.mutable_tensor()); @@ -615,11 +615,9 @@ class ReluGradProcessor : public AgnosticNodeProcessor { } }; -// This is the older, less optimized gather-based SliceProcessor. We keep it as -// a test case for constant propagation optimization. -class SliceProcessorGatherBased : public AgnosticNodeProcessor { +class SliceProcessor : public AgnosticNodeProcessor { public: - SliceProcessorGatherBased(GraphDef* graph, NodeDef* node, NodeMap* node_map) + SliceProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map) : AgnosticNodeProcessor(graph, node, node_map) {} protected: @@ -663,9 +661,30 @@ class SliceProcessorGatherBased : public AgnosticNodeProcessor { } }; -class SliceProcessor : public AgnosticNodeProcessor { +// Specialized SliceProcessor, used if the second and third input are const +// nodes, which could be the case if a constant folding pass is applied +// before this optimization. +class SliceProcessorConst : public AgnosticNodeProcessor { public: - SliceProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map) + SliceProcessorConst(GraphDef* graph, NodeDef* node, NodeMap* node_map) + : AgnosticNodeProcessor(graph, node, node_map) {} + + protected: + Status CustomizedProcessing() override { + // Skip the first input, which is the data to be sliced. + for (int i = 1; i < node_->input_size(); i++) { + auto shape_node = node_map_->GetNode(node_->input(i)); + TF_RETURN_IF_ERROR(UpdateAttrValue(shape_node)); + } + return Status::OK(); + } +}; + +// Specialized SliceProcessor, used if the second input is ConcatOffset. An +// example use case is in the gradient computation of Concat for InceptionV3. +class SliceProcessorConcatOffset : public AgnosticNodeProcessor { + public: + SliceProcessorConcatOffset(GraphDef* graph, NodeDef* node, NodeMap* node_map) : AgnosticNodeProcessor(graph, node, node_map) {} protected: @@ -832,7 +851,7 @@ class DataLayoutOptimizer { node->mutable_attr()->insert({"dtype", attr_data_type}); AttrValue attr_tensor; Tensor tensor(DT_INT32, TensorShape({4})); - for (int i = 0; i < permutation.size(); i++) { + for (int i = 0; static_cast(i) < permutation.size(); i++) { tensor.flat()(i) = permutation[i]; } tensor.AsProtoTensorContent(attr_tensor.mutable_tensor()); @@ -866,7 +885,7 @@ class DataLayoutOptimizer { AttrValue attr_tensor; Tensor tensor(DT_INT32, TensorShape({3})); std::vector axis = {0, 2, 3}; - for (int i = 0; i < axis.size(); i++) { + for (int i = 0; static_cast(i) < axis.size(); i++) { tensor.flat()(i) = axis[i]; } tensor.AsProtoTensorContent(attr_tensor.mutable_tensor()); @@ -938,14 +957,17 @@ class DataLayoutOptimizer { node_processor.reset( new ReluGradProcessor(graph_, node, &node_map_)); } else if (node->op().compare("Slice") == 0) { - auto maybe_concatoffset_node = - node_map_.GetNode(NodeName(node->input(1))); - if (maybe_concatoffset_node->op() == "ConcatOffset") { + auto input1 = node_map_.GetNode(NodeName(node->input(1))); + auto input2 = node_map_.GetNode(NodeName(node->input(2))); + if (input1->op() == "ConcatOffset") { node_processor.reset( - new SliceProcessor(graph_, node, &node_map_)); + new SliceProcessorConcatOffset(graph_, node, &node_map_)); + } else if (input1->op() == "Const" && input2->op() == "Const") { + node_processor.reset( + new SliceProcessorConst(graph_, node, &node_map_)); } else { node_processor.reset( - new SliceProcessorGatherBased(graph_, node, &node_map_)); + new SliceProcessor(graph_, node, &node_map_)); } } else if (node->op().compare("Squeeze") == 0) { diff --git a/tensorflow/core/grappler/optimizers/model_pruner.cc b/tensorflow/core/grappler/optimizers/model_pruner.cc index a89831b6e64..47072665728 100644 --- a/tensorflow/core/grappler/optimizers/model_pruner.cc +++ b/tensorflow/core/grappler/optimizers/model_pruner.cc @@ -51,24 +51,27 @@ Status ModelPruner::Optimize(Cluster* cluster, const GrapplerItem& item, continue; } // Don't remove nodes that drive control dependencies. - if (!rewriter.DrivesControlDependency(node)) { + // Don't remove nodes that are driven by control dependencies either since + // we can't ensure (yet) that we won't increase the number of control + // dependency edges by deleting them (for example, removing a node driven by + // 10 control edges and driving 10 control edges would result in the + // creation of 100 edges). + if (!rewriter.DrivesControlDependency(node) && + !rewriter.IsDrivenByControlDependency(node)) { nodes_to_delete.insert(&node); } } for (auto& node : item.graph.node()) { - if (nodes_to_delete.find(&node) != nodes_to_delete.end()) { - continue; - } NodeDef* new_node = pruned_graph->add_node(); *new_node = node; new_node->clear_input(); rewriter.ForwardInputs(node, nodes_to_delete, new_node); } - LOG(INFO) << "Pruned " << nodes_to_delete.size() - << " nodes from the graph. The graph now contains " - << pruned_graph->node_size() << " nodes."; + VLOG(1) << "Pruned " << nodes_to_delete.size() + << " nodes from the graph. The graph now contains " + << pruned_graph->node_size() << " nodes."; return Status::OK(); } diff --git a/tensorflow/core/grappler/optimizers/model_pruner_test.cc b/tensorflow/core/grappler/optimizers/model_pruner_test.cc index 47d45a6f490..67954d29146 100644 --- a/tensorflow/core/grappler/optimizers/model_pruner_test.cc +++ b/tensorflow/core/grappler/optimizers/model_pruner_test.cc @@ -70,16 +70,22 @@ TEST_F(ModelPrunerTest, StopGradientPruning) { Status status = pruner.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); - EXPECT_EQ(3, output.node_size()); + EXPECT_EQ(5, output.node_size()); const NodeDef& new_a = output.node(0); EXPECT_EQ(NodeName(a.name()), new_a.name()); const NodeDef& new_b = output.node(1); EXPECT_EQ(NodeName(b.name()), new_b.name()); - const NodeDef& new_e = output.node(2); + const NodeDef& new_c = output.node(2); + EXPECT_EQ(NodeName(c.name()), new_c.name()); + const NodeDef& new_d = output.node(3); + EXPECT_EQ(NodeName(d.name()), new_d.name()); + const NodeDef& new_e = output.node(4); EXPECT_EQ(NodeName(e.name()), new_e.name()); EXPECT_EQ(1, new_e.input_size()); EXPECT_EQ(NodeName(b.name()), new_e.input(0)); + EXPECT_EQ(1, new_d.input_size()); + EXPECT_EQ(NodeName(b.name()), new_d.input(0)); } TEST_F(ModelPrunerTest, IdentityPruning) { @@ -104,18 +110,22 @@ TEST_F(ModelPrunerTest, IdentityPruning) { Status status = pruner.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); - EXPECT_EQ(4, output.node_size()); + EXPECT_EQ(5, output.node_size()); const NodeDef& new_a = output.node(0); EXPECT_EQ(NodeName(a.name()), new_a.name()); const NodeDef& new_b = output.node(1); EXPECT_EQ(NodeName(b.name()), new_b.name()); const NodeDef& new_c = output.node(2); EXPECT_EQ(NodeName(c.name()), new_c.name()); - const NodeDef& new_e = output.node(3); + const NodeDef& new_d = output.node(3); + EXPECT_EQ(NodeName(d.name()), new_d.name()); + const NodeDef& new_e = output.node(4); EXPECT_EQ(NodeName(e.name()), new_e.name()); EXPECT_EQ(1, new_e.input_size()); EXPECT_EQ(NodeName(c.name()), new_e.input(0)); + EXPECT_EQ(1, new_d.input_size()); + EXPECT_EQ(NodeName(c.name()), new_d.input(0)); } TEST_F(ModelPrunerTest, PruningSkipsCtrlDependencies) { @@ -142,14 +152,16 @@ TEST_F(ModelPrunerTest, PruningSkipsCtrlDependencies) { Status status = pruner.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); - EXPECT_EQ(4, output.node_size()); + EXPECT_EQ(5, output.node_size()); const NodeDef& new_a = output.node(0); EXPECT_EQ(NodeName(a.name()), new_a.name()); const NodeDef& new_b = output.node(1); EXPECT_EQ(NodeName(b.name()), new_b.name()); const NodeDef& new_c = output.node(2); EXPECT_EQ(NodeName(c.name()), new_c.name()); - const NodeDef& new_e = output.node(3); + const NodeDef& new_d = output.node(3); + EXPECT_EQ(NodeName(d.name()), new_d.name()); + const NodeDef& new_e = output.node(4); EXPECT_EQ(NodeName(e.name()), new_e.name()); EXPECT_EQ(2, new_e.input_size()); @@ -157,7 +169,7 @@ TEST_F(ModelPrunerTest, PruningSkipsCtrlDependencies) { EXPECT_EQ("^c", new_e.input(1)); } -TEST_F(ModelPrunerTest, PruningForwardsCtrlDependencies) { +TEST_F(ModelPrunerTest, PruningPerservesCtrlDependencies) { // Build a simple graph with a few trivially prunable ops. tensorflow::Scope s = tensorflow::Scope::NewRootScope(); @@ -183,20 +195,28 @@ TEST_F(ModelPrunerTest, PruningForwardsCtrlDependencies) { Status status = pruner.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); - EXPECT_EQ(4, output.node_size()); + EXPECT_EQ(6, output.node_size()); const NodeDef& new_a = output.node(0); EXPECT_EQ(NodeName(a.name()), new_a.name()); const NodeDef& new_b = output.node(1); EXPECT_EQ(NodeName(b.name()), new_b.name()); const NodeDef& new_c = output.node(2); EXPECT_EQ(NodeName(c.name()), new_c.name()); - const NodeDef& new_f = output.node(3); + const NodeDef& new_d = output.node(3); + EXPECT_EQ(NodeName(d.name()), new_d.name()); + const NodeDef& new_e = output.node(4); + EXPECT_EQ(NodeName(e.name()), new_e.name()); + const NodeDef& new_f = output.node(5); EXPECT_EQ(NodeName(f.name()), new_f.name()); - EXPECT_EQ(3, new_f.input_size()); - EXPECT_EQ(NodeName(c.name()), new_f.input(0)); - EXPECT_EQ("^b", new_f.input(1)); - EXPECT_EQ("^c", new_f.input(2)); + EXPECT_EQ(1, new_f.input_size()); + EXPECT_EQ(NodeName(e.name()), new_f.input(0)); + EXPECT_EQ(2, new_e.input_size()); + EXPECT_EQ(NodeName(d.name()), new_e.input(0)); + EXPECT_EQ("^c", new_e.input(1)); + EXPECT_EQ(2, new_d.input_size()); + EXPECT_EQ(NodeName(c.name()), new_d.input(0)); + EXPECT_EQ("^b", new_d.input(1)); } TEST_F(ModelPrunerTest, PruningPerservesFetch) { diff --git a/tensorflow/core/kernels/hexagon/hexagon_rewriter_transform.cc b/tensorflow/core/kernels/hexagon/hexagon_rewriter_transform.cc index f2b1958105b..ee548c6887e 100644 --- a/tensorflow/core/kernels/hexagon/hexagon_rewriter_transform.cc +++ b/tensorflow/core/kernels/hexagon/hexagon_rewriter_transform.cc @@ -47,7 +47,7 @@ Status RewriteQuantizedStrippedModelForHexagon( "graph execute op..."; std::vector> inputs; std::vector outputs; - for (auto i = 0; i < context.input_names.size(); ++i) { + for (auto i = 0; static_cast(i) < context.input_names.size(); ++i) { const string& input_name = context.input_names.at(i); // Get input shape diff --git a/tensorflow/core/kernels/image_resizer_state.h b/tensorflow/core/kernels/image_resizer_state.h index 33383d16a86..9ef44a57827 100644 --- a/tensorflow/core/kernels/image_resizer_state.h +++ b/tensorflow/core/kernels/image_resizer_state.h @@ -122,7 +122,7 @@ struct ImageResizerState { int64 channels; float height_scale; float width_scale; - Tensor* output; + Tensor* output = nullptr; private: bool align_corners_; diff --git a/tensorflow/core/kernels/range_sampler.cc b/tensorflow/core/kernels/range_sampler.cc index f7c1e6c52c1..7e57331ab4f 100644 --- a/tensorflow/core/kernels/range_sampler.cc +++ b/tensorflow/core/kernels/range_sampler.cc @@ -262,7 +262,7 @@ FixedUnigramSampler::FixedUnigramSampler(int64 range, } float FixedUnigramSampler::Probability(int64 value) const { - if (value >= weights_.size() || value < 0) { + if (value < 0 || static_cast(value) >= weights_.size()) { return 0.0; } return weights_.at(value) / total_weight_; diff --git a/tensorflow/core/kernels/remote_fused_graph_execute_utils.cc b/tensorflow/core/kernels/remote_fused_graph_execute_utils.cc index cb16ffe6755..a6eb06004a8 100644 --- a/tensorflow/core/kernels/remote_fused_graph_execute_utils.cc +++ b/tensorflow/core/kernels/remote_fused_graph_execute_utils.cc @@ -168,12 +168,12 @@ RemoteFusedGraphExecuteUtils::GetExecutorBuildRegistry() { output_tensors.push_back(input_node_info.second); } - for (int i = 0; i < output_node_names.size(); ++i) { + for (int i = 0; static_cast(i) < output_node_names.size(); ++i) { const string& name = output_node_names.at(i); const Tensor& tensor = output_tensors.at(i); EmplaceTensorShapeType(name, tensor, tensor_shape_map); } - for (int i = 0; i < input_node_info_list.size(); ++i) { + for (int i = 0; static_cast(i) < input_node_info_list.size(); ++i) { const string& name = input_node_info_list.at(i).first; const Tensor& tensor = output_tensors.at(output_node_names.size() + i); EmplaceTensorShapeType(name, tensor, tensor_shape_map); diff --git a/tensorflow/core/kernels/resize_bicubic_op.cc b/tensorflow/core/kernels/resize_bicubic_op.cc index d7b063e0c18..5131bce448e 100644 --- a/tensorflow/core/kernels/resize_bicubic_op.cc +++ b/tensorflow/core/kernels/resize_bicubic_op.cc @@ -235,9 +235,9 @@ inline void interpolate_with_caching( const T* y_ptr_3 = input_b_ptr + y_wai.index_3 * in_row_width; if (num_channels == 3) { // Manually unroll case of 3 channels. - float cached_value_0[4]; - float cached_value_1[4]; - float cached_value_2[4]; + float cached_value_0[4] = {0}; + float cached_value_1[4] = {0}; + float cached_value_2[4] = {0}; for (int64 x = 0; x < resizer_state.out_width; ++x) { const WeightsAndIndices& x_wai = x_wais[x]; // Shift values in cached_value_* to fill first 'advance' values. @@ -316,7 +316,7 @@ inline void interpolate_with_caching( } } else { for (int64 c = 0; c < num_channels; ++c) { - float cached_value[4]; + float cached_value[4] = {0}; for (int64 x = 0; x < resizer_state.out_width; ++x) { const WeightsAndIndices& x_wai = x_wais[x]; // Shift values in cached_value to fill first 'advance' values. diff --git a/tensorflow/core/kernels/tensor_array_ops.cc b/tensorflow/core/kernels/tensor_array_ops.cc index 8202719b4dd..5fc9ce5f928 100644 --- a/tensorflow/core/kernels/tensor_array_ops.cc +++ b/tensorflow/core/kernels/tensor_array_ops.cc @@ -288,8 +288,8 @@ class TensorArrayGradOp : public TensorArrayCreationOp { // may no longer be resized by new Writes. tensor_array->DisableDynamicSize(); - int32 array_size; - int32 marked_size; + int32 array_size = 0; + int32 marked_size = 0; TF_RETURN_IF_ERROR(tensor_array->Size(&array_size)); TF_RETURN_IF_ERROR(tensor_array->MarkedSize(&marked_size)); @@ -615,6 +615,12 @@ class TensorArrayPackOrGatherOp : public OpKernel { Tensor* output_tensor = nullptr; OP_REQUIRES_OK(ctx, ctx->allocate_output(0, output_shape, &output_tensor)); + + // If output_tensor is empty, there is nothing to concatenate so return it. + if (output_shape.num_elements() == 0) { + return; + } + ConstMatrixVector input_tensors_flat; input_tensors_flat.reserve(num_indices); auto output_flat = diff --git a/tensorflow/core/kernels/transpose_op.cc b/tensorflow/core/kernels/transpose_op.cc index fb2ceb4a4a8..bab8e1ee121 100644 --- a/tensorflow/core/kernels/transpose_op.cc +++ b/tensorflow/core/kernels/transpose_op.cc @@ -89,7 +89,7 @@ REGISTER_KERNEL_BUILDER(Name("InvertPermutation") .HostMemory("x") .HostMemory("y"), InvertPermutationOp); -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL // output = TransposeOp(T input, T perm) takes a tensor // of type T and rank N, and a permutation of 0, 1, ..., N-1. It @@ -115,11 +115,6 @@ void TransposeOp::Compute(OpKernelContext* ctx) { perm.shape().DebugString())); auto Vperm = perm.vec(); const int dims = input.dims(); - static const int kMinDims = 0; - static const int kMaxDims = 10; - OP_REQUIRES(ctx, kMinDims <= dims && dims <= kMaxDims, - errors::Unimplemented("Transposing a tensor of rank ", dims, - " is not implemented.")); OP_REQUIRES(ctx, dims == Vperm.size(), errors::InvalidArgument( "transpose expects a vector of size ", input.dims(), diff --git a/tensorflow/core/lib/io/block_builder.cc b/tensorflow/core/lib/io/block_builder.cc index 5a87da6c86a..b2921c076cc 100644 --- a/tensorflow/core/lib/io/block_builder.cc +++ b/tensorflow/core/lib/io/block_builder.cc @@ -70,10 +70,12 @@ size_t BlockBuilder::CurrentSizeEstimate() const { StringPiece BlockBuilder::Finish() { // Append restart array - for (size_t i = 0; i < restarts_.size(); i++) { - core::PutFixed32(&buffer_, restarts_[i]); + CHECK_LE(restarts_.size(), std::numeric_limits::max()); + for (const auto r : restarts_) { + core::PutFixed32(&buffer_, r); } - core::PutFixed32(&buffer_, restarts_.size()); + // Downcast safe because of the CHECK. + core::PutFixed32(&buffer_, static_cast(restarts_.size())); finished_ = true; return StringPiece(buffer_); } @@ -93,19 +95,24 @@ void BlockBuilder::Add(const StringPiece& key, const StringPiece& value) { } } else { // Restart compression - restarts_.push_back(buffer_.size()); + CHECK_LE(buffer_.size(), std::numeric_limits::max()); + restarts_.push_back(static_cast(buffer_.size())); counter_ = 0; } const size_t non_shared = key.size() - shared; + CHECK_LE(shared, std::numeric_limits::max()); + CHECK_LE(non_shared, std::numeric_limits::max()); + CHECK_LE(value.size(), std::numeric_limits::max()); + // Add "" to buffer_ - core::PutVarint32(&buffer_, shared); - core::PutVarint32(&buffer_, non_shared); - core::PutVarint32(&buffer_, value.size()); + core::PutVarint32(&buffer_, static_cast(shared)); + core::PutVarint32(&buffer_, static_cast(non_shared)); + core::PutVarint32(&buffer_, static_cast(value.size())); // Add string delta to buffer_ followed by value buffer_.append(key.data() + shared, non_shared); - buffer_.append(value.data(), value.size()); + buffer_.append(value.data(), static_cast(value.size())); // Update state last_key_.resize(shared); diff --git a/tensorflow/core/lib/io/path.cc b/tensorflow/core/lib/io/path.cc index ab2fd7739f7..d93dd0296e4 100644 --- a/tensorflow/core/lib/io/path.cc +++ b/tensorflow/core/lib/io/path.cc @@ -177,7 +177,7 @@ string CleanPath(StringPiece unclean_path) { } // Calculate and check the length of the cleaned path. - int path_length = dst - path.begin(); + string::difference_type path_length = dst - path.begin(); if (path_length != 0) { // Remove trailing '/' except if it is root path ("/" ==> path_length := 1) if (path_length > 1 && path[path_length - 1] == '/') { diff --git a/tensorflow/core/lib/wav/wav_io.cc b/tensorflow/core/lib/wav/wav_io.cc index 97e218a7931..028ff26ffb9 100644 --- a/tensorflow/core/lib/wav/wav_io.cc +++ b/tensorflow/core/lib/wav/wav_io.cc @@ -262,7 +262,7 @@ Status DecodeLin16WaveAsFloatVector(const string& wav_string, const uint32 data_count = *sample_count * *channel_count; float_values->resize(data_count); for (int i = 0; i < data_count; ++i) { - int16 single_channel_value; + int16 single_channel_value = 0; TF_RETURN_IF_ERROR( ReadValue(wav_string, &single_channel_value, &offset)); (*float_values)[i] = Int16SampleToFloat(single_channel_value); diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt index 2f85351d610..d1f9bbb391a 100644 --- a/tensorflow/core/ops/ops.pbtxt +++ b/tensorflow/core/ops/ops.pbtxt @@ -7162,11 +7162,11 @@ op { } output_arg { name: "output" - description: "A complex64 tensor of the same shape as `input`. The inner-most\n dimension of `input` is replaced with its 1D Fourier Transform.\n\n@compatibility(numpy)\nEquivalent to np.fft.fft\n@end_compatibility" + description: "A complex64 tensor of the same shape as `input`. The inner-most\n dimension of `input` is replaced with its 1D Fourier transform.\n\n@compatibility(numpy)\nEquivalent to np.fft.fft\n@end_compatibility" type: DT_COMPLEX64 } - summary: "Compute the 1-dimensional discrete Fourier Transform over the inner-most" - description: "dimension of `input`." + summary: "Fast Fourier transform." + description: "Computes the 1-dimensional discrete Fourier transform over the inner-most\ndimension of `input`." } op { name: "FFT2D" @@ -7177,11 +7177,11 @@ op { } output_arg { name: "output" - description: "A complex64 tensor of the same shape as `input`. The inner-most 2\n dimensions of `input` are replaced with their 2D Fourier Transform.\n\n@compatibility(numpy)\nEquivalent to np.fft.fft2\n@end_compatibility" + description: "A complex64 tensor of the same shape as `input`. The inner-most 2\n dimensions of `input` are replaced with their 2D Fourier transform.\n\n@compatibility(numpy)\nEquivalent to np.fft.fft2\n@end_compatibility" type: DT_COMPLEX64 } - summary: "Compute the 2-dimensional discrete Fourier Transform over the inner-most" - description: "2 dimensions of `input`." + summary: "2D fast Fourier transform." + description: "Computes the 2-dimensional discrete Fourier transform over the inner-most\n2 dimensions of `input`." } op { name: "FFT3D" @@ -7192,11 +7192,11 @@ op { } output_arg { name: "output" - description: "A complex64 tensor of the same shape as `input`. The inner-most 3\n dimensions of `input` are replaced with their 3D Fourier Transform.\n\n@compatibility(numpy)\nEquivalent to np.fft.fftn with 3 dimensions.\n@end_compatibility" + description: "A complex64 tensor of the same shape as `input`. The inner-most 3\n dimensions of `input` are replaced with their 3D Fourier transform.\n\n@compatibility(numpy)\nEquivalent to np.fft.fftn with 3 dimensions.\n@end_compatibility" type: DT_COMPLEX64 } - summary: "Compute the 3-dimensional discrete Fourier Transform over the inner-most 3" - description: "dimensions of `input`." + summary: "3D fast Fourier transform." + description: "Computes the 3-dimensional discrete Fourier transform over the inner-most 3\ndimensions of `input`." } op { name: "FIFOQueue" @@ -8749,11 +8749,11 @@ op { } output_arg { name: "output" - description: "A complex64 tensor of the same shape as `input`. The inner-most\n dimension of `input` is replaced with its inverse 1D Fourier Transform.\n\n@compatibility(numpy)\nEquivalent to np.fft.ifft\n@end_compatibility" + description: "A complex64 tensor of the same shape as `input`. The inner-most\n dimension of `input` is replaced with its inverse 1D Fourier transform.\n\n@compatibility(numpy)\nEquivalent to np.fft.ifft\n@end_compatibility" type: DT_COMPLEX64 } - summary: "Compute the inverse 1-dimensional discrete Fourier Transform over the inner-most" - description: "dimension of `input`." + summary: "Inverse fast Fourier transform." + description: "Computes the inverse 1-dimensional discrete Fourier transform over the\ninner-most dimension of `input`." } op { name: "IFFT2D" @@ -8764,11 +8764,11 @@ op { } output_arg { name: "output" - description: "A complex64 tensor of the same shape as `input`. The inner-most 2\n dimensions of `input` are replaced with their inverse 2D Fourier Transform.\n\n@compatibility(numpy)\nEquivalent to np.fft.ifft2\n@end_compatibility" + description: "A complex64 tensor of the same shape as `input`. The inner-most 2\n dimensions of `input` are replaced with their inverse 2D Fourier transform.\n\n@compatibility(numpy)\nEquivalent to np.fft.ifft2\n@end_compatibility" type: DT_COMPLEX64 } - summary: "Compute the inverse 2-dimensional discrete Fourier Transform over the inner-most" - description: "2 dimensions of `input`." + summary: "Inverse 2D fast Fourier transform." + description: "Computes the inverse 2-dimensional discrete Fourier transform over the\ninner-most 2 dimensions of `input`." } op { name: "IFFT3D" @@ -8779,11 +8779,11 @@ op { } output_arg { name: "output" - description: "A complex64 tensor of the same shape as `input`. The inner-most 3\n dimensions of `input` are replaced with their inverse 3D Fourier Transform.\n\n@compatibility(numpy)\nEquivalent to np.fft.ifftn with 3 dimensions.\n@end_compatibility" + description: "A complex64 tensor of the same shape as `input`. The inner-most 3\n dimensions of `input` are replaced with their inverse 3D Fourier transform.\n\n@compatibility(numpy)\nEquivalent to np.fft.ifftn with 3 dimensions.\n@end_compatibility" type: DT_COMPLEX64 } - summary: "Compute the inverse 3-dimensional discrete Fourier Transform over the inner-most" - description: "3 dimensions of `input`." + summary: "Inverse 3D fast Fourier transform." + description: "Computes the inverse 3-dimensional discrete Fourier transform over the\ninner-most 3 dimensions of `input`." } op { name: "IRFFT" @@ -8799,11 +8799,11 @@ op { } output_arg { name: "output" - description: "A float32 tensor of the same rank as `input`. The inner-most\n dimension of `input` is replaced with the `fft_length` samples of its inverse\n 1D Fourier Transform.\n\n@compatibility(numpy)\nEquivalent to np.fft.irfft\n@end_compatibility" + description: "A float32 tensor of the same rank as `input`. The inner-most\n dimension of `input` is replaced with the `fft_length` samples of its inverse\n 1D Fourier transform.\n\n@compatibility(numpy)\nEquivalent to np.fft.irfft\n@end_compatibility" type: DT_FLOAT } - summary: "Compute the inverse 1-dimensional discrete Fourier Transform of a real-valued" - description: "signal over the inner-most dimension of `input`.\n\nThe inner-most dimension of `input` is assumed to be the result of `RFFT`: the\n`fft_length / 2 + 1` unique components of the DFT of a real-valued signal. If\n`fft_length` is not provided, it is computed from the size of the inner-most\ndimension of `input` (`fft_length = 2 * (inner - 1)`). If the FFT length used to\ncompute `input` is odd, it should be provided since it cannot be inferred\nproperly." + summary: "Inverse real-valued fast Fourier transform." + description: "Computes the inverse 1-dimensional discrete Fourier transform of a real-valued\nsignal over the inner-most dimension of `input`.\n\nThe inner-most dimension of `input` is assumed to be the result of `RFFT`: the\n`fft_length / 2 + 1` unique components of the DFT of a real-valued signal. If\n`fft_length` is not provided, it is computed from the size of the inner-most\ndimension of `input` (`fft_length = 2 * (inner - 1)`). If the FFT length used to\ncompute `input` is odd, it should be provided since it cannot be inferred\nproperly." } op { name: "IRFFT2D" @@ -8819,11 +8819,11 @@ op { } output_arg { name: "output" - description: "A float32 tensor of the same rank as `input`. The inner-most 2\n dimensions of `input` are replaced with the `fft_length` samples of their\n inverse 2D Fourier Transform.\n\n@compatibility(numpy)\nEquivalent to np.fft.irfft2\n@end_compatibility" + description: "A float32 tensor of the same rank as `input`. The inner-most 2\n dimensions of `input` are replaced with the `fft_length` samples of their\n inverse 2D Fourier transform.\n\n@compatibility(numpy)\nEquivalent to np.fft.irfft2\n@end_compatibility" type: DT_FLOAT } - summary: "Compute the inverse 2-dimensional discrete Fourier Transform of a real-valued" - description: "signal over the inner-most 2 dimensions of `input`.\n\nThe inner-most 2 dimensions of `input` are assumed to be the result of `RFFT2D`:\nThe inner-most dimension contains the `fft_length / 2 + 1` unique components of\nthe DFT of a real-valued signal. If `fft_length` is not provided, it is computed\nfrom the size of the inner-most 2 dimensions of `input`. If the FFT length used\nto compute `input` is odd, it should be provided since it cannot be inferred\nproperly." + summary: "Inverse 2D real-valued fast Fourier transform." + description: "Computes the inverse 2-dimensional discrete Fourier transform of a real-valued\nsignal over the inner-most 2 dimensions of `input`.\n\nThe inner-most 2 dimensions of `input` are assumed to be the result of `RFFT2D`:\nThe inner-most dimension contains the `fft_length / 2 + 1` unique components of\nthe DFT of a real-valued signal. If `fft_length` is not provided, it is computed\nfrom the size of the inner-most 2 dimensions of `input`. If the FFT length used\nto compute `input` is odd, it should be provided since it cannot be inferred\nproperly." } op { name: "IRFFT3D" @@ -8839,11 +8839,11 @@ op { } output_arg { name: "output" - description: "A float32 tensor of the same rank as `input`. The inner-most 3\n dimensions of `input` are replaced with the `fft_length` samples of their\n inverse 3D real Fourier Transform.\n\n@compatibility(numpy)\nEquivalent to np.irfftn with 3 dimensions.\n@end_compatibility" + description: "A float32 tensor of the same rank as `input`. The inner-most 3\n dimensions of `input` are replaced with the `fft_length` samples of their\n inverse 3D real Fourier transform.\n\n@compatibility(numpy)\nEquivalent to np.irfftn with 3 dimensions.\n@end_compatibility" type: DT_FLOAT } - summary: "Compute the inverse 3-dimensional discrete Fourier Transform of a real-valued" - description: "signal over the inner-most 3 dimensions of `input`.\n\nThe inner-most 3 dimensions of `input` are assumed to be the result of `RFFT3D`:\nThe inner-most dimension contains the `fft_length / 2 + 1` unique components of\nthe DFT of a real-valued signal. If `fft_length` is not provided, it is computed\nfrom the size of the inner-most 3 dimensions of `input`. If the FFT length used\nto compute `input` is odd, it should be provided since it cannot be inferred\nproperly." + summary: "Inverse 3D real-valued fast Fourier transform." + description: "Computes the inverse 3-dimensional discrete Fourier transform of a real-valued\nsignal over the inner-most 3 dimensions of `input`.\n\nThe inner-most 3 dimensions of `input` are assumed to be the result of `RFFT3D`:\nThe inner-most dimension contains the `fft_length / 2 + 1` unique components of\nthe DFT of a real-valued signal. If `fft_length` is not provided, it is computed\nfrom the size of the inner-most 3 dimensions of `input`. If the FFT length used\nto compute `input` is odd, it should be provided since it cannot be inferred\nproperly." } op { name: "Identity" @@ -14666,11 +14666,11 @@ op { } output_arg { name: "output" - description: "A complex64 tensor of the same rank as `input`. The inner-most\n dimension of `input` is replaced with the `fft_length / 2 + 1` unique\n frequency components of its 1D Fourier Transform.\n\n@compatibility(numpy)\nEquivalent to np.fft.rfft\n@end_compatibility" + description: "A complex64 tensor of the same rank as `input`. The inner-most\n dimension of `input` is replaced with the `fft_length / 2 + 1` unique\n frequency components of its 1D Fourier transform.\n\n@compatibility(numpy)\nEquivalent to np.fft.rfft\n@end_compatibility" type: DT_COMPLEX64 } - summary: "Compute the 1-dimensional discrete Fourier Transform of a real-valued signal" - description: "over the inner-most dimension of `input`.\n\nSince the DFT of a real signal is Hermitian-symmetric, `RFFT` only returns the\n`fft_length / 2 + 1` unique components of the FFT: the zero-frequency term,\nfollowed by the `fft_length / 2` positive-frequency terms." + summary: "Real-valued fast Fourier transform." + description: "Computes the 1-dimensional discrete Fourier transform of a real-valued signal\nover the inner-most dimension of `input`.\n\nSince the DFT of a real signal is Hermitian-symmetric, `RFFT` only returns the\n`fft_length / 2 + 1` unique components of the FFT: the zero-frequency term,\nfollowed by the `fft_length / 2` positive-frequency terms." } op { name: "RFFT2D" @@ -14686,11 +14686,11 @@ op { } output_arg { name: "output" - description: "A complex64 tensor of the same rank as `input`. The inner-most 2\n dimensions of `input` are replaced with their 2D Fourier Transform. The\n inner-most dimension contains `fft_length / 2 + 1` unique frequency\n components.\n\n@compatibility(numpy)\nEquivalent to np.fft.rfft2\n@end_compatibility" + description: "A complex64 tensor of the same rank as `input`. The inner-most 2\n dimensions of `input` are replaced with their 2D Fourier transform. The\n inner-most dimension contains `fft_length / 2 + 1` unique frequency\n components.\n\n@compatibility(numpy)\nEquivalent to np.fft.rfft2\n@end_compatibility" type: DT_COMPLEX64 } - summary: "Compute the 2-dimensional discrete Fourier Transform of a real-valued signal" - description: "over the inner-most 2 dimensions of `input`.\n\nSince the DFT of a real signal is Hermitian-symmetric, `RFFT2D` only returns the\n`fft_length / 2 + 1` unique components of the FFT for the inner-most dimension\nof `output`: the zero-frequency term, followed by the `fft_length / 2`\npositive-frequency terms." + summary: "2D real-valued fast Fourier transform." + description: "Computes the 2-dimensional discrete Fourier transform of a real-valued signal\nover the inner-most 2 dimensions of `input`.\n\nSince the DFT of a real signal is Hermitian-symmetric, `RFFT2D` only returns the\n`fft_length / 2 + 1` unique components of the FFT for the inner-most dimension\nof `output`: the zero-frequency term, followed by the `fft_length / 2`\npositive-frequency terms." } op { name: "RFFT3D" @@ -14706,11 +14706,11 @@ op { } output_arg { name: "output" - description: "A complex64 tensor of the same rank as `input`. The inner-most 3\n dimensions of `input` are replaced with the their 3D Fourier Transform. The\n inner-most dimension contains `fft_length / 2 + 1` unique frequency\n components.\n\n@compatibility(numpy)\nEquivalent to np.fft.rfftn with 3 dimensions.\n@end_compatibility" + description: "A complex64 tensor of the same rank as `input`. The inner-most 3\n dimensions of `input` are replaced with the their 3D Fourier transform. The\n inner-most dimension contains `fft_length / 2 + 1` unique frequency\n components.\n\n@compatibility(numpy)\nEquivalent to np.fft.rfftn with 3 dimensions.\n@end_compatibility" type: DT_COMPLEX64 } - summary: "Compute the 3-dimensional discrete Fourier Transform of a real-valued signal" - description: "over the inner-most 3 dimensions of `input`.\n\nSince the DFT of a real signal is Hermitian-symmetric, `RFFT3D` only returns the\n`fft_length / 2 + 1` unique components of the FFT for the inner-most dimension\nof `output`: the zero-frequency term, followed by the `fft_length / 2`\npositive-frequency terms." + summary: "3D real-valued fast Fourier transform." + description: "Computes the 3-dimensional discrete Fourier transform of a real-valued signal\nover the inner-most 3 dimensions of `input`.\n\nSince the DFT of a real signal is Hermitian-symmetric, `RFFT3D` only returns the\n`fft_length / 2 + 1` unique components of the FFT for the inner-most dimension\nof `output`: the zero-frequency term, followed by the `fft_length / 2`\npositive-frequency terms." } op { name: "RGBToHSV" diff --git a/tensorflow/core/ops/spectral_ops.cc b/tensorflow/core/ops/spectral_ops.cc index 1a2b2f3dab8..09b460fd147 100644 --- a/tensorflow/core/ops/spectral_ops.cc +++ b/tensorflow/core/ops/spectral_ops.cc @@ -31,12 +31,14 @@ REGISTER_OP("FFT") return shape_inference::UnchangedShapeWithRankAtLeast(c, 1); }) .Doc(R"doc( -Compute the 1-dimensional discrete Fourier Transform over the inner-most +Fast Fourier transform. + +Computes the 1-dimensional discrete Fourier transform over the inner-most dimension of `input`. input: A complex64 tensor. output: A complex64 tensor of the same shape as `input`. The inner-most - dimension of `input` is replaced with its 1D Fourier Transform. + dimension of `input` is replaced with its 1D Fourier transform. @compatibility(numpy) Equivalent to np.fft.fft @@ -50,12 +52,14 @@ REGISTER_OP("IFFT") return shape_inference::UnchangedShapeWithRankAtLeast(c, 1); }) .Doc(R"doc( -Compute the inverse 1-dimensional discrete Fourier Transform over the inner-most -dimension of `input`. +Inverse fast Fourier transform. + +Computes the inverse 1-dimensional discrete Fourier transform over the +inner-most dimension of `input`. input: A complex64 tensor. output: A complex64 tensor of the same shape as `input`. The inner-most - dimension of `input` is replaced with its inverse 1D Fourier Transform. + dimension of `input` is replaced with its inverse 1D Fourier transform. @compatibility(numpy) Equivalent to np.fft.ifft @@ -69,12 +73,14 @@ REGISTER_OP("FFT2D") return shape_inference::UnchangedShapeWithRankAtLeast(c, 2); }) .Doc(R"doc( -Compute the 2-dimensional discrete Fourier Transform over the inner-most +2D fast Fourier transform. + +Computes the 2-dimensional discrete Fourier transform over the inner-most 2 dimensions of `input`. input: A complex64 tensor. output: A complex64 tensor of the same shape as `input`. The inner-most 2 - dimensions of `input` are replaced with their 2D Fourier Transform. + dimensions of `input` are replaced with their 2D Fourier transform. @compatibility(numpy) Equivalent to np.fft.fft2 @@ -88,12 +94,14 @@ REGISTER_OP("IFFT2D") return shape_inference::UnchangedShapeWithRankAtLeast(c, 2); }) .Doc(R"doc( -Compute the inverse 2-dimensional discrete Fourier Transform over the inner-most -2 dimensions of `input`. +Inverse 2D fast Fourier transform. + +Computes the inverse 2-dimensional discrete Fourier transform over the +inner-most 2 dimensions of `input`. input: A complex64 tensor. output: A complex64 tensor of the same shape as `input`. The inner-most 2 - dimensions of `input` are replaced with their inverse 2D Fourier Transform. + dimensions of `input` are replaced with their inverse 2D Fourier transform. @compatibility(numpy) Equivalent to np.fft.ifft2 @@ -107,12 +115,14 @@ REGISTER_OP("FFT3D") return shape_inference::UnchangedShapeWithRankAtLeast(c, 3); }) .Doc(R"doc( -Compute the 3-dimensional discrete Fourier Transform over the inner-most 3 +3D fast Fourier transform. + +Computes the 3-dimensional discrete Fourier transform over the inner-most 3 dimensions of `input`. input: A complex64 tensor. output: A complex64 tensor of the same shape as `input`. The inner-most 3 - dimensions of `input` are replaced with their 3D Fourier Transform. + dimensions of `input` are replaced with their 3D Fourier transform. @compatibility(numpy) Equivalent to np.fft.fftn with 3 dimensions. @@ -126,12 +136,14 @@ REGISTER_OP("IFFT3D") return shape_inference::UnchangedShapeWithRankAtLeast(c, 3); }) .Doc(R"doc( -Compute the inverse 3-dimensional discrete Fourier Transform over the inner-most -3 dimensions of `input`. +Inverse 3D fast Fourier transform. + +Computes the inverse 3-dimensional discrete Fourier transform over the +inner-most 3 dimensions of `input`. input: A complex64 tensor. output: A complex64 tensor of the same shape as `input`. The inner-most 3 - dimensions of `input` are replaced with their inverse 3D Fourier Transform. + dimensions of `input` are replaced with their inverse 3D Fourier transform. @compatibility(numpy) Equivalent to np.fft.ifftn with 3 dimensions. @@ -180,7 +192,9 @@ REGISTER_OP("RFFT") .Output("output: complex64") .SetShapeFn([](InferenceContext* c) { return RFFTShape(c, true, 1); }) .Doc(R"doc( -Compute the 1-dimensional discrete Fourier Transform of a real-valued signal +Real-valued fast Fourier transform. + +Computes the 1-dimensional discrete Fourier transform of a real-valued signal over the inner-most dimension of `input`. Since the DFT of a real signal is Hermitian-symmetric, `RFFT` only returns the @@ -191,7 +205,7 @@ input: A float32 tensor. fft_length: An int32 tensor of shape [1]. The FFT length. output: A complex64 tensor of the same rank as `input`. The inner-most dimension of `input` is replaced with the `fft_length / 2 + 1` unique - frequency components of its 1D Fourier Transform. + frequency components of its 1D Fourier transform. @compatibility(numpy) Equivalent to np.fft.rfft @@ -204,7 +218,9 @@ REGISTER_OP("IRFFT") .Output("output: float") .SetShapeFn([](InferenceContext* c) { return RFFTShape(c, false, 1); }) .Doc(R"doc( -Compute the inverse 1-dimensional discrete Fourier Transform of a real-valued +Inverse real-valued fast Fourier transform. + +Computes the inverse 1-dimensional discrete Fourier transform of a real-valued signal over the inner-most dimension of `input`. The inner-most dimension of `input` is assumed to be the result of `RFFT`: the @@ -218,7 +234,7 @@ input: A complex64 tensor. fft_length: An int32 tensor of shape [1]. The FFT length. output: A float32 tensor of the same rank as `input`. The inner-most dimension of `input` is replaced with the `fft_length` samples of its inverse - 1D Fourier Transform. + 1D Fourier transform. @compatibility(numpy) Equivalent to np.fft.irfft @@ -231,7 +247,9 @@ REGISTER_OP("RFFT2D") .Output("output: complex64") .SetShapeFn([](InferenceContext* c) { return RFFTShape(c, true, 2); }) .Doc(R"doc( -Compute the 2-dimensional discrete Fourier Transform of a real-valued signal +2D real-valued fast Fourier transform. + +Computes the 2-dimensional discrete Fourier transform of a real-valued signal over the inner-most 2 dimensions of `input`. Since the DFT of a real signal is Hermitian-symmetric, `RFFT2D` only returns the @@ -242,7 +260,7 @@ positive-frequency terms. input: A float32 tensor. fft_length: An int32 tensor of shape [2]. The FFT length for each dimension. output: A complex64 tensor of the same rank as `input`. The inner-most 2 - dimensions of `input` are replaced with their 2D Fourier Transform. The + dimensions of `input` are replaced with their 2D Fourier transform. The inner-most dimension contains `fft_length / 2 + 1` unique frequency components. @@ -257,7 +275,9 @@ REGISTER_OP("IRFFT2D") .Output("output: float") .SetShapeFn([](InferenceContext* c) { return RFFTShape(c, false, 2); }) .Doc(R"doc( -Compute the inverse 2-dimensional discrete Fourier Transform of a real-valued +Inverse 2D real-valued fast Fourier transform. + +Computes the inverse 2-dimensional discrete Fourier transform of a real-valued signal over the inner-most 2 dimensions of `input`. The inner-most 2 dimensions of `input` are assumed to be the result of `RFFT2D`: @@ -271,7 +291,7 @@ input: A complex64 tensor. fft_length: An int32 tensor of shape [2]. The FFT length for each dimension. output: A float32 tensor of the same rank as `input`. The inner-most 2 dimensions of `input` are replaced with the `fft_length` samples of their - inverse 2D Fourier Transform. + inverse 2D Fourier transform. @compatibility(numpy) Equivalent to np.fft.irfft2 @@ -284,7 +304,9 @@ REGISTER_OP("RFFT3D") .Output("output: complex64") .SetShapeFn([](InferenceContext* c) { return RFFTShape(c, true, 3); }) .Doc(R"doc( -Compute the 3-dimensional discrete Fourier Transform of a real-valued signal +3D real-valued fast Fourier transform. + +Computes the 3-dimensional discrete Fourier transform of a real-valued signal over the inner-most 3 dimensions of `input`. Since the DFT of a real signal is Hermitian-symmetric, `RFFT3D` only returns the @@ -295,7 +317,7 @@ positive-frequency terms. input: A float32 tensor. fft_length: An int32 tensor of shape [3]. The FFT length for each dimension. output: A complex64 tensor of the same rank as `input`. The inner-most 3 - dimensions of `input` are replaced with the their 3D Fourier Transform. The + dimensions of `input` are replaced with the their 3D Fourier transform. The inner-most dimension contains `fft_length / 2 + 1` unique frequency components. @@ -310,7 +332,9 @@ REGISTER_OP("IRFFT3D") .Output("output: float") .SetShapeFn([](InferenceContext* c) { return RFFTShape(c, false, 3); }) .Doc(R"doc( -Compute the inverse 3-dimensional discrete Fourier Transform of a real-valued +Inverse 3D real-valued fast Fourier transform. + +Computes the inverse 3-dimensional discrete Fourier transform of a real-valued signal over the inner-most 3 dimensions of `input`. The inner-most 3 dimensions of `input` are assumed to be the result of `RFFT3D`: @@ -324,7 +348,7 @@ input: A complex64 tensor. fft_length: An int32 tensor of shape [3]. The FFT length for each dimension. output: A float32 tensor of the same rank as `input`. The inner-most 3 dimensions of `input` are replaced with the `fft_length` samples of their - inverse 3D real Fourier Transform. + inverse 3D real Fourier transform. @compatibility(numpy) Equivalent to np.irfftn with 3 dimensions. diff --git a/tensorflow/core/platform/cpu_info.cc b/tensorflow/core/platform/cpu_info.cc index e119ad5e2a2..451b7209bb7 100644 --- a/tensorflow/core/platform/cpu_info.cc +++ b/tensorflow/core/platform/cpu_info.cc @@ -69,6 +69,10 @@ int GetXCR0EAX() { // Structure for basic CPUID info class CPUIDInfo { public: + string vendor_str; + int family; + int model_num; + CPUIDInfo() : have_adx_(0), have_aes_(0), @@ -115,12 +119,21 @@ public: uint32 eax, ebx, ecx, edx; + // Get vendor string (issue CPUID with eax = 0) + GETCPUID(eax, ebx, ecx, edx, 0, 0); + cpuid->vendor_str.append(reinterpret_cast(&ebx), 4); + cpuid->vendor_str.append(reinterpret_cast(&edx), 4); + cpuid->vendor_str.append(reinterpret_cast(&ecx), 4); + // To get general information and extended features we send eax = 1 and // ecx = 0 to cpuid. The response is returned in eax, ebx, ecx and edx. // (See Intel 64 and IA-32 Architectures Software Developer's Manual // Volume 2A: Instruction Set Reference, A-M CPUID). GETCPUID(eax, ebx, ecx, edx, 1, 0); + cpuid->model_num = static_cast((eax >> 4) & 0xf); + cpuid->family = static_cast((eax >> 8) & 0xf); + cpuid->have_aes_ = (ecx >> 25) & 0x1; cpuid->have_cmov_ = (edx >> 15) & 0x1; cpuid->have_cmpxchg16b_ = (ecx >> 13) & 0x1; @@ -302,5 +315,32 @@ bool TestCPUFeature(CPUFeature feature) { #endif } +std::string CPUVendorIDString() { +#ifdef PLATFORM_IS_X86 + InitCPUIDInfo(); + return cpuid->vendor_str; +#else + return ""; +#endif +} + +int CPUFamily() { +#ifdef PLATFORM_IS_X86 + InitCPUIDInfo(); + return cpuid->family; +#else + return 0; +#endif +} + +int CPUModelNum() { +#ifdef PLATFORM_IS_X86 + InitCPUIDInfo(); + return cpuid->model_num; +#else + return 0; +#endif +} + } // namespace port } // namespace tensorflow diff --git a/tensorflow/core/platform/cpu_info.h b/tensorflow/core/platform/cpu_info.h index f6eee478e8d..331f3e52516 100644 --- a/tensorflow/core/platform/cpu_info.h +++ b/tensorflow/core/platform/cpu_info.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_PLATFORM_CPU_INFO_H_ #define TENSORFLOW_PLATFORM_CPU_INFO_H_ +#include + #if defined(PLATFORM_WINDOWS) #include "tensorflow/core/platform/windows/cpu_info.h" #endif @@ -92,6 +94,18 @@ enum CPUFeature { // Checks CPU registers to return hardware capabilities. bool TestCPUFeature(CPUFeature feature); +// Returns CPU Vendor string (i.e. 'GenuineIntel', 'AuthenticAMD', etc.) +std::string CPUVendorIDString(); + +// Returns CPU family. +int CPUFamily(); + +// Returns CPU model number. +int CPUModelNum(); + +// Returns nominal core processor cycles per second of each processor. +double NominalCPUFrequency(); + } // namespace port } // namespace tensorflow diff --git a/tensorflow/core/platform/posix/port.cc b/tensorflow/core/platform/posix/port.cc index e2b9c586c86..66c4ff37b90 100644 --- a/tensorflow/core/platform/posix/port.cc +++ b/tensorflow/core/platform/posix/port.cc @@ -156,5 +156,10 @@ bool Snappy_Uncompress(const char* input, size_t length, char* output) { string Demangle(const char* mangled) { return mangled; } +double NominalCPUFrequency() { + // TODO(yuefengz): implement it for this platform. + return 1.0; +} + } // namespace port } // namespace tensorflow diff --git a/tensorflow/core/platform/windows/port.cc b/tensorflow/core/platform/windows/port.cc index fc969c1a05d..b29f978f661 100644 --- a/tensorflow/core/platform/windows/port.cc +++ b/tensorflow/core/platform/windows/port.cc @@ -149,5 +149,10 @@ bool Snappy_Uncompress(const char* input, size_t length, char* output) { string Demangle(const char* mangled) { return mangled; } +double NominalCPUFrequency() { + // TODO(yuefengz): implement it for this platform. + return 1.0; +} + } // namespace port } // namespace tensorflow diff --git a/tensorflow/core/util/ctc/ctc_decoder.h b/tensorflow/core/util/ctc/ctc_decoder.h index 294419e907e..5b28aeb70ad 100644 --- a/tensorflow/core/util/ctc/ctc_decoder.h +++ b/tensorflow/core/util/ctc/ctc_decoder.h @@ -89,7 +89,6 @@ class CTCGreedyDecoder : public CTCDecoder { std::vector& output_b = (*output)[0][b]; int prev_class_ix = -1; - std::vector transcription; (*scores)(b, 0) = 0; for (int t = 0; t < seq_len_b; ++t) { auto row = input[t].row(b); @@ -98,7 +97,6 @@ class CTCGreedyDecoder : public CTCDecoder { if (max_class_ix != blank_index_ && !(merge_repeated_ && max_class_ix == prev_class_ix)) { output_b.push_back(max_class_ix); - transcription.push_back(max_class_ix); } prev_class_ix = max_class_ix; } diff --git a/tensorflow/core/util/overflow.h b/tensorflow/core/util/overflow.h index 13b1f305aa7..04be68a111e 100644 --- a/tensorflow/core/util/overflow.h +++ b/tensorflow/core/util/overflow.h @@ -31,9 +31,8 @@ inline int64 MultiplyWithoutOverflow(const int64 x, const int64 y) { const uint64 uy = y; const uint64 uxy = ux * uy; - // Check for overflow, using a cheap check if both inputs are small - static const uint64 kSqrtInt64Max = 3037000500; // ceil(sqrt(2**63 - 1)) - if (TF_PREDICT_FALSE(ux >= kSqrtInt64Max || uy >= kSqrtInt64Max)) { + // Check if we overflow uint64, using a cheap check if both inputs are small + if (TF_PREDICT_FALSE((ux | uy) >> 32 != 0)) { // Ensure nonnegativity. Note that negative numbers will appear "large" // to the unsigned comparisons above. CHECK(x >= 0 && y >= 0); diff --git a/tensorflow/core/util/overflow_test.cc b/tensorflow/core/util/overflow_test.cc index 627f77164e9..f93ba885e6d 100644 --- a/tensorflow/core/util/overflow_test.cc +++ b/tensorflow/core/util/overflow_test.cc @@ -30,8 +30,12 @@ TEST(OverflowTest, Nonnegative) { interesting.push_back(bit + 1); interesting.push_back(bit - 1); } - auto mid = static_cast(std::pow(2, 63.0 / 2)); - for (int i = -5; i < 5; i++) interesting.push_back(mid + i); + for (const int64 mid : {static_cast(1) << 32, + static_cast(std::pow(2, 63.0 / 2))}) { + for (int i = -5; i < 5; i++) { + interesting.push_back(mid + i); + } + } // Check all pairs for (auto x : interesting) { diff --git a/tensorflow/core/util/tensor_format.h b/tensorflow/core/util/tensor_format.h index fe89fe852e2..8c76f0f3c5a 100644 --- a/tensorflow/core/util/tensor_format.h +++ b/tensorflow/core/util/tensor_format.h @@ -177,7 +177,7 @@ inline TensorShape ShapeFromFormat(TensorFormat format, int64 N, gtl::ArraySlice spatial, int64 C) { gtl::InlinedVector dim_sizes(spatial.size() + 2); dim_sizes[GetTensorBatchDimIndex(dim_sizes.size(), format)] = N; - for (int dim = 0; dim < spatial.size(); dim++) { + for (int dim = 0; static_cast(dim) < spatial.size(); dim++) { dim_sizes[GetTensorSpatialDimIndex(dim_sizes.size(), format, dim)] = spatial[dim]; } diff --git a/tensorflow/core/util/util.cc b/tensorflow/core/util/util.cc index 3481a6aaa4d..1e5a9c57126 100644 --- a/tensorflow/core/util/util.cc +++ b/tensorflow/core/util/util.cc @@ -85,7 +85,7 @@ void MovingAverage::AddValue(double v) { static char hex_char[] = "0123456789abcdef"; -string PrintMemory(const char* ptr, int n) { +string PrintMemory(const char* ptr, size_t n) { string ret; ret.resize(n * 3); for (int i = 0; i < n; ++i) { diff --git a/tensorflow/core/util/util.h b/tensorflow/core/util/util.h index c142f4d0d26..4adf2f14dcc 100644 --- a/tensorflow/core/util/util.h +++ b/tensorflow/core/util/util.h @@ -49,7 +49,7 @@ class MovingAverage { // Returns a string printing bytes in ptr[0..n). The output looks // like "00 01 ef cd cd ef". -string PrintMemory(const char* ptr, int n); +string PrintMemory(const char* ptr, size_t n); // Given a flattened index into a tensor, computes a string s so that // StrAppend("tensor", s) is a Python indexing expression. E.g., diff --git a/tensorflow/docs_src/api_guides/python/client.md b/tensorflow/docs_src/api_guides/python/client.md index f5bb256d870..97c19863600 100644 --- a/tensorflow/docs_src/api_guides/python/client.md +++ b/tensorflow/docs_src/api_guides/python/client.md @@ -3,7 +3,7 @@ This library contains classes for launching graphs and executing operations. -The @{$get_started} guide has +The @{$get_started/get_started} guide has examples of how a graph is launched in a @{tf.Session}. ## Session management diff --git a/tensorflow/docs_src/deploy/distributed.md b/tensorflow/docs_src/deploy/distributed.md index bcc5b92db88..cdfb4672fa0 100644 --- a/tensorflow/docs_src/deploy/distributed.md +++ b/tensorflow/docs_src/deploy/distributed.md @@ -2,7 +2,7 @@ This document shows how to create a cluster of TensorFlow servers, and how to distribute a computation graph across that cluster. We assume that you are -familiar with the @{$get_started$basic concepts} of +familiar with the @{$get_started/get_started$basic concepts} of writing TensorFlow programs. ## Hello distributed TensorFlow! diff --git a/tensorflow/docs_src/extend/architecture.md b/tensorflow/docs_src/extend/architecture.md index 085f74c0560..42721eb488c 100644 --- a/tensorflow/docs_src/extend/architecture.md +++ b/tensorflow/docs_src/extend/architecture.md @@ -7,7 +7,7 @@ learning models and system-level optimizations. This document describes the system architecture that makes possible this combination of scale and flexibility. It assumes that you have basic familiarity with TensorFlow programming concepts such as the computation graph, operations, -and sessions. See @{$get_started$Getting Started} +and sessions. See @{$get_started/get_started$Getting Started} for an introduction to these topics. Some familiarity with @{$distributed$distributed TensorFlow} will also be helpful. diff --git a/tensorflow/docs_src/get_started/embedding_viz.md b/tensorflow/docs_src/get_started/embedding_viz.md index 64042497035..f512d5d809b 100644 --- a/tensorflow/docs_src/get_started/embedding_viz.md +++ b/tensorflow/docs_src/get_started/embedding_viz.md @@ -39,7 +39,7 @@ labels/images to the data points. You can do this by generating a [metadata file](#metadata) containing the labels for each point and configuring the projector either by using our Python API, or manually constructing and saving a -[projector_config.pbtxt](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/tensorboard/plugins/projector/projector_config.proto) +[projector_config.pbtxt](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tensorboard/plugins/projector/projector_config.proto) in the same directory as your checkpoint file. ## Setup @@ -68,7 +68,7 @@ saver.save(session, os.path.join(LOG_DIR, "model.ckpt"), step) If you have any metadata (labels, images) associated with your embedding, you can tell TensorBoard about it either by directly storing a -[projector_config.pbtxt](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/tensorboard/plugins/projector/projector_config.proto) +[projector_config.pbtxt](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tensorboard/plugins/projector/projector_config.proto) in the LOG_DIR, or use our python API. For instance, the following projector_config.ptxt associates the @@ -91,7 +91,7 @@ N = 10000 # Number of items (vocab size). D = 200 # Dimensionality of the embedding. embedding_var = tf.Variable(tf.random_normal([N,D]), name='word_embedding') -# Format: tensorflow/contrib/tensorboard/plugins/projector/projector_config.proto +# Format: tensorflow/tensorboard/plugins/projector/projector_config.proto config = projector.ProjectorConfig() # You can add multiple embeddings. Here we add only one. diff --git a/tensorflow/docs_src/get_started/mnist/beginners.md b/tensorflow/docs_src/get_started/mnist/beginners.md index f6d6b230b39..b9a85f3f676 100644 --- a/tensorflow/docs_src/get_started/mnist/beginners.md +++ b/tensorflow/docs_src/get_started/mnist/beginners.md @@ -208,7 +208,8 @@ the \\(x\\)s, add a bias, and then apply softmax. If we write that out as equations, we get:
- +[y1, y2, y3] = softmax(W11*x1 + W12*x2 + W13*x3 + b1,  W21*x1 + W22*x2 + W23*x3 + b2,  W31*x1 + W32*x2 + W33*x3 + b3)
We can "vectorize" this procedure, turning it into a matrix multiplication @@ -216,7 +217,8 @@ and vector addition. This is helpful for computational efficiency. (It's also a useful way to think.)
- +[y1, y2, y3] = softmax([[W11, W12, W13], [W21, W22, W23], [W31, W32, W33]]*[x1, x2, x3] + [b1, b2, b3])
More compactly, we can just write: diff --git a/tensorflow/docs_src/get_started/mnist/pros.md b/tensorflow/docs_src/get_started/mnist/pros.md index 6f0ba492d30..5dbb00c0b52 100644 --- a/tensorflow/docs_src/get_started/mnist/pros.md +++ b/tensorflow/docs_src/get_started/mnist/pros.md @@ -65,12 +65,12 @@ programs is to first create a graph and then launch it in a session. Here we instead use the convenient `InteractiveSession` class, which makes TensorFlow more flexible about how you structure your code. It allows you to interleave operations which build a -@{$get_started#the_computational_graph$computation graph} +@{$get_started/get_started#the_computational_graph$computation graph} with ones that run the graph. This is particularly convenient when working in interactive contexts like IPython. If you are not using an `InteractiveSession`, then you should build the entire computation graph before starting a session and -@{$get_started#the_computational_graph$launching the graph}. +@{$get_started/get_started#the_computational_graph$launching the graph}. ```python import tensorflow as tf @@ -95,8 +95,8 @@ similar to that used in Theano or Torch. The role of the Python code is therefore to build this external computation graph, and to dictate which parts of the computation graph should be run. See -the @{$get_started#the_computational_graph$Computation Graph} -section of @{$get_started} for more detail. +the @{$get_started/get_started#the_computational_graph$Computation Graph} +section of @{$get_started/get_started} for more detail. ## Build a Softmax Regression Model diff --git a/tensorflow/docs_src/install/install_java.md b/tensorflow/docs_src/install/install_java.md index 56c4b8d2c58..111b046689e 100644 --- a/tensorflow/docs_src/install/install_java.md +++ b/tensorflow/docs_src/install/install_java.md @@ -105,7 +105,7 @@ As an example, these steps will create a Maven project that uses TensorFlow: mvn -q compile exec:java -The preceeding command should output Hello from version. If it +The preceding command should output Hello from version. If it does, you've succesfully set up TensorFlow for Java and are ready to use it in Maven projects. If not, check [Stack Overflow](http://stackoverflow.com/questions/tagged/tensorflow) diff --git a/tensorflow/docs_src/install/install_linux.md b/tensorflow/docs_src/install/install_linux.md index 80331e7ea8d..4a537135b6e 100644 --- a/tensorflow/docs_src/install/install_linux.md +++ b/tensorflow/docs_src/install/install_linux.md @@ -507,7 +507,7 @@ TensorFlow programs:
Hello, TensorFlow!
-If you are new to TensorFlow, see @{$get_started$Getting Started with TensorFlow}. +If you are new to TensorFlow, see @{$get_started/get_started$Getting Started with TensorFlow}. If the system outputs an error message instead of a greeting, see [Common installation problems](#common_installation_problems). diff --git a/tensorflow/docs_src/install/install_mac.md b/tensorflow/docs_src/install/install_mac.md index 592036d1eb3..68ad3ceba9c 100644 --- a/tensorflow/docs_src/install/install_mac.md +++ b/tensorflow/docs_src/install/install_mac.md @@ -449,7 +449,7 @@ writing TensorFlow programs:
Hello, TensorFlow!
If you are new to TensorFlow, see -@{$get_started$Getting Started with TensorFlow}. +@{$get_started/get_started$Getting Started with TensorFlow}. If the system outputs an error message instead of a greeting, see [Common installation problems](#common_installation_problems). diff --git a/tensorflow/docs_src/install/install_sources.md b/tensorflow/docs_src/install/install_sources.md index 3597a68b83b..25e428f574e 100644 --- a/tensorflow/docs_src/install/install_sources.md +++ b/tensorflow/docs_src/install/install_sources.md @@ -353,7 +353,7 @@ TensorFlow programs:
Hello, TensorFlow!
-If you are new to TensorFlow, see @{$get_started$Getting Started with +If you are new to TensorFlow, see @{$get_started/get_started$Getting Started with TensorFlow}. If the system outputs an error message instead of a greeting, see [Common diff --git a/tensorflow/docs_src/install/install_windows.md b/tensorflow/docs_src/install/install_windows.md index 3e5451ff4df..7a6f5793419 100644 --- a/tensorflow/docs_src/install/install_windows.md +++ b/tensorflow/docs_src/install/install_windows.md @@ -145,7 +145,7 @@ TensorFlow programs:
Hello, TensorFlow!
-If you are new to TensorFlow, see @{$get_started$Getting Started with +If you are new to TensorFlow, see @{$get_started/get_started$Getting Started with TensorFlow}. If the system outputs an error message instead of a greeting, see [Common @@ -193,5 +193,5 @@ ImportError: cannot import name 'descriptor'
No module named "pywrap_tensorflow"
- +
diff --git a/tensorflow/docs_src/programmers_guide/debugger.md b/tensorflow/docs_src/programmers_guide/debugger.md index 7ecddc548fe..6f442e6e0c4 100644 --- a/tensorflow/docs_src/programmers_guide/debugger.md +++ b/tensorflow/docs_src/programmers_guide/debugger.md @@ -418,8 +418,9 @@ that the graph contains, this kind of disk space issue can happen. There are three possible workarounds or solutions: 1. The constructors of `LocalCLIDebugWrapperSession` and `LocalCLIDebugHook` - provide a keyword argument, `dump_root`, with which you can specify the path + provide a keyword argument, `dump_root`, with which you can specify the path to which **tfdbg** dumps the debug data. For example: + ``` python # For LocalCLIDebugWrapperSession sess = tf_debug.LocalCLIDebugWrapperSession(dump_root="/with/lots/of/space") @@ -432,6 +433,7 @@ There are three possible workarounds or solutions: 2. Reduce the batch size used during the runs. 3. Use the filtering options of **tfdbg**'s `run` command to watch only specific nodes in the graph. For example: + ``` tfdbg> run --node_name_filter .*hidden.* tfdbg> run --op_type_filter Variable.* diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go index de014abafa9..bbf687e3229 100644 --- a/tensorflow/go/op/wrappers.go +++ b/tensorflow/go/op/wrappers.go @@ -9075,15 +9075,16 @@ func Rint(scope *Scope, x tf.Output) (y tf.Output) { return op.Output(0) } -// Compute the 1-dimensional discrete Fourier Transform over the inner-most +// Fast Fourier transform. // +// Computes the 1-dimensional discrete Fourier transform over the inner-most // dimension of `input`. // // Arguments: // input: A complex64 tensor. // // Returns A complex64 tensor of the same shape as `input`. The inner-most -// dimension of `input` is replaced with its 1D Fourier Transform. +// dimension of `input` is replaced with its 1D Fourier transform. // // @compatibility(numpy) // Equivalent to np.fft.fft @@ -9714,15 +9715,16 @@ func ResourceApplyCenteredRMSProp(scope *Scope, var_ tf.Output, mg tf.Output, ms return scope.AddOperation(opspec) } -// Compute the inverse 3-dimensional discrete Fourier Transform over the inner-most +// Inverse 3D fast Fourier transform. // -// 3 dimensions of `input`. +// Computes the inverse 3-dimensional discrete Fourier transform over the +// inner-most 3 dimensions of `input`. // // Arguments: // input: A complex64 tensor. // // Returns A complex64 tensor of the same shape as `input`. The inner-most 3 -// dimensions of `input` are replaced with their inverse 3D Fourier Transform. +// dimensions of `input` are replaced with their inverse 3D Fourier transform. // // @compatibility(numpy) // Equivalent to np.fft.ifftn with 3 dimensions. @@ -9991,15 +9993,16 @@ func ResourceApplyAdam(scope *Scope, var_ tf.Output, m tf.Output, v tf.Output, b return scope.AddOperation(opspec) } -// Compute the 3-dimensional discrete Fourier Transform over the inner-most 3 +// 3D fast Fourier transform. // +// Computes the 3-dimensional discrete Fourier transform over the inner-most 3 // dimensions of `input`. // // Arguments: // input: A complex64 tensor. // // Returns A complex64 tensor of the same shape as `input`. The inner-most 3 -// dimensions of `input` are replaced with their 3D Fourier Transform. +// dimensions of `input` are replaced with their 3D Fourier transform. // // @compatibility(numpy) // Equivalent to np.fft.fftn with 3 dimensions. @@ -10184,15 +10187,16 @@ func ResourceApplyProximalGradientDescent(scope *Scope, var_ tf.Output, alpha tf return scope.AddOperation(opspec) } -// Compute the 2-dimensional discrete Fourier Transform over the inner-most +// 2D fast Fourier transform. // +// Computes the 2-dimensional discrete Fourier transform over the inner-most // 2 dimensions of `input`. // // Arguments: // input: A complex64 tensor. // // Returns A complex64 tensor of the same shape as `input`. The inner-most 2 -// dimensions of `input` are replaced with their 2D Fourier Transform. +// dimensions of `input` are replaced with their 2D Fourier transform. // // @compatibility(numpy) // Equivalent to np.fft.fft2 @@ -10244,15 +10248,16 @@ func Fill(scope *Scope, dims tf.Output, value tf.Output) (output tf.Output) { return op.Output(0) } -// Compute the inverse 2-dimensional discrete Fourier Transform over the inner-most +// Inverse 2D fast Fourier transform. // -// 2 dimensions of `input`. +// Computes the inverse 2-dimensional discrete Fourier transform over the +// inner-most 2 dimensions of `input`. // // Arguments: // input: A complex64 tensor. // // Returns A complex64 tensor of the same shape as `input`. The inner-most 2 -// dimensions of `input` are replaced with their inverse 2D Fourier Transform. +// dimensions of `input` are replaced with their inverse 2D Fourier transform. // // @compatibility(numpy) // Equivalent to np.fft.ifft2 @@ -11256,8 +11261,9 @@ func Cross(scope *Scope, a tf.Output, b tf.Output) (product tf.Output) { return op.Output(0) } -// Compute the inverse 2-dimensional discrete Fourier Transform of a real-valued +// Inverse 2D real-valued fast Fourier transform. // +// Computes the inverse 2-dimensional discrete Fourier transform of a real-valued // signal over the inner-most 2 dimensions of `input`. // // The inner-most 2 dimensions of `input` are assumed to be the result of `RFFT2D`: @@ -11273,7 +11279,7 @@ func Cross(scope *Scope, a tf.Output, b tf.Output) (product tf.Output) { // // Returns A float32 tensor of the same rank as `input`. The inner-most 2 // dimensions of `input` are replaced with the `fft_length` samples of their -// inverse 2D Fourier Transform. +// inverse 2D Fourier transform. // // @compatibility(numpy) // Equivalent to np.fft.irfft2 @@ -14173,15 +14179,16 @@ func Zeta(scope *Scope, x tf.Output, q tf.Output) (z tf.Output) { return op.Output(0) } -// Compute the inverse 1-dimensional discrete Fourier Transform over the inner-most +// Inverse fast Fourier transform. // -// dimension of `input`. +// Computes the inverse 1-dimensional discrete Fourier transform over the +// inner-most dimension of `input`. // // Arguments: // input: A complex64 tensor. // // Returns A complex64 tensor of the same shape as `input`. The inner-most -// dimension of `input` is replaced with its inverse 1D Fourier Transform. +// dimension of `input` is replaced with its inverse 1D Fourier transform. // // @compatibility(numpy) // Equivalent to np.fft.ifft @@ -14200,8 +14207,9 @@ func IFFT(scope *Scope, input tf.Output) (output tf.Output) { return op.Output(0) } -// Compute the inverse 1-dimensional discrete Fourier Transform of a real-valued +// Inverse real-valued fast Fourier transform. // +// Computes the inverse 1-dimensional discrete Fourier transform of a real-valued // signal over the inner-most dimension of `input`. // // The inner-most dimension of `input` is assumed to be the result of `RFFT`: the @@ -14217,7 +14225,7 @@ func IFFT(scope *Scope, input tf.Output) (output tf.Output) { // // Returns A float32 tensor of the same rank as `input`. The inner-most // dimension of `input` is replaced with the `fft_length` samples of its inverse -// 1D Fourier Transform. +// 1D Fourier transform. // // @compatibility(numpy) // Equivalent to np.fft.irfft @@ -14361,8 +14369,9 @@ func AssignAddVariableOp(scope *Scope, resource tf.Output, value tf.Output) (o * return scope.AddOperation(opspec) } -// Compute the 1-dimensional discrete Fourier Transform of a real-valued signal +// Real-valued fast Fourier transform. // +// Computes the 1-dimensional discrete Fourier transform of a real-valued signal // over the inner-most dimension of `input`. // // Since the DFT of a real signal is Hermitian-symmetric, `RFFT` only returns the @@ -14375,7 +14384,7 @@ func AssignAddVariableOp(scope *Scope, resource tf.Output, value tf.Output) (o * // // Returns A complex64 tensor of the same rank as `input`. The inner-most // dimension of `input` is replaced with the `fft_length / 2 + 1` unique -// frequency components of its 1D Fourier Transform. +// frequency components of its 1D Fourier transform. // // @compatibility(numpy) // Equivalent to np.fft.rfft @@ -14951,8 +14960,9 @@ func StringSplit(scope *Scope, input tf.Output, delimiter tf.Output) (indices tf return op.Output(0), op.Output(1), op.Output(2) } -// Compute the inverse 3-dimensional discrete Fourier Transform of a real-valued +// Inverse 3D real-valued fast Fourier transform. // +// Computes the inverse 3-dimensional discrete Fourier transform of a real-valued // signal over the inner-most 3 dimensions of `input`. // // The inner-most 3 dimensions of `input` are assumed to be the result of `RFFT3D`: @@ -14968,7 +14978,7 @@ func StringSplit(scope *Scope, input tf.Output, delimiter tf.Output) (indices tf // // Returns A float32 tensor of the same rank as `input`. The inner-most 3 // dimensions of `input` are replaced with the `fft_length` samples of their -// inverse 3D real Fourier Transform. +// inverse 3D real Fourier transform. // // @compatibility(numpy) // Equivalent to np.irfftn with 3 dimensions. @@ -16420,8 +16430,9 @@ func Erfc(scope *Scope, x tf.Output) (y tf.Output) { return op.Output(0) } -// Compute the 2-dimensional discrete Fourier Transform of a real-valued signal +// 2D real-valued fast Fourier transform. // +// Computes the 2-dimensional discrete Fourier transform of a real-valued signal // over the inner-most 2 dimensions of `input`. // // Since the DFT of a real signal is Hermitian-symmetric, `RFFT2D` only returns the @@ -16434,7 +16445,7 @@ func Erfc(scope *Scope, x tf.Output) (y tf.Output) { // fft_length: An int32 tensor of shape [2]. The FFT length for each dimension. // // Returns A complex64 tensor of the same rank as `input`. The inner-most 2 -// dimensions of `input` are replaced with their 2D Fourier Transform. The +// dimensions of `input` are replaced with their 2D Fourier transform. The // inner-most dimension contains `fft_length / 2 + 1` unique frequency // components. // @@ -20064,8 +20075,9 @@ func TensorArrayGatherV3(scope *Scope, handle tf.Output, indices tf.Output, flow return op.Output(0) } -// Compute the 3-dimensional discrete Fourier Transform of a real-valued signal +// 3D real-valued fast Fourier transform. // +// Computes the 3-dimensional discrete Fourier transform of a real-valued signal // over the inner-most 3 dimensions of `input`. // // Since the DFT of a real signal is Hermitian-symmetric, `RFFT3D` only returns the @@ -20078,7 +20090,7 @@ func TensorArrayGatherV3(scope *Scope, handle tf.Output, indices tf.Output, flow // fft_length: An int32 tensor of shape [3]. The FFT length for each dimension. // // Returns A complex64 tensor of the same rank as `input`. The inner-most 3 -// dimensions of `input` are replaced with the their 3D Fourier Transform. The +// dimensions of `input` are replaced with the their 3D Fourier transform. The // inner-most dimension contains `fft_length / 2 + 1` unique frequency // components. // diff --git a/tensorflow/java/README.md b/tensorflow/java/README.md index 1a9c99bd759..42cd24a7df4 100644 --- a/tensorflow/java/README.md +++ b/tensorflow/java/README.md @@ -24,7 +24,7 @@ your project's `pom.xml`: org.tensorflow tensorflow - 1.1.0-rc0-windows-fix + 1.1.0-rc1 ``` @@ -50,7 +50,7 @@ That's all. As an example, to create a Maven project for the org.tensorflow tensorflow - 1.1.0-rc0-windows-fix + 1.1.0-rc1 @@ -75,22 +75,22 @@ That's all. As an example, to create a Maven project for the This section describes how to use TensorFlow armed with just a JDK installation. 1. Download the Java archive (JAR): - [libtensorflow.jar](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-1.1.0-rc0.jar) + [libtensorflow.jar](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-1.1.0-rc1.jar) (optionally, the Java sources: - [libtensorflow-src.jar](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-src-1.1.0-rc0.jar)). + [libtensorflow-src.jar](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-src-1.1.0-rc1.jar)). 2. Download the native library. GPU-enabled versions required CUDA 8 and cuDNN 5.1. For other versions, the native library will need to be built from source (see below). - Linux: - [CPU-only](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-linux-x86_64-1.1.0-rc0.tar.gz), - [GPU-enabled](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-gpu-linux-x86_64-1.1.0-rc0.tar.gz) + [CPU-only](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-linux-x86_64-1.1.0-rc1.tar.gz), + [GPU-enabled](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-gpu-linux-x86_64-1.1.0-rc1.tar.gz) - OS X: - [CPU-only](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-darwin-x86_64-1.1.0-rc0.tar.gz), - [GPU-enabled](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-gpu-darwin-x86_64-1.1.0-rc0.tar.gz) + [CPU-only](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-darwin-x86_64-1.1.0-rc1.tar.gz), + [GPU-enabled](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-gpu-darwin-x86_64-1.1.0-rc1.tar.gz) - Windows: - [CPU-only](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-windows-x86_64-1.1.0-rc0.zip) + [CPU-only](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-windows-x86_64-1.1.0-rc1.zip) The following shell snippet downloads and extracts the native library on @@ -101,7 +101,7 @@ This section describes how to use TensorFlow armed with just a JDK installation. OS=$(uname -s | tr '[:upper:]' '[:lower:]') mkdir -p ./jni curl -L \ - "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-${TF_TYPE}-${OS}-x86_64-1.1.0-rc0.tar.gz" | + "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-${TF_TYPE}-${OS}-x86_64-1.1.0-rc1.tar.gz" | tar -xz -C ./jni ``` @@ -121,7 +121,7 @@ This section describes how to use TensorFlow armed with just a JDK installation. then it should be compiled with: ```sh - javac -cp libtensorflow-1.1.0-rc0.jar MyClass.java + javac -cp libtensorflow-1.1.0-rc1.jar MyClass.java ``` For a more sophisticated example, see @@ -130,7 +130,7 @@ This section describes how to use TensorFlow armed with just a JDK installation. ```sh javac \ - -cp libtensorflow-1.1.0-rc0.jar \ + -cp libtensorflow-1.1.0-rc1.jar \ ./src/main/java/org/tensorflow/examples/LabelImage.java ``` @@ -138,7 +138,7 @@ This section describes how to use TensorFlow armed with just a JDK installation. library path during execution. For example: ```sh - java -cp libtensorflow-1.1.0-rc0.jar:. -Djava.library.path=./jni MyClass + java -cp libtensorflow-1.1.0-rc1.jar:. -Djava.library.path=./jni MyClass ``` or for the `LabelImage` example: @@ -146,7 +146,7 @@ This section describes how to use TensorFlow armed with just a JDK installation. ```sh java \ -Djava.library.path=./jni \ - -cp libtensorflow-1.1.0-rc0.jar:./src/main/java \ + -cp libtensorflow-1.1.0-rc1.jar:./src/main/java \ org.tensorflow.examples.LabelImage ``` diff --git a/tensorflow/java/maven/libtensorflow/pom.xml b/tensorflow/java/maven/libtensorflow/pom.xml index d8d6a50da77..ec044a3ff68 100644 --- a/tensorflow/java/maven/libtensorflow/pom.xml +++ b/tensorflow/java/maven/libtensorflow/pom.xml @@ -6,7 +6,7 @@ org.tensorflow parentpom - 1.1.0-rc0 + 1.1.0-rc1 ../ libtensorflow diff --git a/tensorflow/java/maven/libtensorflow_jni/pom.xml b/tensorflow/java/maven/libtensorflow_jni/pom.xml index a675859638b..8a243fe8348 100644 --- a/tensorflow/java/maven/libtensorflow_jni/pom.xml +++ b/tensorflow/java/maven/libtensorflow_jni/pom.xml @@ -6,7 +6,7 @@ org.tensorflow parentpom - 1.1.0-rc0 + 1.1.0-rc1 ../ libtensorflow_jni diff --git a/tensorflow/java/maven/pom.xml b/tensorflow/java/maven/pom.xml index 35f9c2ecc34..4939ed0b630 100644 --- a/tensorflow/java/maven/pom.xml +++ b/tensorflow/java/maven/pom.xml @@ -6,7 +6,7 @@ 4.0.0 org.tensorflow parentpom - 1.1.0-rc0 + 1.1.0-rc1 pom https://www.tensorflow.org diff --git a/tensorflow/java/maven/tensorflow/pom.xml b/tensorflow/java/maven/tensorflow/pom.xml index d54face7c49..4b17d5d25ac 100644 --- a/tensorflow/java/maven/tensorflow/pom.xml +++ b/tensorflow/java/maven/tensorflow/pom.xml @@ -6,7 +6,7 @@ org.tensorflow parentpom - 1.1.0-rc0 + 1.1.0-rc1 ../ tensorflow diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 9db763ce78f..4f7d2590456 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -3307,39 +3307,6 @@ py_test( ], ) -py_library( - name = "docs", - srcs = ["framework/docs.py"], - srcs_version = "PY2AND3", -) - -py_library( - name = "gen_docs_combined_lib", - srcs = ["framework/gen_docs_combined.py"], - srcs_version = "PY2AND3", - deps = [ - ":docs", - "//tensorflow:tensorflow_py", - "//tensorflow/contrib/ffmpeg:ffmpeg_ops_py", - "//tensorflow/python/debug:debug_py", - ], -) - -py_binary( - name = "gen_docs_combined", - srcs = ["framework/gen_docs_combined.py"], - main = "framework/gen_docs_combined.py", - srcs_version = "PY2AND3", - deps = [ - ":client", - ":docs", - ":framework", - ":framework_for_generated_wrappers", - "//tensorflow:tensorflow_py", - "//tensorflow/python/debug:debug_py", - ], -) - # ----------------------------------------------------------------------------- # Quantization diff --git a/tensorflow/python/client/session_test.py b/tensorflow/python/client/session_test.py index e53a046a34d..13c27366090 100644 --- a/tensorflow/python/client/session_test.py +++ b/tensorflow/python/client/session_test.py @@ -1404,6 +1404,14 @@ class SessionTest(test_util.TensorFlowTestCase): r2 = sess.partial_run(h, [b, c]) self.assertEqual(r1, r2) + def runTestPartialRunMissingPlaceholderFeedException(self, sess): + x = array_ops.placeholder(dtypes.float32, shape=()) + fetches = [x * 2, x * 3] + handle = sess.partial_run_setup(fetches=fetches, feeds=[]) + with self.assertRaisesRegexp(errors.InvalidArgumentError, + 'You must feed a value for placeholder'): + sess.partial_run(handle, fetches[0]) + def testPartialRunDirect(self): self.runTestPartialRun(session.Session()) @@ -1419,6 +1427,9 @@ class SessionTest(test_util.TensorFlowTestCase): def testRunAndPartialRunDirect(self): self.runTestRunAndPartialRun(session.Session()) + def testPartialRunMissingPlaceholderFeedExceptionDirect(self): + self.runTestPartialRunMissingPlaceholderFeedException(session.Session()) + def testPartialRunDist(self): server = server_lib.Server.create_local_server() self.runTestPartialRun(session.Session(server.target)) @@ -1439,6 +1450,11 @@ class SessionTest(test_util.TensorFlowTestCase): server = server_lib.Server.create_local_server() self.runTestRunAndPartialRun(session.Session(server.target)) + def testPartialRunMissingPlaceholderFeedExceptionDist(self): + server = server_lib.Server.create_local_server() + self.runTestPartialRunMissingPlaceholderFeedException( + session.Session(server.target)) + def testFeedDictKeyException(self): with session.Session() as sess: a = constant_op.constant(1.0, dtypes.float32, name='a') diff --git a/tensorflow/python/debug/lib/session_debug_testlib.py b/tensorflow/python/debug/lib/session_debug_testlib.py index 511ddb16736..5acc1962cb9 100644 --- a/tensorflow/python/debug/lib/session_debug_testlib.py +++ b/tensorflow/python/debug/lib/session_debug_testlib.py @@ -484,9 +484,19 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase): sess.run(variables.global_variables_initializer()) run_options = config_pb2.RunOptions(output_partition_graphs=True) - debug_utils.watch_graph(run_options, - sess.graph, - debug_urls=self._debug_urls()) + if any(url.startswith("grpc://") for url in self._debug_urls()): + debug_utils.watch_graph_with_blacklists( + run_options, + sess.graph, + node_name_regex_blacklist="(.*rnn/while/.*|.*TensorArray.*)", + debug_urls=self._debug_urls()) + # b/36870549: Nodes with these name patterns need to be excluded from + # tfdbg in order to prevent MSAN warnings of uninitialized Tensors + # under the grpc:// debug URL scheme. + else: + debug_utils.watch_graph( + run_options, sess.graph, debug_urls=self._debug_urls()) + run_metadata = config_pb2.RunMetadata() sess.run(train_op, feed_dict={concat_inputs: input_values}, options=run_options, run_metadata=run_metadata) diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py index d83cd8b3702..c20a24b2ee3 100644 --- a/tensorflow/python/estimator/estimator.py +++ b/tensorflow/python/estimator/estimator.py @@ -198,10 +198,10 @@ class Estimator(object): if (steps is not None) and (max_steps is not None): raise ValueError('Can not provide both steps and max_steps.') if steps is not None and steps <= 0: - raise ValueError('Must specify steps >= 0, given: {}'.format(steps)) + raise ValueError('Must specify steps > 0, given: {}'.format(steps)) if max_steps is not None and max_steps <= 0: raise ValueError( - 'Must specify max_steps >= 0, given: {}'.format(max_steps)) + 'Must specify max_steps > 0, given: {}'.format(max_steps)) if max_steps is not None: start_step = _load_global_step_from_checkpoint_dir(self._model_dir) @@ -256,7 +256,7 @@ class Estimator(object): hooks = _check_hooks_type(hooks) if steps is not None: if steps <= 0: - raise ValueError('Must specify steps >= 0, given: {}'.format(steps)) + raise ValueError('Must specify steps > 0, given: {}'.format(steps)) hooks.append(evaluation._StopAfterNEvalsHook( # pylint: disable=protected-access num_evals=steps)) diff --git a/tensorflow/python/estimator/estimator_test.py b/tensorflow/python/estimator/estimator_test.py index cac0b55bd38..889ba3cf38e 100644 --- a/tensorflow/python/estimator/estimator_test.py +++ b/tensorflow/python/estimator/estimator_test.py @@ -239,25 +239,25 @@ class EstimatorTrainTest(test.TestCase): def test_steps0_raises_error(self): est = estimator.Estimator( model_fn=_model_fn_with_eval_metric_ops) - with self.assertRaisesRegexp(ValueError, 'Must specify steps >= 0'): + with self.assertRaisesRegexp(ValueError, 'Must specify steps > 0'): est.train(dummy_input_fn, steps=0) def test_steps_negative_raises_error(self): est = estimator.Estimator( model_fn=_model_fn_with_eval_metric_ops) - with self.assertRaisesRegexp(ValueError, 'Must specify steps >= 0'): + with self.assertRaisesRegexp(ValueError, 'Must specify steps > 0'): est.train(dummy_input_fn, steps=-1) def test_max_steps0_raises_error(self): est = estimator.Estimator( model_fn=_model_fn_with_eval_metric_ops) - with self.assertRaisesRegexp(ValueError, 'Must specify max_steps >= 0'): + with self.assertRaisesRegexp(ValueError, 'Must specify max_steps > 0'): est.train(dummy_input_fn, max_steps=0) def test_max_steps_negative_raises_error(self): est = estimator.Estimator( model_fn=_model_fn_with_eval_metric_ops) - with self.assertRaisesRegexp(ValueError, 'Must specify max_steps >= 0'): + with self.assertRaisesRegexp(ValueError, 'Must specify max_steps > 0'): est.train(dummy_input_fn, max_steps=-1) def test_scaffold_is_used(self): @@ -475,14 +475,14 @@ class EstimatorEvaluateTest(test.TestCase): est = estimator.Estimator( model_fn=_model_fn_with_eval_metric_ops) est.train(dummy_input_fn, steps=5) - with self.assertRaisesRegexp(ValueError, 'Must specify steps >= 0'): + with self.assertRaisesRegexp(ValueError, 'Must specify steps > 0'): est.evaluate(dummy_input_fn, steps=0) def test_steps_negative_raises_error(self): est = estimator.Estimator( model_fn=_model_fn_with_eval_metric_ops) est.train(dummy_input_fn, steps=5) - with self.assertRaisesRegexp(ValueError, 'Must specify steps >= 0'): + with self.assertRaisesRegexp(ValueError, 'Must specify steps > 0'): est.evaluate(dummy_input_fn, steps=-1) def test_global_step_metric_raises_error(self): diff --git a/tensorflow/python/framework/docs.py b/tensorflow/python/framework/docs.py deleted file mode 100644 index 4ae0046117b..00000000000 --- a/tensorflow/python/framework/docs.py +++ /dev/null @@ -1,647 +0,0 @@ -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -"""Updates generated docs from Python doc comments. - -Updates the documentation files. -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import functools -import inspect -import os -import re - - -_arg_re = re.compile(" *([*]{0,2}[a-zA-Z][a-zA-Z0-9_]*):") -_section_re = re.compile("([A-Z][a-zA-Z ]*):$") -_always_drop_symbol_re = re.compile("_[_a-zA-Z0-9]") -_anchor_re = re.compile(r"^[\w.]+$") -_member_mark = "@@" -_indiv_dir = "functions_and_classes" -_num_subdirs = 10 -_subdir_prefix = "shard" - - -class Document(object): - """Base class for an automatically generated document.""" - - def write_markdown_to_file(self, f): - """Writes a Markdown-formatted version of this document to file `f`. - - Args: - f: The output file. - """ - raise NotImplementedError("Document.WriteToFile") - - -class Index(Document): - """An automatically generated index for a collection of documents.""" - - def __init__(self, module_to_name, members, filename_to_library_map, - path_prefix): - """Creates a new Index. - - Args: - module_to_name: Dictionary mapping modules to short names. - members: Dictionary mapping member name to (fullname, member). - filename_to_library_map: A list of (filename, Library) pairs. The order - corresponds to the order in which the libraries appear in the index. - path_prefix: Prefix to add to links in the index. - """ - self._module_to_name = module_to_name - self._members = members - self._filename_to_library_map = filename_to_library_map - self._path_prefix = path_prefix - - def write_markdown_to_file(self, f): - """Writes this index to file `f`. - - The output is formatted as an unordered list. Each list element - contains the title of the library, followed by a list of symbols - in that library hyperlinked to the corresponding anchor in that - library. - - Args: - f: The output file. - """ - print("", file=f) - print("", file=f) - print("# TensorFlow Python reference documentation", file=f) - print("", file=f) - fullname_f = lambda name: self._members[name][0] - anchor_f = lambda name: get_anchor(self._module_to_name, fullname_f(name)) - - for filename, library in self._filename_to_library_map: - sorted_names = sorted(library.mentioned, key=lambda x: (str.lower(x), x)) - member_names = [n for n in sorted_names if n in self._members] - # TODO(wicke): This is a hack that should be removed as soon as the - # website code allows it. - full_filename = self._path_prefix + filename - links = ["[`%s`](%s#%s)" % (name, full_filename, anchor_f(name)) - for name in member_names] - if links: - print("* **[%s](%s)**:" % (library.title, full_filename), file=f) - for link in links: - print(" * %s" % link, file=f) - print("", file=f) - - -def collect_members(module_to_name, exclude=()): - """Collect all symbols from a list of modules. - - Args: - module_to_name: Dictionary mapping modules to short names. - exclude: Set of fully qualified names to exclude. - - Returns: - Dictionary mapping name to (fullname, member) pairs. - - Raises: - RuntimeError: if we can not resolve a name collision. - """ - members = {} - for module, module_name in module_to_name.items(): - all_names = getattr(module, "__all__", None) - for name, member in inspect.getmembers(module): - if ((inspect.isfunction(member) - or inspect.isclass(member) - or isinstance(member, functools.partial)) - and not _always_drop_symbol_re.match(name) and - (all_names is None or name in all_names)): - fullname = "%s.%s" % (module_name, name) - if fullname in exclude: - continue - if name in members: - other_fullname, other_member = members[name] - if member is not other_member: - raise RuntimeError("Short name collision between %s and %s" % - (fullname, other_fullname)) - if len(fullname) == len(other_fullname): - raise RuntimeError("Can't decide whether to use %s or %s for %s: " - "both full names have length %d" % - (fullname, other_fullname, name, len(fullname))) - if len(fullname) > len(other_fullname): - continue # Use the shorter full name - members[name] = fullname, member - return members - - -def get_anchor(module_to_name, fullname): - """Turn a full member name into an anchor. - - Args: - module_to_name: Dictionary mapping modules to short names. - fullname: Fully qualified name of symbol. - - Returns: - HTML anchor string. The longest module name prefix of fullname is - removed to make the anchor. - - Raises: - ValueError: If fullname uses characters invalid in an anchor. - """ - if not _anchor_re.match(fullname): - raise ValueError("'%s' is not a valid anchor" % fullname) - anchor = fullname - for module_name in module_to_name.values(): - if fullname.startswith(module_name + "."): - rest = fullname[len(module_name)+1:] - # Use this prefix iff it is longer than any found before - if len(anchor) > len(rest): - anchor = rest - return anchor - - -def _stable_hash(s): - """A simple string hash that won't change from run to run.""" - ret = 0 - for c in s: - ret = ret * 97 + ord(c) - return ret - - -class Library(Document): - """An automatically generated document for a set of functions and classes.""" - - def __init__(self, - title, - module, - module_to_name, - members, - documented, - exclude_symbols=(), - prefix=None): - """Creates a new Library. - - Args: - title: A human-readable title for the library. - module: Module to pull high level docstring from (for table of contents, - list of Ops to document, etc.). - module_to_name: Dictionary mapping modules to short names. - members: Dictionary mapping member name to (fullname, member). - documented: Set of documented names to update. - exclude_symbols: A list of specific symbols to exclude. - prefix: A string to include at the beginning of the page. - """ - self._title = title - self._module = module - self._module_to_name = module_to_name - self._members = dict(members) # Copy since we mutate it below - self._exclude_symbols = frozenset(exclude_symbols) - documented.update(exclude_symbols) - self._documented = documented - self._mentioned = set() - self._prefix = prefix or "" - - @property - def title(self): - """The human-readable title for this library.""" - return self._title - - @property - def mentioned(self): - """Set of names mentioned in this library.""" - return self._mentioned - - @property - def exclude_symbols(self): - """Set of excluded symbols.""" - return self._exclude_symbols - - def _should_include_member(self, name): - """Returns True if this member should be included in the document.""" - # __x__ should be documented always - name_is_operator = name.startswith("__") and name.endswith("__") - name_is_private = name.startswith("_") and not name_is_operator - name_is_excluded = name in self._exclude_symbols - return not (name_is_private or name_is_excluded) - - def get_imported_modules(self, module): - """Returns the list of modules imported from `module`.""" - for name, member in inspect.getmembers(module): - if inspect.ismodule(member): - yield name, member - - def get_class_members(self, cls_name, cls): - """Returns the list of class members to document in `cls`. - - This function filters the class member to ONLY return those - defined by the class. It drops the inherited ones. - - Args: - cls_name: Qualified name of `cls`. - cls: An inspect object of type 'class'. - - Yields: - name, member tuples. - """ - for name, member in inspect.getmembers(cls): - # Only show methods and properties presently. In Python 3, - # methods register as isfunction. - is_method = (inspect.ismethod(member) or inspect.isfunction(member) - or isinstance(member, functools.partial)) - if not (is_method or isinstance(member, property)): - continue - if self._should_include_member(name): - yield name, ("%s.%s" % (cls_name, name), member) - - def shard_dir(self, name): - """Returns the path of the doc subdirectory for member `name`. - - When generating individual files for each function and class, we shard - the files across several directories to avoid hitting the limit for - files per directory. This function determines the subdirectory for - a member based on a stable hash of its name. - - Args: - name: string. The name of a function or class. - - Returns: - The path to a subdirectory of the api docs directory. - """ - index = _stable_hash(name) % _num_subdirs - return os.path.join(self.functions_and_classes_dir, - _subdir_prefix + str(index)) - - def set_functions_and_classes_dir(self, dirname): - """Sets the name of the directory for function and class markdown files. - - Args: - dirname: string. The name of the directory in which to store function - and class markdown files. - """ - self.functions_and_classes_dir = dirname - - def _generate_signature_for_function(self, func): - """Given a function, returns a string representing its args.""" - args_list = [] - if isinstance(func, functools.partial): - argspec = inspect.getargspec(func.func) - # Remove the args from the original function that have been used up. - first_default_arg = ( - len(argspec.args or []) - len(argspec.defaults or [])) - partial_args = len(func.args) - if argspec.args: - argspec_args = list(argspec.args[partial_args:]) - else: - argspec_args = [] - if argspec.defaults: - argspec_defaults = list(argspec.defaults[ - max(0, partial_args-first_default_arg):]) - else: - argspec_defaults = [] - first_default_arg = max(0, first_default_arg - partial_args) - for kwarg in func.keywords: - if kwarg in argspec_args: - i = argspec_args.index(kwarg) - argspec_args.pop(i) - if i >= first_default_arg: - argspec_defaults.pop(i-first_default_arg) - else: - first_default_arg -= 1 - argspec_varargs = None - argspec_keywords = None - - else: - argspec = inspect.getargspec(func) - argspec_args = argspec.args - argspec_defaults = argspec.defaults - argspec_varargs = argspec.varargs - argspec_keywords = argspec.keywords - - first_arg_with_default = ( - len(argspec_args or []) - len(argspec_defaults or [])) - for arg in argspec_args[:first_arg_with_default]: - if arg == "self": - # Python documentation typically skips `self` when printing method - # signatures. - continue - args_list.append(arg) - - # TODO(mrry): This is a workaround for documenting signature of - # functions that have the @contextlib.contextmanager decorator. - # TODO(aselle): This workaround is brittle on TestCase.__call__ - # so we need to wrap this in a try/catch - # We should do something better. - if argspec_varargs == "args" and argspec_keywords == "kwds": - try: - original_func = func.__closure__[0].cell_contents - return self._generate_signature_for_function(original_func) - except TypeError: - pass - - if argspec_defaults: - for arg, default in zip( - argspec_args[first_arg_with_default:], argspec_defaults): - if callable(default): - if hasattr(default, "__name__"): - args_list.append("%s=%s" % (arg, default.__name__)) - else: - # A callable may be a class instance. - # TODO(fchollet): handle case with non-default constructor - # arguments (currently not present in the TF codebase). - args_list.append("%s=%s()" % (arg, default.__class__.__name__)) - else: - args_list.append("%s=%r" % (arg, default)) - if argspec_varargs: - args_list.append("*" + argspec_varargs) - if argspec_keywords: - args_list.append("**" + argspec_keywords) - return "(" + ", ".join(args_list) + ")" - - def _remove_docstring_indent(self, docstring): - """Remove indenting. - - We follow Python's convention and remove the minimum indent of the lines - after the first, see: - https://www.python.org/dev/peps/pep-0257/#handling-docstring-indentation - preserving relative indentation. - - Args: - docstring: A docstring. - - Returns: - A list of strings, one per line, with the minimum indent stripped. - """ - docstring = docstring or "" - lines = docstring.strip().split("\n") - - min_indent = len(docstring) - for l in lines[1:]: - l = l.rstrip() - if l: - i = 0 - while i < len(l) and l[i] == " ": - i += 1 - if i < min_indent: min_indent = i - for i in range(1, len(lines)): - l = lines[i].rstrip() - if len(l) >= min_indent: - l = l[min_indent:] - lines[i] = l - return lines - - def _print_formatted_docstring(self, docstring, f): - """Formats the given `docstring` as Markdown and prints it to `f`.""" - lines = self._remove_docstring_indent(docstring) - - # Output the lines, identifying "Args" and other section blocks. - i = 0 - - def _at_start_of_section(): - """Returns the header if lines[i] is at start of a docstring section.""" - l = lines[i] - match = _section_re.match(l) - if match and i + 1 < len( - lines) and lines[i + 1].startswith(" "): - return match.group(1) - else: - return None - - while i < len(lines): - l = lines[i] - - section_header = _at_start_of_section() - if section_header: - if i == 0 or lines[i-1]: - print("", file=f) - # Use at least H4 to keep these out of the TOC. - print("##### " + section_header + ":", file=f) - print("", file=f) - i += 1 - outputting_list = False - while i < len(lines): - l = lines[i] - # A new section header terminates the section. - if _at_start_of_section(): - break - match = _arg_re.match(l) - if match: - if not outputting_list: - # We need to start a list. In Markdown, a blank line needs to - # precede a list. - print("", file=f) - outputting_list = True - suffix = l[len(match.group()):].lstrip() - print("* `" + match.group(1) + "`: " + suffix, file=f) - else: - # For lines that don't start with _arg_re, continue the list if it - # has enough indentation. - outputting_list &= l.startswith(" ") - print(l, file=f) - i += 1 - else: - print(l, file=f) - i += 1 - - def _print_function(self, f, prefix, fullname, func): - """Prints the given function to `f`.""" - heading = prefix + " `" + fullname - if not isinstance(func, property): - heading += self._generate_signature_for_function(func) - heading += "` {#%s}" % get_anchor(self._module_to_name, fullname) - print(heading, file=f) - print("", file=f) - self._print_formatted_docstring(inspect.getdoc(func), f) - print("", file=f) - - def _write_member_markdown_to_file(self, f, prefix, name, member): - """Print `member` to `f`.""" - if (inspect.isfunction(member) or inspect.ismethod(member) - or (isinstance(member, functools.partial) - and inspect.isfunction(member.func)) - or isinstance(member, property)): - print("- - -", file=f) - print("", file=f) - self._print_function(f, prefix, name, member) - print("", file=f) - - # Write an individual file for each function. - if inspect.isfunction(member): - indivf = open( - os.path.join(self.shard_dir(name), name + ".md"), "w+") - self._print_function(indivf, prefix, name, member) - elif (inspect.isclass(member) - or (isinstance(member, functools.partial) - and inspect.isclass(member.func))): - print("- - -", file=f) - print("", file=f) - print("%s `class %s` {#%s}" % (prefix, name, - get_anchor(self._module_to_name, name)), - file=f) - print("", file=f) - self._write_class_markdown_to_file(f, name, member) - print("", file=f) - - # Write an individual file for each class. - indivf = open( - os.path.join(self.shard_dir(name), name + ".md"), "w+") - self._write_class_markdown_to_file(indivf, name, member) - else: - raise RuntimeError("Member %s has unknown type %s" % (name, type(member))) - - def _write_docstring_markdown_to_file(self, f, prefix, docstring, members, - imports): - for l in self._remove_docstring_indent(docstring): - if l.startswith(_member_mark): - name = l[len(_member_mark):].strip(" \t") - if name in members: - self._documented.add(name) - self._mentioned.add(name) - self._write_member_markdown_to_file(f, prefix, *members[name]) - del members[name] - elif name in imports: - self._write_module_markdown_to_file(f, imports[name]) - else: - raise ValueError("%s: unknown member `%s`, markdown=`%s`." % ( - self._title, name, l)) - else: - print(l, file=f) - - def _write_class_markdown_to_file(self, f, name, cls): - """Write the class doc to `f`. - - Args: - f: File to write to. - name: name to use. - cls: class object. - """ - # Build the list of class methods to document. - methods = dict(self.get_class_members(name, cls)) - # Used later to check if any methods were called out in the class - # docstring. - num_methods = len(methods) - try: - self._write_docstring_markdown_to_file(f, "####", inspect.getdoc(cls), - methods, {}) - except ValueError as e: - raise ValueError(str(e) + " in class `%s`" % cls.__name__) - - # If some methods were not described, describe them now if they are - # defined by the class itself (not inherited). If NO methods were - # described, describe all methods. - # - # TODO(touts): when all methods have been categorized make it an error - # if some methods are not categorized. - any_method_called_out = (len(methods) != num_methods) - if any_method_called_out: - other_methods = {n: m for n, m in methods.items() if n in cls.__dict__} - if other_methods: - print("\n#### Other Methods", file=f) - else: - other_methods = methods - for name in sorted(other_methods): - self._write_member_markdown_to_file(f, "####", *other_methods[name]) - - def _write_module_markdown_to_file(self, f, module): - imports = dict(self.get_imported_modules(module)) - self._write_docstring_markdown_to_file(f, "###", inspect.getdoc(module), - self._members, imports) - - def write_markdown_to_file(self, f): - """Prints this library to file `f`. - - Args: - f: File to write to. - - Returns: - Dictionary of documented members. - """ - print("", file=f) - print("", file=f) - # TODO(touts): Do not insert these. Let the doc writer put them in - # the module docstring explicitly. - print("#", self._title, file=f) - if self._prefix: - print(self._prefix, file=f) - print("[TOC]", file=f) - print("", file=f) - if self._module is not None: - self._write_module_markdown_to_file(f, self._module) - - def write_other_members(self, f, catch_all=False): - """Writes the leftover members to `f`. - - Args: - f: File to write to. - catch_all: If true, document all missing symbols from any module. - Otherwise, document missing symbols from just this module. - """ - if catch_all: - names = self._members.items() - else: - names = inspect.getmembers(self._module) - all_names = getattr(self._module, "__all__", None) - if all_names is not None: - names = [(n, m) for n, m in names if n in all_names] - leftovers = [] - for name, _ in names: - if name in self._members and name not in self._documented: - leftovers.append(name) - if leftovers: - print("%s: undocumented members: %d" % (self._title, len(leftovers))) - print("\n## Other Functions and Classes", file=f) - for name in sorted(leftovers): - print(" %s" % name) - self._documented.add(name) - self._mentioned.add(name) - self._write_member_markdown_to_file(f, "###", *self._members[name]) - - def assert_no_leftovers(self): - """Generate an error if there are leftover members.""" - leftovers = [] - for name in self._members: - if name in self._members and name not in self._documented: - leftovers.append(name) - if leftovers: - raise RuntimeError("%s: undocumented members: %s" % - (self._title, ", ".join(leftovers))) - - -def write_libraries(output_dir, libraries): - """Write a list of libraries to disk. - - Args: - output_dir: Output directory. - libraries: List of (filename, library) pairs. - """ - files = [open(os.path.join(output_dir, k), "w") for k, _ in libraries] - - # Set the directory in which to save individual class and function md files, - # creating it if it doesn't exist. Create subdirectories to avoid hitting - # the limit for number of files in a directory. - indiv_dir = os.path.join(output_dir, _indiv_dir) - if not os.path.exists(indiv_dir): - os.makedirs(indiv_dir) - - for i in range(0, _num_subdirs): - subdir = os.path.join(indiv_dir, _subdir_prefix + str(i)) - if not os.path.exists(subdir): - os.makedirs(subdir) - - # Document mentioned symbols for all libraries - for f, (_, v) in zip(files, libraries): - v.set_functions_and_classes_dir(indiv_dir) - v.write_markdown_to_file(f) - # Document symbols that no library mentioned. We do this after writing - # out all libraries so that earlier libraries know what later libraries - # documented. - for f, (_, v) in zip(files, libraries): - v.write_other_members(f) - f.close() diff --git a/tensorflow/python/framework/function.py b/tensorflow/python/framework/function.py index 3c663c3a9b2..8b156db6dc4 100644 --- a/tensorflow/python/framework/function.py +++ b/tensorflow/python/framework/function.py @@ -299,6 +299,7 @@ class _FuncGraph(ops.Graph): shape=None, dtype=None, initializer=None, + reuse=None, trainable=True, collections=None, # pylint: disable=redefined-outer-name use_resource=None, @@ -319,6 +320,7 @@ class _FuncGraph(ops.Graph): shape=shape, dtype=dtype, initializer=initializer, + reuse=reuse, trainable=trainable, collections=collections, use_resource=use_resource) @@ -886,6 +888,11 @@ class Defun(object): default graph. Because the addition of the function into the graph is deferred, the decorator can be used anywhere in the program. + Any variables created inside of the function are hoisted into the outer graph. + Note that the variables are created in the variable scope that was active + during the first call to the function. Subsequent function calls will refer to + the same set of variables. + Definitions of functions are frozen in a graph as soon as the graph is used to create a session. Therefore, nodes using the function must be created in the graph before the corresponding session is created. diff --git a/tensorflow/python/framework/function_test.py b/tensorflow/python/framework/function_test.py index 96bf7bde29f..51a9215ac4e 100644 --- a/tensorflow/python/framework/function_test.py +++ b/tensorflow/python/framework/function_test.py @@ -722,6 +722,58 @@ class FunctionTest(test.TestCase): y = Bar(array_ops.zeros([1, 2, 3])) self.assertAllEqual(y.get_shape().as_list(), [1, 1, 2, 3]) + def testVariableReuse(self): + def LinearWithReuse(input_tensor, reuse=None): + size = input_tensor.shape.dims[1] + with variable_scope.variable_scope("linear", reuse=reuse): + w = variable_scope.get_variable("w", shape=[size, size], + dtype=input_tensor.dtype) + return math_ops.matmul(input_tensor, w) + + @function.Defun(dtypes.float32) + def Foo(inputs): + inputs = array_ops.reshape(inputs, [32, 100]) + hidden = LinearWithReuse(inputs) + return LinearWithReuse(hidden, reuse=True) + + input_op = array_ops.placeholder(shape=[32, 100], dtype=dtypes.float32) + output_op = Foo(input_op) + + global_vars = variables.global_variables() + self.assertEqual(len(global_vars), 1) + self.assertEqual(global_vars[0].name, "linear/w:0") + + with session.Session() as sess: + sess.run(variables.global_variables_initializer()) + output_val = sess.run(output_op, + feed_dict={input_op: np.random.rand(32, 100)}) + self.assertEqual(output_val.shape, (32, 100)) + + def testFunctionCallInDifferentVariableScopes(self): + @function.Defun(dtypes.float32) + def Foo(inputs): + var = variable_scope.get_variable("var", shape=[10], dtype=dtypes.float32, + initializer=init_ops.ones_initializer()) + return inputs + var + + input_op = array_ops.placeholder(shape=[10], dtype=dtypes.float32) + with variable_scope.variable_scope("vs1"): + out1_op = Foo(input_op) + + with variable_scope.variable_scope("vs2"): + out2_op = Foo(input_op) + + global_vars = variables.global_variables() + self.assertEqual(len(global_vars), 1) + self.assertEqual(global_vars[0].name, "vs1/var:0") + + with session.Session() as sess: + sess.run(variables.global_variables_initializer()) + out1, out2 = sess.run([out1_op, out2_op], + feed_dict={input_op: np.linspace(1, 10, 10)}) + self.assertAllEqual(out1, np.linspace(2, 11, 10)) + self.assertAllEqual(out2, np.linspace(2, 11, 10)) + class FunctionsFromProtos(test.TestCase): diff --git a/tensorflow/python/framework/gen_docs_combined.py b/tensorflow/python/framework/gen_docs_combined.py deleted file mode 100644 index 65379dda209..00000000000 --- a/tensorflow/python/framework/gen_docs_combined.py +++ /dev/null @@ -1,332 +0,0 @@ -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -"""Updates generated docs from Python doc comments.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import argparse -import collections -import os.path -import sys - -import tensorflow as tf - -from tensorflow.contrib import ffmpeg -from tensorflow.python import debug as tf_debug -from tensorflow.python.client import client_lib -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import docs -from tensorflow.python.framework import framework_lib - -FLAGS = None - - -PREFIX_TEXT = """ -Note: Functions taking `Tensor` arguments can also take anything accepted by -@{tf.convert_to_tensor}. -""" - - -def module_names(): - return [ - "tf", - "tf.errors", - "tf.image", - "tf.nn", - "tf.train", - "tf.python_io", - "tf.saved_model", - "tf.summary", - "tf.test", - "tf.contrib.bayesflow.entropy", - "tf.contrib.bayesflow.monte_carlo", - "tf.contrib.bayesflow.stochastic_graph", - "tf.contrib.bayesflow.stochastic_tensor", - "tf.contrib.bayesflow.variational_inference", - "tf.contrib.copy_graph", - "tf.contrib.crf", - "tf.contrib.distributions", - "tf.contrib.distributions.bijector", - "tf.contrib.ffmpeg", - "tf.contrib.framework", - "tf.contrib.graph_editor", - "tf.contrib.integrate", - "tf.contrib.layers", - "tf.contrib.learn", - "tf.contrib.learn.monitors", - "tf.contrib.legacy_seq2seq", - "tf.contrib.linalg", - "tf.contrib.losses", - "tf.contrib.metrics", - "tf.contrib.opt", - "tf.contrib.rnn", - "tf.contrib.solvers", - "tf.contrib.training", - "tf.contrib.util", - "tf_debug", - ] - - -def find_module(base_module, name): - if name == "tf": - return base_module - # Special case for ffmpeg is needed since it's not linked in by default due - # to size concerns. - elif name == "tf.contrib.ffmpeg": - return ffmpeg - elif name == "tf_debug": - return tf_debug - elif name.startswith("tf."): - subname = name[3:] - subnames = subname.split(".") - parent_module = base_module - for s in subnames: - if not hasattr(parent_module, s): - raise ValueError( - "Module not found: {}. Submodule {} not found in parent module {}." - " Possible candidates are {}".format( - name, s, parent_module.__name__, dir(parent_module))) - parent_module = getattr(parent_module, s) - return parent_module - else: - raise ValueError( - "Invalid module name: {}. Module names must start with 'tf.'".format( - name)) - - -def get_module_to_name(names): - return collections.OrderedDict([(find_module(tf, x), x) for x in names]) - - -def all_libraries(module_to_name, members, documented): - """Make a list of the individual files that we want to create. - - Args: - module_to_name: Dictionary mapping modules to short names. - members: Dictionary mapping member name to (fullname, member). - documented: Set of documented names to update. - - Returns: - List of (filename, docs.Library) pairs. - """ - def library(name, title, module=None, **args): - if module is None: - module = sys.modules["tensorflow.python.ops." + name] - return (name + ".md", docs.Library(title=title, - module_to_name=module_to_name, - members=members, - documented=documented, - module=module, - **args)) - return collections.OrderedDict([ - # Splits of module 'tf'. - library("framework", "Building Graphs", framework_lib), - library("check_ops", "Asserts and boolean checks."), - library("constant_op", "Constants, Sequences, and Random Values", - constant_op, prefix=PREFIX_TEXT), - library("state_ops", - "Variables", - exclude_symbols=["create_partitioned_variables"], - prefix=PREFIX_TEXT), - library("array_ops", - "Tensor Transformations", - exclude_symbols=["list_diff"], - prefix=PREFIX_TEXT), - library("math_ops", - "Math", - exclude_symbols=["sparse_matmul", "arg_min", "arg_max", - "lin_space", "sparse_segment_mean_grad"], - prefix=PREFIX_TEXT), - library("string_ops", "Strings", - prefix=PREFIX_TEXT), - library("histogram_ops", "Histograms"), - library("control_flow_ops", "Control Flow", prefix=PREFIX_TEXT), - library("functional_ops", "Higher Order Functions", prefix=PREFIX_TEXT), - library("tensor_array_ops", "TensorArray Operations", prefix=PREFIX_TEXT), - library("session_ops", "Tensor Handle Operations", prefix=PREFIX_TEXT), - library("image", "Images", tf.image, exclude_symbols=["ResizeMethod"], - prefix=PREFIX_TEXT), - library("sparse_ops", - "Sparse Tensors", - exclude_symbols=["serialize_sparse", "serialize_many_sparse", - "deserialize_many_sparse"], - prefix=PREFIX_TEXT), - library("io_ops", - "Inputs and Readers", - exclude_symbols=["LookupTableBase", "HashTable", - "initialize_all_tables", - "tables_initializer", - "parse_single_sequence_example", - "string_to_hash_bucket"], - prefix=PREFIX_TEXT), - library("python_io", "Data IO (Python functions)", tf.python_io), - library("nn", - "Neural Network", - tf.nn, - exclude_symbols=["conv2d_backprop_input", - "conv2d_backprop_filter", "avg_pool_grad", - "max_pool_grad", "max_pool_grad_with_argmax", - "batch_norm_with_global_normalization_grad", - "lrn_grad", "relu6_grad", "softplus_grad", - "softsign_grad", "xw_plus_b", "relu_layer", - "lrn", "batch_norm_with_global_normalization", - "batch_norm_with_global_normalization_grad", - "all_candidate_sampler", "seq2seq"], - prefix=PREFIX_TEXT), - library("client", "Running Graphs", client_lib), - library("train", - "Training", - tf.train, - exclude_symbols=["Feature", "Features", "BytesList", "FloatList", - "Int64List", "Example", "InferenceExample", - "FeatureList", "FeatureLists", "RankingExample", - "SequenceExample"]), - library("script_ops", - "Wraps python functions", - prefix=PREFIX_TEXT), - library("summary", "Summary Operations", tf.summary), - library("test", "Testing", tf.test), - library("contrib.bayesflow.entropy", - "BayesFlow Entropy (contrib)", - tf.contrib.bayesflow.entropy), - library("contrib.bayesflow.monte_carlo", - "BayesFlow Monte Carlo (contrib)", - tf.contrib.bayesflow.monte_carlo), - library("contrib.bayesflow.stochastic_graph", - "BayesFlow Stochastic Graph (contrib)", - tf.contrib.bayesflow.stochastic_graph), - library("contrib.bayesflow.stochastic_tensor", - "BayesFlow Stochastic Tensors (contrib)", - tf.contrib.bayesflow.stochastic_tensor), - library("contrib.bayesflow.variational_inference", - "BayesFlow Variational Inference (contrib)", - tf.contrib.bayesflow.variational_inference), - library("contrib.crf", "CRF (contrib)", tf.contrib.crf), - library("contrib.distributions", "Statistical Distributions (contrib)", - tf.contrib.distributions), - library("contrib.distributions.bijector", - "Random variable transformations (contrib)", - tf.contrib.distributions.bijector), - library("contrib.ffmpeg", "FFmpeg (contrib)", ffmpeg), - library("contrib.framework", "Framework (contrib)", tf.contrib.framework), - library("contrib.graph_editor", "Graph Editor (contrib)", - tf.contrib.graph_editor), - library("contrib.integrate", "Integrate (contrib)", tf.contrib.integrate), - library("contrib.layers", "Layers (contrib)", tf.contrib.layers), - library("contrib.learn", "Learn (contrib)", tf.contrib.learn), - library("contrib.learn.monitors", "Monitors (contrib)", - tf.contrib.learn.monitors), - library("contrib.legacy_seq2seq", "Sequence to Sequence (contrib)", - tf.contrib.legacy_seq2seq), - library("contrib.linalg", "Linear Algebra (contrib)", - tf.contrib.linalg), - library("contrib.losses", "Losses (contrib)", tf.contrib.losses), - library("contrib.opt", "Optimization (contrib)", tf.contrib.opt), - library("contrib.rnn", "RNN and Cells (contrib)", tf.contrib.rnn), - library("contrib.metrics", "Metrics (contrib)", tf.contrib.metrics), - library("contrib.training", "Training (contrib)", tf.contrib.training), - library("contrib.util", "Utilities (contrib)", tf.contrib.util), - library("contrib.copy_graph", "Copying Graph Elements (contrib)", - tf.contrib.copy_graph), - library("tf_debug", "TensorFlow Debugger", tf_debug), - ]) - -_hidden_symbols = ["Event", "LogMessage", "Summary", "SessionLog", "xrange", - "HistogramProto", "ConfigProto", "NodeDef", "GraphDef", - "GPUOptions", "GraphOptions", "RunOptions", "RunMetadata", - "SessionInterface", "BaseSession", "NameAttrList", - "AttrValue", "OptimizerOptions", - "CollectionDef", "MetaGraphDef", "QueueRunnerDef", - "SaverDef", "VariableDef", "TestCase", "GrpcServer", - "ClusterDef", "JobDef", "ServerDef", "TensorInfo"] - -# TODO(skleinfeld, deannarubin) Address shortname -# conflict between tf.contrib.learn.NanLossDuringTrainingError and -# tf.contrib.learn.monitors.NanLossDuringTrainingError, arising due -# to imports in learn/python/learn/__init__.py -# TODO(wicke): Remove contrib.layers.relu* after shortnames are -# disabled. These conflict with tf.nn.relu* -EXCLUDE = frozenset(["tf.contrib.learn.monitors.NanLossDuringTrainingError", - "tf.contrib.layers.dropout", - "tf.contrib.layers.bias_add", - "tf.contrib.layers.conv2d", - "tf.contrib.layers.conv2d_transpose", - "tf.contrib.layers.separable_conv2d", - "tf.contrib.layers.softmax", - "tf.contrib.layers.relu", "tf.contrib.layers.relu6", - "tf.contrib.framework.assert_global_step", - "tf.contrib.framework.get_global_step", - "tf.contrib.learn.NanLossDuringTrainingError", - "tf.contrib.layers.stack", - "tf.contrib.layers.ProblemType", - "tf.confusion_matrix"]) - - -def main(unused_argv): - if not FLAGS.out_dir: - tf.logging.error("out_dir not specified") - return -1 - - # Document libraries - documented = set() - module_to_name = get_module_to_name(module_names()) - members = docs.collect_members(module_to_name, exclude=EXCLUDE) - libraries = all_libraries(module_to_name, members, documented).items() - - # Define catch_all library before calling write_libraries to avoid complaining - # about generically hidden symbols. - catch_all = docs.Library(title="Catch All", module=None, - exclude_symbols=_hidden_symbols, - module_to_name=module_to_name, members=members, - documented=documented) - - # Write docs to files - docs.write_libraries(FLAGS.out_dir, libraries) - - # Make it easy to search for hidden symbols - if FLAGS.print_hidden_regex: - hidden = set(_hidden_symbols) - for _, lib in libraries: - hidden.update(lib.exclude_symbols) - print(r"hidden symbols regex = r'\b(%s)\b'" % "|".join(sorted(hidden))) - - # Verify that all symbols are mentioned in some library doc. - catch_all.assert_no_leftovers() - - # Generate index - with open(os.path.join(FLAGS.out_dir, "index.md"), "w") as f: - docs.Index(module_to_name, members, libraries, - "../../api_docs/python/").write_markdown_to_file(f) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.register("type", "bool", lambda v: v.lower() == "true") - parser.add_argument( - "--out_dir", - type=str, - default=None, - help="Directory to which docs should be written.") - parser.add_argument( - "--print_hidden_regex", - type="bool", - nargs="?", - const=True, - default=False, - help="Dump a regular expression matching any hidden symbol") - FLAGS, unparsed = parser.parse_known_args() - tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD index 333c6b8f2ba..f21402652fa 100644 --- a/tensorflow/python/kernel_tests/BUILD +++ b/tensorflow/python/kernel_tests/BUILD @@ -1578,9 +1578,9 @@ cuda_py_test( ) cuda_py_test( - name = "scalar_strict_test", + name = "scalar_test", size = "small", - srcs = ["scalar_strict_test.py"], + srcs = ["scalar_test.py"], additional_deps = [ "//third_party/py/numpy", "//tensorflow/python:array_ops", diff --git a/tensorflow/python/kernel_tests/betainc_op_test.py b/tensorflow/python/kernel_tests/betainc_op_test.py index afdb436dc68..08b03f85180 100644 --- a/tensorflow/python/kernel_tests/betainc_op_test.py +++ b/tensorflow/python/kernel_tests/betainc_op_test.py @@ -25,76 +25,78 @@ import numpy as np from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gradient_checker +from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import math_ops from tensorflow.python.platform import test from tensorflow.python.platform import tf_logging class BetaincTest(test.TestCase): - use_gpu = False - def _testBetaInc(self, dtype): + def _testBetaInc(self, a_s, b_s, x_s, dtype): try: from scipy import special # pylint: disable=g-import-not-at-top np_dt = dtype.as_numpy_dtype # Test random values - a_s = np.abs(np.random.randn(10, 10) * 30).astype(np_dt) # in (0, infty) - b_s = np.abs(np.random.randn(10, 10) * 30).astype(np_dt) # in (0, infty) - x_s = np.random.rand(10, 10).astype(np_dt) # in (0, 1) - with self.test_session(use_gpu=self.use_gpu): - tf_a_s = constant_op.constant(a_s, dtype=dtype) - tf_b_s = constant_op.constant(b_s, dtype=dtype) - tf_x_s = constant_op.constant(x_s, dtype=dtype) - tf_out = math_ops.betainc(tf_a_s, tf_b_s, tf_x_s).eval() + a_s = a_s.astype(np_dt) # in (0, infty) + b_s = b_s.astype(np_dt) # in (0, infty) + x_s = x_s.astype(np_dt) # in (0, 1) + tf_a_s = constant_op.constant(a_s, dtype=dtype) + tf_b_s = constant_op.constant(b_s, dtype=dtype) + tf_x_s = constant_op.constant(x_s, dtype=dtype) + tf_out_t = math_ops.betainc(tf_a_s, tf_b_s, tf_x_s) + with self.test_session(): + tf_out = tf_out_t.eval() scipy_out = special.betainc(a_s, b_s, x_s).astype(np_dt) # the scipy version of betainc uses a double-only implementation. # TODO(ebrevdo): identify reasons for (sometime) precision loss # with doubles tol = 1e-4 if dtype == dtypes.float32 else 5e-5 - self.assertAllCloseAccordingToType(scipy_out, tf_out, rtol=tol, atol=tol) + self.assertAllCloseAccordingToType(scipy_out, tf_out, rtol=tol, atol=0) # Test out-of-range values (most should return nan output) combinations = list(itertools.product([-1, 0, 0.5, 1.0, 1.5], repeat=3)) a_comb, b_comb, x_comb = np.asarray(list(zip(*combinations)), dtype=np_dt) - with self.test_session(use_gpu=self.use_gpu): + with self.test_session(): tf_comb = math_ops.betainc(a_comb, b_comb, x_comb).eval() scipy_comb = special.betainc(a_comb, b_comb, x_comb).astype(np_dt) self.assertAllCloseAccordingToType(scipy_comb, tf_comb) # Test broadcasting between scalars and other shapes - with self.test_session(use_gpu=self.use_gpu): + with self.test_session(): self.assertAllCloseAccordingToType( special.betainc(0.1, b_s, x_s).astype(np_dt), math_ops.betainc(0.1, b_s, x_s).eval(), rtol=tol, - atol=tol) + atol=0) self.assertAllCloseAccordingToType( special.betainc(a_s, 0.1, x_s).astype(np_dt), math_ops.betainc(a_s, 0.1, x_s).eval(), rtol=tol, - atol=tol) + atol=0) self.assertAllCloseAccordingToType( special.betainc(a_s, b_s, 0.1).astype(np_dt), math_ops.betainc(a_s, b_s, 0.1).eval(), rtol=tol, - atol=tol) + atol=0) self.assertAllCloseAccordingToType( special.betainc(0.1, b_s, 0.1).astype(np_dt), math_ops.betainc(0.1, b_s, 0.1).eval(), rtol=tol, - atol=tol) + atol=0) self.assertAllCloseAccordingToType( special.betainc(0.1, 0.1, 0.1).astype(np_dt), math_ops.betainc(0.1, 0.1, 0.1).eval(), rtol=tol, - atol=tol) + atol=0) with self.assertRaisesRegexp(ValueError, "must be equal"): math_ops.betainc(0.5, [0.5], [[0.5]]) - with self.test_session(use_gpu=self.use_gpu): + with self.test_session(): with self.assertRaisesOpError("Shapes of .* are inconsistent"): a_p = array_ops.placeholder(dtype) b_p = array_ops.placeholder(dtype) @@ -108,14 +110,79 @@ class BetaincTest(test.TestCase): tf_logging.warn("Cannot test special functions: %s" % str(e)) def testBetaIncFloat(self): - self._testBetaInc(dtypes.float32) + a_s = np.abs(np.random.randn(10, 10) * 30) # in (0, infty) + b_s = np.abs(np.random.randn(10, 10) * 30) # in (0, infty) + x_s = np.random.rand(10, 10) # in (0, 1) + self._testBetaInc(a_s, b_s, x_s, dtypes.float32) def testBetaIncDouble(self): - self._testBetaInc(dtypes.float64) + a_s = np.abs(np.random.randn(10, 10) * 30) # in (0, infty) + b_s = np.abs(np.random.randn(10, 10) * 30) # in (0, infty) + x_s = np.random.rand(10, 10) # in (0, 1) + self._testBetaInc(a_s, b_s, x_s, dtypes.float64) + def testBetaIncDoubleVeryLargeValues(self): + a_s = np.abs(np.random.randn(10, 10) * 1e15) # in (0, infty) + b_s = np.abs(np.random.randn(10, 10) * 1e15) # in (0, infty) + x_s = np.random.rand(10, 10) # in (0, 1) + self._testBetaInc(a_s, b_s, x_s, dtypes.float64) -class BetaincTestGPU(BetaincTest): - use_gpu = True + def testBetaIncDoubleVerySmallValues(self): + a_s = np.abs(np.random.randn(10, 10) * 1e-16) # in (0, infty) + b_s = np.abs(np.random.randn(10, 10) * 1e-16) # in (0, infty) + x_s = np.random.rand(10, 10) # in (0, 1) + self._testBetaInc(a_s, b_s, x_s, dtypes.float64) + + def testBetaIncFloatVerySmallValues(self): + a_s = np.abs(np.random.randn(10, 10) * 1e-8) # in (0, infty) + b_s = np.abs(np.random.randn(10, 10) * 1e-8) # in (0, infty) + x_s = np.random.rand(10, 10) # in (0, 1) + self._testBetaInc(a_s, b_s, x_s, dtypes.float32) + + def testBetaIncFpropAndBpropAreNeverNAN(self): + with self.test_session() as sess: + space = np.logspace(-8, 5).tolist() + space_x = np.linspace(1e-16, 1 - 1e-16).tolist() + ga_s, gb_s, gx_s = zip(*list(itertools.product(space, space, space_x))) + # Test grads are never nan + ga_s_t = constant_op.constant(ga_s, dtype=dtypes.float32) + gb_s_t = constant_op.constant(gb_s, dtype=dtypes.float32) + gx_s_t = constant_op.constant(gx_s, dtype=dtypes.float32) + tf_gout_t = math_ops.betainc(ga_s_t, gb_s_t, gx_s_t) + tf_gout, grads_x = sess.run( + [tf_gout_t, + gradients_impl.gradients(tf_gout_t, [ga_s_t, gb_s_t, gx_s_t])[2]]) + + # Equivalent to `assertAllFalse` (if it existed). + self.assertAllEqual(np.zeros_like(grads_x).astype(np.bool), + np.isnan(tf_gout)) + self.assertAllEqual(np.zeros_like(grads_x).astype(np.bool), + np.isnan(grads_x)) + + def testBetaIncGrads(self): + err_tolerance = 1e-3 + with self.test_session(): + # Test gradient + ga_s = np.abs(np.random.randn(2, 2) * 30) # in (0, infty) + gb_s = np.abs(np.random.randn(2, 2) * 30) # in (0, infty) + gx_s = np.random.rand(2, 2) # in (0, 1) + tf_ga_s = constant_op.constant(ga_s, dtype=dtypes.float64) + tf_gb_s = constant_op.constant(gb_s, dtype=dtypes.float64) + tf_gx_s = constant_op.constant(gx_s, dtype=dtypes.float64) + tf_gout_t = math_ops.betainc(tf_ga_s, tf_gb_s, tf_gx_s) + err = gradient_checker.compute_gradient_error( + [tf_gx_s], [gx_s.shape], tf_gout_t, gx_s.shape) + print("betainc gradient err = %g " % err) + self.assertLess(err, err_tolerance) + + # Test broadcast gradient + gx_s = np.random.rand() # in (0, 1) + tf_gx_s = constant_op.constant(gx_s, dtype=dtypes.float64) + tf_gout_t = math_ops.betainc(tf_ga_s, tf_gb_s, tf_gx_s) + err = gradient_checker.compute_gradient_error( + [tf_gx_s], [()], tf_gout_t, ga_s.shape) + print("betainc gradient err = %g " % err) + self.assertLess(err, err_tolerance) if __name__ == "__main__": diff --git a/tensorflow/python/kernel_tests/scalar_strict_test.py b/tensorflow/python/kernel_tests/scalar_test.py similarity index 95% rename from tensorflow/python/kernel_tests/scalar_strict_test.py rename to tensorflow/python/kernel_tests/scalar_test.py index e208217637c..b34426cc215 100644 --- a/tensorflow/python/kernel_tests/scalar_strict_test.py +++ b/tensorflow/python/kernel_tests/scalar_test.py @@ -27,20 +27,15 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import sparse_ops import tensorflow.python.ops.nn_grad # pylint: disable=unused-import -from tensorflow.python.platform import control_imports from tensorflow.python.platform import test -class ScalarStrictTest(test.TestCase): +class ScalarTest(test.TestCase): def check(self, op, args, error, correct=None): # Within Google, the switch to scalar strict occurred at version 6. - if control_imports.USE_OSS: - lenient = [] - strict = [5, 6] - else: - lenient = [5] - strict = [6] + lenient = [] + strict = [5, 6] # Use placeholders to bypass shape inference, since only the C++ # GraphDef level is ever scalar lenient. diff --git a/tensorflow/python/kernel_tests/template_test.py b/tensorflow/python/kernel_tests/template_test.py index be2d6a566ab..54e8098e4e6 100644 --- a/tensorflow/python/kernel_tests/template_test.py +++ b/tensorflow/python/kernel_tests/template_test.py @@ -306,7 +306,7 @@ class TemplateTest(test.TestCase): self.assertEqual(custom_getter_count[0], 2) # Test that custom getter is called when the variable scope is created - # during construction + # during construction custom_getter_count[0] = 0 tmpl2 = template.make_template( "s2", @@ -319,6 +319,28 @@ class TemplateTest(test.TestCase): tmpl2() self.assertEqual(custom_getter_count[0], 2) + def test_fails_gracefully(self): + for create_scope_now in [True, False]: + def module_function_with_one_arg(inputs): + w = variable_scope.get_variable( + "w", shape=[1], initializer=init_ops.zeros_initializer()) + return inputs * w + + templatized_function = template.make_template( + "f1", module_function_with_one_arg, + create_scope_now_=create_scope_now) + data = array_ops.zeros(1) + try: + # Try to connect with a kwarg which is unsupported. + templatized_function(data, is_training=True) + except TypeError: + pass + + # The failed __call__ hasn't modified the inner state. + self.assertFalse(templatized_function._variables_created) + templatized_function(data) + self.assertTrue(templatized_function._variables_created) + def test_name_scopes_for_variable_scopes(self): # Test that name scopes are not unnecessarily uniquified (but are # still uniquified when necessary). diff --git a/tensorflow/python/kernel_tests/tensor_array_ops_test.py b/tensorflow/python/kernel_tests/tensor_array_ops_test.py index e52ae95281a..9aadf638f2d 100644 --- a/tensorflow/python/kernel_tests/tensor_array_ops_test.py +++ b/tensorflow/python/kernel_tests/tensor_array_ops_test.py @@ -39,7 +39,7 @@ from tensorflow.python.platform import test class TensorArrayTest(test.TestCase): def testTensorArrayWriteRead(self): - with self.test_session(use_gpu=True) as session: + with self.test_session() as session: ta = tensor_array_ops.TensorArray( dtype=dtypes.float32, tensor_array_name="foo", @@ -61,7 +61,7 @@ class TensorArrayTest(test.TestCase): def _testTensorArrayWritePack(self, tf_dtype): dtype = tf_dtype.as_numpy_dtype() - with self.test_session(use_gpu=True): + with self.test_session(): ta = tensor_array_ops.TensorArray( dtype=tf_dtype, tensor_array_name="foo", size=3) @@ -92,9 +92,23 @@ class TensorArrayTest(test.TestCase): def testTensorArrayWritePack(self): self._testTensorArrayWritePackMaybeLegacy() + def testEmptyTensorArrayPack(self): + with self.test_session(): + ta = tensor_array_ops.TensorArray( + dtype=dtypes.float32, tensor_array_name="foo", size=3) + + empty_element = np.zeros((0, 1), dtype=np.float32) + w0 = ta.write(0, empty_element) + w1 = w0.write(1, empty_element) + w2 = w1.write(2, empty_element) + + c0 = w2.stack() + + self.assertAllEqual([3, 0, 1], c0.eval().shape) + def _testTensorArrayWriteConcat(self, tf_dtype): dtype = tf_dtype.as_numpy_dtype() - with self.test_session(use_gpu=True): + with self.test_session(): ta = tensor_array_ops.TensorArray( dtype=tf_dtype, tensor_array_name="foo", size=3, infer_shape=False) @@ -137,7 +151,7 @@ class TensorArrayTest(test.TestCase): def _testTensorArrayUnpackRead(self, tf_dtype): dtype = tf_dtype.as_numpy_dtype() - with self.test_session(use_gpu=True) as session: + with self.test_session() as session: ta = tensor_array_ops.TensorArray( dtype=tf_dtype, tensor_array_name="foo", size=3) @@ -202,7 +216,7 @@ class TensorArrayTest(test.TestCase): def _testTensorArraySplitRead(self, tf_dtype): dtype = tf_dtype.as_numpy_dtype() - with self.test_session(use_gpu=True) as session: + with self.test_session() as session: ta = tensor_array_ops.TensorArray( dtype=tf_dtype, tensor_array_name="foo", size=3, infer_shape=False) @@ -259,7 +273,7 @@ class TensorArrayTest(test.TestCase): self._testTensorArraySplitRead(dtypes.string) def testTensorGradArrayWriteRead(self): - with self.test_session(use_gpu=True) as session: + with self.test_session() as session: ta = tensor_array_ops.TensorArray( dtype=dtypes.float32, tensor_array_name="foo", @@ -292,7 +306,7 @@ class TensorArrayTest(test.TestCase): self.assertAllEqual(-2.0, g_d2) def testTensorGradArrayDynamicWriteRead(self): - with self.test_session(use_gpu=True) as session: + with self.test_session() as session: ta = tensor_array_ops.TensorArray( dtype=dtypes.float32, tensor_array_name="foo", @@ -333,7 +347,7 @@ class TensorArrayTest(test.TestCase): self.assertAllEqual(3, g_vs) def testTensorGradAccessTwiceReceiveSameObject(self): - with self.test_session(use_gpu=True) as session: + with self.test_session() as session: ta = tensor_array_ops.TensorArray( dtype=dtypes.float32, tensor_array_name="foo", size=3) g_ta_0 = ta.grad("grad") @@ -349,7 +363,7 @@ class TensorArrayTest(test.TestCase): self.assertAllEqual([[4.0, 5.0]], d_r1_0) def testTensorArrayWriteWrongIndexOrDataTypeFails(self): - with self.test_session(use_gpu=True): + with self.test_session(): ta = tensor_array_ops.TensorArray( dtype=dtypes.float32, tensor_array_name="foo", size=3) @@ -371,7 +385,7 @@ class TensorArrayTest(test.TestCase): ta.write(3, 3.0).flow.eval() def testTensorArrayReadWrongIndexOrDataTypeFails(self): - with self.test_session(use_gpu=True): + with self.test_session(): ta = tensor_array_ops.TensorArray( dtype=dtypes.float32, tensor_array_name="foo", size=3) @@ -402,7 +416,7 @@ class TensorArrayTest(test.TestCase): ta.read(3).eval() def testTensorArrayWriteMultipleFails(self): - with self.test_session(use_gpu=True): + with self.test_session(): ta = tensor_array_ops.TensorArray( dtype=dtypes.float32, tensor_array_name="foo", size=3) @@ -412,7 +426,7 @@ class TensorArrayTest(test.TestCase): ta.write(2, 3.0).write(2, 3.0).flow.eval() def testTensorArrayConcatIncompatibleShapesFails(self): - with self.test_session(use_gpu=True): + with self.test_session(): ta = tensor_array_ops.TensorArray( dtype=dtypes.float32, tensor_array_name="foo", @@ -444,7 +458,7 @@ class TensorArrayTest(test.TestCase): w3.concat().eval() def testTensorArraySplitIncompatibleShapesFails(self): - with self.test_session(use_gpu=True): + with self.test_session(): ta = tensor_array_ops.TensorArray( dtype=dtypes.float32, tensor_array_name="foo", @@ -478,7 +492,7 @@ class TensorArrayTest(test.TestCase): ta.split([1.0], [1]).flow.eval() def _testTensorArrayWriteGradientAddMultipleAdds(self, dtype): - with self.test_session(use_gpu=True): + with self.test_session(): ta = tensor_array_ops.TensorArray( dtype=dtype, tensor_array_name="foo", size=3, infer_shape=False) ta_grad = ta.grad("grad") @@ -517,7 +531,7 @@ class TensorArrayTest(test.TestCase): self._testTensorArrayWriteGradientAddMultipleAdds(dtype) def testMultiTensorArray(self): - with self.test_session(use_gpu=True): + with self.test_session(): h1 = tensor_array_ops.TensorArray( size=1, dtype=dtypes.float32, tensor_array_name="foo") w1 = h1.write(0, 4.0) @@ -532,7 +546,7 @@ class TensorArrayTest(test.TestCase): self.assertAllClose(9.0, r.eval()) def _testTensorArrayGradientWriteReadType(self, dtype): - with self.test_session(use_gpu=True) as session: + with self.test_session() as session: ta = tensor_array_ops.TensorArray( dtype=dtypes.as_dtype(dtype), tensor_array_name="foo", @@ -584,7 +598,7 @@ class TensorArrayTest(test.TestCase): self._testTensorArrayGradientWriteReadType(dtype) def _testTensorArrayGradientWritePackConcatAndRead(self): - with self.test_session(use_gpu=True) as sess: + with self.test_session() as sess: ta = tensor_array_ops.TensorArray( dtype=dtypes.float32, tensor_array_name="foo", @@ -619,7 +633,7 @@ class TensorArrayTest(test.TestCase): self._testTensorArrayGradientWritePackConcatAndRead() def testTensorArrayReadTwice(self): - with self.test_session(use_gpu=True): + with self.test_session(): value = constant_op.constant([[1.0, -1.0], [10.0, -10.0]]) ta_readonce = tensor_array_ops.TensorArray( @@ -648,7 +662,7 @@ class TensorArrayTest(test.TestCase): self.assertAllEqual([1.0, -1.0], r1_readtwice.eval()) def _testTensorArrayGradientUnpackRead(self): - with self.test_session(use_gpu=True) as session: + with self.test_session() as session: ta = tensor_array_ops.TensorArray( dtype=dtypes.float32, tensor_array_name="foo", @@ -676,7 +690,7 @@ class TensorArrayTest(test.TestCase): self._testTensorArrayGradientUnpackRead() def testTensorArrayGradientSplitConcat(self): - with self.test_session(use_gpu=True) as session: + with self.test_session() as session: ta = tensor_array_ops.TensorArray( dtype=dtypes.float32, tensor_array_name="foo", size=2) @@ -698,7 +712,7 @@ class TensorArrayTest(test.TestCase): grad_vals[0]) def _testTensorArrayGradientDynamicUnpackRead(self): - with self.test_session(use_gpu=True) as session: + with self.test_session() as session: ta = tensor_array_ops.TensorArray( dtype=dtypes.float32, tensor_array_name="foo", @@ -723,21 +737,21 @@ class TensorArrayTest(test.TestCase): self._testTensorArrayGradientDynamicUnpackRead() def testCloseTensorArray(self): - with self.test_session(use_gpu=True) as session: + with self.test_session() as session: ta = tensor_array_ops.TensorArray( dtype=dtypes.float32, tensor_array_name="foo", size=3) c1 = ta.close() session.run(c1) def testSizeTensorArray(self): - with self.test_session(use_gpu=True): + with self.test_session(): ta = tensor_array_ops.TensorArray( dtype=dtypes.float32, tensor_array_name="foo", size=3) s = ta.size() self.assertAllEqual(3, s.eval()) def testWriteCloseTensorArray(self): - with self.test_session(use_gpu=True): + with self.test_session(): ta = tensor_array_ops.TensorArray( dtype=dtypes.float32, tensor_array_name="foo", @@ -749,7 +763,7 @@ class TensorArrayTest(test.TestCase): def _testWhileLoopWritePackGradients(self, dynamic_size, dtype): np_dtype = dtype.as_numpy_dtype - with self.test_session(use_gpu=True) as session: + with self.test_session() as session: v0 = array_ops.identity(np.arange(3 * 5, dtype=np_dtype).reshape(3, 5)) var = variables.Variable(np.arange(100, 105, dtype=np_dtype)) state0 = array_ops.identity(np.array([1] * 5, dtype=np_dtype)) @@ -864,7 +878,7 @@ class TensorArrayTest(test.TestCase): self.assertAllClose(31.0, grad.eval()) def testSumOfTwoReadVariablesWithoutRepeatGrad(self): - with self.test_session(use_gpu=True) as session: + with self.test_session() as session: a = array_ops.identity( np.arange( 3 * 5, dtype=np.float32).reshape(3, 5) + 1) @@ -1062,7 +1076,7 @@ class TensorArrayTest(test.TestCase): self.assertAllEqual(r0.get_shape(), tensor_shape.unknown_shape()) def _testGradientWhenNotAllComponentsRead(self): - with self.test_session(use_gpu=True) as session: + with self.test_session() as session: ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=2) x = constant_op.constant([2.0, 3.0]) w = ta.unstack(x) @@ -1076,7 +1090,7 @@ class TensorArrayTest(test.TestCase): self._testGradientWhenNotAllComponentsRead() def _testTensorArrayUnpackDynamic(self): - with self.test_session(use_gpu=True) as sess: + with self.test_session() as sess: ta = tensor_array_ops.TensorArray( dtype=dtypes.float32, size=3, dynamic_size=True) x = constant_op.constant([1.0, 2.0, 3.0]) @@ -1091,7 +1105,7 @@ class TensorArrayTest(test.TestCase): self._testTensorArrayUnpackDynamic() def testTensorArraySplitDynamic(self): - with self.test_session(use_gpu=True) as sess: + with self.test_session() as sess: ta = tensor_array_ops.TensorArray( dtype=dtypes.float32, size=3, dynamic_size=True) x = constant_op.constant([1.0, 2.0, 3.0]) @@ -1103,7 +1117,7 @@ class TensorArrayTest(test.TestCase): self.assertAllEqual(np.array([1.0, 1.0, 1.0]), sess.run(grad)[0]) def _testTensorArrayEvalEmpty(self): - with self.test_session(use_gpu=True): + with self.test_session(): ta = tensor_array_ops.TensorArray( dtype=dtypes.float32, size=0, dynamic_size=False, infer_shape=False) with self.assertRaisesOpError( @@ -1116,7 +1130,7 @@ class TensorArrayTest(test.TestCase): self._testTensorArrayEvalEmpty() def _testTensorArrayEvalEmptyWithDefault(self): - with self.test_session(use_gpu=True): + with self.test_session(): ta = tensor_array_ops.TensorArray( dtype=dtypes.float32, size=0, dynamic_size=False, infer_shape=True) self.assertEqual(0, ta.size().eval()) @@ -1132,7 +1146,7 @@ class TensorArrayTest(test.TestCase): self._testTensorArrayEvalEmptyWithDefault() def testTensorArrayScatterReadAndGradients(self): - with self.test_session(use_gpu=True) as session: + with self.test_session() as session: ta = tensor_array_ops.TensorArray( dtype=dtypes.float32, tensor_array_name="foo", @@ -1158,7 +1172,7 @@ class TensorArrayTest(test.TestCase): self.assertAllEqual([[2.0, 3.0], [4.0, 5.0]], grad_vals[0]) def testTensorArrayWriteGatherAndGradients(self): - with self.test_session(use_gpu=True) as session: + with self.test_session() as session: ta = tensor_array_ops.TensorArray( dtype=dtypes.float32, tensor_array_name="foo", @@ -1249,7 +1263,7 @@ class TensorArrayTest(test.TestCase): self.assertTrue("gpu:0" in ta.handle.device.lower()) def testTensorArrayLazyDeviceSettingDoesNotConfuseInitialAccess(self): - with self.test_session(use_gpu=True) as session: + with self.test_session() as session: ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=2) self.assertEqual(ta.handle.device, "") diff --git a/tensorflow/python/kernel_tests/transpose_op_test.py b/tensorflow/python/kernel_tests/transpose_op_test.py index 939a86c0b50..7b112a6a17b 100644 --- a/tensorflow/python/kernel_tests/transpose_op_test.py +++ b/tensorflow/python/kernel_tests/transpose_op_test.py @@ -358,9 +358,6 @@ class TransposeTest(test.TestCase): with self.assertRaises(ValueError): array_ops.transpose( np.arange(0., 30).reshape([2, 3, 5]), [[0, 1], [2, 3]]) - self._testError( - np.arange(0., 2**11).reshape([2] * 11), np.arange(11), - "not implemented") with self.assertRaises(ValueError): array_ops.transpose(np.arange(0., 30).reshape([2, 3, 5]), [0, 1, 3]) self._testError( diff --git a/tensorflow/python/lib/core/py_func.cc b/tensorflow/python/lib/core/py_func.cc index 040a4513caa..f6062aa03d9 100644 --- a/tensorflow/python/lib/core/py_func.cc +++ b/tensorflow/python/lib/core/py_func.cc @@ -265,7 +265,7 @@ class NumpyTensorBuffer : public TensorBuffer { Status ConvertNdarrayToTensor(PyObject* obj, Tensor* ret) { PyArrayObject* input = reinterpret_cast(obj); - DataType dtype; + DataType dtype = DT_INVALID; TensorShape shape; for (int i = 0; i < PyArray_NDIM(input); ++i) { shape.AddDim(PyArray_SHAPE(input)[i]); diff --git a/tensorflow/python/ops/math_grad.py b/tensorflow/python/ops/math_grad.py index 8aa8de530c2..409e3c5111e 100644 --- a/tensorflow/python/ops/math_grad.py +++ b/tensorflow/python/ops/math_grad.py @@ -429,7 +429,7 @@ def _DigammaGrad(op, grad): @ops.RegisterGradient("Igamma") def _IgammaGrad(op, grad): - """Returns gradient of igamma(a, x) with respect to a and x.""" + """Returns gradient of igamma(a, x) with respect to x.""" # TODO(ebrevdo): Perhaps add the derivative w.r.t. a a = op.inputs[0] x = op.inputs[1] @@ -440,14 +440,43 @@ def _IgammaGrad(op, grad): # Perform operations in log space before summing, because Gamma(a) # and Gamma'(a) can grow large. partial_x = math_ops.exp(-x + (a - 1) * math_ops.log(x) - math_ops.lgamma(a)) + # TODO(b/36815900): Mark None return values as NotImplemented return (None, array_ops.reshape(math_ops.reduce_sum(partial_x * grad, rx), sx)) @ops.RegisterGradient("Igammac") def _IgammacGrad(op, grad): - """Returns gradient of igammac(a, x) = 1 - igamma(a, x) w.r.t. a and x.""" - return [-1 * g if g is not None else None for g in _IgammaGrad(op, grad)] + """Returns gradient of igammac(a, x) = 1 - igamma(a, x) w.r.t. x.""" + _, igamma_grad_x = _IgammaGrad(op, grad) + return None, -igamma_grad_x + + +@ops.RegisterGradient("Betainc") +def _BetaincGrad(op, grad): + """Returns gradient of betainc(a, b, x) with respect to x.""" + # TODO(ebrevdo): Perhaps add the derivative w.r.t. a, b + a, b, x = op.inputs + + # two cases: x is a scalar and a/b are same-shaped tensors, or vice + # versa; so its sufficient to check against shape(a). + sa = array_ops.shape(a) + sx = array_ops.shape(x) + # pylint: disable=protected-access + _, rx = gen_array_ops._broadcast_gradient_args(sa, sx) + # pylint: enable=protected-access + + # Perform operations in log space before summing, because terms + # can grow large. + log_beta = (gen_math_ops.lgamma(a) + gen_math_ops.lgamma(b) + - gen_math_ops.lgamma(a + b)) + partial_x = math_ops.exp( + (b - 1) * math_ops.log(1 - x) + (a - 1) * math_ops.log(x) - log_beta) + + # TODO(b/36815900): Mark None return values as NotImplemented + return (None, # da + None, # db + array_ops.reshape(math_ops.reduce_sum(partial_x * grad, rx), sx)) @ops.RegisterGradient("Zeta") @@ -465,6 +494,7 @@ def _ZetaGrad(op, grad): x = math_ops.conj(x) q = math_ops.conj(q) partial_q = -x * math_ops.zeta(x + 1, q) + # TODO(b/36815900): Mark None return values as NotImplemented return (None, array_ops.reshape(math_ops.reduce_sum(partial_q * grad, rq), sq)) @@ -484,6 +514,7 @@ def _PolygammaGrad(op, grad): n = math_ops.conj(n) x = math_ops.conj(x) partial_x = math_ops.polygamma(n + 1, x) + # TODO(b/36815900): Mark None return values as NotImplemented return (None, array_ops.reshape(math_ops.reduce_sum(partial_x * grad, rx), sx)) diff --git a/tensorflow/python/ops/template.py b/tensorflow/python/ops/template.py index 80dd74521be..48be9e2cdae 100644 --- a/tensorflow/python/ops/template.py +++ b/tensorflow/python/ops/template.py @@ -261,20 +261,23 @@ class Template(object): return self._call_func(args, kwargs, check_for_new_variables=True) else: # This is the first visit to __call__, but the scope has already been - # created in the constructor. Set _variables_created so that subsequent - # calls take the if branch above. - self._variables_created = True + # created in the constructor. Set _variables_created after the inner + # function is successfully called so that subsequent calls take the if + # branch above. with variable_scope.variable_scope(self._variable_scope): - return self._call_func(args, kwargs, check_for_new_variables=False) + result = self._call_func(args, kwargs, check_for_new_variables=False) + self._variables_created = True + return result else: # The scope was not created at construction time, so create it here. # Subsequent calls should reuse variables. - self._variables_created = True with variable_scope.variable_scope( self._unique_name, self._name, custom_getter=self._custom_getter) as vs: self._variable_scope = vs - return self._call_func(args, kwargs, check_for_new_variables=False) + result = self._call_func(args, kwargs, check_for_new_variables=False) + self._variables_created = True + return result @property def variable_scope(self): diff --git a/tensorflow/python/ops/variable_scope.py b/tensorflow/python/ops/variable_scope.py index 2f97abdc791..19c5d3c3ea0 100644 --- a/tensorflow/python/ops/variable_scope.py +++ b/tensorflow/python/ops/variable_scope.py @@ -904,6 +904,7 @@ class VariableScope(object): dtype=None, initializer=None, regularizer=None, + reuse=None, trainable=True, collections=None, caching_device=None, @@ -920,6 +921,8 @@ class VariableScope(object): partitioner = self._partitioner if custom_getter is None: custom_getter = self._custom_getter + if reuse is None: + reuse = self._reuse full_name = self.name + "/" + name if self.name else name # Variable names only depend on variable_scope (full_name here), @@ -942,7 +945,7 @@ class VariableScope(object): return var_store.get_variable( full_name, shape=shape, dtype=dtype, initializer=initializer, - regularizer=regularizer, reuse=self.reuse, trainable=trainable, + regularizer=regularizer, reuse=reuse, trainable=trainable, collections=collections, caching_device=caching_device, partitioner=partitioner, validate_shape=validate_shape, use_resource=use_resource, custom_getter=custom_getter) diff --git a/tensorflow/python/platform/control_imports.py b/tensorflow/python/platform/control_imports.py deleted file mode 100644 index 61b29ca4e57..00000000000 --- a/tensorflow/python/platform/control_imports.py +++ /dev/null @@ -1,28 +0,0 @@ -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -"""Switch between Google or open source dependencies.""" -# Switch between Google and OSS dependencies -USE_OSS = True - -# Per-dependency switches determining whether each dependency is ready -# to be replaced by its OSS equivalence. -# TODO(danmane,mrry,opensource): Flip these switches, then remove them -OSS_APP = True -OSS_FLAGS = True -OSS_GFILE = True -OSS_GOOGLETEST = True -OSS_LOGGING = True -OSS_PARAMETERIZED = True diff --git a/tensorflow/python/saved_model/README.md b/tensorflow/python/saved_model/README.md index 9fd795adb51..8422fb6404b 100644 --- a/tensorflow/python/saved_model/README.md +++ b/tensorflow/python/saved_model/README.md @@ -22,6 +22,8 @@ The following is a summary of the features in SavedModel: and outputs. This is called a `Signature`. * SavedModel uses [SignatureDefs](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/protobuf/meta_graph.proto) to allow generic support for signatures that may need to be saved with the graphs. + * For commonly used SignatureDefs in the context of TensorFlow Serving, + please see documentation [here](https://github.com/tensorflow/serving/blob/master/tensorflow_serving/g3doc/signature_defs.md). * Support for `Assets`. * For cases where ops depend on external files for initialization, such as vocabularies, SavedModel supports this via `assets`. diff --git a/tensorflow/python/summary/text_summary.py b/tensorflow/python/summary/text_summary.py index 82dee45d267..52bc913b2ad 100644 --- a/tensorflow/python/summary/text_summary.py +++ b/tensorflow/python/summary/text_summary.py @@ -34,12 +34,17 @@ def text_summary(name, tensor, collections=None): """Summarizes textual data. Text data summarized via this plugin will be visible in the Text Dashboard - in TensorBoard. + in TensorBoard. The standard TensorBoard Text Dashboard will render markdown + in the strings, and will automatically organize 1d and 2d tensors into tables. + If a tensor with more than 2 dimensions is provided, a 2d subarray will be + displayed along with a warning message. (Note that this behavior is not + intrinsic to the text summary api, but rather to the default TensorBoard text + plugin.) Args: name: A name for the generated node. Will also serve as a series name in TensorBoard. - tensor: a scalar string-type Tensor to summarize. + tensor: a string-type Tensor to summarize. collections: Optional list of ops.GraphKeys. The collections to add the summary to. Defaults to [_ops.GraphKeys.SUMMARIES] @@ -49,16 +54,12 @@ def text_summary(name, tensor, collections=None): type `string` which contains `Summary` protobufs. Raises: - ValueError: If tensor has the wrong shape or type. + ValueError: If tensor has the wrong type. """ if tensor.dtype != dtypes.string: raise ValueError("Expected tensor %s to have dtype string, got %s" % (tensor.name, tensor.dtype)) - if tensor.shape.ndims != 0: - raise ValueError("Expected tensor %s to be scalar, has shape %s" % - (tensor.name, tensor.shape)) - t_summary = tensor_summary(name, tensor, collections=collections) text_assets = plugin_asset.get_plugin_asset(TextSummaryPluginAsset) text_assets.register_tensor(t_summary.op.name) diff --git a/tensorflow/python/summary/text_summary_test.py b/tensorflow/python/summary/text_summary_test.py index 69739573c10..31009702ca4 100644 --- a/tensorflow/python/summary/text_summary_test.py +++ b/tensorflow/python/summary/text_summary_test.py @@ -40,17 +40,19 @@ class TextPluginTest(test_util.TensorFlowTestCase): num = array_ops.constant(1) text_summary.text_summary("foo", num) - with self.assertRaises(ValueError): - arr = array_ops.constant(["one", "two", "three"]) - text_summary.text_summary("foo", arr) + # The API accepts vectors. + arr = array_ops.constant(["one", "two", "three"]) + summ = text_summary.text_summary("foo", arr) + self.assertEqual(summ.op.type, "TensorSummary") + # the API accepts scalars summ = text_summary.text_summary("foo", array_ops.constant("one")) self.assertEqual(summ.op.type, "TensorSummary") - text_summary.text_summary("bar", array_ops.constant("2"), collections=[]) - summaries = framework_ops.get_collection( - framework_ops.GraphKeys.SUMMARIES) - self.assertEqual(len(summaries), 1) + def testTextSummaryCollections(self): + text_summary.text_summary("bar", array_ops.constant("2"), collections=[]) + summaries = framework_ops.get_collection(framework_ops.GraphKeys.SUMMARIES) + self.assertEqual(len(summaries), 0) if __name__ == "__main__": diff --git a/tensorflow/python/training/basic_session_run_hooks.py b/tensorflow/python/training/basic_session_run_hooks.py index f13b87dfed6..6fd20ce8013 100644 --- a/tensorflow/python/training/basic_session_run_hooks.py +++ b/tensorflow/python/training/basic_session_run_hooks.py @@ -40,6 +40,7 @@ from tensorflow.core.util.event_pb2 import SessionLog from tensorflow.python.framework import meta_graph from tensorflow.python.framework import ops from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.training import saver as saver_lib from tensorflow.python.training import session_run_hook from tensorflow.python.training import training_util from tensorflow.python.training.session_run_hook import SessionRunArgs @@ -124,7 +125,7 @@ class LoggingTensorHook(session_run_hook.SessionRunHook): def __init__(self, tensors, every_n_iter=None, every_n_secs=None, formatter=None): - """Initializes a LoggingHook monitor. + """Initializes a `LoggingTensorHook`. Args: tensors: `dict` that maps string-valued tags to tensors/tensor names, @@ -189,10 +190,10 @@ class LoggingTensorHook(session_run_hook.SessionRunHook): class StopAtStepHook(session_run_hook.SessionRunHook): - """Monitor to request stop at a specified step.""" + """Hook that requests stop at a specified step.""" def __init__(self, num_steps=None, last_step=None): - """Create a StopAtStep Hook. + """Initializes a `StopAtStepHook`. This hook requests stop after either a number of steps have been executed or a last step has been reached. Only one of the two options can be @@ -234,51 +235,48 @@ class StopAtStepHook(session_run_hook.SessionRunHook): class CheckpointSaverListener(object): - """An interface for event hooks that depend on a checkpoint. + """Interface for listeners that take action before or after checkpoint save. - CheckpointSaverListeners are similar to SessionRunHooks, and can be useful to - track training, report progress, and more. The distinction is that - CheckpointSaverListeners run only in steps when CheckpointSaverHook is - triggered, and provide callbacks to run before or after the checkpoint is - generated. This is in contrast to SessionRunHooks, which may run in steps - when no checkpoint is written, and which have no guaranteed execution order - in any case. CheckpointSaverListeners use the observer pattern and notify at - the following points: - - when a session starts being used + `CheckpointSaverListener` triggers only in steps when `CheckpointSaverHook` is + triggered, and provides callbacks at the following points: + - before using the session - before each call to `Saver.save()` - after each call to `Saver.save()` - - when the session closed + - at the end of session - Custom CheckpointSaverListeners look like this: - class ExampleCheckpointSaverListerner(CheckpointSaverListener): - def begin(self): - # You can add ops to the graph here. - print('Starting the session.') - self.your_tensor = ... + To use a listener, implement a class and pass the listener to a + `CheckpointSaverHook`, as in this example: - def before_save(self, session, global_step_value): - print('About to write a checkpoint') + ```python + class ExampleCheckpointSaverListerner(CheckpointSaverListener): + def begin(self): + # You can add ops to the graph here. + print('Starting the session.') + self.your_tensor = ... - def after_save(self, session, global_step_value): - print('Done writing checkpoint.') + def before_save(self, session, global_step_value): + print('About to write a checkpoint') - def end(self, session, global_step_value): - print('Done with the session.') + def after_save(self, session, global_step_value): + print('Done writing checkpoint.') - A CheckpointSaverListener may simply take some action after every checkpoint. - It is also possible for the listener to use its own schedule to act less - frequently, based on wall clock time or on global_step_value. In this case, - implementors must be careful about what happens at end(). When end is called, - The CheckpointSaverHook will have already triggered after_save() in the same - global_step, but the listener may or may not have actually acted on it. - The listener may want to be sure to act at end() if there is a fresh - checkpoint available, but should not act twice if after_save() already handled - it. In this case, end() should have logic to detect the situation and do the - right thing, similar to what CheckpointSaverHook.end() does using - self._timer.last_triggered_step(). + def end(self, session, global_step_value): + print('Done with the session.') - To use such listeners, in your `model_fn` return a `CheckpointSaverHook` as - part of `training_chief_hooks`. + ... + listener = ExampleCheckpointSaverListerner() + saver_hook = tf.train.CheckpointSaverHook( + checkpoint_dir, listeners=[listener]) + with tf.train.MonitoredTrainingSession(chief_only_hooks=[saver_hook]): + ... + ``` + + A `CheckpointSaverListener` may simply take some action after every + checkpoint save. It is also possible for the listener to use its own schedule + to act less frequently, e.g. based on global_step_value. In this case, + implementors should implement the `end()` method to handle actions related to + the last checkpoint save. But the listener should not act twice if + `after_save()` already handled this last checkpoint save. """ def begin(self): @@ -305,7 +303,7 @@ class CheckpointSaverHook(session_run_hook.SessionRunHook): checkpoint_basename="model.ckpt", scaffold=None, listeners=None): - """Initialize CheckpointSaverHook monitor. + """Initializes a `CheckpointSaverHook`. Args: checkpoint_dir: `str`, base directory for the checkpoint files. @@ -315,18 +313,18 @@ class CheckpointSaverHook(session_run_hook.SessionRunHook): checkpoint_basename: `str`, base name for the checkpoint files. scaffold: `Scaffold`, use to get saver object. listeners: List of `CheckpointSaverListener` subclass instances. - Used for callbacks that run immediately after the corresponding - CheckpointSaverHook callbacks, only in steps where the - CheckpointSaverHook was triggered. + Used for callbacks that run immediately before or after this hook saves + the checkpoint. Raises: ValueError: One of `save_steps` or `save_secs` should be set. ValueError: Exactly one of saver or scaffold should be set. """ logging.info("Create CheckpointSaverHook.") - if ((saver is None and scaffold is None) or - (saver is not None and scaffold is not None)): - raise ValueError("Exactly one of saver or scaffold must be provided.") + if saver is not None and scaffold is not None: + raise ValueError("You cannot provide both saver and scaffold.") + if saver is None and scaffold is None: + saver = saver_lib._get_saver_or_default() # pylint: disable=protected-access self._saver = saver self._checkpoint_dir = checkpoint_dir self._save_path = os.path.join(checkpoint_dir, checkpoint_basename) @@ -401,7 +399,7 @@ class CheckpointSaverHook(session_run_hook.SessionRunHook): class StepCounterHook(session_run_hook.SessionRunHook): - """Steps per second monitor.""" + """Hook that counts steps per second.""" def __init__(self, every_n_steps=100, @@ -453,14 +451,13 @@ class NanLossDuringTrainingError(RuntimeError): class NanTensorHook(session_run_hook.SessionRunHook): - """NaN Loss monitor. + """Monitors the loss tensor and stops training if loss is NaN. - Monitors loss and stops training if loss is NaN. Can either fail with exception or just stop training. """ def __init__(self, loss_tensor, fail_on_nan_loss=True): - """Initializes NanLoss monitor. + """Initializes a `NanTensorHook`. Args: loss_tensor: `Tensor`, the loss tensor. @@ -494,7 +491,7 @@ class SummarySaverHook(session_run_hook.SessionRunHook): summary_writer=None, scaffold=None, summary_op=None): - """Initializes a `SummarySaver` monitor. + """Initializes a `SummarySaverHook`. Args: save_steps: `int`, save summaries every N steps. Exactly one of @@ -590,7 +587,7 @@ class SummarySaverHook(session_run_hook.SessionRunHook): class GlobalStepWaiterHook(session_run_hook.SessionRunHook): - """Delay execution until global step reaches to wait_until_step. + """Delays execution until global step reaches `wait_until_step`. This hook delays execution until global step reaches to `wait_until_step`. It is used to gradually start workers in distributed settings. One example usage @@ -599,7 +596,7 @@ class GlobalStepWaiterHook(session_run_hook.SessionRunHook): """ def __init__(self, wait_until_step): - """Create a _GlobalStepWaiterHook. + """Initializes a `GlobalStepWaiterHook`. Args: wait_until_step: an `int` shows until which global step should we wait. @@ -637,10 +634,10 @@ class GlobalStepWaiterHook(session_run_hook.SessionRunHook): class FinalOpsHook(session_run_hook.SessionRunHook): - """A run hook which evaluates `Tensors` at the end of a session.""" + """A hook which evaluates `Tensors` at the end of a session.""" def __init__(self, final_ops, final_ops_feed_dict=None): - """Constructs the FinalOpHook with ops to run at the end of the session. + """Initializes `FinalOpHook` with ops to run at the end of the session. Args: final_ops: A single `Tensor`, a list of `Tensors` or a dictionary of @@ -666,10 +663,11 @@ class FeedFnHook(session_run_hook.SessionRunHook): """Runs `feed_fn` and sets the `feed_dict` accordingly.""" def __init__(self, feed_fn): - """Constructs the FeedFnHook with given `feed_fn`. + """Initializes a `FeedFnHook`. Args: - feed_fn: function, no arguments and returns `dict` to feed. + feed_fn: function that takes no arguments and returns `dict` of `Tensor` + to feed. """ self.feed_fn = feed_fn diff --git a/tensorflow/python/training/basic_session_run_hooks_test.py b/tensorflow/python/training/basic_session_run_hooks_test.py index c2636d46f59..ecb61d447bf 100644 --- a/tensorflow/python/training/basic_session_run_hooks_test.py +++ b/tensorflow/python/training/basic_session_run_hooks_test.py @@ -346,6 +346,98 @@ class CheckpointSaverHookTest(test.TestCase): 'end': 1 }, listener.get_counts()) + def test_listener_with_monitored_session(self): + with ops.Graph().as_default(): + scaffold = monitored_session.Scaffold() + global_step = variables.get_or_create_global_step() + train_op = state_ops.assign_add(global_step, 1) + listener = MockCheckpointSaverListener() + hook = basic_session_run_hooks.CheckpointSaverHook( + self.model_dir, + save_steps=1, + scaffold=scaffold, + listeners=[listener]) + with monitored_session.SingularMonitoredSession( + hooks=[hook], + scaffold=scaffold, + checkpoint_dir=self.model_dir) as sess: + sess.run(train_op) + sess.run(train_op) + global_step_val = sess.run(global_step) + listener_counts = listener.get_counts() + self.assertEqual(2, global_step_val) + self.assertEqual({ + 'begin': 1, + 'before_save': 2, + 'after_save': 2, + 'end': 1 + }, listener_counts) + + def test_listener_with_default_saver(self): + with ops.Graph().as_default(): + global_step = variables.get_or_create_global_step() + train_op = state_ops.assign_add(global_step, 1) + listener = MockCheckpointSaverListener() + hook = basic_session_run_hooks.CheckpointSaverHook( + self.model_dir, + save_steps=1, + listeners=[listener]) + with monitored_session.SingularMonitoredSession( + hooks=[hook], + checkpoint_dir=self.model_dir) as sess: + sess.run(train_op) + sess.run(train_op) + global_step_val = sess.run(global_step) + listener_counts = listener.get_counts() + self.assertEqual(2, global_step_val) + self.assertEqual({ + 'begin': 1, + 'before_save': 2, + 'after_save': 2, + 'end': 1 + }, listener_counts) + + with ops.Graph().as_default(): + global_step = variables.get_or_create_global_step() + with monitored_session.SingularMonitoredSession( + checkpoint_dir=self.model_dir) as sess2: + global_step_saved_val = sess2.run(global_step) + self.assertEqual(2, global_step_saved_val) + + def test_two_listeners_with_default_saver(self): + with ops.Graph().as_default(): + global_step = variables.get_or_create_global_step() + train_op = state_ops.assign_add(global_step, 1) + listener1 = MockCheckpointSaverListener() + listener2 = MockCheckpointSaverListener() + hook = basic_session_run_hooks.CheckpointSaverHook( + self.model_dir, + save_steps=1, + listeners=[listener1, listener2]) + with monitored_session.SingularMonitoredSession( + hooks=[hook], + checkpoint_dir=self.model_dir) as sess: + sess.run(train_op) + sess.run(train_op) + global_step_val = sess.run(global_step) + listener1_counts = listener1.get_counts() + listener2_counts = listener2.get_counts() + self.assertEqual(2, global_step_val) + self.assertEqual({ + 'begin': 1, + 'before_save': 2, + 'after_save': 2, + 'end': 1 + }, listener1_counts) + self.assertEqual(listener1_counts, listener2_counts) + + with ops.Graph().as_default(): + global_step = variables.get_or_create_global_step() + with monitored_session.SingularMonitoredSession( + checkpoint_dir=self.model_dir) as sess2: + global_step_saved_val = sess2.run(global_step) + self.assertEqual(2, global_step_saved_val) + @test.mock.patch('time.time') def test_save_secs_saves_periodically(self, mock_time): # Let's have a realistic start time diff --git a/tensorflow/python/training/monitored_session.py b/tensorflow/python/training/monitored_session.py index ae76a1ab580..cf8692eda13 100644 --- a/tensorflow/python/training/monitored_session.py +++ b/tensorflow/python/training/monitored_session.py @@ -22,7 +22,6 @@ from __future__ import print_function import abc from tensorflow.core.protobuf import config_pb2 -from tensorflow.core.protobuf import saver_pb2 from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops @@ -180,11 +179,7 @@ class Scaffold(object): summary.merge_all) # pylint: disable=g-long-lambda if self._saver is None: - self._saver = Scaffold.get_or_default( - 'saver', - ops.GraphKeys.SAVERS, - lambda: training_saver.Saver(sharded=True, allow_empty=True, - write_version=saver_pb2.SaverDef.V2)) + self._saver = training_saver._get_saver_or_default() # pylint: disable=protected-access # pylint: enable=g-long-lambda self._saver.build() diff --git a/tensorflow/python/training/saver.py b/tensorflow/python/training/saver.py index ae5fc54d854..43b61742467 100644 --- a/tensorflow/python/training/saver.py +++ b/tensorflow/python/training/saver.py @@ -712,6 +712,33 @@ class BaseSaverBuilder(object): version=self._write_version) +def _get_saver_or_default(): + """Returns the saver from SAVERS collection, or creates a default one. + + This method is used by other members of the training module, such as + `Scaffold`, or `CheckpointSaverHook`. + + Returns: + `Saver`. + + Raises: + RuntimeError: If the SAVERS collection already has more than one items. + """ + collection_key = ops.GraphKeys.SAVERS + savers = ops.get_collection(collection_key) + if savers: + if len(savers) > 1: + raise RuntimeError( + "More than one item in collection {}. " + "Please indicate which one to use by passing it to the constructor.". + format(collection_key)) + return savers[0] + saver = Saver(sharded=True, allow_empty=True) + if saver is not None: + ops.add_to_collection(collection_key, saver) + return saver + + def _GetCheckpointFilename(save_dir, latest_filename): """Returns a filename for storing the CheckpointState. diff --git a/tensorflow/python/training/saver_test.py b/tensorflow/python/training/saver_test.py index c0f0b309c1f..3172a1d5ba9 100644 --- a/tensorflow/python/training/saver_test.py +++ b/tensorflow/python/training/saver_test.py @@ -1785,11 +1785,9 @@ class MetaGraphTest(test.TestCase): # Test that we can import a meta graph into a namescope. test_dir = self._get_test_dir("import_into_namescope") filename = os.path.join(test_dir, "ckpt") - image = array_ops.placeholder(dtypes.float32, [None, 784]) - label = array_ops.placeholder(dtypes.float32, [None, 10]) + image = array_ops.placeholder(dtypes.float32, [None, 784], name="image") + label = array_ops.placeholder(dtypes.float32, [None, 10], name="label") with session.Session() as sess: - label = array_ops.identity(label, name="label") - image = array_ops.identity(image, name="image") weights = variables.Variable( random_ops.random_uniform([784, 10]), name="weights") bias = variables.Variable(array_ops.zeros([10]), name="bias") diff --git a/tensorflow/python/util/nest.py b/tensorflow/python/util/nest.py index ad0fcd31e5f..ade62389810 100644 --- a/tensorflow/python/util/nest.py +++ b/tensorflow/python/util/nest.py @@ -99,7 +99,7 @@ def flatten(nest): return list(_yield_flat_nest(nest)) if is_sequence(nest) else [nest] -def _recursive_assert_same_structure(nest1, nest2): +def _recursive_assert_same_structure(nest1, nest2, check_types): is_sequence_nest1 = is_sequence(nest1) if is_sequence_nest1 != is_sequence(nest2): raise ValueError( @@ -109,28 +109,31 @@ def _recursive_assert_same_structure(nest1, nest2): if is_sequence_nest1: type_nest1 = type(nest1) type_nest2 = type(nest2) - if type_nest1 != type_nest2: + if check_types and type_nest1 != type_nest2: raise TypeError( "The two structures don't have the same sequence type. First " "structure has type %s, while second structure has type %s." % (type_nest1, type_nest2)) for n1, n2 in zip(nest1, nest2): - _recursive_assert_same_structure(n1, n2) + _recursive_assert_same_structure(n1, n2, check_types) -def assert_same_structure(nest1, nest2): +def assert_same_structure(nest1, nest2, check_types=True): """Asserts that two structures are nested in the same way. Args: nest1: an arbitrarily nested structure. nest2: an arbitrarily nested structure. + check_types: if `True` (default) types of sequences are checked as + well. If set to `False`, for example a list and a tuple of objects will + look same if they have the same size. Raises: ValueError: If the two structures do not have the same number of elements or if the two structures are not nested in the same way. TypeError: If the two structures differ in the type of sequence in any of - their substructures. + their substructures. Only possible if `check_types` is `True`. """ len_nest1 = len(flatten(nest1)) if is_sequence(nest1) else 1 len_nest2 = len(flatten(nest2)) if is_sequence(nest2) else 1 @@ -138,7 +141,7 @@ def assert_same_structure(nest1, nest2): raise ValueError("The two structures don't have the same number of " "elements. First structure: %s, second structure: %s." % (nest1, nest2)) - _recursive_assert_same_structure(nest1, nest2) + _recursive_assert_same_structure(nest1, nest2, check_types) def flatten_dict_items(dictionary): @@ -266,7 +269,7 @@ def pack_sequence_as(structure, flat_sequence): return _sequence_like(structure, packed) -def map_structure(func, *structure): +def map_structure(func, *structure, **check_types_dict): """Applies `func` to each entry in `structure` and returns a new structure. Applies `func(x[0], x[1], ...)` where x[i] is an entry in @@ -277,17 +280,24 @@ def map_structure(func, *structure): func: A callable that acceps as many arguments are there are structures. *structure: scalar, or tuple or list of constructed scalars and/or other tuples/lists, or scalars. Note: numpy arrays are considered scalars. + **check_types_dict: only valid keyword argument is `check_types`. If set to + `True` (default) the types of iterables within the structures have to be + same (e.g. `map_structure(func, [1], (1,))` raises a `TypeError` + exception). To allow this set this argument to `False`. Returns: A new structure with the same arity as `structure`, whose values correspond to `func(x[0], x[1], ...)` where `x[i]` is a value in the corresponding - location in `structure[i]`. + location in `structure[i]`. If there are different sequence types and + `check_types` is `False` the sequence types of the first structure will be + used. Raises: TypeError: If `func` is not callable or if the structures do not match each other by depth tree. ValueError: If no structure is provided or if the structures do not match each other by type. + ValueError: If wrong keyword arguments are provided. """ if not callable(func): raise TypeError("func must be callable, got: %s" % func) @@ -295,8 +305,15 @@ def map_structure(func, *structure): if not structure: raise ValueError("Must provide at least one structure") + if check_types_dict: + if "check_types" not in check_types_dict or len(check_types_dict) > 1: + raise ValueError("Only valid keyword argument is check_types") + check_types = check_types_dict["check_types"] + else: + check_types = True + for other in structure[1:]: - assert_same_structure(structure[0], other) + assert_same_structure(structure[0], other, check_types=check_types) flat_structure = [flatten(s) for s in structure] entries = zip(*flat_structure) @@ -315,7 +332,7 @@ def _yield_flat_up_to(shallow_tree, input_tree): yield input_tree -def assert_shallow_structure(shallow_tree, input_tree): +def assert_shallow_structure(shallow_tree, input_tree, check_types=True): """Asserts that `shallow_tree` is a shallow structure of `input_tree`. That is, this function tests if the `input_tree` structure can be created from @@ -341,11 +358,13 @@ def assert_shallow_structure(shallow_tree, input_tree): Args: shallow_tree: an arbitrarily nested structure. input_tree: an arbitrarily nested structure. + check_types: if `True` (default) the sequence types of `shallow_tree` and + `input_tree` have to be the same. Raises: TypeError: If `shallow_tree` is a sequence but `input_tree` is not. TypeError: If the sequence types of `shallow_tree` are different from - `input_tree`. + `input_tree`. Only raised if `check_types` is `True`. ValueError: If the sequence lengths of `shallow_tree` are different from `input_tree`. """ @@ -355,7 +374,7 @@ def assert_shallow_structure(shallow_tree, input_tree): "If shallow structure is a sequence, input must also be a sequence. " "Input has type: %s." % type(input_tree)) - if not isinstance(input_tree, type(shallow_tree)): + if check_types and not isinstance(input_tree, type(shallow_tree)): raise TypeError( "The two structures don't have the same sequence type. Input " "structure has type %s, while shallow structure has type %s." @@ -368,7 +387,8 @@ def assert_shallow_structure(shallow_tree, input_tree): % (len(input_tree), len(shallow_tree))) for shallow_branch, input_branch in zip(shallow_tree, input_tree): - assert_shallow_structure(shallow_branch, input_branch) + assert_shallow_structure(shallow_branch, input_branch, + check_types=check_types) def flatten_up_to(shallow_tree, input_tree): diff --git a/tensorflow/python/util/nest_test.py b/tensorflow/python/util/nest_test.py index f6a2d8b6631..8a17d990da2 100644 --- a/tensorflow/python/util/nest_test.py +++ b/tensorflow/python/util/nest_test.py @@ -139,6 +139,13 @@ class NestTest(test.TestCase): "don't have the same nested structure"): nest.assert_same_structure([[3], 4], [3, [4]]) + structure1_list = [[[1, 2], 3], 4, [5, 6]] + with self.assertRaisesRegexp(TypeError, + "don't have the same sequence type"): + nest.assert_same_structure(structure1, structure1_list) + nest.assert_same_structure(structure1, structure2, check_types=False) + nest.assert_same_structure(structure1, structure1_list, check_types=False) + def testMapStructure(self): structure1 = (((1, 2), 3), 4, (5, 6)) structure2 = (((7, 8), 9), 10, (11, 12)) @@ -169,6 +176,23 @@ class NestTest(test.TestCase): with self.assertRaisesRegexp(ValueError, "same nested structure"): nest.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5))) + structure1_list = [[[1, 2], 3], 4, [5, 6]] + with self.assertRaisesRegexp(TypeError, "same sequence type"): + nest.map_structure(lambda x, y: None, structure1, structure1_list) + + nest.map_structure(lambda x, y: None, structure1, structure1_list, + check_types=False) + + with self.assertRaisesRegexp(ValueError, "same nested structure"): + nest.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5)), + check_types=False) + + with self.assertRaisesRegexp(ValueError, "Only valid keyword argument"): + nest.map_structure(lambda x: None, structure1, foo="a") + + with self.assertRaisesRegexp(ValueError, "Only valid keyword argument"): + nest.map_structure(lambda x: None, structure1, check_types=False, foo="a") + def testAssertShallowStructure(self): inp_ab = ["a", "b"] inp_abc = ["a", "b", "c"] @@ -186,6 +210,7 @@ class NestTest(test.TestCase): "<(type|class) 'list'>.") with self.assertRaisesRegexp(TypeError, expected_message): nest.assert_shallow_structure(inp_ab2, inp_ab1) + nest.assert_shallow_structure(inp_ab2, inp_ab1, check_types=False) def testFlattenUpTo(self): input_tree = [[[2, 2], [3, 3]], [[4, 9], [5, 5]]] diff --git a/tensorflow/tensorboard/components/tf_graph/tf-graph-controls.html b/tensorflow/tensorboard/components/tf_graph/tf-graph-controls.html index dce0708d0c4..f2a1b5658f2 100644 --- a/tensorflow/tensorboard/components/tf_graph/tf-graph-controls.html +++ b/tensorflow/tensorboard/components/tf_graph/tf-graph-controls.html @@ -145,7 +145,7 @@ svg.icon { .title small { font-weight: normal; } -.deviceList { +.deviceList, .xlaClusterList { max-height: 200px; overflow-y: auto; } @@ -235,6 +235,12 @@ table.control-holder { table.tf-graph-controls td.input-element-table-data { padding: 0 0 0 20px; } + +/** Override inline styles that suppress pointer events for disabled buttons. Otherwise, the */ +/* tooltips do not appear. */ +#color-by-radio-group paper-radio-button { + pointer-events: auto !important; +} @@ -317,13 +323,39 @@ table.tf-graph-controls td.input-element-table-data {
Color
- + Structure + Device - + + + XLA Cluster + + + Coloring by XLA cluster is only enabled if at least 1 op specifies an XLA cluster. + + + + Compute time + + + Coloring by compute time is only enabled if the RunMetadata proto is passed to the + FileWriter when a specific session is run. + + + + Memory + + + Coloring by memory is only enabled if the RunMetadata proto is passed to the + FileWriter when a specific session is run. +
@@ -410,6 +442,32 @@ table.tf-graph-controls td.input-element-table-data {
+