diff --git a/tensorflow/BUILD b/tensorflow/BUILD index 2688e92d08a..6a70c0e4057 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -297,6 +297,7 @@ filegroup( "//tensorflow/tensorboard/backend:all_files", "//tensorflow/tensorboard/backend/event_processing:all_files", "//tensorflow/tensorboard/components:all_files", + "//tensorflow/tensorboard/components/tf_text_dashboard:all_files", "//tensorflow/tensorboard/components/vz_data_summary:all_files", "//tensorflow/tensorboard/components/vz_line_chart:all_files", "//tensorflow/tensorboard/components/vz_line_chart/demo:all_files", diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index 295777e897a..d4bcc01b6b8 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -135,6 +135,9 @@ class TF_ManagedBuffer : public TensorBuffer { proto->set_requested_bytes(rb); proto->set_allocator_name(tensorflow::cpu_allocator()->Name()); } + + // Prevents input forwarding from mutating this buffer. + bool OwnsMemory() const override { return false; } }; void* allocate_tensor(const char* operation, size_t len) { diff --git a/tensorflow/cc/BUILD b/tensorflow/cc/BUILD index 9a41d2bb1d7..9d1870af0af 100644 --- a/tensorflow/cc/BUILD +++ b/tensorflow/cc/BUILD @@ -314,6 +314,7 @@ tf_gen_op_wrappers_cc( name = "cc_ops", op_lib_names = [ "array_ops", + "audio_ops", "candidate_sampling_ops", "control_flow_ops", "data_flow_ops", diff --git a/tensorflow/compiler/jit/kernels/xla_local_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_local_launch_op.cc index 70b9c6fb0fb..8b43c7c1564 100644 --- a/tensorflow/compiler/jit/kernels/xla_local_launch_op.cc +++ b/tensorflow/compiler/jit/kernels/xla_local_launch_op.cc @@ -260,6 +260,8 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) { xla::ExecutableRunOptions run_options; run_options.set_stream(stream); run_options.set_allocator(&xla_allocator); + run_options.set_inter_op_thread_pool( + ctx->device()->tensorflow_cpu_worker_threads()->workers); run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device()); Env* env = Env::Default(); auto start_time = env->NowMicros(); diff --git a/tensorflow/compiler/jit/xla_device_ops.cc b/tensorflow/compiler/jit/xla_device_ops.cc index 0d3a2fa3393..f68dba6b6a2 100644 --- a/tensorflow/compiler/jit/xla_device_ops.cc +++ b/tensorflow/compiler/jit/xla_device_ops.cc @@ -19,13 +19,6 @@ limitations under the License. namespace tensorflow { -void XlaDeviceAssignOp::Copy(OpKernelContext* context, Tensor* lhs, - const Tensor& rhs) { - std::shared_ptr gd = - XlaTransferManager::GetTensorGlobalData(rhs); - XlaTransferManager::SetTensorGlobalData(std::move(gd), lhs); -} - XlaDeviceDummyOp::XlaDeviceDummyOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} void XlaDeviceDummyOp::Compute(OpKernelContext* ctx) { diff --git a/tensorflow/compiler/jit/xla_device_ops.h b/tensorflow/compiler/jit/xla_device_ops.h index b084dcaa7d9..a52239df252 100644 --- a/tensorflow/compiler/jit/xla_device_ops.h +++ b/tensorflow/compiler/jit/xla_device_ops.h @@ -20,7 +20,6 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/resource_mgr.h" -#include "tensorflow/core/kernels/assign_op.h" #include "tensorflow/core/kernels/constant_op.h" #include "tensorflow/core/kernels/control_flow_ops.h" #include "tensorflow/core/kernels/identity_op.h" @@ -30,14 +29,6 @@ limitations under the License. namespace tensorflow { -// Implementation of Assign for XLA devices. -class XlaDeviceAssignOp : public AssignOp { - public: - using AssignOp::AssignOp; - - void Copy(OpKernelContext* context, Tensor* lhs, const Tensor& rhs) override; -}; - // Dummy OpKernel, used for kernels assigned to an XLA device that should be // compiled. Should never be called at runtime since such ops should be // rewritten to a _XlaLaunch op. If it is called, it means the placer placed an @@ -72,28 +63,6 @@ class XlaDeviceDummyOp : public OpKernel { REGISTER_KERNEL_BUILDER(Name("PlaceholderV2").Device(DEVICE), \ PlaceholderOp); \ \ - REGISTER_KERNEL_BUILDER( \ - Name("Variable").Device(DEVICE).TypeConstraint("dtype", TYPES), \ - VariableOp); \ - REGISTER_KERNEL_BUILDER( \ - Name("VariableV2").Device(DEVICE).TypeConstraint("dtype", TYPES), \ - VariableOp); \ - REGISTER_KERNEL_BUILDER( \ - Name("TemporaryVariable").Device(DEVICE).TypeConstraint("dtype", TYPES), \ - TemporaryVariableOp); \ - REGISTER_KERNEL_BUILDER(Name("DestroyTemporaryVariable") \ - .Device(DEVICE) \ - .TypeConstraint("T", TYPES), \ - DestroyTemporaryVariableOp); \ - REGISTER_KERNEL_BUILDER(Name("IsVariableInitialized") \ - .Device(DEVICE) \ - .TypeConstraint("dtype", TYPES) \ - .HostMemory("is_initialized"), \ - IsVariableInitializedOp); \ - REGISTER_KERNEL_BUILDER( \ - Name("Assign").Device(DEVICE).TypeConstraint("T", TYPES), \ - XlaDeviceAssignOp); \ - \ REGISTER_KERNEL_BUILDER(Name("ControlTrigger").Device(DEVICE), \ ControlTriggerOp); \ REGISTER_KERNEL_BUILDER(Name("Enter").Device(DEVICE), EnterOp); \ diff --git a/tensorflow/compiler/tf2xla/op_registrations.cc b/tensorflow/compiler/tf2xla/op_registrations.cc index cbfa239ca88..d27520213a7 100644 --- a/tensorflow/compiler/tf2xla/op_registrations.cc +++ b/tensorflow/compiler/tf2xla/op_registrations.cc @@ -614,9 +614,12 @@ REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, Name("TruncateDiv").TypeConstraint("T", kGpuIntTypes)); REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, Name("TruncateMod").TypeConstraint("T", kGpuNumericTypes)); -REGISTER_XLA_KERNEL( - DEVICE_GPU_XLA_JIT, - Name("TruncatedNormal").TypeConstraint("dtype", kGpuFloatTypes)); + +// TODO(b/34969189) The implementation of TruncatedNormal triggers a bug on GPU. +// REGISTER_XLA_KERNEL( +// DEVICE_GPU_XLA_JIT, +// Name("TruncatedNormal").TypeConstraint("dtype", kGpuFloatTypes)); + REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, Name("Unpack").TypeConstraint("T", kGpuAllTypes)); REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, Name("VarIsInitializedOp")); diff --git a/tensorflow/compiler/xla/client/client.cc b/tensorflow/compiler/xla/client/client.cc index 2a8a4b321a6..973204eae2d 100644 --- a/tensorflow/compiler/xla/client/client.cc +++ b/tensorflow/compiler/xla/client/client.cc @@ -301,8 +301,7 @@ StatusOr>> Client::ExecuteParallel( } std::vector> outputs; - for (tensorflow::gtl::ArraySlice::size_type i = 0; - i < computations.size(); ++i) { + for (size_t i = 0; i < computations.size(); ++i) { outputs.push_back( MakeUnique(stub_, response.responses(i).output())); if (computations[i].execution_profile != nullptr) { diff --git a/tensorflow/compiler/xla/client/computation.cc b/tensorflow/compiler/xla/client/computation.cc index cd7d8df58b8..0f9ca7b4fe4 100644 --- a/tensorflow/compiler/xla/client/computation.cc +++ b/tensorflow/compiler/xla/client/computation.cc @@ -33,7 +33,7 @@ Computation::Computation(Computation&& computation) } void Computation::Reset() { - // TODO(leary) deallocate any owned computation. + // TODO(b/34469253) deallocate any owned computation. ResetWithoutFreeing(); } diff --git a/tensorflow/compiler/xla/client/computation_builder.cc b/tensorflow/compiler/xla/client/computation_builder.cc index 4b7f9c1822c..88efd87d1cc 100644 --- a/tensorflow/compiler/xla/client/computation_builder.cc +++ b/tensorflow/compiler/xla/client/computation_builder.cc @@ -106,9 +106,7 @@ bool ComputationBuilder::MakeWindow( tensorflow::gtl::ArraySlice> padding, tensorflow::gtl::ArraySlice lhs_dilation, tensorflow::gtl::ArraySlice rhs_dilation, Window* window) { - const auto verify_size = [&](const tensorflow::gtl::ArraySlice< - int64>::size_type x, - const char* x_name) { + const auto verify_size = [&](const size_t x, const char* x_name) { if (x == 0 || x == window_dimensions.size()) { return true; } else { diff --git a/tensorflow/compiler/xla/client/computation_builder.h b/tensorflow/compiler/xla/client/computation_builder.h index 0d2d541d34f..87ceb43d1fe 100644 --- a/tensorflow/compiler/xla/client/computation_builder.h +++ b/tensorflow/compiler/xla/client/computation_builder.h @@ -541,8 +541,8 @@ class ComputationBuilder { // (float32 is specified as there is an implicit float32 -1.0f constant // exponent). // - // TODO(leary) axe F32 suffix, can be determined by reflecting on the shape of - // the operand. + // TODO(b/34468990) axe F32 suffix, can be determined by reflecting on the + // shape of the operand. ComputationDataHandle ReciprocalF32(const ComputationDataHandle& operand); // Enqueues a negate instruction onto the computation. @@ -839,7 +839,7 @@ template ComputationDataHandle ComputationBuilder::ConstantR4FromArray4DWithLayout( const Array4D& values, const Layout& layout) { return ConstantOp([&values, &layout](Literal* literal) { - LiteralUtil::PopulateR4FromArray4D(values, layout, literal); + LiteralUtil::PopulateR4FromArray4DWithLayout(values, layout, literal); }); } diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc index 1b764564912..bfd14bc1c01 100644 --- a/tensorflow/compiler/xla/client/local_client.cc +++ b/tensorflow/compiler/xla/client/local_client.cc @@ -309,6 +309,14 @@ int LocalClient::default_device_ordinal() const { return local_service_->backend().default_device_ordinal(); } +const Backend& LocalClient::backend() const { + return local_service_->backend(); +} + +Backend* LocalClient::mutable_backend() { + return local_service_->mutable_backend(); +} + StatusOr> LocalClient::Compile( const Computation& computation, const tensorflow::gtl::ArraySlice argument_layouts, diff --git a/tensorflow/compiler/xla/client/local_client.h b/tensorflow/compiler/xla/client/local_client.h index 7bffc75ab5f..2c467efcea1 100644 --- a/tensorflow/compiler/xla/client/local_client.h +++ b/tensorflow/compiler/xla/client/local_client.h @@ -224,6 +224,10 @@ class LocalClient : public Client { // capability). bool device_ordinal_supported(int device_ordinal) const; + // Returns the backend used to execute computations. + const Backend& backend() const; + Backend* mutable_backend(); + private: LocalService* local_service_; }; diff --git a/tensorflow/compiler/xla/client/padding.cc b/tensorflow/compiler/xla/client/padding.cc index 8d758157115..0b18d8946a2 100644 --- a/tensorflow/compiler/xla/client/padding.cc +++ b/tensorflow/compiler/xla/client/padding.cc @@ -35,8 +35,7 @@ std::vector> MakePadding( return low_high_padding; case Padding::kSame: - for (tensorflow::gtl::ArraySlice::size_type i = 0; - i < input_dimensions.size(); ++i) { + for (size_t i = 0; i < input_dimensions.size(); ++i) { int64 input_dimension = input_dimensions[i]; int64 window_dimension = window_dimensions[i]; int64 window_stride = window_strides[i]; diff --git a/tensorflow/compiler/xla/index_util.cc b/tensorflow/compiler/xla/index_util.cc index eb937c36146..e3248d8e908 100644 --- a/tensorflow/compiler/xla/index_util.cc +++ b/tensorflow/compiler/xla/index_util.cc @@ -32,8 +32,7 @@ namespace xla { // Padding and nested layouts not supported yet. DCHECK_EQ(0, shape.layout().padded_dimensions_size()); - for (tensorflow::gtl::ArraySlice::size_type i = 0; - i < multi_index.size(); ++i) { + for (size_t i = 0; i < multi_index.size(); ++i) { DCHECK_GE(multi_index[i], 0); DCHECK_LT(multi_index[i], shape.dimensions(i)) << "indexing beyond extent in dimension " << i << ":" diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 15d1fe7852c..b9118fab254 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -1133,6 +1133,22 @@ cc_library( ], ) +cc_library( + name = "hlo_verifier", + srcs = ["hlo_verifier.cc"], + hdrs = ["hlo_verifier.h"], + deps = [ + ":hlo", + ":hlo_pass", + "//tensorflow/compiler/xla:status", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", + ], +) + cc_library( name = "hlo_rematerialization", srcs = ["hlo_rematerialization.cc"], diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index 327234688f7..4c00ec37965 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -896,7 +896,7 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopPad) { HloInstruction* zero = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); PaddingConfig no_padding; - for (auto i = 0; i < 2; ++i) { + for (int i = 0; i < 2; ++i) { auto dimension = no_padding.add_dimensions(); dimension->set_edge_padding_low(0); dimension->set_edge_padding_high(0); @@ -926,7 +926,7 @@ TEST_F(AlgebraicSimplifierTest, NegativePadding) { PaddingConfig padding; int64 low_padding[2] = {-1, -2}; int64 high_padding[2] = {2, -3}; - for (auto i = 0; i < 2; ++i) { + for (int i = 0; i < 2; ++i) { auto dimension = padding.add_dimensions(); dimension->set_edge_padding_low(low_padding[i]); dimension->set_edge_padding_high(high_padding[i]); diff --git a/tensorflow/compiler/xla/service/allocation_tracker.cc b/tensorflow/compiler/xla/service/allocation_tracker.cc index 8f169cd0368..e59fad4e052 100644 --- a/tensorflow/compiler/xla/service/allocation_tracker.cc +++ b/tensorflow/compiler/xla/service/allocation_tracker.cc @@ -138,8 +138,7 @@ tensorflow::Status AllocationTracker::DeallocateShape( TF_RET_CHECK(ShapeUtil::TupleElementCount(shape) == elements.size()) << "tuple has unexpected number of elements: " << elements.size() << " != " << ShapeUtil::TupleElementCount(shape); - for (std::vector::size_type i = 0; - i < elements.size(); ++i) { + for (size_t i = 0; i < elements.size(); ++i) { VLOG(2) << "recursing onto the tuple elements"; TF_RETURN_IF_ERROR(DeallocateShape(backend, device_ordinal, &elements[i], shape.tuple_shapes(i), diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc index f1fc608caa0..e2b550fc022 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment.cc @@ -212,6 +212,13 @@ StatusOr BufferAssignment::GetUniqueTopLevelSlice( return GetUniqueSlice(instruction, /*index=*/{}); } +bool BufferAssignment::SharesSliceAtIndex( + const HloInstruction* hlo_a, const ShapeIndex& shape_index_a, + const HloInstruction* hlo_b, const ShapeIndex& shape_index_b) const { + return GetUniqueSlice(hlo_a, shape_index_a).ConsumeValueOrDie() == + GetUniqueSlice(hlo_b, shape_index_b).ConsumeValueOrDie(); +} + StatusOr BufferAssignment::GetUniqueTopLevelOutputSlice() const { return GetUniqueTopLevelSlice( diff --git a/tensorflow/compiler/xla/service/buffer_assignment.h b/tensorflow/compiler/xla/service/buffer_assignment.h index 82b9bf49ece..b82acb19b34 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.h +++ b/tensorflow/compiler/xla/service/buffer_assignment.h @@ -294,6 +294,15 @@ class BufferAssignment { return GetPointsToSet(instruction).element(index); } + // Returns true if 'hlo_a{shape_index_a}' and 'hlo_b{shape_index_b}' + // share the same BufferAllocation::Slice. + // Returns false otherwise. + // REQUIRES: BufferAssignment assigned allocations to both instructions. + bool SharesSliceAtIndex(const HloInstruction* hlo_a, + const ShapeIndex& shape_index_a, + const HloInstruction* hlo_b, + const ShapeIndex& shape_index_b) const; + // Returns the underlying points-to analysis used for this assignment. const TuplePointsToAnalysis& points_to_analysis() const { return liveness_->points_to_analysis(); diff --git a/tensorflow/compiler/xla/service/buffer_liveness.cc b/tensorflow/compiler/xla/service/buffer_liveness.cc index 0fe6e37c00f..736f227aa42 100644 --- a/tensorflow/compiler/xla/service/buffer_liveness.cc +++ b/tensorflow/compiler/xla/service/buffer_liveness.cc @@ -121,6 +121,7 @@ bool BufferLiveness::live_range_strictly_before(const LogicalBuffer& a, // *) Is element-wise. // *) Is a loop fusion instruction (with DynamicUpdateSlice fused root) where // the singleton use of 'a' at 'a.index' is the fused root at operand 0. + // *) Use of 'operand' is DynamicUpdateSlice at operand index 0. for (const BufferAlias& alias : points_to_analysis_->GetBufferAliases(a)) { if (b.instruction()->IsUserOf(alias.instruction()) && !CanShareOperandBufferWithUser(alias.instruction(), alias.index(), diff --git a/tensorflow/compiler/xla/service/buffer_liveness_test.cc b/tensorflow/compiler/xla/service/buffer_liveness_test.cc index e7aa93f8dbc..e71b98298b3 100644 --- a/tensorflow/compiler/xla/service/buffer_liveness_test.cc +++ b/tensorflow/compiler/xla/service/buffer_liveness_test.cc @@ -612,6 +612,93 @@ TEST_F(FusedDynamicUpdateSliceLivenessTest, WithInterference) { EXPECT_TRUE(Run(/*update_uses_tuple_element1=*/true)); } +class DynamicUpdateSliceLivenessTest : public BufferLivenessTest { + protected: + // Builds and runs a computation (see test case computation graphs below). + // Runs BufferLiveness on this computation. + // Returns whether buffer interference is detected between tuple-shaped + // parameter and root instructions at tuple element 1. + bool Run(const bool tuple_element1_has_two_uses) { + auto builder = HloComputation::Builder(TestName()); + // Create param0 Tuple. + Shape data_shape = ShapeUtil::MakeShape(F32, {8}); + Shape update_shape = ShapeUtil::MakeShape(F32, {3}); + auto tuple_param0 = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "param0")); + + auto gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape, tuple_param0, 0)); + + auto gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape, tuple_param0, 1)); + + auto update = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1({2.f, 2.f, 2.f}))); + + if (tuple_element1_has_two_uses) { + // Add 'gte0' and 'gte1' to create another user of 'gte1'. + gte0 = builder.AddInstruction(HloInstruction::CreateBinary( + data_shape, HloOpcode::kAdd, gte0, gte1)); + } + // Create a DynamicUpdateSlice instruction of tuple element 1 with 'update'. + auto starts = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR1({2}))); + auto dynamic_update_slice = + builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( + data_shape, gte1, update, starts)); + // Create output tuple. + auto tuple_root = builder.AddInstruction( + HloInstruction::CreateTuple({gte0, dynamic_update_slice})); + // Build module and get reference to entry computation. + auto module = MakeUnique(TestName()); + module->AddEntryComputation(builder.Build()); + // Run BufferLiveness on 'module'. + auto liveness = + BufferLiveness::Run(module.get(), + MakeUnique(module.get())) + .ConsumeValueOrDie(); + // Return whether or not buffers interfernce is detected between + // 'tuple_param0' and 'tuple_root' at shape index '{1}'. + return TupleElementsMayInterfere(*liveness, tuple_param0, tuple_root, {1}); + } +}; + +// Tests that live ranges of buffers Param0[1] and Tuple[1] do not overlap in +// the following computation (because DynamicUpdateSlice (at operand 0) is the +// unique user): +// +// Parameter0 +// | | +// GTE(0) GTE(1) Const Const +// | \ | / +// | DynamicUpdateSlice +// \ / +// Tuple +// +TEST_F(DynamicUpdateSliceLivenessTest, NoInterference) { + EXPECT_FALSE(Run(/*tuple_element1_has_two_uses=*/false)); +} + +// Tests that live ranges of buffers Param0[1] and Tuple[1] do overlap because +// GTE(1) has two users: +// 1) DynamicUpdateSlice at operand 0. +// 2) Add at operand 1. +// +// Parameter0 +// | | +// GTE(0) GTE(1) +// | / | +// | / | +// Add | Const Const +// | | | | +// | DynamicUpdateSlice +// \ / +// Tuple +// +TEST_F(DynamicUpdateSliceLivenessTest, WithInterference) { + EXPECT_TRUE(Run(/*tuple_element1_has_two_uses=*/true)); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/compiler.h b/tensorflow/compiler/xla/service/compiler.h index 45cbe2b7aeb..6f43c9b8040 100644 --- a/tensorflow/compiler/xla/service/compiler.h +++ b/tensorflow/compiler/xla/service/compiler.h @@ -110,8 +110,7 @@ class Compiler { // The compiler may optionally specialize to the individual device // (not just type of device) indicated by the executor. // - // TODO(leary) will need to update this API when a single computation can run - // across multiple devices simultaneously. + // Use the overload below to compile computations that run in parallel. virtual StatusOr> Compile( std::unique_ptr module, std::unique_ptr module_config, HloDumper dump_hlo, diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index abb7ace0c05..e9963528111 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -69,6 +69,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_pass", "//tensorflow/compiler/xla/service:hlo_pass_pipeline", "//tensorflow/compiler/xla/service:hlo_subcomputation_unification", + "//tensorflow/compiler/xla/service:hlo_verifier", "//tensorflow/compiler/xla/service:inliner", "//tensorflow/compiler/xla/service:reshape_mover", "//tensorflow/compiler/xla/service:transpose_folding", diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index 786c9847dc0..c5433d4b89d 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -68,6 +68,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_pass_fix.h" #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h" #include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h" +#include "tensorflow/compiler/xla/service/hlo_verifier.h" #include "tensorflow/compiler/xla/service/inliner.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/service/reshape_mover.h" @@ -214,6 +215,7 @@ Status CpuCompiler::RunHloPasses(HloModule* hlo_module, HloDumper dump_hlo) { // Optimization pipeline. HloPassPipeline pipeline("CPU", dump_hlo); + pipeline.AddInvariantChecker(); // TODO(b/35786417): Re-enable inliner pass after fixing the bug and deciding // where we will take this pass in future. @@ -573,8 +575,7 @@ CpuCompiler::CompileAheadOfTime( } std::vector> results; - for (std::vector>::size_type i = 0; - i < hlo_modules.size(); ++i) { + for (size_t i = 0; i < hlo_modules.size(); ++i) { HloModule* hlo_module = hlo_modules[i].get(); HloModuleConfig* module_config = module_configs[i].get(); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.h b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.h index b7c646ad47d..0eca4c3473e 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.h @@ -24,8 +24,9 @@ namespace cpu { class CpuInstructionFusion : public InstructionFusion { public: - CpuInstructionFusion() {} - ~CpuInstructionFusion() override {} + CpuInstructionFusion() + : InstructionFusion(CpuInstructionFusion::IsExpensive) {} + ~CpuInstructionFusion() override = default; protected: bool ShouldFuse(HloInstruction* consumer, int64 operand_index) override; diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index a7972db3ebe..51c6dc4426f 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -1111,7 +1111,7 @@ Status IrEmitter::HandleReduce(HloInstruction* reduce, HloInstruction* arg, llvm_ir::IrArray::Index input_index = reduced_dims_index; llvm_ir::IrArray::Index::const_iterator it = index.begin(); - for (auto i = 0; i < input_index.size(); ++i) { + for (size_t i = 0; i < input_index.size(); ++i) { if (input_index[i] == nullptr) { input_index[i] = *it++; } @@ -1180,7 +1180,7 @@ Status IrEmitter::HandlePad(HloInstruction* pad) { // output_index := edge_padding_low + operand_index * (interior_padding + 1) const PaddingConfig& padding_config = pad->padding_config(); llvm_ir::IrArray::Index output_index; - for (auto i = 0; i < operand_index.size(); ++i) { + for (size_t i = 0; i < operand_index.size(); ++i) { llvm::Value* offset = ir_builder_.CreateMul( operand_index[i], ir_builder_.getInt64(padding_config.dimensions(i).interior_padding() + @@ -1294,12 +1294,12 @@ Status IrEmitter::HandleCustomCall( llvm_ir::EmitAllocaAtFunctionEntryWithCount( i8_ptr_type, ir_builder_.getInt32(operands.size()), "cc_operands_alloca", &ir_builder_); - for (auto i = 0; i < operands.size(); ++i) { + for (size_t i = 0; i < operands.size(); ++i) { const HloInstruction* operand = operands[i]; llvm::Value* operand_as_i8ptr = ir_builder_.CreatePointerCast(GetEmittedValueFor(operand), i8_ptr_type); llvm::Value* slot_in_operands_alloca = ir_builder_.CreateInBoundsGEP( - operands_alloca, {ir_builder_.getInt32(i)}); + operands_alloca, {ir_builder_.getInt64(i)}); ir_builder_.CreateStore(operand_as_i8ptr, slot_in_operands_alloca); } auto* custom_call_ir_function = @@ -1659,13 +1659,13 @@ void IrEmitter::EmitArrayFunctionCallInto( ir_builder_.getInt32(parameter_addresses.size()), tensorflow::strings::StrCat(name, "_parameter_addresses"), &ir_builder_); - for (auto i = 0; i < parameter_addresses.size(); ++i) { + for (size_t i = 0; i < parameter_addresses.size(); ++i) { llvm::Value* parameter_as_i8ptr = ir_builder_.CreateBitCast( parameter_addresses[i], ir_builder_.getInt8PtrTy(), llvm_ir::AsStringRef(tensorflow::strings::StrCat(name, "_parameter_", i, "_address_as_i8ptr"))); llvm::Value* slot_in_param_adresses = ir_builder_.CreateInBoundsGEP( - parameter_addresses_buffer, {ir_builder_.getInt32(i)}); + parameter_addresses_buffer, {ir_builder_.getInt64(i)}); ir_builder_.CreateStore(parameter_as_i8ptr, slot_in_param_adresses); } diff --git a/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc index bab3440e2c2..7a4723e8d75 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc @@ -97,77 +97,81 @@ static void MarkLiveAddressesInOutput( } } -StatusOr -ParallelCpuExecutable::ExecuteOnStream( - const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments, - HloExecutionProfile* hlo_execution_profile) { - se::Stream* stream = run_options->stream(); - DeviceMemoryAllocator* memory_allocator = run_options->allocator(); - VLOG(3) << "ExecuteOnStream arg size: " << arguments.size(); - if (!arguments.empty()) { - VLOG(3) << "ExecuteOnStream arg[0]: " << arguments.at(0).opaque(); - } - - // Allocate the temporary buffers required for the computation. - se::StreamExecutor* stream_executor = stream->parent(); - int device_ordinal = stream_executor->device_ordinal(); - int64 buffer_count = assignment_->Allocations().size(); - VLOG(3) << "temp buffer count: " << buffer_count; - - std::vector device_allocations; - for (BufferAllocation::Index i = 0; i < buffer_count; ++i) { +Status ParallelCpuExecutable::AllocateBuffers( + DeviceMemoryAllocator* memory_allocator, int device_ordinal, + std::vector* buffers) { + CHECK_EQ(buffers->size(), assignment_->Allocations().size()); + VLOG(3) << "Allocating " << assignment_->Allocations().size() + << " allocations for module " << module().name(); + for (BufferAllocation::Index i = 0; i < assignment_->Allocations().size(); + ++i) { auto& allocation = assignment_->GetAllocation(i); + + VLOG(3) << allocation.ToString(); + if (allocation.is_entry_computation_parameter()) { - // Buffers do not need to be allocated for parameters. - device_allocations.push_back(se::DeviceMemoryBase(nullptr)); + VLOG(3) << "allocation #" << i << " is a parameter"; continue; } if (allocation.is_thread_local()) { - // Buffers do not need to be allocated for thread-local temporaries. - device_allocations.push_back(se::DeviceMemoryBase(nullptr)); + VLOG(3) << "buffer #" << i << " is thread-local"; continue; } - TF_ASSIGN_OR_RETURN( - se::DeviceMemoryBase device_allocation, - memory_allocator->Allocate(device_ordinal, allocation.size())); + int64 buffer_size = allocation.size(); + if (!(*buffers)[i].is_null()) { + VLOG(3) << "buffer #" << i + << " is in the preallocated result ShapedBuffer"; + } else { + TF_ASSIGN_OR_RETURN((*buffers)[i], memory_allocator->Allocate( + device_ordinal, buffer_size)); - if (VLOG_IS_ON(3)) { - VLOG(3) << "ParallelCpuExecutable allocating " << allocation.size() - << " bytes for allocation #" << i << " [" - << device_allocation.opaque() << "]"; - std::vector parts; - for (const auto& buffer_offset_size : allocation.assigned_buffers()) { - const LogicalBuffer& buffer = *buffer_offset_size.first; - parts.push_back(tensorflow::strings::StrCat( - buffer.instruction()->parent()->name(), "::", buffer.ToString())); - } - VLOG(3) << " " << tensorflow::str_util::Join(parts, ", "); + VLOG(3) << "buffer #" << i << " allocated " << buffer_size << " bytes [" + << (*buffers)[i].opaque() << "]"; } - device_allocations.push_back(device_allocation); // Since the output buffer and all the temporary buffers were written into // by the JITed code, msan has no way of knowing their memory was // initialized. Mark them initialized so that msan doesn't flag loads from // these buffers. - TF_ANNOTATE_MEMORY_IS_INITIALIZED(device_allocation.opaque(), - allocation.size()); + TF_ANNOTATE_MEMORY_IS_INITIALIZED((*buffers)[i].opaque(), buffer_size); } TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice result_slice, assignment_->GetUniqueTopLevelOutputSlice()); - const BufferAllocation::Index result_index = result_slice.index(); - VLOG(3) << "result index: " << result_index; + VLOG(3) << "result index: " << result_slice.index(); + return Status::OK(); +} + +Status ParallelCpuExecutable::ExecuteComputeFunctions( + const ExecutableRunOptions* run_options, + tensorflow::gtl::ArraySlice arguments, + tensorflow::gtl::ArraySlice buffers, + HloExecutionProfile* hlo_execution_profile) { + std::vector argument_buffers(arguments.size()); + for (int i = 0; i < arguments.size(); ++i) { + TF_RET_CHECK(!ShapeUtil::IsTuple(arguments[i]->shape())); + argument_buffers[i] = arguments[i]->buffer(/*index=*/{}); + } + return ExecuteComputeFunctions(run_options, argument_buffers, buffers, + hlo_execution_profile); +} + +Status ParallelCpuExecutable::ExecuteComputeFunctions( + const ExecutableRunOptions* run_options, + tensorflow::gtl::ArraySlice arguments, + tensorflow::gtl::ArraySlice buffers, + HloExecutionProfile* hlo_execution_profile) { // Allocate profiling counters for each hlo instruction that we would like to // profile. Allocate an additional profile counter for the entire // computation. std::vector profile_counters(hlo_to_profile_idx_.size() + 1); std::vector buffer_pointers; - for (auto& device_allocation : device_allocations) { + buffer_pointers.reserve(buffers.size()); + for (auto device_allocation : buffers) { buffer_pointers.push_back(device_allocation.opaque()); } @@ -210,8 +214,7 @@ ParallelCpuExecutable::ExecuteOnStream( void** temps_array = buffer_pointers.data(); uint64* profile_counters_array = profile_counters.data(); - auto* thread_pool = - CHECK_NOTNULL(run_options->run_options().inter_op_thread_pool()); + auto* thread_pool = CHECK_NOTNULL(run_options->inter_op_thread_pool()); tensorflow::mutex completion_queue_lock; tensorflow::condition_variable completion_queue_cv; std::deque completion_queue; @@ -310,6 +313,42 @@ ParallelCpuExecutable::ExecuteOnStream( } } + return Status::OK(); +} + +StatusOr +ParallelCpuExecutable::ExecuteOnStream( + const ServiceExecutableRunOptions* run_options, + tensorflow::gtl::ArraySlice arguments, + HloExecutionProfile* hlo_execution_profile) { + se::Stream* stream = run_options->stream(); + DeviceMemoryAllocator* memory_allocator = run_options->allocator(); + VLOG(3) << "ExecuteOnStream arg size: " << arguments.size(); + if (!arguments.empty()) { + VLOG(3) << "ExecuteOnStream arg[0]: " << arguments.at(0).opaque(); + } + + // Allocate the temporary buffers required for the computation. + se::StreamExecutor* stream_executor = stream->parent(); + int device_ordinal = stream_executor->device_ordinal(); + int64 buffer_count = assignment_->Allocations().size(); + VLOG(3) << "temp buffer count: " << buffer_count; + + std::vector device_allocations( + assignment_->Allocations().size()); + TF_RETURN_IF_ERROR(AllocateBuffers(memory_allocator, + stream->parent()->device_ordinal(), + &device_allocations)); + + TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice result_slice, + assignment_->GetUniqueTopLevelOutputSlice()); + const BufferAllocation::Index result_index = result_slice.index(); + VLOG(3) << "result index: " << result_index; + + TF_RETURN_IF_ERROR(ExecuteComputeFunctions(&run_options->run_options(), + arguments, device_allocations, + hlo_execution_profile)); + // Mark the buffers that are actually live (used in the output) when the // computation finishes executing. std::unordered_set marked_addresses; @@ -328,7 +367,7 @@ ParallelCpuExecutable::ExecuteOnStream( // live because they are referenced by the output of the computation // and are needed by the service. They will be deallocated by the // service. - for (auto i = 0; i < device_allocations.size(); ++i) { + for (size_t i = 0; i < device_allocations.size(); ++i) { auto alloc = device_allocations[i]; if (marked_addresses.count(alloc.opaque()) == 0 && alloc.opaque() != nullptr) { @@ -345,8 +384,74 @@ StatusOr> ParallelCpuExecutable::ExecuteOnStream( const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice arguments, HloExecutionProfile* hlo_execution_profile) { - return Unimplemented( - "ParallelCpuExecutable not supported yet with LocalService execution"); + if (GetRootPointsToSet().IsAmbiguous()) { + return Unimplemented("Points-to set of root instruction is ambiguous"); + } + + se::Stream* stream = run_options->stream(); + DeviceMemoryAllocator* memory_allocator = run_options->allocator(); + std::vector buffers(assignment_->Allocations().size()); + + TF_ASSIGN_OR_RETURN(std::unique_ptr result_buffer, + ShapedBuffer::MakeShapedBuffer( + result_shape(), stream->parent()->platform(), + stream->parent()->device_ordinal())); + + TF_RETURN_IF_ERROR(AllocateBuffers( + memory_allocator, stream->parent()->device_ordinal(), &buffers)); + + TF_RETURN_IF_ERROR(ExecuteComputeFunctions( + &run_options->run_options(), arguments, buffers, hlo_execution_profile)); + + // Copy DeviceMemoryBase values which contain the array(s) of the result into + // the respective location in ShapedBuffer which is returned to the caller. + std::vector buffers_in_result(assignment_->Allocations().size(), false); + TF_RETURN_IF_ERROR( + result_buffer->mutable_shape_index_to_buffer_entry() + ->ForEachMutableElement( + [&buffers, &buffers_in_result, &result_buffer, this]( + const ShapeIndex& index, bool is_leaf, size_t* buffer_entry) { + if (is_leaf) { + const std::vector& sources = + this->GetRootPointsToSet().element(index); + // The points to set is unambiguous so the set should be a + // singleton. + CHECK_EQ(1, sources.size()); + const LogicalBuffer* buffer_source = sources[0]; + HloInstruction* src = buffer_source->instruction(); + + // The source for this result buffer can be a nested buffer + // such as a tuple element. + + // The source instruction should have a non-parameter buffer + // assigned. + TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice, + this->assignment_->GetUniqueSlice( + src, buffer_source->index())); + CHECK(!slice.allocation()->is_entry_computation_parameter()); + + const BufferAllocation::Index buffer_index = slice.index(); + const se::DeviceMemoryBase& buffer = buffers[buffer_index]; + CHECK(!buffer.is_null() || buffer.size() == 0); + *buffer_entry = result_buffer->mutable_buffers()->size(); + result_buffer->mutable_buffers()->push_back(buffer); + buffers_in_result[buffer_index] = true; + } + return Status::OK(); + })); + + // Free all buffers not in the result. + for (size_t i = 0; i < buffers.size(); ++i) { + se::DeviceMemoryBase alloc = buffers[i]; + if (!buffers_in_result[i] && !alloc.is_null()) { + VLOG(3) << "CpuExecutable deallocating buffer #" << i << " [" + << alloc.opaque() << "]"; + TF_RETURN_IF_ERROR(memory_allocator->Deallocate( + stream->parent()->device_ordinal(), &alloc)); + } + } + + return std::move(result_buffer); } StatusOr @@ -358,5 +463,10 @@ ParallelCpuExecutable::ExecuteAsyncOnStream( "Asynchronous execution on stream is not yet supported on CPU."); } +const PointsToSet& ParallelCpuExecutable::GetRootPointsToSet() const { + return assignment_->points_to_analysis().GetPointsToSet( + module().entry_computation()->root_instruction()); +} + } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h b/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h index 7ce059bb1da..7223de9f079 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h +++ b/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h @@ -84,6 +84,35 @@ class ParallelCpuExecutable : public Executable { } private: + // Allocate buffers required for execution and assign them to the elements of + // "buffers". "buffers" should be sized to the number of buffers in buffer + // assignment. Each vector element corresponds to a particular Index. If + // a vector element already contains a non-null DeviceMemoryBase, then no + // buffer is assigned for this element. + Status AllocateBuffers( + DeviceMemoryAllocator* memory_allocator, int device_ordinal, + std::vector* buffers); + + // Calls the generated functions in 'function_names_', performing the + // computation with the given arguments using the supplied buffers. + Status ExecuteComputeFunctions( + const ExecutableRunOptions* run_options, + tensorflow::gtl::ArraySlice + arguments, + tensorflow::gtl::ArraySlice + buffers, + HloExecutionProfile* hlo_execution_profile); + Status ExecuteComputeFunctions( + const ExecutableRunOptions* run_options, + tensorflow::gtl::ArraySlice arguments, + tensorflow::gtl::ArraySlice + buffers, + HloExecutionProfile* hlo_execution_profile); + + // Returns the points-to set of the root instruction of the entry + // computation. Uses points-to analysis from buffer assignment. + const PointsToSet& GetRootPointsToSet() const; + // The JIT containing compiled modules. tensorflow::mutex jit_mutex_; std::unique_ptr jit_ GUARDED_BY(jit_mutex_); diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc index 96342451fd6..a04815dad94 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc @@ -937,6 +937,68 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( }; case HloOpcode::kRng: return MakeRngElementGenerator(hlo, operand_to_generator); + case HloOpcode::kPad: + return [=, &operand_to_generator]( + const IrArray::Index& padded_index) -> StatusOr { + auto index = padded_index; + llvm::Value* in_bounds = ir_builder_->getTrue(); + for (size_t i = 0; i < index.size(); ++i) { + auto index_typed_const = [=](int64 n) { + return llvm::ConstantInt::get(index[i]->getType(), n); + }; + const auto& pad_dim = hlo->padding_config().dimensions(i); + index[i] = ir_builder_->CreateSub( + index[i], index_typed_const(pad_dim.edge_padding_low())); + in_bounds = ir_builder_->CreateAnd( + in_bounds, + ir_builder_->CreateICmpSGE(index[i], index_typed_const(0)), + "in_bounds"); + in_bounds = ir_builder_->CreateAnd( + in_bounds, + ir_builder_->CreateICmpEQ( + index_typed_const(0), + ir_builder_->CreateURem( + index[i], + index_typed_const(pad_dim.interior_padding() + 1))), + "in_bounds"); + index[i] = ir_builder_->CreateSDiv( + index[i], index_typed_const(pad_dim.interior_padding() + 1)); + in_bounds = ir_builder_->CreateAnd( + in_bounds, + ir_builder_->CreateICmpSLT( + index[i], + index_typed_const(hlo->operand(0)->shape().dimensions(i))), + "in_bounds"); + } + + // if (in_bounds) { + // ret_value = operand0[index]; // source + // } else { + // ret_value = *operand1; // padding + // } + llvm::Value* ret_value_addr = llvm_ir::EmitAllocaAtFunctionEntry( + llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(), + ir_builder_), + "pad_result_addr", ir_builder_); + llvm_ir::LlvmIfData if_data = + llvm_ir::EmitIfThenElse(in_bounds, "in_bounds", ir_builder_); + SetToFirstInsertPoint(if_data.true_block, ir_builder_); + TF_ASSIGN_OR_RETURN(llvm::Value * operand_value, + operand_to_generator.at(hlo->operand(0))(index)); + ir_builder_->CreateStore(operand_value, ret_value_addr); + + SetToFirstInsertPoint(if_data.false_block, ir_builder_); + TF_ASSIGN_OR_RETURN(llvm::Value * padding_value, + operand_to_generator.at(hlo->operand(1))({})); + ir_builder_->CreateStore(padding_value, ret_value_addr); + + SetToFirstInsertPoint(if_data.after_block, ir_builder_); + // Don't create phi(operand_value, padding_value) here, because invoking + // operand_to_generator may create new basic blocks, making the parent + // of operand_value or padding_value no longer a predecessor of + // if_data.after_block. + return ir_builder_->CreateLoad(ret_value_addr); + }; default: return [this, hlo, &operand_to_generator](const IrArray::Index& index) { return Unimplemented("%s", HloOpcodeString(hlo->opcode()).c_str()); diff --git a/tensorflow/compiler/xla/service/executable.cc b/tensorflow/compiler/xla/service/executable.cc index 6a5e904f17d..ef973676ea4 100644 --- a/tensorflow/compiler/xla/service/executable.cc +++ b/tensorflow/compiler/xla/service/executable.cc @@ -40,8 +40,7 @@ Executable::ExecuteOnStreams( std::vector return_values( run_options.size()); - for (tensorflow::gtl::ArraySlice::size_type i = 0; - i < run_options.size(); ++i) { + for (size_t i = 0; i < run_options.size(); ++i) { // We cannot BlockHostUntilDone() on the already-launched executions in case // of error, since if the executions communicate, the initially launched // executions may never complete if not all executions are running. diff --git a/tensorflow/compiler/xla/service/executable.h b/tensorflow/compiler/xla/service/executable.h index b5307ad1df3..eb36aba33a7 100644 --- a/tensorflow/compiler/xla/service/executable.h +++ b/tensorflow/compiler/xla/service/executable.h @@ -39,9 +39,6 @@ namespace xla { // A given platform's compiler will produce an Executable -- this is a uniform // interface that is used for launching compiled programs across platforms. -// -// TODO(leary) will need to extend this to support multiple streams/devices as -// we begin to compile single programs to run on multiple devices. class Executable { public: explicit Executable(std::unique_ptr hlo_module, diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.cc b/tensorflow/compiler/xla/service/generic_transfer_manager.cc index 715d3f33bc0..7b87ac6da1d 100644 --- a/tensorflow/compiler/xla/service/generic_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/generic_transfer_manager.cc @@ -118,7 +118,7 @@ GenericTransferManager::ShallowCopyTupleFromDevice( // Create a DeviceMemoryBase from each void* pointer. std::vector destination; - for (std::vector::size_type i = 0; i < element_pointers.size(); ++i) { + for (size_t i = 0; i < element_pointers.size(); ++i) { if (element_pointers[i] == nullptr && !ShapeUtil::HasZeroElements(shape.tuple_shapes(i))) { return FailedPrecondition("tuple contains nullptr at element %lu", i); diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 1c078835dea..9de6d65a27b 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -356,13 +356,12 @@ cc_library( srcs = ["fusion_merger.cc"], hdrs = ["fusion_merger.h"], deps = [ + ":instruction_fusion", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_cost_analysis", "//tensorflow/compiler/xla/service:hlo_pass", - "//tensorflow/compiler/xla/service:instruction_fusion", "//tensorflow/core:lib", ], ) @@ -434,6 +433,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_pass", "//tensorflow/compiler/xla/service:hlo_pass_pipeline", "//tensorflow/compiler/xla/service:hlo_subcomputation_unification", + "//tensorflow/compiler/xla/service:hlo_verifier", "//tensorflow/compiler/xla/service:reshape_mover", "//tensorflow/compiler/xla/service:transpose_folding", "//tensorflow/compiler/xla/service/gpu/llvm_gpu_backend", diff --git a/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc b/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc index b8fb81fa76e..9fdf717b5d4 100644 --- a/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc +++ b/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc @@ -87,7 +87,8 @@ tensorflow::Status BufferAllocations::TearDown( const std::set& live_addresses, const BufferAssignment& buffer_assignment) { // Deallocate temporary buffers. - for (auto i = 0; i < buffer_assignment.Allocations().size(); ++i) { + const int64 num_buffers = buffer_assignment.Allocations().size(); + for (BufferAllocation::Index i = 0; i < num_buffers; ++i) { const BufferAllocation& allocation = buffer_assignment.GetAllocation(i); se::DeviceMemoryBase buffer_address = GetDeviceAddress(allocation.index()); // Deallocate buffers marked "maybe_live_out" but aren't actually live out, diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc index 1cb03db4eee..1667ab36792 100644 --- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc @@ -270,69 +270,6 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator( const HloInstruction* hlo, const HloToElementGeneratorMap& operand_to_generator) const { switch (hlo->opcode()) { - case HloOpcode::kPad: - return [=, &operand_to_generator]( - const IrArray::Index& padded_index) -> StatusOr { - auto index = padded_index; - llvm::Value* in_bounds = - llvm::ConstantInt::get(ir_builder_->getInt1Ty(), 1); - for (auto i = 0; i < index.size(); ++i) { - auto index_typed_const = [=](int64 n) { - return llvm::ConstantInt::get(index[i]->getType(), n); - }; - const auto& pad_dim = hlo->padding_config().dimensions(i); - index[i] = ir_builder_->CreateSub( - index[i], index_typed_const(pad_dim.edge_padding_low())); - in_bounds = ir_builder_->CreateAnd( - in_bounds, - ir_builder_->CreateICmpSGE(index[i], index_typed_const(0)), - "in_bounds"); - in_bounds = ir_builder_->CreateAnd( - in_bounds, - ir_builder_->CreateICmpEQ( - index_typed_const(0), - ir_builder_->CreateURem( - index[i], - index_typed_const(pad_dim.interior_padding() + 1))), - "in_bounds"); - index[i] = ir_builder_->CreateSDiv( - index[i], index_typed_const(pad_dim.interior_padding() + 1)); - in_bounds = ir_builder_->CreateAnd( - in_bounds, - ir_builder_->CreateICmpSLT( - index[i], - index_typed_const(hlo->operand(0)->shape().dimensions(i))), - "in_bounds"); - } - - // if (in_bounds) { - // ret_value = operand0[index]; // source - // } else { - // ret_value = *operand1; // padding - // } - llvm::Value* ret_value_addr = llvm_ir::EmitAllocaAtFunctionEntry( - llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(), - ir_builder_), - "pad_result_addr", ir_builder_); - llvm_ir::LlvmIfData if_data = - llvm_ir::EmitIfThenElse(in_bounds, "in_bounds", ir_builder_); - SetToFirstInsertPoint(if_data.true_block, ir_builder_); - TF_ASSIGN_OR_RETURN(llvm::Value * operand_value, - operand_to_generator.at(hlo->operand(0))(index)); - ir_builder_->CreateStore(operand_value, ret_value_addr); - - SetToFirstInsertPoint(if_data.false_block, ir_builder_); - TF_ASSIGN_OR_RETURN(llvm::Value * padding_value, - operand_to_generator.at(hlo->operand(1))({})); - ir_builder_->CreateStore(padding_value, ret_value_addr); - - SetToFirstInsertPoint(if_data.after_block, ir_builder_); - // Don't create phi(operand_value, padding_value) here, because invoking - // operand_to_generator may create new basic blocks, making the parent - // of operand_value or padding_value no longer a predecessor of - // if_data.after_block. - return ir_builder_->CreateLoad(ret_value_addr); - }; case HloOpcode::kMap: return [=, &operand_to_generator]( const IrArray::Index& index) -> StatusOr { diff --git a/tensorflow/compiler/xla/service/gpu/fusion_merger.cc b/tensorflow/compiler/xla/service/gpu/fusion_merger.cc index 0ee117afcd9..afb78b8300b 100644 --- a/tensorflow/compiler/xla/service/gpu/fusion_merger.cc +++ b/tensorflow/compiler/xla/service/gpu/fusion_merger.cc @@ -18,8 +18,8 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h" #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" -#include "tensorflow/compiler/xla/service/instruction_fusion.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" @@ -221,7 +221,7 @@ Status FusionInstructionMerger::HandleFusion(HloInstruction* fusion) { fusion->fused_instructions().end(), [](const std::unique_ptr& instruction) { if (instruction->opcode() != HloOpcode::kParameter && - IsExpensive(*instruction)) { + GpuInstructionFusion::IsExpensive(*instruction)) { return false; } return true; diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index 95f1d05fec0..f692f28bd98 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -51,6 +51,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_pass_fix.h" #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h" #include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h" +#include "tensorflow/compiler/xla/service/hlo_verifier.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/service/reshape_mover.h" #include "tensorflow/compiler/xla/service/transpose_folding.h" @@ -121,6 +122,7 @@ tensorflow::Status OptimizeHloModule(HloModule* hlo_module, const se::DeviceDescription& device_desc) { { HloPassPipeline pipeline("optimization", dump_hlo); + pipeline.AddInvariantChecker(); { auto& pass = pipeline.AddPass>( "simplification", dump_hlo); @@ -157,6 +159,7 @@ tensorflow::Status PrepareHloModuleForIrEmitting( // (b/27180329). Therefore, in that case, we set the output to be a copy of // the parameter. HloPassPipeline pipeline("GPU-ir-emit-prepare", dump_hlo); + pipeline.AddInvariantChecker(); pipeline.AddPass(); pipeline.AddPass( module_config->mutable_entry_computation_layout()); diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion.h b/tensorflow/compiler/xla/service/gpu/instruction_fusion.h index 21f3b542a27..bb2990e6dfc 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion.h +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion.h @@ -25,7 +25,7 @@ namespace gpu { class GpuInstructionFusion : public InstructionFusion { public: explicit GpuInstructionFusion(bool may_duplicate) - : InstructionFusion(may_duplicate) {} + : InstructionFusion(GpuInstructionFusion::IsExpensive, may_duplicate) {} bool ShouldFuse(HloInstruction* consumer, int64 operand_index) override; diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc index d17ef5a67e6..5f3ce85f857 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc @@ -513,7 +513,7 @@ Status IrEmitter::HandleReduce(HloInstruction* reduce, HloInstruction* arg, llvm_ir::IrArray::Index input_index = reduced_dims_index; llvm_ir::IrArray::Index::const_iterator it = index.begin(); - for (auto i = 0; i < input_index.size(); ++i) { + for (size_t i = 0; i < input_index.size(); ++i) { if (input_index[i] == nullptr) { input_index[i] = *it++; } @@ -614,7 +614,7 @@ llvm_ir::IrArray::Index IrEmitter::EmitOperandArrayLoopNest( llvm_ir::IrArray::Index index = loop_nest->AddLoopsForShapeOnDimensions(shape, dimensions, name_suffix); // Verify every dimension except the reduction dimension was set in the index. - for (auto dimension = 0; dimension < index.size(); ++dimension) { + for (size_t dimension = 0; dimension < index.size(); ++dimension) { if (dimension == reduction_dimension) { DCHECK_EQ(nullptr, index[dimension]); } else { diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index f9698a06747..9b7aa7c860b 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -283,14 +283,7 @@ bool CanUpdateDynamicSliceInPlace(const BufferAssignment& assignment, return false; } auto* operand = fusion->operand(fusion_operand->parameter_number()); - - BufferAllocation::Slice operand_slice = - assignment.GetUniqueSlice(operand, index).ConsumeValueOrDie(); - - BufferAllocation::Slice fusion_slice = - assignment.GetUniqueTopLevelSlice(fusion).ConsumeValueOrDie(); - - return operand_slice == fusion_slice; + return assignment.SharesSliceAtIndex(fusion, {}, operand, index); } } // namespace @@ -387,9 +380,7 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { TF_RETURN_IF_ERROR(root->Accept(&fused_emitter)); // Recursively lookup 'fusion_operand' for DynamicUpdateSlice operand 0. - ShapeIndex index_unused; - auto* fusion_operand = - LatestNonGteAncestorAndIndex(root->operand(0), &index_unused); + auto* fusion_operand = LatestNonGteAncestor(root->operand(0)); CHECK_EQ(HloOpcode::kParameter, fusion_operand->opcode()); // Operand(0) the input array which shares an allocation with the output. diff --git a/tensorflow/compiler/xla/service/gpu/thunk_schedule.cc b/tensorflow/compiler/xla/service/gpu/thunk_schedule.cc index 8addcd87eaa..bdb062837c5 100644 --- a/tensorflow/compiler/xla/service/gpu/thunk_schedule.cc +++ b/tensorflow/compiler/xla/service/gpu/thunk_schedule.cc @@ -79,7 +79,7 @@ ThunkSchedule::ThunkSchedule( void ThunkSchedule::RemoveRedundantDependencyEdges() { std::unordered_map thunk_to_total_order; - for (auto i = 0; i < thunk_total_order_.size(); ++i) { + for (int i = 0; i < thunk_total_order_.size(); ++i) { InsertOrDie(&thunk_to_total_order, thunk_total_order_[i], i); } diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index ccc2c387494..35f8dcb7ca6 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -92,13 +92,22 @@ HloInstruction* HloComputation::AddInstructionInternal( // Generate a unique name for the instruction. instruction->set_name( instruction_name_uniquer_.GetUniqueName(instruction->name())); - instruction->set_parent(this); + Reparent(instruction.get()); HloInstruction* pinst = instruction.get(); instruction_iterators_[pinst] = instructions_.insert(instructions_.end(), std::move(instruction)); return pinst; } +void HloComputation::Reparent(HloInstruction* instruction) { + instruction->set_parent(this); + if (instruction->opcode() == HloOpcode::kFusion) { + for (auto& i : instruction->fused_instructions()) { + Reparent(i.get()); + } + } +} + /* static */ bool HloComputation::IsRemovable(const HloOpcode& opcode) { return !(opcode == HloOpcode::kParameter || opcode == HloOpcode::kRecv || opcode == HloOpcode::kSend || opcode == HloOpcode::kTrace || diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h index 08cde9d0321..ef3cba6fa08 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.h +++ b/tensorflow/compiler/xla/service/hlo_computation.h @@ -235,6 +235,14 @@ class HloComputation { HloInstruction* AddInstructionInternal( std::unique_ptr instruction); + // Helper for setting the parent of instructions that are added to this + // computation. + // + // Because we clone HLO instructions without knowing what computation they're + // destined to be added to, this is required to appropriate set the parent on + // fused instruction sequences. + void Reparent(HloInstruction* instruction); + // Fuses HLOs in instructions_to_fuse into fusion_instruction. // // Pre-condition: fusion_instruction's opcode is kFusion. diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index fa0e3d934c0..0af4c99d0a5 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -75,8 +75,7 @@ string InstructionSequenceGraph( std::vector param_instructions; for (auto& instruction : instructions) { if (instruction->opcode() == HloOpcode::kParameter) { - std::vector::size_type param_number = - instruction->parameter_number(); + size_t param_number = instruction->parameter_number(); if (param_instructions.size() < param_number + 1) { param_instructions.resize(param_number + 1, nullptr); diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index ded9934a057..905647c2ed9 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -431,6 +431,7 @@ HloInstruction::CreateSelectAndScatter( const Shape& shape, FusionKind fusion_kind, HloInstruction* fused_root) { auto instruction = WrapUnique(new HloInstruction(HloOpcode::kFusion, shape)); instruction->fusion_kind_ = fusion_kind; + instruction->set_parent(fused_root->parent()); instruction->CloneAndFuseInternal(fused_root); instruction->CheckFusionInstruction(); return instruction; @@ -568,6 +569,7 @@ HloInstruction* HloInstruction::CloneAndFuseInternal( std::unique_ptr param_instruction = CreateParameter(param_no, operand->shape(), "fusion_param"); + param_instruction->set_parent(parent()); param_instruction->parent_fusion_instruction_ = this; fused_parameters_.push_back(param_instruction.get()); fused_instructions_.push_back(std::move(param_instruction)); @@ -602,6 +604,7 @@ void HloInstruction::CheckFusionInstruction() const { for (auto& instruction : fused_instructions_) { CHECK(instruction->IsFused()); CHECK_EQ(this, instruction->fusion_instruction()); + CHECK_EQ(parent(), instruction->parent()) << instruction->ToString(); } // Fused root instruction and fused parameters must all be owned by the fusion @@ -838,12 +841,14 @@ std::unique_ptr HloInstruction::Clone(const string& suffix) { std::unique_ptr clone = CloneWithNewOperands(shape_, operands_); clone->name_ = name() + "." + suffix; + clone->set_parent(parent()); return clone; } std::unique_ptr HloInstruction::CloneFusionWithNewOperands( const Shape& shape, tensorflow::gtl::ArraySlice operands) { CHECK_EQ(opcode_, HloOpcode::kFusion); + CHECK(parent() != nullptr); auto new_instruction = WrapUnique(new HloInstruction(HloOpcode::kFusion, shape)); @@ -883,6 +888,7 @@ std::unique_ptr HloInstruction::CloneFusionWithNewOperands( old_fused_instruction->CloneWithNewOperands( old_fused_instruction->shape(), new_operands)); HloInstruction* new_fused_instruction = new_fused_instructions.back().get(); + new_fused_instruction->set_parent(parent()); new_fused_instruction->parent_fusion_instruction_ = new_instruction.get(); InsertOrDie(&old_to_new, old_fused_instruction, new_fused_instruction); } @@ -893,6 +899,7 @@ std::unique_ptr HloInstruction::CloneFusionWithNewOperands( new_instruction->fused_instructions_ = std::move(new_fused_instructions); new_instruction->fused_parameters_ = std::move(new_fused_parameters); new_instruction->fused_root_ = FindOrDie(old_to_new, fused_root_); + new_instruction->set_parent(parent()); new_instruction->CheckFusionInstruction(); return new_instruction; } diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index f075ba6ea52..6557ca91163 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -538,6 +538,9 @@ class HloInstruction { // instruction. The order is a reverse postorder of the fused expression (root // is first in the order). // + // Note: although the list itself is const, the instructions contained in the + // list returned here are mutable. + // // Precondition: opcode() == HloOpcode::kFusion const std::list>& fused_instructions() const; diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc index 91468fd35b0..6e3c9830712 100644 --- a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc +++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc @@ -26,6 +26,8 @@ limitations under the License. #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" +using ::tensorflow::strings::StrAppend; + namespace xla { namespace { @@ -44,6 +46,14 @@ StatusOr HloPassPipeline::Run(HloModule* module) { tensorflow::str_util::Split(flags->xla_disable_hlo_passes, ','); tensorflow::gtl::FlatSet disabled_passes(tmp.begin(), tmp.end()); + auto run_invariant_checkers = [this, module]() -> Status { + for (auto& invariant_checker : invariant_checkers_) { + TF_ASSIGN_OR_RETURN(bool changed, invariant_checker->Run(module)); + TF_RET_CHECK(!changed) << "invariant checkers must not change the graph"; + } + return Status::OK(); + }; + string prefix = name().ToString() + ": pipeline start"; bool changed = false; string message; @@ -55,15 +65,17 @@ StatusOr HloPassPipeline::Run(HloModule* module) { // Emit label containing: "after foo-pass, before bar-pass". message.clear(); - tensorflow::strings::StrAppend(&message, prefix, ", before ", pass->name()); + StrAppend(&message, prefix, ", before ", pass->name()); DumpModule(dumper_, *module, message); + TF_RETURN_IF_ERROR(run_invariant_checkers()); TF_ASSIGN_OR_RETURN(bool changed_this_pass, pass->Run(module)); changed |= changed_this_pass; prefix.clear(); - tensorflow::strings::StrAppend(&prefix, name(), ": after ", pass->name()); + StrAppend(&prefix, name(), ": after ", pass->name()); } + TF_RETURN_IF_ERROR(run_invariant_checkers()); DumpModule(dumper_, *module, prefix + ", pipeline end"); return changed; } diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline.h b/tensorflow/compiler/xla/service/hlo_pass_pipeline.h index 7a9c606a487..a8c2d518730 100644 --- a/tensorflow/compiler/xla/service/hlo_pass_pipeline.h +++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline.h @@ -52,6 +52,16 @@ class HloPassPipeline : public HloPassInterface { return *pass; } + // Add an invariant-checking pass to the pipeline. It will be run before and + // after each HLO pass. The invariant checking pass must not mutate the graph + // (it is required to always return "false" from its Run() method). + template + T& AddInvariantChecker(Args&&... args) { + auto pass = new T(std::forward(args)...); + invariant_checkers_.push_back(std::unique_ptr(pass)); + return *pass; + } + // Run all passes on the given HLO module. StatusOr Run(HloModule* module) override; @@ -59,6 +69,7 @@ class HloPassPipeline : public HloPassInterface { const string name_; Compiler::HloDumper dumper_; std::vector> passes_; + std::vector> invariant_checkers_; TF_DISALLOW_COPY_AND_ASSIGN(HloPassPipeline); }; diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc new file mode 100644 index 00000000000..035b570ed34 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -0,0 +1,38 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_verifier.h" + +namespace xla { + +StatusOr HloVerifier::Run(HloModule* module) { + for (auto& computation : module->computations()) { + for (const auto& instruction : computation->instructions()) { + TF_RET_CHECK(instruction->parent() == computation.get()); + if (instruction->opcode() == HloOpcode::kFusion) { + for (const auto& fused : instruction->fused_instructions()) { + TF_RET_CHECK(fused->parent() == computation.get()) + << "Fused HLO was missing a parent: " << fused->ToString() + << " parent: " << fused->parent() + << " computation: " << computation.get(); + } + } + } + } + + return false; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h new file mode 100644 index 00000000000..5159420b3fb --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_verifier.h @@ -0,0 +1,37 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_HLO_VERIFIER_H_ +#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_HLO_VERIFIER_H_ + +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +namespace xla { + +// HLO pass that verifies invariants of HLO instructions for each computation in +// the module. +class HloVerifier : public HloPassInterface { + public: + ~HloVerifier() override = default; + tensorflow::StringPiece name() const override { return "verifier"; } + + // Note: always returns false (no instructions are ever modified by this + // pass). + StatusOr Run(HloModule* module) override; +}; + +} // namespace xla + +#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_HLO_VERIFIER_H_ diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index 7160129c12d..c162945bcae 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -29,7 +29,8 @@ limitations under the License. namespace xla { -bool IsExpensive(const HloInstruction& instruction) { +/*static*/ bool InstructionFusion::IsExpensive( + const HloInstruction& instruction) { switch (instruction.opcode()) { // Cheap instructions. case HloOpcode::kAbs: @@ -105,9 +106,14 @@ bool IsExpensive(const HloInstruction& instruction) { return false; } -bool FusionWouldDuplicate(HloInstruction* producer, HloInstruction* consumer) { - return !(producer->users().size() == 1 && consumer->IsUserOf(producer)); +namespace { +// Returns true if fusing producer into consumer would cause producer to be +// duplicated. This is the case if producer has uses other than consumer. +bool FusionWouldDuplicate(const HloInstruction& producer, + const HloInstruction& consumer) { + return !(producer.users().size() == 1 && consumer.IsUserOf(&producer)); } +} // namespace StatusOr InstructionFusion::Run(HloModule* module) { bool changed = false; @@ -125,8 +131,7 @@ StatusOr InstructionFusion::Run(HloModule* module) { std::vector post_order(post_order_list.begin(), post_order_list.end()); tensorflow::gtl::FlatMap post_order_index; - for (std::vector::size_type i = 0; i < post_order.size(); - ++i) { + for (size_t i = 0; i < post_order.size(); ++i) { InsertOrDie(&post_order_index, post_order[i], i); } @@ -263,8 +268,8 @@ bool InstructionFusion::ShouldFuse(HloInstruction* consumer, int64 operand_index) { HloInstruction* producer = consumer->mutable_operand(operand_index); // Cost condition: don't duplicate expensive instructions. - if (FusionWouldDuplicate(producer, consumer) && - (IsExpensive(*producer) || !may_duplicate_)) { + if (FusionWouldDuplicate(*producer, *consumer) && + (is_expensive_(*producer) || !may_duplicate_)) { return false; } @@ -277,7 +282,7 @@ bool InstructionFusion::ShouldFuse(HloInstruction* consumer, // Cost condition: not fuse (expensive producers) and (consumers who reuse // operand elements). if (consumer->ReusesOperandElements(operand_index) && - IsExpensive(*producer)) { + is_expensive_(*producer)) { return false; } diff --git a/tensorflow/compiler/xla/service/instruction_fusion.h b/tensorflow/compiler/xla/service/instruction_fusion.h index b8fd3dd4f37..a9f3723f2df 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.h +++ b/tensorflow/compiler/xla/service/instruction_fusion.h @@ -24,15 +24,6 @@ limitations under the License. namespace xla { -// Returns true if the computation of the given instruction is significantly -// more expensive than just writing all the values of the instructions' result -// array. Expensive operations should not be duplicated. -bool IsExpensive(const HloInstruction& instruction); - -// Returns true if fusing producer into consumer would cause producer to be -// duplicated. This is the case if producer has uses other than consumer. -bool FusionWouldDuplicate(HloInstruction* producer, HloInstruction* consumer); - // HLO pass which performs instruction fusion. Instructions are fused // "vertically", meaning producing instructions are fused into their consumers // with the intent that the loops which compute their values will be fused in @@ -40,15 +31,22 @@ bool FusionWouldDuplicate(HloInstruction* producer, HloInstruction* consumer); // instructions to fuse. class InstructionFusion : public HloPassInterface { public: - explicit InstructionFusion(bool may_duplicate = true) - : may_duplicate_(may_duplicate) {} - ~InstructionFusion() override {} + explicit InstructionFusion( + std::function is_expensive, + bool may_duplicate = true) + : is_expensive_(is_expensive), may_duplicate_(may_duplicate) {} + ~InstructionFusion() override = default; tensorflow::StringPiece name() const override { return "fusion"; } // Run instruction fusion on the given computation. Returns whether the // computation was changed (instructions were fused). StatusOr Run(HloModule* module) override; + // Returns true if the computation of the given instruction is significantly + // more expensive than just writing all the values of the instructions' result + // array. Expensive operations will not be duplicated. + static bool IsExpensive(const HloInstruction& instruction); + protected: // Returns whether the given producer instruction should be fused into the // given consumer instruction. producer is necessarily an operand of consumer. @@ -74,6 +72,10 @@ class InstructionFusion : public HloPassInterface { private: HloInstruction* Fuse(HloInstruction* producer, HloInstruction* consumer); + // Used to determine if an HLO is expensive. Expensive operations will not be + // duplicated. + std::function is_expensive_; + // Returns whether we may duplicate an instruction if we want to fuse it. bool may_duplicate_; diff --git a/tensorflow/compiler/xla/service/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/instruction_fusion_test.cc index 2e3742ed75f..a4c269f0ebd 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion_test.cc @@ -36,7 +36,9 @@ TEST_F(InstructionFusionTest, auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(broadcast2, computation->root_instruction()); EXPECT_TRUE( - InstructionFusion(/*may_duplicate=*/true).Run(module.get()).ValueOrDie()); + InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); EXPECT_EQ(broadcast2, computation->root_instruction()); } @@ -55,7 +57,9 @@ TEST_F(InstructionFusionTest, auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(broadcast2, computation->root_instruction()); EXPECT_TRUE( - InstructionFusion(/*may_duplicate=*/true).Run(module.get()).ValueOrDie()); + InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); EXPECT_EQ(HloOpcode::kFusion, computation->root_instruction()->opcode()); } @@ -73,7 +77,9 @@ TEST_F(InstructionFusionTest, auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(reshape2, computation->root_instruction()); EXPECT_TRUE( - InstructionFusion(/*may_duplicate=*/true).Run(module.get()).ValueOrDie()); + InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); EXPECT_EQ(HloOpcode::kFusion, computation->root_instruction()->opcode()); } @@ -91,7 +97,9 @@ TEST_F(InstructionFusionTest, auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(transpose2, computation->root_instruction()); EXPECT_TRUE( - InstructionFusion(/*may_duplicate=*/true).Run(module.get()).ValueOrDie()); + InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); EXPECT_EQ(HloOpcode::kFusion, computation->root_instruction()->opcode()); } @@ -106,7 +114,9 @@ TEST_F(InstructionFusionTest, PotentialBitcastReshapeOfParameterUnfused) { auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(reshape1, computation->root_instruction()); EXPECT_FALSE( - InstructionFusion(/*may_duplicate=*/true).Run(module.get()).ValueOrDie()); + InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); } TEST_F(InstructionFusionTest, PotentialBitcastSimpleReshapeOfParameterUnfused) { @@ -120,7 +130,9 @@ TEST_F(InstructionFusionTest, PotentialBitcastSimpleReshapeOfParameterUnfused) { auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(reshape1, computation->root_instruction()); EXPECT_FALSE( - InstructionFusion(/*may_duplicate=*/true).Run(module.get()).ValueOrDie()); + InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); } TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfParameterUnfused) { @@ -134,7 +146,9 @@ TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfParameterUnfused) { auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(transpose1, computation->root_instruction()); EXPECT_FALSE( - InstructionFusion(/*may_duplicate=*/true).Run(module.get()).ValueOrDie()); + InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/liveness_util.cc b/tensorflow/compiler/xla/service/liveness_util.cc index 7d157e8fd5f..caaf56a5516 100644 --- a/tensorflow/compiler/xla/service/liveness_util.cc +++ b/tensorflow/compiler/xla/service/liveness_util.cc @@ -106,6 +106,7 @@ std::vector> GetAllUsesOfInstructionAtIndex( // *) Is a loop fusion instruction where the only use of 'operand' at 'index' // in the set 'user.fused_instructions' is a DynamicUpdateSlice fused root // at operand 0. +// *) Use of 'operand' is DynamicUpdateSlice at operand index 0. bool CanShareOperandBufferWithUser( HloInstruction* operand, const ShapeIndex& operand_index, HloInstruction* user, const ShapeIndex& user_index, @@ -143,6 +144,11 @@ bool CanShareOperandBufferWithUser( break; } return false; + } else if (user->opcode() == HloOpcode::kDynamicUpdateSlice) { + // We eliminated other users in BufferLiveness::live_range_strictly_before, + // so here we just need to check that the use is at operand index 0. + std::vector operand_indices = user->OperandIndices(operand); + return operand_indices.size() == 1 && operand_indices[0] == 0; } // Check if 'user' is element-wise. return user->IsElementwise(); diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index 346e0fefcb8..451bb8c7ead 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -256,8 +256,7 @@ StatusOr> Service::ResolveAndValidateArguments( tensorflow::gtl::ArraySlice arguments, const Backend* backend, int device_ordinal) { std::vector allocations; - for (tensorflow::gtl::ArraySlice::size_type i = 0; - i < arguments.size(); ++i) { + for (size_t i = 0; i < arguments.size(); ++i) { auto allocation_status = allocation_tracker_.Resolve(*arguments[i]); if (!allocation_status.ok()) { return Status(allocation_status.status().code(), @@ -296,8 +295,7 @@ StatusOr> Service::CreateModuleConfig( program_shape.parameters_size(), arguments.size()); } - for (tensorflow::gtl::ArraySlice::size_type i = 0; - i < arguments.size(); ++i) { + for (size_t i = 0; i < arguments.size(); ++i) { // Verify that shape of arguments matches the shape of the arguments in the // ProgramShape. if (!ShapeUtil::Compatible(arguments[i]->shape(), @@ -385,8 +383,7 @@ StatusOr>> Service::BuildExecutables( hlo_dumper, std::move(executors))); if (!other_directory_path.empty()) { - for (std::vector::size_type i = 0; - i < versioned_handles.size(); ++i) { + for (size_t i = 0; i < versioned_handles.size(); ++i) { executables[i]->set_session_module(std::move(session_modules[i])); } } diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc index 38d87940658..4a4a6e64ffa 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc @@ -598,7 +598,9 @@ class FusionPointsToAnalysisTest : public TuplePointsToAnalysisTest { // Build computation and add it to module as entry computation. BuildModule(builder.Build()); // Run instruction fusion HloPass. - EXPECT_TRUE(InstructionFusion().Run(module_.get()).ValueOrDie()); + EXPECT_TRUE(InstructionFusion(InstructionFusion::IsExpensive) + .Run(module_.get()) + .ValueOrDie()); // Get computation root instruction (should be a kFusion). auto* fusion = module_->entry_computation()->root_instruction(); EXPECT_EQ(HloOpcode::kFusion, fusion->opcode()); diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h index 92c08d03acd..34f82603e89 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.h +++ b/tensorflow/compiler/xla/tests/client_library_test_base.h @@ -56,6 +56,8 @@ class ClientLibraryTestBase : public ::testing::Test { execution_options_.set_disable_fast_math(disabled); } + void SetSeed(uint64 seed) { execution_options_.set_seed(seed); } + // TODO(b/25566808): Add helper that populates a literal from a testdata file. // Convenience methods for building and running a computation from a builder. diff --git a/tensorflow/compiler/xla/tests/local_client_test_base.cc b/tensorflow/compiler/xla/tests/local_client_test_base.cc index 8948d77aec8..7fe4c9020f4 100644 --- a/tensorflow/compiler/xla/tests/local_client_test_base.cc +++ b/tensorflow/compiler/xla/tests/local_client_test_base.cc @@ -187,8 +187,13 @@ ExecutableBuildOptions LocalClientTestBase::DefaultExecutableBuildOptions() } ExecutableRunOptions LocalClientTestBase::DefaultExecutableRunOptions() const { - return ExecutableRunOptions().set_allocator( - GetOrCreateAllocator(local_client_->platform())); + ExecutableRunOptions run_options; + run_options.set_inter_op_thread_pool( + local_client_->backend().inter_op_thread_pool()); + run_options.set_intra_op_thread_pool( + local_client_->backend().eigen_intra_op_thread_pool_device()); + run_options.set_allocator(GetOrCreateAllocator(local_client_->platform())); + return run_options; } std::unique_ptr LocalClientTestBase::ExecuteLocallyOrDie( diff --git a/tensorflow/compiler/xla/tests/prng_test.cc b/tensorflow/compiler/xla/tests/prng_test.cc index 1a151e20093..0cd0f97b062 100644 --- a/tensorflow/compiler/xla/tests/prng_test.cc +++ b/tensorflow/compiler/xla/tests/prng_test.cc @@ -53,6 +53,7 @@ void PrngTest::UniformTest(T a, T b, tensorflow::gtl::ArraySlice dims) { builder.ConstantR0(a), builder.ConstantR0(b), ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType(), dims)); + SetSeed(42); auto actual = ExecuteAndTransferOrDie(&builder, /*arguments=*/{}); EXPECT_TRUE(ContainersEqual(dims, actual->shape().dimensions())); LiteralUtil::EachCell(*actual, @@ -118,6 +119,7 @@ double PrngTest::UniformChiSquared(int32 range_size, int32 expected_count) { builder.ConstantR0(range_size), ShapeUtil::MakeShape(S32, {sample_size})); + SetSeed(42); auto actual = ExecuteAndTransferOrDie(&builder, /*arguments=*/{}); std::vector counts(range_size, 0); LiteralUtil::EachCell( @@ -264,6 +266,7 @@ XLA_TEST_F(PrngTest, TenValuesN01) { builder.RngNormal(builder.ConstantR0(0), builder.ConstantR0(1), ShapeUtil::MakeShape(F32, {10})); + SetSeed(42); ExecuteAndTransferOrDie(&builder, /*arguments=*/{}); // TODO(b/25995601): Test that resultant values are reasonable } diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD index 7663dc76983..595d8997388 100644 --- a/tensorflow/contrib/BUILD +++ b/tensorflow/contrib/BUILD @@ -91,6 +91,7 @@ cc_library( "//tensorflow/contrib/input_pipeline:input_pipeline_ops_op_lib", "//tensorflow/contrib/layers:bucketization_op_op_lib", "//tensorflow/contrib/layers:sparse_feature_cross_op_op_lib", + "//tensorflow/contrib/tensor_forest:tensor_forest_ops_op_lib", ], ) diff --git a/tensorflow/contrib/cloud/kernels/BUILD b/tensorflow/contrib/cloud/kernels/BUILD index 2500f10c74d..35bab9abfbf 100644 --- a/tensorflow/contrib/cloud/kernels/BUILD +++ b/tensorflow/contrib/cloud/kernels/BUILD @@ -37,7 +37,7 @@ tf_kernel_library( srcs = [ "bigquery_reader_ops.cc", ], - visibility = ["//tensorflow:__subpackages__"], + visibility = ["//visibility:public"], deps = [ ":bigquery_table_accessor", ":bigquery_table_partition_proto_cc", diff --git a/tensorflow/contrib/distributions/python/kernel_tests/dirichlet_multinomial_test.py b/tensorflow/contrib/distributions/python/kernel_tests/dirichlet_multinomial_test.py index 60703e6997c..54691d2095d 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/dirichlet_multinomial_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/dirichlet_multinomial_test.py @@ -419,7 +419,7 @@ class DirichletMultinomialTest(test.TestCase): with self.test_session() as sess: dist = ds.DirichletMultinomial( total_count=5., - concentration=2. * self._rng.rand(4, 3, 2).astype(np.float32)) + concentration=1. + 2. * self._rng.rand(4, 3, 2).astype(np.float32)) n = int(3e3) x = dist.sample(n, seed=0) sample_mean = math_ops.reduce_mean(x, 0) @@ -448,7 +448,7 @@ class DirichletMultinomialTest(test.TestCase): with self.test_session() as sess: dist = ds.DirichletMultinomial( total_count=5., - concentration=2. * self._rng.rand(4).astype(np.float32)) + concentration=1. + 2. * self._rng.rand(4).astype(np.float32)) n = int(5e3) x = dist.sample(n, seed=0) sample_mean = math_ops.reduce_mean(x, 0) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/mvn_diag_test.py b/tensorflow/contrib/distributions/python/kernel_tests/mvn_diag_test.py index 9934b3d275f..aa9d45f151d 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/mvn_diag_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/mvn_diag_test.py @@ -21,6 +21,7 @@ from __future__ import print_function import numpy as np from scipy import stats from tensorflow.contrib import distributions +from tensorflow.contrib.distributions.python.ops import bijectors from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.ops import nn_ops @@ -50,6 +51,22 @@ class MultivariateNormalDiagTest(test.TestCase): dist = ds.MultivariateNormalDiag(mu, diag, validate_args=True) self.assertAllEqual([3, 1], dist.sample(3).get_shape()) + def testDistWithBatchShapeOneThenTransformedThroughSoftplus(self): + # This complex combination of events resulted in a loss of static shape + # information when tensor_util.constant_value(self._needs_rotation) was + # being used incorrectly (resulting in always rotating). + # Batch shape = [1], event shape = [3] + mu = array_ops.zeros((1, 3)) + diag = array_ops.ones((1, 3)) + with self.test_session(): + base_dist = ds.MultivariateNormalDiag(mu, diag, validate_args=True) + dist = ds.TransformedDistribution( + base_dist, + validate_args=True, + bijector=bijectors.Softplus(event_ndims=1)) + samps = dist.sample(5) # Shape [5, 1, 3]. + self.assertAllEqual([5, 1], dist.log_prob(samps).get_shape()) + def testMean(self): mu = [-1., 1] diag = [1., -5] diff --git a/tensorflow/contrib/distributions/python/ops/transformed_distribution.py b/tensorflow/contrib/distributions/python/ops/transformed_distribution.py index 86ae8d4a521..844f78ca968 100644 --- a/tensorflow/contrib/distributions/python/ops/transformed_distribution.py +++ b/tensorflow/contrib/distributions/python/ops/transformed_distribution.py @@ -546,7 +546,8 @@ class TransformedDistribution(distributions.Distribution): def _maybe_rotate_dims(self, x, rotate_right=False): """Helper which rolls left event_dims left or right event_dims right.""" - if tensor_util.constant_value(self._needs_rotation) is False: + needs_rotation_const = tensor_util.constant_value(self._needs_rotation) + if needs_rotation_const is not None and not needs_rotation_const: return x ndims = array_ops.rank(x) n = (ndims - self._rotate_ndims) if rotate_right else self._rotate_ndims diff --git a/tensorflow/contrib/layers/__init__.py b/tensorflow/contrib/layers/__init__.py index 6ba8f7e8aec..8e2516dcd57 100644 --- a/tensorflow/contrib/layers/__init__.py +++ b/tensorflow/contrib/layers/__init__.py @@ -35,6 +35,7 @@ See the @{$python/contrib.layers} guide. @@relu6 @@repeat @@safe_embedding_lookup_sparse +@@scale_gradient @@separable_conv2d @@separable_convolution2d @@softmax @@ -68,6 +69,7 @@ See the @{$python/contrib.layers} guide. @@embedding_column @@scattered_embedding_column @@input_from_feature_columns +@@transform_features @@joint_weighted_sum_from_feature_columns @@make_place_holder_tensors_for_base_features @@multi_class_target diff --git a/tensorflow/contrib/layers/python/layers/feature_column.py b/tensorflow/contrib/layers/python/layers/feature_column.py index a04adaec315..282c556424e 100644 --- a/tensorflow/contrib/layers/python/layers/feature_column.py +++ b/tensorflow/contrib/layers/python/layers/feature_column.py @@ -1248,21 +1248,29 @@ def scattered_embedding_column(column_name, initializer=None): """Creates an embedding column of a sparse feature using parameter hashing. - The i-th embedding component of a value v is found by retrieving an - embedding weight whose index is a fingerprint of the pair (v,i). + This is a useful shorthand when you have a sparse feature you want to use an + embedding for, but also want to hash the embedding's values in each dimension + to a variable based on a different hash. + + Specifically, the i-th embedding component of a value v is found by retrieving + an embedding weight whose index is a fingerprint of the pair (v,i). An embedding column with sparse_column_with_hash_bucket such as - embedding_column( + + embedding_column( sparse_column_with_hash_bucket(column_name, bucket_size), dimension) could be replaced by - scattered_embedding_column( - column_name, size=bucket_size * dimension, dimension=dimension, + + scattered_embedding_column( + column_name, + size=bucket_size * dimension, + dimension=dimension, hash_key=tf.contrib.layers.SPARSE_FEATURE_CROSS_DEFAULT_HASH_KEY) - for the same number of embedding parameters and hopefully reduced impact of - collisions with a cost of slowing down training. + for the same number of embedding parameters. This should hopefully reduce the + impact of collisions, but adds the cost of slowing down training. Args: column_name: A string defining sparse column name. diff --git a/tensorflow/contrib/layers/python/layers/feature_column_ops.py b/tensorflow/contrib/layers/python/layers/feature_column_ops.py index de62a1778f3..7f1bfc9605b 100644 --- a/tensorflow/contrib/layers/python/layers/feature_column_ops.py +++ b/tensorflow/contrib/layers/python/layers/feature_column_ops.py @@ -144,6 +144,7 @@ def _input_from_feature_columns(columns_to_tensors, output_rank, default_name): """Implementation of `input_from(_sequence)_feature_columns`.""" + columns_to_tensors = columns_to_tensors.copy() check_feature_columns(feature_columns) with variable_scope.variable_scope(scope, default_name=default_name, @@ -430,6 +431,7 @@ def joint_weighted_sum_from_feature_columns(columns_to_tensors, ValueError: if FeatureColumn cannot be used for linear predictions. """ + columns_to_tensors = columns_to_tensors.copy() check_feature_columns(feature_columns) with variable_scope.variable_scope( scope, @@ -518,6 +520,7 @@ def weighted_sum_from_feature_columns(columns_to_tensors, Raises: ValueError: if FeatureColumn cannot be used for linear predictions. """ + columns_to_tensors = columns_to_tensors.copy() check_feature_columns(feature_columns) with variable_scope.variable_scope( scope, @@ -684,8 +687,8 @@ def transform_features(features, feature_columns): Returns: A `dict` mapping FeatureColumn to `Tensor` and `SparseTensor` values. """ - check_feature_columns(feature_columns) columns_to_tensor = features.copy() + check_feature_columns(feature_columns) transformer = _Transformer(columns_to_tensor) for column in sorted(set(feature_columns), key=lambda x: x.key): transformer.transform(column) diff --git a/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py b/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py index 6624f201c13..35daef9cc6e 100644 --- a/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py +++ b/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py @@ -187,27 +187,28 @@ class TransformerTest(test.TestCase): self.assertAllEqual(output.dense_shape.eval(), [2, 2]) def testEmbeddingColumn(self): - hashed_sparse = feature_column.sparse_column_with_hash_bucket("wire", 10) wire_tensor = sparse_tensor.SparseTensor( values=["omar", "stringer", "marlo"], indices=[[0, 0], [1, 0], [1, 1]], dense_shape=[2, 2]) features = {"wire": wire_tensor} - output = feature_column_ops._Transformer(features).transform( - feature_column.embedding_column(hashed_sparse, 10)) - expected = feature_column_ops._Transformer(features).transform( - hashed_sparse) - with self.test_session(): - self.assertAllEqual(output.values.eval(), expected.values.eval()) - self.assertAllEqual(output.indices.eval(), expected.indices.eval()) - self.assertAllEqual(output.dense_shape.eval(), - expected.dense_shape.eval()) + hashed_sparse = feature_column.sparse_column_with_hash_bucket("wire", 10) + wire_embedding = feature_column.embedding_column(hashed_sparse, 10) # Test transform features. output = feature_column_ops.transform_features( - features=features, feature_columns=[hashed_sparse]) - self.assertEqual(len(output), 1) + features=features, feature_columns=[hashed_sparse, wire_embedding]) + # Check that features dict haven't changed + self.assertEqual({"wire": wire_tensor}, features) + self.assertEqual(len(output), 2) self.assertIn(hashed_sparse, output) + self.assertIn(wire_embedding, output) + with self.test_session(): + self.assertAllEqual(output[wire_embedding].indices.eval(), + wire_tensor.indices.eval()) + self.assertAllEqual(output[wire_embedding].dense_shape.eval(), [2, 2]) + self.assertAllEqual(output[wire_embedding].values.eval(), + output[hashed_sparse].values.eval()) def testSparseColumnWithKeys(self): keys_sparse = feature_column.sparse_column_with_keys( diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py index 07be8e9990f..65dcf8577f0 100644 --- a/tensorflow/contrib/layers/python/layers/layers.py +++ b/tensorflow/contrib/layers/python/layers/layers.py @@ -28,6 +28,7 @@ from tensorflow.contrib.framework.python.ops import variables from tensorflow.contrib.layers.python.layers import initializers from tensorflow.contrib.layers.python.layers import utils from tensorflow.python.framework import dtypes +from tensorflow.python.framework import function from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.layers import convolutional as convolutional_layers @@ -68,6 +69,7 @@ __all__ = ['avg_pool2d', 'relu', 'relu6', 'repeat', + 'scale_gradient', 'separable_conv2d', 'separable_convolution2d', 'softmax', @@ -1745,6 +1747,48 @@ def repeat(inputs, repetitions, layer, *args, **kwargs): return outputs +def _scale_gradient_shape(op): + """Shape helper function for scale_gradient function below.""" + return [op.inputs[0].shape] + + +def _scale_gradient_grad(op, grad): + """Python gradient helper function for scale_gradient function below.""" + return [grad * op.inputs[1], None] + + +@function.Defun(python_grad_func=_scale_gradient_grad, + shape_func=_scale_gradient_shape) +def scale_gradient(inputs, gradient_multiplier): + """Identity operation, but with the gradient multiplied by a tensor. + + The TensorFlow gradient system will compute the gradient with respect to + `inputs` as the product of the gradient with respect to the `output` + multiplied by a specified `gradient_multiplier` tensor. If + `gradient_multiplier` is equal to 1, then this results in the true gradient. + Otherwise, it results in a scaled gradient. + + This can be useful for adjusting the relative learning rate of different + parameter tensors when performing gradient descent, and because this rescaling + can be inserted at arbitrary locations within a graph, is often more + convenient to apply than simply rescaling the final computed gradients. + + Args: + inputs: Tensor to be output. + gradient_multiplier: Tensor by which to multiply the gradient with respect + to `output` to compute the gradient with respect to `inputs`. Its shape + must be broadcastable to the shape of `inputs`. + + Returns: + output Tensor, equal to `inputs`. + """ + # gradient_multiplier is implicitly saved by decorator, and only used for + # gradient computation. + del gradient_multiplier + + return inputs + + @add_arg_scope def separable_convolution2d( inputs, diff --git a/tensorflow/contrib/layers/python/layers/layers_test.py b/tensorflow/contrib/layers/python/layers/layers_test.py index 8219f49dc59..3bc31a26249 100644 --- a/tensorflow/contrib/layers/python/layers/layers_test.py +++ b/tensorflow/contrib/layers/python/layers/layers_test.py @@ -2980,6 +2980,22 @@ class SeparableConv2dTest(test.TestCase): sess.run(net, feed_dict={images_placeholder: images}) +class ScaleGradientTests(test.TestCase): + """Simple tests of the scale_gradient function.""" + + def testBasic(self): + with self.test_session(): + x = np.array([42], np.float32) + gradient_scale = np.array([2], np.float32) + + x = ops.convert_to_tensor(x) + y = layers_lib.scale_gradient(x, gradient_scale) + + np.testing.assert_array_equal(x.eval(), y.eval()) + g_x, = gradients_impl.gradients(y, [x], [np.array([3], np.float32)]) + np.testing.assert_array_equal([3 * 2], g_x.eval()) + + class SoftmaxTests(test.TestCase): def setUp(self): diff --git a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py index b1189aba89b..37fe6faa312 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py +++ b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py @@ -165,8 +165,7 @@ def _dnn_linear_combined_model_fn(features, labels, mode, params, config=None): * embedding_lr_multipliers: Optional. A dictionary from `EmbeddingColumn` to a `float` multiplier. Multiplier will be used to multiply with learning rate for the embedding variables. - * input_layer_min_slice_size: Optional. The min slice size of input layer - partitions. If not provided, will use the default of 64M. + * input_layer_partitioner: Optional. Partitioner for input layer. config: `RunConfig` object to configure the runtime settings. Returns: @@ -174,7 +173,7 @@ def _dnn_linear_combined_model_fn(features, labels, mode, params, config=None): Raises: ValueError: If both `linear_feature_columns` and `dnn_features_columns` - are empty at the same time. + are empty at the same time, or `input_layer_partitioner` is missing. """ head = params["head"] linear_feature_columns = params.get("linear_feature_columns") @@ -186,9 +185,11 @@ def _dnn_linear_combined_model_fn(features, labels, mode, params, config=None): dnn_activation_fn = params.get("dnn_activation_fn") or nn.relu dnn_dropout = params.get("dnn_dropout") gradient_clip_norm = params.get("gradient_clip_norm") - input_layer_min_slice_size = ( - params.get("input_layer_min_slice_size") or 64 << 20) num_ps_replicas = config.num_ps_replicas if config else 0 + input_layer_partitioner = params.get("input_layer_partitioner") or ( + partitioned_variables.min_max_variable_partitioner( + max_partitions=num_ps_replicas, + min_slice_size=64 << 20)) embedding_lr_multipliers = params.get("embedding_lr_multipliers", {}) fix_global_step_increment_bug = params.get( "fix_global_step_increment_bug", True) @@ -221,10 +222,6 @@ def _dnn_linear_combined_model_fn(features, labels, mode, params, config=None): dnn_parent_scope, values=tuple(six.itervalues(features)), partitioner=dnn_partitioner): - input_layer_partitioner = ( - partitioned_variables.min_max_variable_partitioner( - max_partitions=num_ps_replicas, - min_slice_size=input_layer_min_slice_size)) with variable_scope.variable_scope( "input_from_feature_columns", values=tuple(six.itervalues(features)), @@ -387,7 +384,8 @@ class DNNLinearCombinedEstimator(estimator.Estimator): config=None, feature_engineering_fn=None, embedding_lr_multipliers=None, - fix_global_step_increment_bug=False): + fix_global_step_increment_bug=False, + input_layer_partitioner=None): """Initializes a DNNLinearCombinedEstimator instance. Note: New users must set `fix_global_step_increment_bug=True` when creating @@ -432,6 +430,7 @@ class DNNLinearCombinedEstimator(estimator.Estimator): steps to optimize both linear and dnn parts. If `True`, this bug is fixed. New users must set this to `True`, but the default value is `False` for backwards compatibility. + input_layer_partitioner: Optional. Partitioner for input layer. Raises: ValueError: If both linear_feature_columns and dnn_features_columns are @@ -459,6 +458,7 @@ class DNNLinearCombinedEstimator(estimator.Estimator): "gradient_clip_norm": gradient_clip_norm, "embedding_lr_multipliers": embedding_lr_multipliers, "fix_global_step_increment_bug": fix_global_step_increment_bug, + "input_layer_partitioner": input_layer_partitioner }, feature_engineering_fn=feature_engineering_fn) @@ -602,19 +602,25 @@ class DNNLinearCombinedClassifier(estimator.Estimator): ValueError: If both `linear_feature_columns` and `dnn_features_columns` are empty at the same time. """ - if n_classes < 2: - raise ValueError("n_classes should be greater than 1. Given: {}".format( - n_classes)) + head = head_lib.multi_class_head( + n_classes=n_classes, + weight_column_name=weight_column_name, + enable_centered_bias=enable_centered_bias) linear_feature_columns = tuple(linear_feature_columns or []) dnn_feature_columns = tuple(dnn_feature_columns or []) self._feature_columns = linear_feature_columns + dnn_feature_columns if not self._feature_columns: raise ValueError("Either linear_feature_columns or dnn_feature_columns " "must be defined.") - head = head_lib.multi_class_head( - n_classes=n_classes, - weight_column_name=weight_column_name, - enable_centered_bias=enable_centered_bias) + + # TODO(b/35922130): Replace with `input_layer_partitioner` arg. + input_layer_partitioner = None + if input_layer_min_slice_size is not None: + input_layer_partitioner = ( + partitioned_variables.min_max_variable_partitioner( + max_partitions=config.num_ps_replicas if config else 0, + min_slice_size=input_layer_min_slice_size)) + super(DNNLinearCombinedClassifier, self).__init__( model_fn=_dnn_linear_combined_model_fn, model_dir=model_dir, @@ -631,7 +637,7 @@ class DNNLinearCombinedClassifier(estimator.Estimator): "dnn_dropout": dnn_dropout, "gradient_clip_norm": gradient_clip_norm, "embedding_lr_multipliers": embedding_lr_multipliers, - "input_layer_min_slice_size": input_layer_min_slice_size, + "input_layer_partitioner": input_layer_partitioner, "fix_global_step_increment_bug": fix_global_step_increment_bug, }, feature_engineering_fn=feature_engineering_fn) @@ -916,6 +922,15 @@ class DNNLinearCombinedRegressor(estimator.Estimator): if not self._feature_columns: raise ValueError("Either linear_feature_columns or dnn_feature_columns " "must be defined.") + + # TODO(b/35922130): Replace with `input_layer_partitioner` arg. + input_layer_partitioner = None + if input_layer_min_slice_size is not None: + input_layer_partitioner = ( + partitioned_variables.min_max_variable_partitioner( + max_partitions=config.num_ps_replicas if config else 0, + min_slice_size=input_layer_min_slice_size)) + head = head_lib.regression_head( weight_column_name=weight_column_name, label_dimension=label_dimension, @@ -936,7 +951,7 @@ class DNNLinearCombinedRegressor(estimator.Estimator): "dnn_dropout": dnn_dropout, "gradient_clip_norm": gradient_clip_norm, "embedding_lr_multipliers": embedding_lr_multipliers, - "input_layer_min_slice_size": input_layer_min_slice_size, + "input_layer_partitioner": input_layer_partitioner, "fix_global_step_increment_bug": fix_global_step_increment_bug, }, feature_engineering_fn=feature_engineering_fn) diff --git a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined_test.py b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined_test.py index 13456c7f59c..301211ee822 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined_test.py @@ -341,7 +341,7 @@ class DNNLinearCombinedClassifierTest(test.TestCase): input_layer_min_slice_size=1) # Ensure the param is passed in. - self.assertEqual(1, classifier.params['input_layer_min_slice_size']) + self.assertTrue(callable(classifier.params['input_layer_partitioner'])) # Ensure the partition count is 10. classifier.fit(input_fn=_input_fn_float_label, steps=50) diff --git a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py index 628e106c42a..63ef058caa1 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py +++ b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py @@ -188,6 +188,10 @@ def build_sequence_input(features, A `Tensor` of dtype `float32` and shape `[batch_size, padded_length, ?]`. This will be used as input to an RNN. """ + features = features.copy() + features.update(layers.transform_features( + features, + list(sequence_feature_columns) + list(context_feature_columns or []))) sequence_input = layers.sequence_input_from_feature_columns( columns_to_tensors=features, feature_columns=sequence_feature_columns, diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator.py b/tensorflow/contrib/learn/python/learn/estimators/estimator.py index 7a952969451..107454dca1a 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/estimator.py +++ b/tensorflow/contrib/learn/python/learn/estimators/estimator.py @@ -378,6 +378,8 @@ class BaseEstimator( self._model_dir = tempfile.mkdtemp() logging.warning('Using temporary folder as model directory: %s', self._model_dir) + if self._config.model_dir is None: + self._config = self._config.replace(model_dir=self._model_dir) # Set device function depending if there are replicas or not. self._device_fn = _get_replica_device_setter(self._config) diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py b/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py index de52abd6c31..1dc362beb89 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py @@ -481,6 +481,7 @@ class EstimatorTest(test.TestCase): est = estimator.Estimator(model_fn=linear_model_fn, config=config) self.assertEqual('test_dir', est.config.model_dir) + self.assertEqual('test_dir', est.model_dir) def testModelDirAndRunConfigModelDir(self): config = run_config.RunConfig(model_dir='test_dir') @@ -489,11 +490,30 @@ class EstimatorTest(test.TestCase): model_dir='test_dir') self.assertEqual('test_dir', est.config.model_dir) - with self.assertRaises(ValueError): + with self.assertRaisesRegexp( + ValueError, + 'model_dir are set both in constructor and RunConfig, ' + 'but with different'): estimator.Estimator(model_fn=linear_model_fn, config=config, model_dir='different_dir') + def testModelDirIsCopiedToRunConfig(self): + config = run_config.RunConfig() + self.assertIsNone(config.model_dir) + + est = estimator.Estimator(model_fn=linear_model_fn, + model_dir='test_dir', + config=config) + self.assertEqual('test_dir', est.config.model_dir) + self.assertEqual('test_dir', est.model_dir) + + def testModelDirAsTempDir(self): + with test.mock.patch.object(tempfile, 'mkdtemp', return_value='temp_dir'): + est = estimator.Estimator(model_fn=linear_model_fn) + self.assertEqual('temp_dir', est.config.model_dir) + self.assertEqual('temp_dir', est.model_dir) + def testCheckInputs(self): est = estimator.SKCompat(estimator.Estimator(model_fn=linear_model_fn)) # Lambdas so we have to different objects to compare diff --git a/tensorflow/contrib/learn/python/learn/estimators/linear.py b/tensorflow/contrib/learn/python/learn/estimators/linear.py index d7f1017a46a..d1b4aedb81e 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/linear.py +++ b/tensorflow/contrib/learn/python/learn/estimators/linear.py @@ -231,6 +231,8 @@ def sdca_model_fn(features, labels, mode, params): with variable_scope.variable_op_scope( features.values(), parent_scope) as scope: + features = features.copy() + features.update(layers.transform_features(features, feature_columns)) logits, columns_to_variables, bias = ( layers.weighted_sum_from_feature_columns( columns_to_tensors=features, diff --git a/tensorflow/contrib/learn/python/learn/estimators/run_config.py b/tensorflow/contrib/learn/python/learn/estimators/run_config.py index 6049295bea8..bc7465bbc22 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/run_config.py +++ b/tensorflow/contrib/learn/python/learn/estimators/run_config.py @@ -18,12 +18,14 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections import copy import json import os import six +from tensorflow.contrib.framework.python.framework import experimental from tensorflow.core.protobuf import config_pb2 from tensorflow.python.training import server_lib @@ -291,7 +293,7 @@ class RunConfig(ClusterConfig): new_copy = copy.deepcopy(self) - # TODO(xiejw): Allow more fields, such as the user allowed changed ones. + # TODO(b/33295821): Allow more fields to be replaced. for key, new_value in six.iteritems(kwargs): if key == 'model_dir': new_copy._model_dir = new_value # pylint: disable=protected-access @@ -301,6 +303,24 @@ class RunConfig(ClusterConfig): return new_copy + @experimental + def uid(self): + """Generates a 'Unique Identifier' based on all internal fields. + + Caller should use the uid string to check `RunConfig` instance integrity + in one session use, but should not rely on the implementation details, which + is subject to change. + + Returns: + A uid string. + """ + # TODO(b/33295821): Allows user to specify a whitelist. + state = {k: v for k, v in self.__dict__.items() if not k.startswith('__')} + ordered_state = collections.OrderedDict( + sorted(state.items(), key=lambda t: t[0])) + return ', '.join( + '%s=%r' % (k, v) for (k, v) in six.iteritems(ordered_state)) + @property def model_dir(self): return self._model_dir diff --git a/tensorflow/contrib/learn/python/learn/estimators/run_config_test.py b/tensorflow/contrib/learn/python/learn/estimators/run_config_test.py index 842f71689f0..4d312ca8eea 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/run_config_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/run_config_test.py @@ -238,6 +238,18 @@ class RunConfigTest(test.TestCase): with self.assertRaises(ValueError): config.replace(some_undefined_property=RANDOM_SEED) + def test_uid(self): + config = run_config.RunConfig( + tf_random_seed=RANDOM_SEED, model_dir=TEST_DIR) + + expected_uid = config.uid() + # Check for 10 times, which should prove something. + for _ in range(10): + self.assertEqual(expected_uid, config.uid()) + + new_config = config.replace(model_dir=ANOTHER_TEST_DIR) + self.assertNotEqual(expected_uid, new_config.uid()) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/learn/python/learn/graph_actions.py b/tensorflow/contrib/learn/python/learn/graph_actions.py index b2e48156917..4b7867f2d00 100644 --- a/tensorflow/contrib/learn/python/learn/graph_actions.py +++ b/tensorflow/contrib/learn/python/learn/graph_actions.py @@ -42,9 +42,7 @@ from tensorflow.python.ops import logging_ops from tensorflow.python.ops import resources from tensorflow.python.ops import variables from tensorflow.python.platform import tf_logging as logging -from tensorflow.python.training import basic_session_run_hooks from tensorflow.python.training import coordinator -from tensorflow.python.training import monitored_session from tensorflow.python.training import queue_runner from tensorflow.python.training import saver as tf_saver from tensorflow.python.training import session_manager as session_manager_lib @@ -121,205 +119,6 @@ def _run_with_monitors(session, step, tensors, feed_dict, monitors): return outputs, should_stop -def _monitored_train(graph, - output_dir, - train_op, - loss_op, - global_step_tensor=None, - init_op=None, - init_feed_dict=None, - init_fn=None, - log_every_steps=10, - supervisor_is_chief=True, - supervisor_master='', - supervisor_save_model_secs=600, - supervisor_save_model_steps=None, - keep_checkpoint_max=5, - keep_checkpoint_every_n_hours=10000.0, - supervisor_save_summaries_secs=None, - supervisor_save_summaries_steps=100, - feed_fn=None, - steps=None, - fail_on_nan_loss=True, - hooks=None, - max_steps=None): - """Train a model via monitored_session. - - Given `graph`, a directory to write outputs to (`output_dir`), and some ops, - run a training loop. The given `train_op` performs one step of training on the - model. The `loss_op` represents the objective function of the training. It is - expected to increment the `global_step_tensor`, a scalar integer tensor - counting training steps. This function uses `Supervisor` to initialize the - graph (from a checkpoint if one is available in `output_dir`), write summaries - defined in the graph, and write regular checkpoints as defined by - `supervisor_save_model_secs`. - - Training continues until `global_step_tensor` evaluates to `max_steps`, or, if - `fail_on_nan_loss`, until `loss_op` evaluates to `NaN`. In that case the - program is terminated with exit code 1. - - Args: - graph: A graph to train. It is expected that this graph is not in use - elsewhere. - output_dir: A directory to write outputs to. - train_op: An op that performs one training step when run. - loss_op: A scalar loss tensor. - global_step_tensor: A tensor representing the global step. If none is given, - one is extracted from the graph using the same logic as in `Supervisor`. - init_op: An op that initializes the graph. If `None`, use `Supervisor`'s - default. - init_feed_dict: A dictionary that maps `Tensor` objects to feed values. - This feed dictionary will be used when `init_op` is evaluated. - init_fn: Optional callable passed to Supervisor to initialize the model. - log_every_steps: Output logs regularly. The logs contain timing data and the - current loss. A `0` or negative value disables logging. - supervisor_is_chief: Whether the current process is the chief supervisor in - charge of restoring the model and running standard services. - supervisor_master: The master string to use when preparing the session. - supervisor_save_model_secs: Save checkpoints every this many seconds. Can - not be specified with `supervisor_save_model_steps`. - supervisor_save_model_steps: Save checkpoints every this many steps. Can not - be specified with `supervisor_save_model_secs`. - keep_checkpoint_max: The maximum number of recent checkpoint files to - keep. As new files are created, older files are deleted. If None or 0, - all checkpoint files are kept. This is simply passed as the max_to_keep - arg to `tf.train.Saver` constructor. - keep_checkpoint_every_n_hours: In addition to keeping the most recent - `keep_checkpoint_max` checkpoint files, you might want to keep one checkpoint file - for every N hours of training. This can be useful if you want to later - analyze how a model progressed during a long training session. For - example, passing `keep_checkpoint_every_n_hours=2` ensures that you keep - one checkpoint file for every 2 hours of training. The default value of - 10,000 hours effectively disables the feature. - supervisor_save_summaries_secs: Save summaries every - `supervisor_save_summaries_secs` seconds when training. - supervisor_save_summaries_steps: Save summaries every - `supervisor_save_summaries_steps` steps when training. Exactly one of - `supervisor_save_model_steps` and `supervisor_save_model_secs` should be - specified, and the other should be None. - feed_fn: A function that is called every iteration to produce a `feed_dict` - passed to `session.run` calls. Optional. - steps: Trains for this many steps (e.g. current global step + `steps`). - fail_on_nan_loss: If true, raise `NanLossDuringTrainingError` if `loss_op` - evaluates to `NaN`. If false, continue training as if nothing happened. - hooks: List of `SessionRunHook` subclass instances. Used for callbacks - inside the training loop. - max_steps: Number of total steps for which to train model. If `None`, - train forever. Two calls fit(steps=100) means 200 training iterations. - On the other hand two calls of fit(max_steps=100) means, second call - will not do any iteration since first call did all 100 steps. - - Returns: - The final loss value. - - Raises: - ValueError: If `output_dir`, `train_op`, `loss_op`, or `global_step_tensor` - is not provided. See `tf.contrib.framework.get_global_step` for how we - look up the latter if not provided explicitly. - NanLossDuringTrainingError: If `fail_on_nan_loss` is `True`, and loss ever - evaluates to `NaN`. - ValueError: If both `steps` and `max_steps` are not `None`. - """ - if (steps is not None) and (max_steps is not None): - raise ValueError('Can not provide both steps and max_steps.') - if not output_dir: - raise ValueError('Output directory should be non-empty %s.' % output_dir) - if train_op is None: - raise ValueError('Missing train_op.') - if loss_op is None: - raise ValueError('Missing loss_op.') - if hooks is None: - hooks = [] - if not isinstance(hooks, list): - raise ValueError('Hooks should be a list.') - with graph.as_default(): - global_step_tensor = contrib_variables.assert_or_get_global_step( - graph, global_step_tensor) - if global_step_tensor is None: - raise ValueError('No "global_step" was provided or found in the graph.') - - if max_steps is not None: - try: - start_step = load_variable(output_dir, global_step_tensor.name) - if max_steps <= start_step: - logging.info('Skipping training since max_steps has already saved.') - return None - except: # pylint: disable=bare-except - pass - - # Adapted SessionRunHooks such as ExportMonitor depend on the - # CheckpointSaverHook to be executed before they should be executed. - # The `hooks` param comprises of deprecated monitor hooks - # (such as ExportMonitor). Appending them after the basic_session_run_hooks. - all_hooks = [] - with graph.as_default(): - all_hooks.append(basic_session_run_hooks.NanTensorHook( - loss_op, fail_on_nan_loss=fail_on_nan_loss)) - if log_every_steps > 0: - all_hooks.append(basic_session_run_hooks.LoggingTensorHook({ - 'loss': loss_op.name, - 'step': global_step_tensor.name - }, every_n_iter=log_every_steps)) - - def make_saver(): - return tf_saver.Saver( - sharded=True, - max_to_keep=keep_checkpoint_max, - keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours, - defer_build=True) - - scaffold = monitored_session.Scaffold( - init_op=init_op, - init_feed_dict=init_feed_dict, - init_fn=init_fn, - saver=monitored_session.Scaffold.get_or_default('saver', - ops.GraphKeys.SAVERS, - make_saver)) - - if not supervisor_is_chief: - session_creator = monitored_session.WorkerSessionCreator( - scaffold=scaffold, - master=supervisor_master) - else: - session_creator = monitored_session.ChiefSessionCreator( - scaffold=scaffold, - checkpoint_dir=output_dir, - master=supervisor_master) - summary_writer = summary_io.SummaryWriterCache.get(output_dir) - all_hooks.append( - basic_session_run_hooks.StepCounterHook( - summary_writer=summary_writer)) - all_hooks.append( - basic_session_run_hooks.SummarySaverHook( - save_secs=supervisor_save_summaries_secs, - save_steps=supervisor_save_summaries_steps, - summary_writer=summary_writer, - scaffold=scaffold)) - if (supervisor_save_model_secs is not None - or supervisor_save_model_steps is not None): - all_hooks.append( - basic_session_run_hooks.CheckpointSaverHook( - output_dir, - save_secs=supervisor_save_model_secs, - save_steps=supervisor_save_model_steps, - scaffold=scaffold)) - - if steps is not None or max_steps is not None: - all_hooks.append(basic_session_run_hooks.StopAtStepHook(steps, max_steps)) - all_hooks.extend(hooks) - - with monitored_session.MonitoredSession( - session_creator=session_creator, - hooks=all_hooks) as super_sess: - loss = None - while not super_sess.should_stop(): - _, loss = super_sess.run([train_op, loss_op], feed_fn() if feed_fn else - None) - - summary_io.SummaryWriterCache.clear() - return loss - - @_graph_action_deprecation def train(graph, output_dir, diff --git a/tensorflow/contrib/learn/python/learn/graph_actions_test.py b/tensorflow/contrib/learn/python/learn/graph_actions_test.py index 2aeddfd84ab..0d039d593b7 100644 --- a/tensorflow/contrib/learn/python/learn/graph_actions_test.py +++ b/tensorflow/contrib/learn/python/learn/graph_actions_test.py @@ -27,7 +27,6 @@ from tensorflow.contrib.framework.python.ops import variables as variables_lib from tensorflow.contrib.learn.python import learn from tensorflow.contrib.learn.python.learn.monitors import BaseMonitor from tensorflow.python.framework import constant_op -from tensorflow.python.framework import meta_graph from tensorflow.python.framework import ops from tensorflow.python.framework import test_ops from tensorflow.python.ops import control_flow_ops @@ -36,7 +35,6 @@ from tensorflow.python.ops import state_ops from tensorflow.python.ops import variables from tensorflow.python.platform import test from tensorflow.python.summary import summary -from tensorflow.python.training import monitored_session from tensorflow.python.training import saver as saver_lib @@ -220,7 +218,7 @@ class GraphActionsTest(test.TestCase): self.assertTrue(request_stop.called) def test_run_feeds_iter_calls_resources_init(self): - with ops.Graph().as_default() as g: + with ops.Graph().as_default(): in0, _, _ = self._build_inference_graph() handle = test_ops.stub_resource_handle_op(container='a', shared_name='b') resources.register_resource( @@ -314,7 +312,7 @@ class GraphActionsTest(test.TestCase): with ops.Graph().as_default() as g, self.test_session(g): variables_lib.create_global_step() v = variables.Variable(1.0) - w = variables.Variable( + variables.Variable( v + 1, collections=[ops.GraphKeys.LOCAL_VARIABLES], trainable=False) ready_for_local_init_op = variables.report_uninitialized_variables( variables.global_variables()) @@ -396,223 +394,11 @@ class GraphActionsTest(test.TestCase): }}, expected_session_logs=[]) - def test_train_invalid_args(self): - with ops.Graph().as_default() as g, self.test_session(g): - train_op = constant_op.constant(1.0) - loss_op = constant_op.constant(2.0) - with self.assertRaisesRegexp(ValueError, 'utput directory'): - learn.graph_actions._monitored_train( - g, # pylint: disable=protected-access - output_dir=None, - train_op=train_op, - loss_op=loss_op) - with self.assertRaisesRegexp(ValueError, 'utput directory'): - learn.graph_actions._monitored_train( # pylint: disable=protected-access - g, - output_dir='', - train_op=constant_op.constant(1.0), - loss_op=constant_op.constant(2.0)) - with self.assertRaisesRegexp(ValueError, 'train_op'): - learn.graph_actions._monitored_train( # pylint: disable=protected-access - g, - output_dir=self._output_dir, - train_op=None, - loss_op=loss_op) - with self.assertRaisesRegexp(ValueError, 'loss_op'): - learn.graph_actions._monitored_train( # pylint: disable=protected-access - g, - output_dir=self._output_dir, - train_op=constant_op.constant(1.0), - loss_op=None) - with self.assertRaisesRegexp(ValueError, 'global_step'): - learn.graph_actions._monitored_train( # pylint: disable=protected-access - g, - output_dir=self._output_dir, - train_op=constant_op.constant(1.0), - loss_op=loss_op) - # TODO(ptucker): Resume training from previous ckpt. # TODO(ptucker): !supervisor_is_chief # TODO(ptucker): Custom init op for training. # TODO(ptucker): Mock supervisor, and assert all interactions. - def test_train(self): - with ops.Graph().as_default() as g, self.test_session(g): - with ops.control_dependencies(self._build_inference_graph()): - train_op = state_ops.assign_add(variables_lib.get_global_step(), 1) - writer = learn.graph_actions.get_summary_writer(self._output_dir) - self._assert_summaries(self._output_dir, writer) - self._assert_ckpt(self._output_dir, False) - loss = learn.graph_actions._monitored_train( # pylint: disable=protected-access - g, - output_dir=self._output_dir, - train_op=train_op, - loss_op=constant_op.constant(2.0), - steps=1) - meta_graph_def = meta_graph.create_meta_graph_def( - graph_def=g.as_graph_def(add_shapes=True), - saver_def=monitored_session.Scaffold().finalize().saver.saver_def) - self.assertEqual(2.0, loss) - self._assert_summaries( - self._output_dir, - writer, - expected_graphs=[g], - expected_meta_graphs=[meta_graph_def]) - self._assert_ckpt(self._output_dir, True) - - def test_train_steps_is_incremental(self): - with ops.Graph().as_default() as g, self.test_session(g): - with ops.control_dependencies(self._build_inference_graph()): - train_op = state_ops.assign_add(variables_lib.get_global_step(), 1) - learn.graph_actions._monitored_train( # pylint: disable=protected-access - g, - output_dir=self._output_dir, - train_op=train_op, - loss_op=constant_op.constant(2.0), - steps=10) - step = checkpoint_utils.load_variable( - self._output_dir, variables_lib.get_global_step().name) - self.assertEqual(10, step) - - with ops.Graph().as_default() as g, self.test_session(g): - with ops.control_dependencies(self._build_inference_graph()): - train_op = state_ops.assign_add(variables_lib.get_global_step(), 1) - learn.graph_actions._monitored_train( # pylint: disable=protected-access - g, - output_dir=self._output_dir, - train_op=train_op, - loss_op=constant_op.constant(2.0), - steps=15) - step = checkpoint_utils.load_variable( - self._output_dir, variables_lib.get_global_step().name) - self.assertEqual(25, step) - - def test_train_max_steps_is_not_incremental(self): - with ops.Graph().as_default() as g, self.test_session(g): - with ops.control_dependencies(self._build_inference_graph()): - train_op = state_ops.assign_add(variables_lib.get_global_step(), 1) - learn.graph_actions._monitored_train( # pylint: disable=protected-access - g, - output_dir=self._output_dir, - train_op=train_op, - loss_op=constant_op.constant(2.0), - max_steps=10) - step = checkpoint_utils.load_variable( - self._output_dir, variables_lib.get_global_step().name) - self.assertEqual(10, step) - - with ops.Graph().as_default() as g, self.test_session(g): - with ops.control_dependencies(self._build_inference_graph()): - train_op = state_ops.assign_add(variables_lib.get_global_step(), 1) - learn.graph_actions._monitored_train( # pylint: disable=protected-access - g, - output_dir=self._output_dir, - train_op=train_op, - loss_op=constant_op.constant(2.0), - max_steps=15) - step = checkpoint_utils.load_variable( - self._output_dir, variables_lib.get_global_step().name) - self.assertEqual(15, step) - - def test_train_skip_train_if_max_step_already_saved(self): - with ops.Graph().as_default() as g, self.test_session(g): - with ops.control_dependencies(self._build_inference_graph()): - train_op = state_ops.assign_add(variables_lib.get_global_step(), 1) - learn.graph_actions._monitored_train( # pylint: disable=protected-access - g, - output_dir=self._output_dir, - train_op=train_op, - loss_op=constant_op.constant(2.0), - max_steps=10) - step = checkpoint_utils.load_variable( - self._output_dir, variables_lib.get_global_step().name) - self.assertEqual(10, step) - - with ops.Graph().as_default() as g, self.test_session(g): - with ops.control_dependencies(self._build_inference_graph()): - train_op = state_ops.assign_add(variables_lib.get_global_step(), 1) - learn.graph_actions._monitored_train( # pylint: disable=protected-access - g, - output_dir=self._output_dir, - train_op=train_op, - loss_op=constant_op.constant(2.0), - max_steps=10) - step = checkpoint_utils.load_variable( - self._output_dir, variables_lib.get_global_step().name) - self.assertEqual(10, step) - - def test_train_loss(self): - with ops.Graph().as_default() as g, self.test_session(g): - variables_lib.create_global_step() - loss_var = variables_lib.local_variable(10.0) - train_op = control_flow_ops.group( - state_ops.assign_add(variables_lib.get_global_step(), 1), - state_ops.assign_add(loss_var, -1.0)) - writer = learn.graph_actions.get_summary_writer(self._output_dir) - self._assert_summaries(self._output_dir, writer) - self._assert_ckpt(self._output_dir, False) - loss = learn.graph_actions._monitored_train( # pylint: disable=protected-access - g, - output_dir=self._output_dir, - train_op=train_op, - loss_op=loss_var.value(), - steps=6) - self.assertEqual(4.0, loss) - self._assert_summaries( - self._output_dir, - writer, - expected_graphs=[g], - expected_meta_graphs=None) - self._assert_ckpt(self._output_dir, True) - - def test_train_summaries(self): - with ops.Graph().as_default() as g, self.test_session(g): - with ops.control_dependencies(self._build_inference_graph()): - train_op = state_ops.assign_add(variables_lib.get_global_step(), 1) - loss_op = constant_op.constant(2.0) - summary.scalar('loss', loss_op) - writer = learn.graph_actions.get_summary_writer(self._output_dir) - self._assert_summaries(self._output_dir, writer) - self._assert_ckpt(self._output_dir, False) - loss = learn.graph_actions._monitored_train( # pylint: disable=protected-access - g, - output_dir=self._output_dir, - train_op=train_op, - loss_op=loss_op, - steps=1) - meta_graph_def = meta_graph.create_meta_graph_def( - graph_def=g.as_graph_def(add_shapes=True), - saver_def=monitored_session.Scaffold().finalize().saver.saver_def) - self.assertEqual(2.0, loss) - self._assert_summaries( - self._output_dir, - writer, - expected_graphs=[g], - expected_meta_graphs=[meta_graph_def], - expected_summaries={1: { - 'loss': 2.0 - }}) - self._assert_ckpt(self._output_dir, True) - - def test_train_override_saver(self): - with ops.Graph().as_default() as g, self.test_session(g): - with ops.control_dependencies(self._build_inference_graph()): - train_op = state_ops.assign_add(variables_lib.get_global_step(), 1) - self._assert_ckpt(self._output_dir, False) - real_saver = saver_lib.Saver() - saver = test.mock.Mock(wraps=real_saver, saver_def=real_saver.saver_def) - ops.add_to_collection(ops.GraphKeys.SAVERS, saver) - loss = learn.graph_actions._monitored_train( # pylint: disable=protected-access - g, - output_dir=self._output_dir, - train_op=train_op, - loss_op=constant_op.constant(2.0), - steps=1) - self.assertEqual(2.0, loss) - self._assert_ckpt(self._output_dir, True) - self.assertTrue(saver.build.called) - self.assertEqual(1, saver.save.call_count) - # TODO(ispir): remove following tests after deprecated train. class GraphActionsTrainTest(test.TestCase): diff --git a/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_tril_test.py b/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_tril_test.py index 85c32efb833..9f5f2856f13 100644 --- a/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_tril_test.py +++ b/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_tril_test.py @@ -60,14 +60,6 @@ class LinearOperatorTriLTest( return operator, mat, feed_dict - def test_assert_positive_definite(self): - # Singlular matrix with one positive eigenvalue and one negative eigenvalue. - with self.test_session(): - tril = [[1., 0.], [1., -1.]] - operator = linalg.LinearOperatorTriL(tril) - with self.assertRaisesOpError("was not positive definite"): - operator.assert_positive_definite().run() - def test_assert_non_singular(self): # Singlular matrix with one positive eigenvalue and one zero eigenvalue. with self.test_session(): diff --git a/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_util_test.py b/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_util_test.py index bf6f8f83027..a06af336e71 100644 --- a/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_util_test.py +++ b/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_util_test.py @@ -21,8 +21,10 @@ import numpy as np from tensorflow.contrib import linalg as linalg_lib from tensorflow.contrib.linalg.python.ops import linear_operator_util +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed +from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.platform import test @@ -91,6 +93,142 @@ class AssertNoEntriesWithModulusZeroTest(test.TestCase): z, message="ABC123").run() +class BroadcastMatrixBatchDimsTest(test.TestCase): + + def test_zero_batch_matrices_returned_as_empty_list(self): + self.assertAllEqual( + [], linear_operator_util.broadcast_matrix_batch_dims([])) + + def test_one_batch_matrix_returned_after_tensor_conversion(self): + arr = rng.rand(2, 3, 4) + tensor, = linear_operator_util.broadcast_matrix_batch_dims([arr]) + self.assertTrue(isinstance(tensor, ops.Tensor)) + + with self.test_session(): + self.assertAllClose(arr, tensor.eval()) + + def test_static_dims_broadcast(self): + # x.batch_shape = [3, 1, 2] + # y.batch_shape = [4, 1] + # broadcast batch shape = [3, 4, 2] + x = rng.rand(3, 1, 2, 1, 5) + y = rng.rand(4, 1, 3, 7) + batch_of_zeros = np.zeros((3, 4, 2, 1, 1)) + x_bc_expected = x + batch_of_zeros + y_bc_expected = y + batch_of_zeros + + x_bc, y_bc = linear_operator_util.broadcast_matrix_batch_dims([x, y]) + + with self.test_session() as sess: + self.assertAllEqual(x_bc_expected.shape, x_bc.get_shape()) + self.assertAllEqual(y_bc_expected.shape, y_bc.get_shape()) + x_bc_, y_bc_ = sess.run([x_bc, y_bc]) + self.assertAllClose(x_bc_expected, x_bc_) + self.assertAllClose(y_bc_expected, y_bc_) + + def test_static_dims_broadcast_second_arg_higher_rank(self): + # x.batch_shape = [1, 2] + # y.batch_shape = [1, 3, 1] + # broadcast batch shape = [1, 3, 2] + x = rng.rand(1, 2, 1, 5) + y = rng.rand(1, 3, 2, 3, 7) + batch_of_zeros = np.zeros((1, 3, 2, 1, 1)) + x_bc_expected = x + batch_of_zeros + y_bc_expected = y + batch_of_zeros + + x_bc, y_bc = linear_operator_util.broadcast_matrix_batch_dims([x, y]) + + with self.test_session() as sess: + self.assertAllEqual(x_bc_expected.shape, x_bc.get_shape()) + self.assertAllEqual(y_bc_expected.shape, y_bc.get_shape()) + x_bc_, y_bc_ = sess.run([x_bc, y_bc]) + self.assertAllClose(x_bc_expected, x_bc_) + self.assertAllClose(y_bc_expected, y_bc_) + + def test_dynamic_dims_broadcast_32bit(self): + # x.batch_shape = [3, 1, 2] + # y.batch_shape = [4, 1] + # broadcast batch shape = [3, 4, 2] + x = rng.rand(3, 1, 2, 1, 5).astype(np.float32) + y = rng.rand(4, 1, 3, 7).astype(np.float32) + batch_of_zeros = np.zeros((3, 4, 2, 1, 1)).astype(np.float32) + x_bc_expected = x + batch_of_zeros + y_bc_expected = y + batch_of_zeros + + x_ph = array_ops.placeholder(dtypes.float32) + y_ph = array_ops.placeholder(dtypes.float32) + + x_bc, y_bc = linear_operator_util.broadcast_matrix_batch_dims([x_ph, y_ph]) + + with self.test_session() as sess: + x_bc_, y_bc_ = sess.run([x_bc, y_bc], feed_dict={x_ph: x, y_ph: y}) + self.assertAllClose(x_bc_expected, x_bc_) + self.assertAllClose(y_bc_expected, y_bc_) + + def test_dynamic_dims_broadcast_32bit_second_arg_higher_rank(self): + # x.batch_shape = [1, 2] + # y.batch_shape = [3, 4, 1] + # broadcast batch shape = [3, 4, 2] + x = rng.rand(1, 2, 1, 5).astype(np.float32) + y = rng.rand(3, 4, 1, 3, 7).astype(np.float32) + batch_of_zeros = np.zeros((3, 4, 2, 1, 1)).astype(np.float32) + x_bc_expected = x + batch_of_zeros + y_bc_expected = y + batch_of_zeros + + x_ph = array_ops.placeholder(dtypes.float32) + y_ph = array_ops.placeholder(dtypes.float32) + + x_bc, y_bc = linear_operator_util.broadcast_matrix_batch_dims([x_ph, y_ph]) + + with self.test_session() as sess: + x_bc_, y_bc_ = sess.run([x_bc, y_bc], feed_dict={x_ph: x, y_ph: y}) + self.assertAllClose(x_bc_expected, x_bc_) + self.assertAllClose(y_bc_expected, y_bc_) + + def test_less_than_two_dims_raises_static(self): + x = rng.rand(3) + y = rng.rand(1, 1) + + with self.assertRaisesRegexp(ValueError, "at least two dimensions"): + linear_operator_util.broadcast_matrix_batch_dims([x, y]) + + with self.assertRaisesRegexp(ValueError, "at least two dimensions"): + linear_operator_util.broadcast_matrix_batch_dims([y, x]) + + +class MatmulWithBroadcastTest(test.TestCase): + + def test_static_dims_broadcast(self): + # batch_shape = [2] + # for each batch member, we have a 1x3 matrix times a 3x7 matrix ==> 1x7 + x = rng.rand(2, 1, 3) + y = rng.rand(3, 7) + y_broadcast = y + np.zeros((2, 1, 1)) + + with self.test_session(): + result = linear_operator_util.matmul_with_broadcast(x, y) + self.assertAllEqual((2, 1, 7), result.get_shape()) + expected = math_ops.matmul(x, y_broadcast) + self.assertAllEqual(expected.eval(), result.eval()) + + def test_dynamic_dims_broadcast_32bit(self): + # batch_shape = [2] + # for each batch member, we have a 1x3 matrix times a 3x7 matrix ==> 1x7 + x = rng.rand(2, 1, 3) + y = rng.rand(3, 7) + y_broadcast = y + np.zeros((2, 1, 1)) + + x_ph = array_ops.placeholder(dtypes.float64) + y_ph = array_ops.placeholder(dtypes.float64) + + with self.test_session() as sess: + result, expected = sess.run( + [linear_operator_util.matmul_with_broadcast(x_ph, y_ph), + math_ops.matmul(x, y_broadcast)], + feed_dict={x_ph: x, y_ph: y}) + self.assertAllEqual(expected, result) + + class DomainDimensionStubOperator(object): def __init__(self, domain_dimension): diff --git a/tensorflow/contrib/linalg/python/ops/linear_operator.py b/tensorflow/contrib/linalg/python/ops/linear_operator.py index 27996480cf2..5052a0b15cf 100644 --- a/tensorflow/contrib/linalg/python/ops/linear_operator.py +++ b/tensorflow/contrib/linalg/python/ops/linear_operator.py @@ -155,8 +155,9 @@ class LinearOperator(object): is_self_adjoint: Expect that this operator is equal to its hermitian transpose. If `dtype` is real, this is equivalent to being symmetric. is_positive_definite: Expect that this operator is positive definite, - meaning the real part of all eigenvalues is positive. We do not require - the operator to be self-adjoint to be positive-definite. See: + meaning the quadratic form `x^H A x` has positive real part for all + nonzero `x`. Note that we do not require the operator to be + self-adjoint to be positive-definite. See: https://en.wikipedia.org/wiki/Positive-definite_matrix\ #Extension_for_non_symmetric_matrices is_square: Expect that this operator acts like square [batch] matrices. @@ -461,8 +462,9 @@ class LinearOperator(object): def assert_positive_definite(self, name="assert_positive_definite"): """Returns an `Op` that asserts this operator is positive definite. - Here, positive definite means the real part of all eigenvalues is positive. - We do not require the operator to be self-adjoint. + Here, positive definite means that the quadratic form `x^H A x` has positive + real part for all nonzero `x`. Note that we do not require the operator to + be self-adjoint to be positive definite. Args: name: A name to give this `Op`. diff --git a/tensorflow/contrib/linalg/python/ops/linear_operator_composition.py b/tensorflow/contrib/linalg/python/ops/linear_operator_composition.py index 11ce5e0e64b..9f3a4d230f7 100644 --- a/tensorflow/contrib/linalg/python/ops/linear_operator_composition.py +++ b/tensorflow/contrib/linalg/python/ops/linear_operator_composition.py @@ -113,7 +113,7 @@ class LinearOperatorComposition(linear_operator.LinearOperator): is_self_adjoint=None, is_positive_definite=None, name=None): - """Initialize a `LinearOperatorComposition`. + r"""Initialize a `LinearOperatorComposition`. `LinearOperatorComposition` is initialized with a list of operators `[op_1,...,op_J]`. For the `apply` method to be well defined, the @@ -127,9 +127,10 @@ class LinearOperatorComposition(linear_operator.LinearOperator): is_self_adjoint: Expect that this operator is equal to its hermitian transpose. is_positive_definite: Expect that this operator is positive definite, - meaning the real part of all eigenvalues is positive. We do not require - the operator to be self-adjoint to be positive-definite. See: - https://en.wikipedia.org/wiki/Positive-definite_matrix + meaning the quadratic form `x^H A x` has positive real part for all + nonzero `x`. Note that we do not require the operator to be + self-adjoint to be positive-definite. See: + https://en.wikipedia.org/wiki/Positive-definite_matrix\ #Extension_for_non_symmetric_matrices name: A name for this `LinearOperator`. Default is the individual operators names joined with `_o_`. diff --git a/tensorflow/contrib/linalg/python/ops/linear_operator_diag.py b/tensorflow/contrib/linalg/python/ops/linear_operator_diag.py index c09d2367622..0cd7e72a8b6 100644 --- a/tensorflow/contrib/linalg/python/ops/linear_operator_diag.py +++ b/tensorflow/contrib/linalg/python/ops/linear_operator_diag.py @@ -114,7 +114,7 @@ class LinearOperatorDiag(linear_operator.LinearOperator): is_self_adjoint=None, is_positive_definite=None, name="LinearOperatorDiag"): - """Initialize a `LinearOperatorDiag`. + r"""Initialize a `LinearOperatorDiag`. Args: diag: Shape `[B1,...,Bb, N]` `Tensor` with `b >= 0` `N >= 0`. @@ -124,9 +124,10 @@ class LinearOperatorDiag(linear_operator.LinearOperator): is_self_adjoint: Expect that this operator is equal to its hermitian transpose. If `diag.dtype` is real, this is auto-set to `True`. is_positive_definite: Expect that this operator is positive definite, - meaning the real part of all eigenvalues is positive. We do not require - the operator to be self-adjoint to be positive-definite. See: - https://en.wikipedia.org/wiki/Positive-definite_matrix + meaning the quadratic form `x^H A x` has positive real part for all + nonzero `x`. Note that we do not require the operator to be + self-adjoint to be positive-definite. See: + https://en.wikipedia.org/wiki/Positive-definite_matrix\ #Extension_for_non_symmetric_matrices name: A name for this `LinearOperator`. diff --git a/tensorflow/contrib/linalg/python/ops/linear_operator_full_matrix.py b/tensorflow/contrib/linalg/python/ops/linear_operator_full_matrix.py index f8d0202baf9..f9349682215 100644 --- a/tensorflow/contrib/linalg/python/ops/linear_operator_full_matrix.py +++ b/tensorflow/contrib/linalg/python/ops/linear_operator_full_matrix.py @@ -109,7 +109,7 @@ class LinearOperatorFullMatrix(linear_operator.LinearOperator): is_self_adjoint=None, is_positive_definite=None, name="LinearOperatorFullMatrix"): - """Initialize a `LinearOperatorFullMatrix`. + r"""Initialize a `LinearOperatorFullMatrix`. Args: matrix: Shape `[B1,...,Bb, M, N]` with `b >= 0`, `M, N >= 0`. @@ -118,9 +118,10 @@ class LinearOperatorFullMatrix(linear_operator.LinearOperator): is_self_adjoint: Expect that this operator is equal to its hermitian transpose. is_positive_definite: Expect that this operator is positive definite, - meaning the real part of all eigenvalues is positive. We do not require - the operator to be self-adjoint to be positive-definite. See: - https://en.wikipedia.org/wiki/Positive-definite_matrix + meaning the quadratic form `x^H A x` has positive real part for all + nonzero `x`. Note that we do not require the operator to be + self-adjoint to be positive-definite. See: + https://en.wikipedia.org/wiki/Positive-definite_matrix\ #Extension_for_non_symmetric_matrices name: A name for this `LinearOperator`. diff --git a/tensorflow/contrib/linalg/python/ops/linear_operator_identity.py b/tensorflow/contrib/linalg/python/ops/linear_operator_identity.py index 6523bd29c32..60d8b2cdc03 100644 --- a/tensorflow/contrib/linalg/python/ops/linear_operator_identity.py +++ b/tensorflow/contrib/linalg/python/ops/linear_operator_identity.py @@ -200,7 +200,7 @@ class LinearOperatorIdentity(BaseLinearOperatorIdentity): is_positive_definite=True, assert_proper_shapes=False, name="LinearOperatorIdentity"): - """Initialize a `LinearOperatorIdentity`. + r"""Initialize a `LinearOperatorIdentity`. The `LinearOperatorIdentity` is initialized with arguments defining `dtype` and shape. @@ -218,7 +218,12 @@ class LinearOperatorIdentity(BaseLinearOperatorIdentity): is_non_singular: Expect that this operator is non-singular. is_self_adjoint: Expect that this operator is equal to its hermitian transpose. - is_positive_definite: Expect that this operator is positive definite. + is_positive_definite: Expect that this operator is positive definite, + meaning the quadratic form `x^H A x` has positive real part for all + nonzero `x`. Note that we do not require the operator to be + self-adjoint to be positive-definite. See: + https://en.wikipedia.org/wiki/Positive-definite_matrix\ + #Extension_for_non_symmetric_matrices assert_proper_shapes: Python `bool`. If `False`, only perform static checks that initialization and method arguments have proper shape. If `True`, and static checks are inconclusive, add asserts to the graph. @@ -523,7 +528,7 @@ class LinearOperatorScaledIdentity(BaseLinearOperatorIdentity): is_positive_definite=None, assert_proper_shapes=False, name="LinearOperatorScaledIdentity"): - """Initialize a `LinearOperatorScaledIdentity`. + r"""Initialize a `LinearOperatorScaledIdentity`. The `LinearOperatorScaledIdentity` is initialized with `num_rows`, which determines the size of each identity matrix, and a `multiplier`, @@ -538,7 +543,12 @@ class LinearOperatorScaledIdentity(BaseLinearOperatorIdentity): is_non_singular: Expect that this operator is non-singular. is_self_adjoint: Expect that this operator is equal to its hermitian transpose. - is_positive_definite: Expect that this operator is positive definite. + is_positive_definite: Expect that this operator is positive definite, + meaning the quadratic form `x^H A x` has positive real part for all + nonzero `x`. Note that we do not require the operator to be + self-adjoint to be positive-definite. See: + https://en.wikipedia.org/wiki/Positive-definite_matrix\ + #Extension_for_non_symmetric_matrices assert_proper_shapes: Python `bool`. If `False`, only perform static checks that initialization and method arguments have proper shape. If `True`, and static checks are inconclusive, add asserts to the graph. diff --git a/tensorflow/contrib/linalg/python/ops/linear_operator_tril.py b/tensorflow/contrib/linalg/python/ops/linear_operator_tril.py index d72ec542e28..38461ce8a22 100644 --- a/tensorflow/contrib/linalg/python/ops/linear_operator_tril.py +++ b/tensorflow/contrib/linalg/python/ops/linear_operator_tril.py @@ -23,7 +23,6 @@ from tensorflow.contrib.linalg.python.ops import linear_operator_util from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops -from tensorflow.python.ops import check_ops from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops @@ -108,7 +107,7 @@ class LinearOperatorTriL(linear_operator.LinearOperator): is_self_adjoint=None, is_positive_definite=None, name="LinearOperatorTriL"): - """Initialize a `LinearOperatorTriL`. + r"""Initialize a `LinearOperatorTriL`. Args: tril: Shape `[B1,...,Bb, N, N]` with `b >= 0`, `N >= 0`. @@ -122,9 +121,10 @@ class LinearOperatorTriL(linear_operator.LinearOperator): real-valued diagonal entries. In this case it is advised to use `LinearOperatorDiag`. is_positive_definite: Expect that this operator is positive definite, - meaning the real part of all eigenvalues is positive. We do not require - the operator to be self-adjoint to be positive-definite. See: - https://en.wikipedia.org/wiki/Positive-definite_matrix + meaning the quadratic form `x^H A x` has positive real part for all + nonzero `x`. Note that we do not require the operator to be + self-adjoint to be positive-definite. See: + https://en.wikipedia.org/wiki/Positive-definite_matrix\ #Extension_for_non_symmetric_matrices name: A name for this `LinearOperator`. @@ -173,20 +173,6 @@ class LinearOperatorTriL(linear_operator.LinearOperator): self._diag, message="Singular operator: Diagonal contained zero values.") - def _assert_positive_definite(self): - if self.dtype.is_complex: - message = ( - "Diagonal operator had diagonal entries with non-positive real part, " - "thus was not positive definite.") - else: - message = ( - "Real diagonal operator had non-positive diagonal entries, " - "thus was not positive definite.") - - return check_ops.assert_positive( - math_ops.real(self._diag), - message=message) - def _apply(self, x, adjoint=False): return math_ops.matmul(self._tril, x, adjoint_a=adjoint) diff --git a/tensorflow/contrib/linalg/python/ops/linear_operator_udvh_update.py b/tensorflow/contrib/linalg/python/ops/linear_operator_udvh_update.py index 7c7776e624b..89b5c1ab1b9 100644 --- a/tensorflow/contrib/linalg/python/ops/linear_operator_udvh_update.py +++ b/tensorflow/contrib/linalg/python/ops/linear_operator_udvh_update.py @@ -170,6 +170,8 @@ class LinearOperatorUDVHUpdate(linear_operator.LinearOperator): Default is `None`, unless `base_operator` is positive-definite `v = None` (meaning `u=v`), and `is_diag_update_positive`, in which case this defaults to `True`. + Note that we say an operator is positive definite when the quadratic + form `x^H A x` has positive real part for all nonzero `x`. is_square: Expect that this operator acts like square [batch] matrices. name: A name for this `LinearOperator`. diff --git a/tensorflow/contrib/linalg/python/ops/linear_operator_util.py b/tensorflow/contrib/linalg/python/ops/linear_operator_util.py index 6e56fac2e3d..a52a235677f 100644 --- a/tensorflow/contrib/linalg/python/ops/linear_operator_util.py +++ b/tensorflow/contrib/linalg/python/ops/linear_operator_util.py @@ -87,13 +87,208 @@ def assert_compatible_matrix_dimensions(operator, x): assert_same_dd = check_ops.assert_equal( array_ops.shape(x)[-2], operator.domain_dimension_tensor(), - message=( - "Incompatible matrix dimensions. " - "shape[-2] of argument to be the same as this operator")) + message=("Incompatible matrix dimensions. " + "shape[-2] of argument to be the same as this operator")) return assert_same_dd +def assert_is_batch_matrix(tensor): + """Static assert that `tensor` has rank `2` or higher.""" + sh = tensor.get_shape() + if sh.ndims is not None and sh.ndims < 2: + raise ValueError( + "Expected [batch] matrix to have at least two dimensions. Found: " + "%s" % tensor) + + +def broadcast_matrix_batch_dims(batch_matrices, name=None): + """Broadcast leading dimensions of zero or more [batch] matrices. + + Example broadcasting one batch dim of two simple matrices. + + ```python + x = [[1, 2], + [3, 4]] # Shape [2, 2], no batch dims + + y = [[[1]]] # Shape [1, 1, 1], 1 batch dim of shape [1] + + x_bc, y_bc = broadcast_matrix_batch_dims([x, y]) + + x_bc + ==> [[[1, 2], + [3, 4]]] # Shape [1, 2, 2], 1 batch dim of shape [1]. + + y_bc + ==> same as y + ``` + + Example broadcasting many batch dims + + ```python + x = tf.random_normal(shape=(2, 3, 1, 4, 4)) + y = tf.random_normal(shape=(1, 3, 2, 5, 5)) + x_bc, y_bc = broadcast_matrix_batch_dims([x, y]) + + x_bc.shape + ==> (2, 3, 2, 4, 4) + + y_bc.shape + ==> (2, 3, 2, 5, 5) + ``` + + Args: + batch_matrices: Iterable of `Tensor`s, each having two or more dimensions. + name: A string name to prepend to created ops. + + Returns: + bcast_matrices: List of `Tensor`s, with `bcast_matricies[i]` containing + the values from `batch_matrices[i]`, with possibly broadcast batch dims. + + Raises: + ValueError: If any input `Tensor` is statically determined to have less + than two dimensions. + """ + with ops.name_scope( + name or "broadcast_matrix_batch_dims", values=batch_matrices): + check_ops.assert_proper_iterable(batch_matrices) + batch_matrices = list(batch_matrices) + + for i, mat in enumerate(batch_matrices): + batch_matrices[i] = ops.convert_to_tensor(mat) + assert_is_batch_matrix(batch_matrices[i]) + + if len(batch_matrices) < 2: + return batch_matrices + + # Try static broadcasting. + # bcast_batch_shape is the broadcast batch shape of ALL matrices. + # E.g. if batch_matrices = [x, y], with + # x.shape = [2, j, k] (batch shape = [2]) + # y.shape = [3, 1, l, m] (batch shape = [3, 1]) + # ==> bcast_batch_shape = [3, 2] + bcast_batch_shape = batch_matrices[0].get_shape()[:-2] + for mat in batch_matrices[1:]: + bcast_batch_shape = array_ops.broadcast_static_shape( + bcast_batch_shape, mat.get_shape()[:-2]) + if bcast_batch_shape.is_fully_defined(): + # The [1, 1] at the end will broadcast with anything. + bcast_shape = bcast_batch_shape.concatenate([1, 1]) + for i, mat in enumerate(batch_matrices): + if mat.get_shape()[:-2] != bcast_batch_shape: + batch_matrices[i] = _broadcast_to_shape(mat, bcast_shape) + return batch_matrices + + # Since static didn't work, do dynamic, which always copies data. + bcast_batch_shape = array_ops.shape(batch_matrices[0])[:-2] + for mat in batch_matrices[1:]: + bcast_batch_shape = array_ops.broadcast_dynamic_shape( + bcast_batch_shape, array_ops.shape(mat)[:-2]) + bcast_shape = array_ops.concat([bcast_batch_shape, [1, 1]], axis=0) + for i, mat in enumerate(batch_matrices): + batch_matrices[i] = _broadcast_to_shape(mat, bcast_shape) + + return batch_matrices + + +def _broadcast_to_shape(x, shape): + return x + array_ops.zeros(shape=shape, dtype=x.dtype) + + +def matmul_with_broadcast(a, + b, + transpose_a=False, + transpose_b=False, + adjoint_a=False, + adjoint_b=False, + a_is_sparse=False, + b_is_sparse=False, + name=None): + """Multiplies matrix `a` by matrix `b`, producing `a @ b`. + + The inputs must be matrices (or tensors of rank > 2, representing batches of + matrices). + + Both matrices must be of the same type. The supported types are: + `float16`, `float32`, `float64`, `int32`, `complex64`, `complex128`. + + Either matrix can be transposed or adjointed (conjugated and transposed) on + the fly by setting one of the corresponding flag to `True`. These are `False` + by default. + + If one or both of the matrices contain a lot of zeros, a more efficient + multiplication algorithm can be used by setting the corresponding + `a_is_sparse` or `b_is_sparse` flag to `True`. These are `False` by default. + This optimization is only available for plain matrices (rank-2 tensors) with + datatypes `bfloat16` or `float32`. + + For example: + + ```python + # A 2-batch of 3x4 matrices + a = tf.random_normal(shape=(2, 3, 4)) + + # A single 4x5 matrix + b = tf.random_normal(shape=(4, 5)) + + result = matmul_with_broadcast(a, b) + + result.shape + ==> (2, 3, 5) + + result[0,...] + ==> tf.matmul(a[0,...], b) + + result[1,...] + ==> tf.matmul(a[1,...], b) + ``` + + Args: + a: `Tensor` of type `float16`, `float32`, `float64`, `int32`, `complex64`, + `complex128` and `rank > 1`. + b: `Tensor` with same type as `a` having compatible matrix dimensions and + broadcastable batch dimensions. + transpose_a: If `True`, `a` is transposed before multiplication. + transpose_b: If `True`, `b` is transposed before multiplication. + adjoint_a: If `True`, `a` is conjugated and transposed before + multiplication. + adjoint_b: If `True`, `b` is conjugated and transposed before + multiplication. + a_is_sparse: If `True`, `a` is treated as a sparse matrix. + b_is_sparse: If `True`, `b` is treated as a sparse matrix. + name: Name for the operation (optional). + + Returns: + A `Tensor` of the same type as `a` and `b` where each inner-most matrix is + the product of the corresponding matrices in `a` and `b`, e.g. if all + transpose or adjoint attributes are `False`: + + The leading shape of `output` is the result of broadcasting the leading + dimensions of `a` and `b`. + + `output`[..., i, j] = sum_k (`a`[..., i, k] * `b`[..., k, j]), + for all indices i, j. + + Note: This is matrix product, not element-wise product. + + + Raises: + ValueError: If transpose_a and adjoint_a, or transpose_b and adjoint_b + are both set to True. + """ + with ops.name_scope(name, "MatMulWithBroadcast", [a, b]) as name: + a, b = broadcast_matrix_batch_dims([a, b]) + return math_ops.matmul( + a, + b, + transpose_a=transpose_a, + transpose_b=transpose_b, + adjoint_a=adjoint_a, + adjoint_b=adjoint_b, + a_is_sparse=a_is_sparse, + b_is_sparse=b_is_sparse) + + def shape_tensor(shape, name=None): """Convert Tensor using default type, unless empty list or tuple.""" # Works just like random_ops._ShapeTensor. diff --git a/tensorflow/contrib/linear_optimizer/python/sdca_estimator.py b/tensorflow/contrib/linear_optimizer/python/sdca_estimator.py index d4c74a98df6..733b03eed36 100644 --- a/tensorflow/contrib/linear_optimizer/python/sdca_estimator.py +++ b/tensorflow/contrib/linear_optimizer/python/sdca_estimator.py @@ -270,6 +270,8 @@ def sdca_model_fn(features, labels, mode, params, config=None): with variable_scope.variable_op_scope(features.values(), parent_scope) as scope: + features = features.copy() + features.update(layers.transform_features(features, feature_columns)) logits, columns_to_variables, bias = ( layers.weighted_sum_from_feature_columns( columns_to_tensors=features, diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops.py b/tensorflow/contrib/metrics/python/ops/metric_ops.py index c1ba9d4eadf..90c1440f085 100644 --- a/tensorflow/contrib/metrics/python/ops/metric_ops.py +++ b/tensorflow/contrib/metrics/python/ops/metric_ops.py @@ -1879,11 +1879,11 @@ def streaming_pearson_correlation(predictions, math_ops.multiply(math_ops.sqrt(var_predictions), math_ops.sqrt(var_labels)), 'pearson_r') - with ops.control_dependencies( - [update_cov, update_var_predictions, update_var_labels]): - update_op = _safe_div(update_cov, math_ops.multiply( - math_ops.sqrt(update_var_predictions), - math_ops.sqrt(update_var_labels)), 'update_op') + update_op = _safe_div( + update_cov, + math_ops.multiply(math_ops.sqrt(update_var_predictions), + math_ops.sqrt(update_var_labels)), + 'update_op') if metrics_collections: ops.add_to_collections(metrics_collections, pearson_r) diff --git a/tensorflow/contrib/rnn/__init__.py b/tensorflow/contrib/rnn/__init__.py index 88d9c8469a6..ce1ed7f491b 100644 --- a/tensorflow/contrib/rnn/__init__.py +++ b/tensorflow/contrib/rnn/__init__.py @@ -42,6 +42,7 @@ See @{$python/contrib.rnn} guide. @@GridLSTMCell @@BidirectionalGridLSTMCell @@NASCell +@@PhasedLSTMCell ### RNNCell wrappers @@AttentionCellWrapper diff --git a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py index 431065ef0b3..d3af4de7211 100644 --- a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py +++ b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py @@ -173,7 +173,6 @@ class RNNCellTest(test.TestCase): with self.test_session() as sess: num_units = 8 batch_size = 3 - input_size = 4 feature_size = 2 frequency_skip = 1 num_frequency_blocks = [1, 1] @@ -844,6 +843,45 @@ class RNNCellTest(test.TestCase): "be set to num_units at cell init."): cell(inputs, init_state) + def testPhasedLSTMCell(self): + with self.test_session() as sess: + num_units = 2 + batch_size = 3 + input_size = 4 + expected_state_c = np.array( + [[2.954548e-01, 8.354891e-04], + [2.834632e-01, 8.158963e-01], + [2.291694e-01, 1.325745e-04]], + dtype=np.float32) + expected_state_h = np.array( + [[2.116566e-01, 5.985238e-04], + [2.137760e-01, 6.153145e-01], + [1.742966e-01, 1.008306e-04]], + dtype=np.float32) + with variable_scope.variable_scope( + "root", initializer=init_ops.constant_initializer(0.5)): + t = array_ops.zeros([batch_size, 1], dtype=dtypes.float64) + x = array_ops.zeros([batch_size, input_size]) + c0 = array_ops.zeros([batch_size, 2]) + h0 = array_ops.zeros([batch_size, 2]) + state0 = core_rnn_cell_impl.LSTMStateTuple(c0, h0) + output, state = rnn_cell.PhasedLSTMCell(num_units=num_units)((t, x), + state0) + sess.run([variables.global_variables_initializer()]) + res = sess.run([output, state], { + t.name: + np.array([[1.], [2.], [3.]]), + x.name: + np.array([[1., 1., 1., 1.], + [2., 2., 2., 2.], + [3., 3., 3., 3.]]), + }) + # This is a smoke test, making sure expected values are unchanged. + self.assertEqual(len(res), 2) + self.assertAllClose(res[0], res[1].h) + self.assertAllClose(res[1].c, expected_state_c) + self.assertAllClose(res[1].h, expected_state_h) + class LayerNormBasicLSTMCellTest(test.TestCase): diff --git a/tensorflow/contrib/rnn/python/ops/rnn_cell.py b/tensorflow/contrib/rnn/python/ops/rnn_cell.py index c447b52f660..2cd18142131 100644 --- a/tensorflow/contrib/rnn/python/ops/rnn_cell.py +++ b/tensorflow/contrib/rnn/python/ops/rnn_cell.py @@ -33,6 +33,7 @@ from tensorflow.python.ops import clip_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import random_ops from tensorflow.python.ops import variable_scope as vs from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import nest @@ -1683,3 +1684,178 @@ class CompiledWrapper(core_rnn_cell.RNNCell): with jit.experimental_jit_scope(compile_ops=compile_ops): return self._cell(inputs, state, scope=scope) + + +def _random_exp_initializer(minval, + maxval, + seed=None, + dtype=dtypes.float32): + """Returns an exponential distribution initializer. + + Args: + minval: float or a scalar float Tensor. With value > 0. Lower bound of the + range of random values to generate. + maxval: float or a scalar float Tensor. With value > minval. Upper bound of + the range of random values to generate. + seed: An integer. Used to create random seeds. + dtype: The data type. + + Returns: + An initializer that generates tensors with an exponential distribution. + """ + + def _initializer(shape, dtype=dtype, partition_info=None): + del partition_info # Unused. + return math_ops.exp( + random_ops.random_uniform( + shape, + math_ops.log(minval), + math_ops.log(maxval), + dtype, + seed=seed)) + + return _initializer + + +class PhasedLSTMCell(core_rnn_cell.RNNCell): + """Phased LSTM recurrent network cell. + + https://arxiv.org/pdf/1610.09513v1.pdf + """ + + def __init__(self, + num_units, + use_peepholes=False, + leak=0.001, + ratio_on=0.1, + trainable_ratio_on=True, + period_init_min=1.0, + period_init_max=1000.0, + reuse=None): + """Initialize the Phased LSTM cell. + + Args: + num_units: int, The number of units in the Phased LSTM cell. + use_peepholes: bool, set True to enable peephole connections. + leak: float or scalar float Tensor with value in [0, 1]. Leak applied + during training. + ratio_on: float or scalar float Tensor with value in [0, 1]. Ratio of the + period during which the gates are open. + trainable_ratio_on: bool, weather ratio_on is trainable. + period_init_min: float or scalar float Tensor. With value > 0. + Minimum value of the initalized period. + The period values are initialized by drawing from the distribution: + e^U(log(period_init_min), log(period_init_max)) + Where U(.,.) is the uniform distribution. + period_init_max: float or scalar float Tensor. + With value > period_init_min. Maximum value of the initalized period. + reuse: (optional) Python boolean describing whether to reuse variables + in an existing scope. If not `True`, and the existing scope already has + the given variables, an error is raised. + """ + self._num_units = num_units + self._use_peepholes = use_peepholes + self._leak = leak + self._ratio_on = ratio_on + self._trainable_ratio_on = trainable_ratio_on + self._period_init_min = period_init_min + self._period_init_max = period_init_max + self._reuse = reuse + + @property + def state_size(self): + return core_rnn_cell.LSTMStateTuple(self._num_units, self._num_units) + + @property + def output_size(self): + return self._num_units + + def _mod(self, x, y): + """Modulo function that propagates x gradients.""" + return array_ops.stop_gradient(math_ops.mod(x, y) - x) + x + + def _get_cycle_ratio(self, time, phase, period): + """Compute the cycle ratio in the dtype of the time.""" + phase_casted = math_ops.cast(phase, dtype=time.dtype) + period_casted = math_ops.cast(period, dtype=time.dtype) + shifted_time = time - phase_casted + cycle_ratio = self._mod(shifted_time, period_casted) / period_casted + return math_ops.cast(cycle_ratio, dtype=dtypes.float32) + + def __call__(self, inputs, state, scope=None): + """Phased LSTM Cell. + + Args: + inputs: A tuple of 2 Tensor. + The first Tensor has shape [batch, 1], and type float32 or float64. + It stores the time. + The second Tensor has shape [batch, features_size], and type float32. + It stores the features. + state: core_rnn_cell.LSTMStateTuple, state from previous timestep. + scope: string, id of the variable scope. + + Returns: + A tuple containing: + - A Tensor of float32, and shape [batch_size, num_units], representing the + output of the cell. + - A core_rnn_cell.LSTMStateTuple, containing 2 Tensors of float32, shape + [batch_size, num_units], representing the new state and the output. + """ + with _checked_scope(self, scope or "phased_lstm_cell", reuse=self._reuse): + (c_prev, h_prev) = state + (time, x) = inputs + + in_mask_gates = [x, h_prev] + if self._use_peepholes: + in_mask_gates.append(c_prev) + + with vs.variable_scope("mask_gates"): + mask_gates = math_ops.sigmoid( + _linear(in_mask_gates, 2 * self._num_units, True)) + [input_gate, forget_gate] = array_ops.split( + axis=1, num_or_size_splits=2, value=mask_gates) + + with vs.variable_scope("new_input"): + new_input = math_ops.tanh( + _linear([x, h_prev], self._num_units, True)) + + new_c = (c_prev * forget_gate + input_gate * new_input) + + in_out_gate = [x, h_prev] + if self._use_peepholes: + in_out_gate.append(new_c) + + with vs.variable_scope("output_gate"): + output_gate = math_ops.sigmoid( + _linear(in_out_gate, self._num_units, True)) + + new_h = math_ops.tanh(new_c) * output_gate + + period = vs.get_variable( + "period", [self._num_units], + initializer=_random_exp_initializer( + self._period_init_min, self._period_init_max)) + phase = vs.get_variable( + "phase", [self._num_units], + initializer=init_ops.random_uniform_initializer( + 0., period.initial_value)) + ratio_on = vs.get_variable( + "ratio_on", [self._num_units], + initializer=init_ops.constant_initializer(self._ratio_on), + trainable=self._trainable_ratio_on) + + cycle_ratio = self._get_cycle_ratio(time, phase, period) + + k_up = 2 * cycle_ratio / ratio_on + k_down = 2 - k_up + k_closed = self._leak * cycle_ratio + + k = array_ops.where(cycle_ratio < ratio_on, k_down, k_closed) + k = array_ops.where(cycle_ratio < 0.5 * ratio_on, k_up, k) + + new_c = k * new_c + (1 - k) * c_prev + new_h = k * new_h + (1 - k) * h_prev + + new_state = core_rnn_cell.LSTMStateTuple(new_c, new_h) + + return new_h, new_state diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py index 2aacce9e9c0..9c3015ff250 100644 --- a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function # pylint: enable=unused-import +import sys import functools import numpy as np @@ -38,6 +39,15 @@ from tensorflow.python.util import nest # pylint: enable=g-import-not-at-top +# for testing +AttentionWrapperState = wrapper.AttentionWrapperState # pylint: disable=invalid-name +LSTMStateTuple = core_rnn_cell.LSTMStateTuple # pylint: disable=invalid-name +BasicDecoderOutput = basic_decoder.BasicDecoderOutput # pylint: disable=invalid-name +float32 = np.float32 +int32 = np.int32 +array = np.array + + class AttentionWrapperTest(test.TestCase): def assertAllClose(self, *args, **kwargs): @@ -48,10 +58,11 @@ class AttentionWrapperTest(test.TestCase): def _testWithAttention(self, create_attention_mechanism, - expected_final_outputs, + expected_final_output, expected_final_state, attention_mechanism_depth=3, - attention_history=False): + attention_history=False, + name=""): encoder_sequence_length = [3, 2, 3, 1, 0] decoder_sequence_length = [2, 0, 1, 2, 3] batch_size = 5 @@ -126,7 +137,13 @@ class AttentionWrapperTest(test.TestCase): "state_attention_history": state_attention_history, }) - nest.map_structure(self.assertAllClose, expected_final_outputs, + print("Copy/paste (%s)\nexpected_final_output = " % name, + sess_results["final_outputs"]) + sys.stdout.flush() + print("Copy/paste (%s)\nexpected_final_state = " % name, + sess_results["final_state"]) + sys.stdout.flush() + nest.map_structure(self.assertAllClose, expected_final_output, sess_results["final_outputs"]) nest.map_structure(self.assertAllClose, expected_final_state, sess_results["final_state"]) @@ -137,533 +154,534 @@ class AttentionWrapperTest(test.TestCase): np.transpose(sess_results["final_outputs"].rnn_output, (1, 0, 2))) - def testBahndahauNotNormalized(self): + def testBahdanauNotNormalized(self): create_attention_mechanism = wrapper.BahdanauAttention - array = np.array - float32 = np.float32 - int32 = np.int32 - - expected_final_outputs = basic_decoder.BasicDecoderOutput( + expected_final_output = BasicDecoderOutput( rnn_output=array( [[[ - 1.25166783e-02, -6.88887993e-03, 3.17239435e-03, - -1.98234897e-03, 4.77387803e-03, -1.38330357e-02 + 1.89980457e-03, 1.89681584e-03, 2.05339328e-03, -3.83376027e-03, + -4.31808922e-03, -6.45466987e-03 ], [ - 1.28883058e-02, -6.76271692e-03, 3.13419267e-03, - -2.02183682e-03, 5.62057737e-03, -1.35373026e-02 + 2.27232254e-03, 2.02509761e-03, 2.01666891e-03, -3.87230632e-03, + -3.47119337e-03, -6.15991233e-03 ], [ - 1.24917831e-02, -6.71574520e-03, 3.42238229e-03, - -1.79501204e-03, 5.33161033e-03, -1.36620644e-02 + 1.87640532e-03, 2.07374478e-03, 2.30582547e-03, -3.64564802e-03, + -3.75995948e-03, -6.28685066e-03 ]], [[ - 1.55150667e-02, -1.07274549e-02, 4.44198400e-03, - -9.73310322e-04, 1.27242506e-02, -1.21861566e-02 + 4.89835022e-03, -1.94158917e-03, 3.32316267e-03, + -2.82446202e-03, 3.63192149e-03, -4.80734091e-03 ], [ - 1.57585666e-02, -1.07965544e-02, 4.61554807e-03, - -1.01510016e-03, 1.22341057e-02, -1.27029382e-02 + 5.14256489e-03, -2.00877781e-03, 3.49807227e-03, + -2.86567654e-03, 3.14202951e-03, -5.32575324e-03 ], [ - 1.58304181e-02, -1.09712025e-02, 4.67861444e-03, - -1.03920139e-03, 1.23004699e-02, -1.25949886e-02 + 5.21511910e-03, -2.18198029e-03, 3.56219849e-03, + -2.88951304e-03, 3.20866983e-03, -5.21918852e-03 ]], [[ - 9.26700700e-03, -9.75431874e-03, -9.95740294e-04, - -1.27463136e-06, 3.81659716e-03, -1.64887272e-02 + -1.34951377e-03, -9.68646549e-04, -2.11444520e-03, + -1.85243192e-03, -5.27541339e-03, -9.10969637e-03 ], [ - 9.25191958e-03, -9.80092678e-03, -8.48566880e-04, - 5.02091134e-05, 3.46567202e-03, -1.67435352e-02 + -1.36390887e-03, -1.01293903e-03, -1.96592091e-03, + -1.80044665e-03, -5.62618347e-03, -9.36636236e-03 ], [ - 9.48173273e-03, -9.52653307e-03, -8.79382715e-04, - -3.07094306e-05, 4.05955408e-03, -1.67226996e-02 + -1.13357347e-03, -7.37126335e-04, -1.99582824e-03, + -1.88097963e-03, -5.03196474e-03, -9.34652984e-03 ]], [[ - 1.21462569e-02, -1.27578378e-02, 1.54045003e-04, 2.70257704e-03, - 7.79421115e-03, -8.14041123e-04 + 1.52963377e-03, -3.97205260e-03, -9.64675564e-04, + 8.51404853e-04, -1.29804458e-03, 6.56467676e-03 ], [ - 1.18412934e-02, -1.33513296e-02, 3.54760559e-05, 2.67801876e-03, - 6.99122995e-03, -9.46014654e-04 + 1.22557906e-03, -4.56343032e-03, -1.08188344e-03, + 8.27252632e-04, -2.10058759e-03, 6.43082103e-03 ], [ - 1.16087487e-02, -1.31632648e-02, -2.98853614e-04, - 2.49515846e-03, 6.92677684e-03, -6.92734495e-04 + 9.93478228e-04, -4.37378604e-03, -1.41531695e-03, + 6.44775166e-04, -2.16480484e-03, 6.68286439e-03 ]], [[ - 1.02377674e-02, -8.72955937e-03, 1.22555892e-03, 2.03830865e-03, - 8.93574394e-03, -7.28237582e-03 + -3.78854020e-04, 5.62231544e-05, 1.06837302e-04, 1.87137164e-04, + -1.56512906e-04, 9.63474595e-05 ], [ - 1.05115287e-02, -8.92531779e-03, 1.14568521e-03, 1.91635895e-03, - 8.94328393e-03, -7.39541650e-03 + -1.04306288e-04, -1.37411975e-04, 2.82689070e-05, + 6.56487318e-05, -1.48634164e-04, -1.84347919e-05 ], [ - 1.07398070e-02, -8.56867433e-03, 1.52354129e-03, 2.06834078e-03, - 9.36511997e-03, -7.64556089e-03 + 1.24452345e-04, 2.20821079e-04, 4.07114130e-04, 2.18028668e-04, + 2.73401442e-04, -2.69805576e-04 ]]], dtype=float32), sample_id=array( - [[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]], + [[2, 0, 2], [0, 0, 0], [1, 1, 1], [5, 5, 5], [3, 3, 2]], dtype=int32)) - expected_final_state = wrapper.AttentionWrapperState( - time=3, - attention_history=(), - cell_state=core_rnn_cell.LSTMStateTuple( + expected_final_state = AttentionWrapperState( + cell_state=LSTMStateTuple( c=array( [[ - -0.0220502, -0.008058, -0.00160266, 0.01609341, -0.01380513, - -0.00749483, -0.00816989, -0.01210028, 0.01795324 + -2.18963176e-02, -8.04424379e-03, -1.48289464e-03, + 1.61068402e-02, -1.37983467e-02, -7.57976994e-03, + -8.28560349e-03, -1.18737305e-02, 1.78835373e-02 ], [ - 0.01727026, -0.0142065, -0.00399991, 0.03195379, - -0.03547479, -0.02138772, -0.00610318, -0.00191625, - -0.01937846 + 1.74205080e-02, -1.41929444e-02, -3.88092734e-03, + 3.19708064e-02, -3.54689620e-02, -2.14698724e-02, + -6.21716119e-03, -1.69295724e-03, -1.94495302e-02 ], [ - -0.0116077, 0.00876439, -0.01641787, -0.01400803, - 0.01347527, -0.01036386, 0.00627491, -0.0096361, -0.00650565 + -1.14528481e-02, 8.77819210e-03, -1.62970200e-02, + -1.39963552e-02, 1.34831406e-02, -1.04494914e-02, + 6.16127765e-03, -9.41022579e-03, -6.57590060e-03 ], [ - -0.04763387, -0.01192631, -0.00019412, 0.04103886, - -0.00137999, 0.02126684, -0.02793711, -0.05467696, - -0.02912051 + -4.74753827e-02, -1.19123599e-02, -7.40140676e-05, + 4.10552323e-02, -1.36711076e-03, 2.11795457e-02, + -2.80460119e-02, -5.44509329e-02, -2.91906092e-02 ], [ - 0.02241185, -0.00141741, 0.01911988, 0.00547728, - -0.01280068, -0.00307024, -0.00494239, 0.02169247, - 0.01631995 + 2.25644894e-02, -1.40382675e-03, 1.92396250e-02, + 5.49034867e-03, -1.27930511e-02, -3.15603940e-03, + -5.05525898e-03, 2.19191350e-02, 1.62497871e-02 ]], dtype=float32), h=array( [[ - -1.10613741e-02, -3.98175791e-03, -8.15514475e-04, - 7.90482666e-03, -7.02390168e-03, -3.76394135e-03, - -4.16183751e-03, -6.17114361e-03, 8.95532221e-03 + -1.09840557e-02, -3.97477299e-03, -7.54582870e-04, + 7.91188516e-03, -7.02184858e-03, -3.80711886e-03, + -4.22059745e-03, -6.05464494e-03, 8.92061181e-03 ], [ - 8.60657450e-03, -7.17655150e-03, -1.94156705e-03, - 1.62583217e-02, -1.76821016e-02, -1.06200138e-02, - -3.01904045e-03, -9.57608980e-04, -9.95732192e-03 + 8.68131686e-03, -7.16938032e-03, -1.88384682e-03, + 1.62678920e-02, -1.76827926e-02, -1.06622791e-02, + -3.07528162e-03, -8.45885137e-04, -9.99388192e-03 ], [ - -5.78935863e-03, 4.49362956e-03, -8.13615043e-03, - -6.95384294e-03, 6.75151078e-03, -5.07845683e-03, - 3.11869266e-03, -4.72904649e-03, -3.20469099e-03 + -5.71205560e-03, 4.50050412e-03, -8.07640795e-03, + -6.94844872e-03, 6.75682165e-03, -5.12113515e-03, + 3.06208082e-03, -4.61743120e-03, -3.23931244e-03 ], [ - -2.38025561e-02, -5.89242764e-03, -9.76260417e-05, - 2.01697368e-02, -6.82076614e-04, 1.07111251e-02, - -1.42077375e-02, -2.70790439e-02, -1.44685479e-02 + -2.37231534e-02, -5.88526297e-03, -3.72226204e-05, + 2.01789513e-02, -6.75848918e-04, 1.06686354e-02, + -1.42624676e-02, -2.69628745e-02, -1.45034352e-02 ], [ - 1.11825848e-02, -6.99267141e-04, 9.82748345e-03, - 2.74566701e-03, -6.56377291e-03, -1.53681310e-03, - -2.48806458e-03, 1.10462429e-02, 7.97568541e-03 + 1.12585640e-02, -6.92534202e-04, 9.88917705e-03, + 2.75237625e-03, -6.56115822e-03, -1.57997780e-03, + -2.54477374e-03, 1.11598391e-02, 7.94144534e-03 ]], dtype=float32)), attention=array( [[ - 1.24917831e-02, -6.71574520e-03, 3.42238229e-03, - -1.79501204e-03, 5.33161033e-03, -1.36620644e-02 + 0.00187641, 0.00207374, 0.00230583, -0.00364565, -0.00375996, + -0.00628685 ], [ - 1.58304181e-02, -1.09712025e-02, 4.67861444e-03, - -1.03920139e-03, 1.23004699e-02, -1.25949886e-02 + 0.00521512, -0.00218198, 0.0035622, -0.00288951, 0.00320867, + -0.00521919 ], [ - 9.48173273e-03, -9.52653307e-03, -8.79382715e-04, - -3.07094306e-05, 4.05955408e-03, -1.67226996e-02 + -0.00113357, -0.00073713, -0.00199583, -0.00188098, -0.00503196, + -0.00934653 ], [ - 1.16087487e-02, -1.31632648e-02, -2.98853614e-04, - 2.49515846e-03, 6.92677684e-03, -6.92734495e-04 + 0.00099348, -0.00437379, -0.00141532, 0.00064478, -0.0021648, + 0.00668286 ], [ - 1.07398070e-02, -8.56867433e-03, 1.52354129e-03, 2.06834078e-03, - 9.36511997e-03, -7.64556089e-03 + 0.00012445, 0.00022082, 0.00040711, 0.00021803, 0.0002734, + -0.00026981 ]], - dtype=float32)) - self._testWithAttention(create_attention_mechanism, expected_final_outputs, - expected_final_state, attention_history=True) + dtype=float32), + time=3, + attention_history=()) - def testBahndahauNormalized(self): + self._testWithAttention( + create_attention_mechanism, + expected_final_output, + expected_final_state, + attention_history=True, + name="testBahdanauNotNormalized") + + def testBahdanauNormalized(self): create_attention_mechanism = functools.partial( - wrapper.BahdanauAttention, normalize=True, attention_r_initializer=2.0) + wrapper.BahdanauAttention, normalize=True) - array = np.array - float32 = np.float32 - int32 = np.int32 - - expected_final_output = basic_decoder.BasicDecoderOutput( + expected_final_output = BasicDecoderOutput( rnn_output=array( [[[ - 1.72670335e-02, -5.83671592e-03, 6.38638902e-03, - -8.11776379e-04, 1.12681929e-03, -1.24236047e-02 + 6.64783875e-03, 2.94425711e-03, 5.26542449e-03, -2.64955591e-03, + -7.95925129e-03, -5.02286293e-03 ], [ - 1.75918192e-02, -5.73426578e-03, 6.29768707e-03, - -8.63141613e-04, 2.03352375e-03, -1.21420780e-02 + 7.01954123e-03, 3.07301106e-03, 5.22849336e-03, -2.68844375e-03, + -7.11239874e-03, -4.72904276e-03 ], [ - 1.72424167e-02, -5.66471322e-03, 6.63427915e-03, - -6.23903936e-04, 1.68706616e-03, -1.22524602e-02 + 6.62360899e-03, 3.12234787e-03, 5.51807694e-03, -2.46222341e-03, + -7.40198931e-03, -4.85701021e-03 ]], [[ - 1.79958157e-02, -9.80986748e-03, 4.73218597e-03, - -3.89962713e-03, 1.41502675e-02, -1.48344040e-02 + 7.37589924e-03, -1.02620223e-03, 3.61374952e-03, + -5.74620720e-03, 5.05625410e-03, -7.45209027e-03 ], [ - 1.82184577e-02, -9.88379307e-03, 4.90130857e-03, - -3.91892251e-03, 1.36479288e-02, -1.53291579e-02 + 7.61946291e-03, -1.09287468e-03, 3.78817180e-03, + -5.78709645e-03, 4.56611114e-03, -7.96987582e-03 ], [ - 1.83001235e-02, -1.00617753e-02, 4.97077405e-03, - -3.94908339e-03, 1.37211196e-02, -1.52311027e-02 + 7.69207766e-03, -1.26582675e-03, 3.85218812e-03, + -5.81111759e-03, 4.63287206e-03, -7.86337163e-03 ]], [[ - 7.93476030e-03, -8.46967567e-03, -7.16930721e-04, - 4.37953044e-04, 1.04503892e-03, -1.82424393e-02 + -2.69413739e-03, 3.47183552e-04, -1.82145904e-03, + -1.39805069e-03, -8.05486552e-03, -1.08372131e-02 ], [ - 7.90629163e-03, -8.48819874e-03, -5.57833235e-04, - 5.02390554e-04, 6.79406337e-04, -1.84837580e-02 + -2.70848931e-03, 3.03293345e-04, -1.67230750e-03, + -1.34555507e-03, -8.40565283e-03, -1.10935047e-02 ], [ - 8.14734399e-03, -8.23053624e-03, -5.92814526e-04, - 4.16347990e-04, 1.29250437e-03, -1.84548404e-02 + -2.47822329e-03, 5.79408603e-04, -1.70188327e-03, + -1.42583530e-03, -7.81180616e-03, -1.10740755e-02 ]], [[ - 1.21026095e-02, -1.26739489e-02, 1.78718648e-04, 2.68748170e-03, - 7.80996867e-03, -9.69076063e-04 + 1.48582947e-03, -3.88786104e-03, -9.39912978e-04, + 8.36255029e-04, -1.28223014e-03, 6.40908210e-03 ], [ - 1.17978491e-02, -1.32678337e-02, 6.00410858e-05, 2.66301399e-03, - 7.00691342e-03, -1.10030361e-03 + 1.18177081e-03, -4.47923271e-03, -1.05711201e-03, + 8.12121783e-04, -2.08477327e-03, 6.27523474e-03 ], [ - 1.15651665e-02, -1.30795036e-02, -2.74205930e-04, - 2.48012133e-03, 6.94250735e-03, -8.47495161e-04 + 9.49664740e-04, -4.28957958e-03, -1.39053771e-03, + 6.29657647e-04, -2.14899099e-03, 6.52727811e-03 ]], [[ - 1.02377674e-02, -8.72955937e-03, 1.22555892e-03, 2.03830865e-03, - 8.93574394e-03, -7.28237582e-03 + -3.78854020e-04, 5.62231544e-05, 1.06837302e-04, 1.87137164e-04, + -1.56512906e-04, 9.63474595e-05 ], [ - 1.05115287e-02, -8.92531779e-03, 1.14568521e-03, 1.91635895e-03, - 8.94328393e-03, -7.39541650e-03 + -1.04306288e-04, -1.37411975e-04, 2.82689070e-05, + 6.56487318e-05, -1.48634164e-04, -1.84347919e-05 ], [ - 1.07398070e-02, -8.56867433e-03, 1.52354129e-03, 2.06834078e-03, - 9.36511997e-03, -7.64556089e-03 + 1.24452345e-04, 2.20821079e-04, 4.07114130e-04, 2.18028668e-04, + 2.73401442e-04, -2.69805576e-04 ]]], dtype=float32), sample_id=array( - [[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]], + [[0, 0, 0], [0, 0, 0], [1, 1, 1], [5, 5, 5], [3, 3, 2]], dtype=int32)) - expected_final_state = wrapper.AttentionWrapperState( - time=3, - attention_history=(), - cell_state=core_rnn_cell.LSTMStateTuple( + expected_final_state = AttentionWrapperState( + cell_state=LSTMStateTuple( c=array( [[ - -0.02209264, -0.00794879, -0.00157153, 0.01614309, - -0.01383773, -0.00750943, -0.00824213, -0.01210296, - 0.01794949 + -2.19389871e-02, -7.93421268e-03, -1.45148858e-03, + 1.61569901e-02, -1.38310911e-02, -7.59426132e-03, + -8.35836027e-03, -1.18763093e-02, 1.78797375e-02 ], [ - 0.01726926, -0.01418139, -0.0040099, 0.0319339, -0.03545783, - -0.02142831, -0.00609501, -0.00195033, -0.01938949 + 1.74194798e-02, -1.41677596e-02, -3.89095861e-03, + 3.19508761e-02, -3.54519747e-02, -2.15105712e-02, + -6.20894879e-03, -1.72719418e-03, -1.94605980e-02 ], [ - -0.01159083, 0.0087524, -0.01639001, -0.01400012, - 0.01342422, -0.01041037, 0.00620991, -0.00960796, - -0.00650131 + -1.14357909e-02, 8.76635592e-03, -1.62690803e-02, + -1.39883338e-02, 1.34323873e-02, -1.04959216e-02, + 6.09614328e-03, -9.38197412e-03, -6.57159975e-03 ], [ - -0.04763237, -0.01192762, -0.00019377, 0.04103839, - -0.00138058, 0.02126443, -0.02793917, -0.05467755, - -0.02912025 + -4.74738739e-02, -1.19136795e-02, -7.36564398e-05, + 4.10547666e-02, -1.36771239e-03, 2.11771261e-02, + -2.80481018e-02, -5.44515178e-02, -2.91903559e-02 ], [ - 0.02241185, -0.00141741, 0.01911988, 0.00547728, - -0.01280068, -0.00307024, -0.00494239, 0.02169247, - 0.01631995 + 2.25644894e-02, -1.40382675e-03, 1.92396250e-02, + 5.49034867e-03, -1.27930511e-02, -3.15603940e-03, + -5.05525898e-03, 2.19191350e-02, 1.62497871e-02 ]], dtype=float32), h=array( [[ - -1.10821165e-02, -3.92766716e-03, -7.99638336e-04, - 7.92923011e-03, -7.04019284e-03, -3.77124036e-03, - -4.19876305e-03, -6.17261464e-03, 8.95325281e-03 + -1.10049099e-02, -3.92028037e-03, -7.38571223e-04, + 7.93652050e-03, -7.03821564e-03, -3.81436548e-03, + -4.25778655e-03, -6.05606195e-03, 8.91851448e-03 ], [ - 8.60597286e-03, -7.16368994e-03, -1.94644753e-03, - 1.62479617e-02, -1.76739115e-02, -1.06403306e-02, - -3.01484042e-03, -9.74688213e-04, -9.96260438e-03 + 8.68070032e-03, -7.15647917e-03, -1.88874488e-03, + 1.62575077e-02, -1.76745858e-02, -1.06826536e-02, + -3.07105901e-03, -8.63034453e-04, -9.99918394e-03 ], [ - -5.78098884e-03, 4.48751403e-03, -8.12216662e-03, - -6.94991415e-03, 6.72604749e-03, -5.10144979e-03, - 3.08637507e-03, -4.71517537e-03, -3.20256175e-03 + -5.70359221e-03, 4.49446775e-03, -8.06238409e-03, + -6.94446685e-03, 6.73149945e-03, -5.14409645e-03, + 3.02969781e-03, -4.60351165e-03, -3.23720207e-03 ], [ - -2.38018110e-02, -5.89307398e-03, -9.74484938e-05, - 2.01694984e-02, -6.82370039e-04, 1.07099237e-02, - -1.42087601e-02, -2.70793457e-02, -1.44684138e-02 + -2.37224046e-02, -5.88591257e-03, -3.70427515e-05, + 2.01787166e-02, -6.76146999e-04, 1.06674293e-02, + -1.42635051e-02, -2.69631781e-02, -1.45033030e-02 ], [ - 1.11825848e-02, -6.99267141e-04, 9.82748345e-03, - 2.74566701e-03, -6.56377291e-03, -1.53681310e-03, - -2.48806458e-03, 1.10462429e-02, 7.97568541e-03 + 1.12585640e-02, -6.92534202e-04, 9.88917705e-03, + 2.75237625e-03, -6.56115822e-03, -1.57997780e-03, + -2.54477374e-03, 1.11598391e-02, 7.94144534e-03 ]], dtype=float32)), attention=array( [[ - 0.01724242, -0.00566471, 0.00663428, -0.0006239, 0.00168707, - -0.01225246 + 0.00662361, 0.00312235, 0.00551808, -0.00246222, -0.00740199, + -0.00485701 ], [ - 0.01830012, -0.01006178, 0.00497077, -0.00394908, 0.01372112, - -0.0152311 + 0.00769208, -0.00126583, 0.00385219, -0.00581112, 0.00463287, + -0.00786337 ], [ - 0.00814734, -0.00823054, -0.00059281, 0.00041635, 0.0012925, - -0.01845484 + -0.00247822, 0.00057941, -0.00170188, -0.00142584, -0.00781181, + -0.01107408 ], [ - 0.01156517, -0.0130795, -0.00027421, 0.00248012, 0.00694251, - -0.0008475 + 0.00094966, -0.00428958, -0.00139054, 0.00062966, -0.00214899, + 0.00652728 ], [ - 0.01073981, -0.00856867, 0.00152354, 0.00206834, 0.00936512, - -0.00764556 + 0.00012445, 0.00022082, 0.00040711, 0.00021803, 0.0002734, + -0.00026981 ]], - dtype=float32)) + dtype=float32), + time=3, + attention_history=()) - self._testWithAttention(create_attention_mechanism, expected_final_output, - expected_final_state) + self._testWithAttention( + create_attention_mechanism, + expected_final_output, + expected_final_state, + name="testBahdanauNormalized") def testLuongNotNormalized(self): create_attention_mechanism = wrapper.LuongAttention - array = np.array - float32 = np.float32 - int32 = np.int32 - - expected_final_output = basic_decoder.BasicDecoderOutput( + expected_final_output = BasicDecoderOutput( rnn_output=array( [[[ - 1.23641128e-02, -6.82715839e-03, 3.24165262e-03, - -1.90772023e-03, 4.69654519e-03, -1.37025211e-02 + 1.74749165e-03, 1.95862399e-03, 2.12293095e-03, -3.75889172e-03, + -4.39571124e-03, -6.32379763e-03 ], [ - 1.29463980e-02, -6.79699238e-03, 3.10124992e-03, - -2.02869414e-03, 5.66399656e-03, -1.35517996e-02 + 2.33045570e-03, 1.99094601e-03, 1.98377599e-03, -3.87950847e-03, + -3.42792575e-03, -6.17497414e-03 ], [ - 1.22659411e-02, -6.81970268e-03, 3.15135531e-03, - -1.96937821e-03, 5.62768336e-03, -1.39173865e-02 + 1.65032526e-03, 1.96972815e-03, 2.03462853e-03, -3.82007333e-03, + -3.46369296e-03, -6.54224353e-03 ]], [[ - 1.53944232e-02, -1.07725551e-02, 4.42822604e-03, - -8.30623554e-04, 1.26549732e-02, -1.20573286e-02 + 4.77780215e-03, -1.98677275e-03, 3.30950436e-03, + -2.68179504e-03, 3.56271653e-03, -4.67860466e-03 ], [ - 1.57453734e-02, -1.08157266e-02, 4.62466478e-03, - -9.88351414e-04, 1.22286947e-02, -1.26876952e-02 + 5.13039157e-03, -2.02797214e-03, 3.50760575e-03, + -2.83981953e-03, 3.13726603e-03, -5.31156827e-03 ], [ - 1.57857724e-02, -1.09536834e-02, 4.64798324e-03, - -1.01319887e-03, 1.22695938e-02, -1.25500849e-02 + 5.17205056e-03, -2.16446724e-03, 3.53219034e-03, + -2.86490913e-03, 3.17879021e-03, -5.17592067e-03 ]], [[ - 9.23123397e-03, -9.42669343e-03, -9.09919385e-04, - 6.09827694e-05, 3.90436035e-03, -1.63374804e-02 + -1.38538703e-03, -6.40910701e-04, -2.02864106e-03, + -1.79018872e-03, -5.18789608e-03, -8.95875692e-03 ], [ - 9.22935922e-03, -9.57853813e-03, -7.92966573e-04, - 8.89014918e-05, 3.52671882e-03, -1.66499857e-02 + -1.38620089e-03, -7.92010222e-04, -1.91070826e-03, + -1.76206254e-03, -5.56525169e-03, -9.27332044e-03 ], [ - 9.49526206e-03, -9.39475093e-03, -8.49372707e-04, - -1.72815053e-05, 4.16132808e-03, -1.66336838e-02 + -1.11966045e-03, -6.07630936e-04, -1.96643686e-03, + -1.86803937e-03, -4.93048411e-03, -9.25842486e-03 ]], [[ - 1.21248290e-02, -1.27166547e-02, 1.66158192e-04, 2.69516627e-03, - 7.80194718e-03, -8.90152063e-04 + 1.50820788e-03, -3.93087184e-03, -9.52563598e-04, + 8.43994785e-04, -1.29030924e-03, 6.48857141e-03 ], [ - 1.17861275e-02, -1.32453050e-02, 6.66640699e-05, 2.65894993e-03, - 7.01114535e-03, -1.14195189e-03 + 1.17029145e-03, -4.45716921e-03, -1.05062663e-03, + 8.08141369e-04, -2.08062865e-03, 6.23444980e-03 ], [ - 1.15833860e-02, -1.31145213e-02, -2.84505659e-04, - 2.48642010e-03, 6.93593081e-03, -7.82784075e-04 + 9.67921398e-04, -4.32466762e-03, -1.40085898e-03, + 6.35969569e-04, -2.15558149e-03, 6.59212377e-03 ]], [[ - 1.02377674e-02, -8.72955937e-03, 1.22555892e-03, 2.03830865e-03, - 8.93574394e-03, -7.28237582e-03 + -3.78854020e-04, 5.62231544e-05, 1.06837302e-04, 1.87137164e-04, + -1.56512906e-04, 9.63474595e-05 ], [ - 1.05115287e-02, -8.92531779e-03, 1.14568521e-03, 1.91635895e-03, - 8.94328393e-03, -7.39541650e-03 + -1.04306288e-04, -1.37411975e-04, 2.82689070e-05, + 6.56487318e-05, -1.48634164e-04, -1.84347919e-05 ], [ - 1.07398070e-02, -8.56867433e-03, 1.52354129e-03, 2.06834078e-03, - 9.36511997e-03, -7.64556089e-03 + 1.24452345e-04, 2.20821079e-04, 4.07114130e-04, 2.18028668e-04, + 2.73401442e-04, -2.69805576e-04 ]]], dtype=float32), sample_id=array( - [[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]], + [[2, 0, 2], [0, 0, 0], [1, 1, 1], [5, 5, 5], [3, 3, 2]], dtype=int32)) - expected_final_state = wrapper.AttentionWrapperState( - time=3, - attention_history=(), - cell_state=core_rnn_cell.LSTMStateTuple( + + expected_final_state = AttentionWrapperState( + cell_state=LSTMStateTuple( c=array( [[ - -0.02204997, -0.00805805, -0.00160245, 0.01609369, - -0.01380494, -0.00749439, -0.00817, -0.01209992, 0.01795316 + -2.18960866e-02, -8.04429129e-03, -1.48267671e-03, + 1.61071159e-02, -1.37981661e-02, -7.57933082e-03, + -8.28570686e-03, -1.18733812e-02, 1.78834442e-02 ], [ - 0.01727016, -0.01420713, -0.00399972, 0.03195436, - -0.03547532, -0.02138666, -0.00610335, -0.00191557, - -0.01937821 + 1.74204130e-02, -1.41935758e-02, -3.88074201e-03, + 3.19713727e-02, -3.54694910e-02, -2.14688145e-02, + -6.21731905e-03, -1.69229065e-03, -1.94492843e-02 ], [ - -0.01160429, 0.00876595, -0.01641685, -0.01400784, - 0.01348004, -0.01036458, 0.00627241, -0.00963544, - -0.00650568 + -1.14494488e-02, 8.77974741e-03, -1.62960067e-02, + -1.39961652e-02, 1.34879015e-02, -1.04502086e-02, + 6.15879148e-03, -9.40956455e-03, -6.57592434e-03 ], [ - -0.04763246, -0.01192755, -0.00019379, 0.04103841, - -0.00138055, 0.02126456, -0.02793905, -0.0546775, - -0.02912027 + -4.74739634e-02, -1.19136050e-02, -7.36759976e-05, + 4.10547927e-02, -1.36767328e-03, 2.11772677e-02, + -2.80479677e-02, -5.44514805e-02, -2.91903690e-02 ], [ - 0.02241185, -0.00141741, 0.01911988, 0.00547728, - -0.01280068, -0.00307024, -0.00494239, 0.02169247, - 0.01631995 + 2.25644894e-02, -1.40382675e-03, 1.92396250e-02, + 5.49034867e-03, -1.27930511e-02, -3.15603940e-03, + -5.05525898e-03, 2.19191350e-02, 1.62497871e-02 ]], dtype=float32), h=array( [[ - -1.10612623e-02, -3.98178305e-03, -8.15406092e-04, - 7.90496264e-03, -7.02379830e-03, -3.76371504e-03, - -4.16189339e-03, -6.17096573e-03, 8.95528216e-03 + -1.09839402e-02, -3.97479767e-03, -7.54472159e-04, + 7.91201927e-03, -7.02175125e-03, -3.80689627e-03, + -4.22065007e-03, -6.05447078e-03, 8.92056432e-03 ], [ - 8.60652886e-03, -7.17687514e-03, -1.94147555e-03, - 1.62586085e-02, -1.76823605e-02, -1.06194830e-02, - -3.01912241e-03, -9.57269047e-04, -9.95719433e-03 + 8.68127123e-03, -7.16970162e-03, -1.88375649e-03, + 1.62681788e-02, -1.76830534e-02, -1.06617520e-02, + -3.07536125e-03, -8.45551898e-04, -9.99375992e-03 ], [ - -5.78764686e-03, 4.49441886e-03, -8.13564472e-03, - -6.95375400e-03, 6.75391173e-03, -5.07880514e-03, - 3.11744539e-03, -4.72871540e-03, -3.20470310e-03 + -5.71034756e-03, 4.50129062e-03, -8.07590690e-03, + -6.94835978e-03, 6.75921654e-03, -5.12148207e-03, + 3.06083867e-03, -4.61710012e-03, -3.23932176e-03 ], [ - -2.38018595e-02, -5.89303859e-03, -9.74571449e-05, - 2.01695058e-02, -6.82353624e-04, 1.07099945e-02, - -1.42086931e-02, -2.70793252e-02, -1.44684194e-02 + -2.37224493e-02, -5.88587578e-03, -3.70525813e-05, + 2.01787278e-02, -6.76127791e-04, 1.06675029e-02, + -1.42634306e-02, -2.69631632e-02, -1.45033058e-02 ], [ - 1.11825848e-02, -6.99267141e-04, 9.82748345e-03, - 2.74566701e-03, -6.56377291e-03, -1.53681310e-03, - -2.48806458e-03, 1.10462429e-02, 7.97568541e-03 + 1.12585640e-02, -6.92534202e-04, 9.88917705e-03, + 2.75237625e-03, -6.56115822e-03, -1.57997780e-03, + -2.54477374e-03, 1.11598391e-02, 7.94144534e-03 ]], dtype=float32)), attention=array( [[ - 1.22659411e-02, -6.81970268e-03, 3.15135531e-03, - -1.96937821e-03, 5.62768336e-03, -1.39173865e-02 + 0.00165033, 0.00196973, 0.00203463, -0.00382007, -0.00346369, + -0.00654224 ], [ - 1.57857724e-02, -1.09536834e-02, 4.64798324e-03, - -1.01319887e-03, 1.22695938e-02, -1.25500849e-02 + 0.00517205, -0.00216447, 0.00353219, -0.00286491, 0.00317879, + -0.00517592 ], [ - 9.49526206e-03, -9.39475093e-03, -8.49372707e-04, - -1.72815053e-05, 4.16132808e-03, -1.66336838e-02 + -0.00111966, -0.00060763, -0.00196644, -0.00186804, -0.00493048, + -0.00925842 ], [ - 1.15833860e-02, -1.31145213e-02, -2.84505659e-04, - 2.48642010e-03, 6.93593081e-03, -7.82784075e-04 + 0.00096792, -0.00432467, -0.00140086, 0.00063597, -0.00215558, + 0.00659212 ], [ - 1.07398070e-02, -8.56867433e-03, 1.52354129e-03, 2.06834078e-03, - 9.36511997e-03, -7.64556089e-03 + 0.00012445, 0.00022082, 0.00040711, 0.00021803, 0.0002734, + -0.00026981 ]], - dtype=float32)) + dtype=float32), + time=3, + attention_history=()) self._testWithAttention( create_attention_mechanism, expected_final_output, expected_final_state, - attention_mechanism_depth=9) + attention_mechanism_depth=9, + name="testLuongNotNormalized") - def testLuongNormalized(self): + def testLuongScaled(self): create_attention_mechanism = functools.partial( - wrapper.LuongAttention, normalize=True, attention_r_initializer=2.0) + wrapper.LuongAttention, scale=True) - array = np.array - float32 = np.float32 - int32 = np.int32 - - expected_final_output = basic_decoder.BasicDecoderOutput( + expected_final_output = BasicDecoderOutput( rnn_output=array( [[[ - 1.23956744e-02, -6.88115368e-03, 3.15234554e-03, - -1.97300944e-03, 4.79680905e-03, -1.38076628e-02 + 1.74749165e-03, 1.95862399e-03, 2.12293095e-03, -3.75889172e-03, + -4.39571124e-03, -6.32379763e-03 ], [ - 1.28376717e-02, -6.78718928e-03, 3.07988771e-03, - -2.03956687e-03, 5.68403490e-03, -1.35601182e-02 + 2.33045570e-03, 1.99094601e-03, 1.98377599e-03, -3.87950847e-03, + -3.42792575e-03, -6.17497414e-03 ], [ - 1.23463338e-02, -6.76322030e-03, 3.28891934e-03, - -1.86874042e-03, 5.47897862e-03, -1.37654068e-02 + 1.65032526e-03, 1.96972815e-03, 2.03462853e-03, -3.82007333e-03, + -3.46369296e-03, -6.54224353e-03 ]], [[ - 1.54412268e-02, -1.07613346e-02, 4.43824846e-03, - -8.81063985e-04, 1.26828086e-02, -1.21067995e-02 + 4.77780215e-03, -1.98677275e-03, 3.30950436e-03, + -2.68179504e-03, 3.56271653e-03, -4.67860466e-03 ], [ - 1.57206059e-02, -1.08218864e-02, 4.61952807e-03, - -9.61483689e-04, 1.22140013e-02, -1.26614980e-02 + 5.13039157e-03, -2.02797214e-03, 3.50760575e-03, + -2.83981953e-03, 3.13726603e-03, -5.31156827e-03 ], [ - 1.57821011e-02, -1.09842420e-02, 4.66934917e-03, - -9.85997496e-04, 1.22719472e-02, -1.25438003e-02 + 5.17205056e-03, -2.16446724e-03, 3.53219034e-03, + -2.86490913e-03, 3.17879021e-03, -5.17592067e-03 ]], [[ - 9.27361846e-03, -9.66077764e-03, -9.69522633e-04, - 1.48308463e-05, 3.88664147e-03, -1.64083000e-02 + -1.38538703e-03, -6.40910701e-04, -2.02864106e-03, + -1.79018872e-03, -5.18789608e-03, -8.95875692e-03 ], [ - 9.26287938e-03, -9.74234194e-03, -8.32488062e-04, - 5.83778601e-05, 3.52663640e-03, -1.66827720e-02 + -1.38620089e-03, -7.92010222e-04, -1.91070826e-03, + -1.76206254e-03, -5.56525169e-03, -9.27332044e-03 ], [ - 9.50474478e-03, -9.49789397e-03, -8.71829456e-04, - -3.09986062e-05, 4.13423358e-03, -1.66635048e-02 + -1.11966045e-03, -6.07630936e-04, -1.96643686e-03, + -1.86803937e-03, -4.93048411e-03, -9.25842486e-03 ]], [[ - 1.21398102e-02, -1.27454493e-02, 1.57688977e-04, 2.70034792e-03, - 7.79653806e-03, -8.36936757e-04 + 1.50820788e-03, -3.93087184e-03, -9.52563598e-04, + 8.43994785e-04, -1.29030924e-03, 6.48857141e-03 ], [ - 1.18234595e-02, -1.33170560e-02, 4.55579720e-05, 2.67185434e-03, - 6.99766818e-03, -1.00935437e-03 + 1.17029145e-03, -4.45716921e-03, -1.05062663e-03, + 8.08141369e-04, -2.08062865e-03, 6.23444980e-03 ], [ - 1.16009805e-02, -1.31483339e-02, -2.94458936e-04, - 2.49248254e-03, 6.92958105e-03, -7.20315147e-04 + 9.67921398e-04, -4.32466762e-03, -1.40085898e-03, + 6.35969569e-04, -2.15558149e-03, 6.59212377e-03 ]], [[ - 1.02377674e-02, -8.72955937e-03, 1.22555892e-03, 2.03830865e-03, - 8.93574394e-03, -7.28237582e-03 + -3.78854020e-04, 5.62231544e-05, 1.06837302e-04, 1.87137164e-04, + -1.56512906e-04, 9.63474595e-05 ], [ - 1.05115287e-02, -8.92531779e-03, 1.14568521e-03, 1.91635895e-03, - 8.94328393e-03, -7.39541650e-03 + -1.04306288e-04, -1.37411975e-04, 2.82689070e-05, + 6.56487318e-05, -1.48634164e-04, -1.84347919e-05 ], [ - 1.07398070e-02, -8.56867433e-03, 1.52354129e-03, 2.06834078e-03, - 9.36511997e-03, -7.64556089e-03 + 1.24452345e-04, 2.20821079e-04, 4.07114130e-04, 2.18028668e-04, + 2.73401442e-04, -2.69805576e-04 ]]], dtype=float32), sample_id=array( - [[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]], + [[2, 0, 2], [0, 0, 0], [1, 1, 1], [5, 5, 5], [3, 3, 2]], dtype=int32)) - expected_final_state = wrapper.AttentionWrapperState( - time=3, - attention_history=(), - cell_state=core_rnn_cell.LSTMStateTuple( + + expected_final_state = AttentionWrapperState( + cell_state=LSTMStateTuple( c=array( [[ - -0.02204949, -0.00805957, -0.001603, 0.01609283, - -0.01380462, -0.0074945, -0.00816895, -0.01210009, - 0.01795324 + -2.18960866e-02, -8.04429129e-03, -1.48267671e-03, + 1.61071159e-02, -1.37981661e-02, -7.57933082e-03, + -8.28570686e-03, -1.18733812e-02, 1.78834442e-02 ], [ - 0.01727016, -0.01420708, -0.00399973, 0.03195432, - -0.03547529, -0.02138673, -0.00610332, -0.00191565, - -0.01937822 + 1.74204130e-02, -1.41935758e-02, -3.88074201e-03, + 3.19713727e-02, -3.54694910e-02, -2.14688145e-02, + -6.21731905e-03, -1.69229065e-03, -1.94492843e-02 ], [ - -0.01160676, 0.00876512, -0.01641791, -0.01400807, - 0.01347767, -0.01036341, 0.00627499, -0.00963627, - -0.00650573 + -1.14494488e-02, 8.77974741e-03, -1.62960067e-02, + -1.39961652e-02, 1.34879015e-02, -1.04502086e-02, + 6.15879148e-03, -9.40956455e-03, -6.57592434e-03 ], [ - -0.04763342, -0.01192671, -0.00019402, 0.04103871, - -0.00138017, 0.02126611, -0.02793773, -0.05467714, - -0.02912043 + -4.74739634e-02, -1.19136050e-02, -7.36759976e-05, + 4.10547927e-02, -1.36767328e-03, 2.11772677e-02, + -2.80479677e-02, -5.44514805e-02, -2.91903690e-02 ], [ - 0.02241185, -0.00141741, 0.01911988, 0.00547728, - -0.01280068, -0.00307024, -0.00494239, 0.02169247, - 0.01631995 + 2.25644894e-02, -1.40382675e-03, 1.92396250e-02, + 5.49034867e-03, -1.27930511e-02, -3.15603940e-03, + -5.05525898e-03, 2.19191350e-02, 1.62497871e-02 ]], dtype=float32), h=array( [[ - -1.10610286e-02, -3.98253463e-03, -8.15684092e-04, - 7.90454168e-03, -7.02364743e-03, -3.76377185e-03, - -4.16135695e-03, -6.17104582e-03, 8.95532966e-03 + -1.09839402e-02, -3.97479767e-03, -7.54472159e-04, + 7.91201927e-03, -7.02175125e-03, -3.80689627e-03, + -4.22065007e-03, -6.05447078e-03, 8.92056432e-03 ], [ - 8.60653073e-03, -7.17685232e-03, -1.94147974e-03, - 1.62585936e-02, -1.76823437e-02, -1.06195193e-02, - -3.01911240e-03, -9.57308919e-04, -9.95720550e-03 + 8.68127123e-03, -7.16970162e-03, -1.88375649e-03, + 1.62681788e-02, -1.76830534e-02, -1.06617520e-02, + -3.07536125e-03, -8.45551898e-04, -9.99375992e-03 ], [ - -5.78888878e-03, 4.49400023e-03, -8.13617278e-03, - -6.95386063e-03, 6.75271638e-03, -5.07823005e-03, - 3.11873178e-03, -4.72912844e-03, -3.20472987e-03 + -5.71034756e-03, 4.50129062e-03, -8.07590690e-03, + -6.94835978e-03, 6.75921654e-03, -5.12148207e-03, + 3.06083867e-03, -4.61710012e-03, -3.23932176e-03 ], [ - -2.38023344e-02, -5.89262368e-03, -9.75721487e-05, - 2.01696623e-02, -6.82163402e-04, 1.07107637e-02, - -1.42080421e-02, -2.70791352e-02, -1.44685050e-02 + -2.37224493e-02, -5.88587578e-03, -3.70525813e-05, + 2.01787278e-02, -6.76127791e-04, 1.06675029e-02, + -1.42634306e-02, -2.69631632e-02, -1.45033058e-02 ], [ - 1.11825848e-02, -6.99267141e-04, 9.82748345e-03, - 2.74566701e-03, -6.56377291e-03, -1.53681310e-03, - -2.48806458e-03, 1.10462429e-02, 7.97568541e-03 + 1.12585640e-02, -6.92534202e-04, 9.88917705e-03, + 2.75237625e-03, -6.56115822e-03, -1.57997780e-03, + -2.54477374e-03, 1.11598391e-02, 7.94144534e-03 ]], dtype=float32)), attention=array( [[ - 1.23463338e-02, -6.76322030e-03, 3.28891934e-03, - -1.86874042e-03, 5.47897862e-03, -1.37654068e-02 + 0.00165033, 0.00196973, 0.00203463, -0.00382007, -0.00346369, + -0.00654224 ], [ - 1.57821011e-02, -1.09842420e-02, 4.66934917e-03, - -9.85997496e-04, 1.22719472e-02, -1.25438003e-02 + 0.00517205, -0.00216447, 0.00353219, -0.00286491, 0.00317879, + -0.00517592 ], [ - 9.50474478e-03, -9.49789397e-03, -8.71829456e-04, - -3.09986062e-05, 4.13423358e-03, -1.66635048e-02 + -0.00111966, -0.00060763, -0.00196644, -0.00186804, -0.00493048, + -0.00925842 ], [ - 1.16009805e-02, -1.31483339e-02, -2.94458936e-04, - 2.49248254e-03, 6.92958105e-03, -7.20315147e-04 + 0.00096792, -0.00432467, -0.00140086, 0.00063597, -0.00215558, + 0.00659212 ], [ - 1.07398070e-02, -8.56867433e-03, 1.52354129e-03, 2.06834078e-03, - 9.36511997e-03, -7.64556089e-03 + 0.00012445, 0.00022082, 0.00040711, 0.00021803, 0.0002734, + -0.00026981 ]], - dtype=float32)) + dtype=float32), + time=3, + attention_history=()) + self._testWithAttention( create_attention_mechanism, expected_final_output, expected_final_state, - attention_mechanism_depth=9) + attention_mechanism_depth=9, + name="testLuongScaled") if __name__ == "__main__": diff --git a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py index 1da06c0c093..ac79dbecefc 100644 --- a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py +++ b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py @@ -176,19 +176,18 @@ class LuongAttention(_BaseAttentionMechanism): "Effective Approaches to Attention-based Neural Machine Translation." EMNLP 2015. https://arxiv.org/abs/1508.04025 - The second is the normalized form. This form is inspired by the - normalization proposed for Bahdanau attention in - - Colin Raffel, Thang Luong, Peter J. Liu, Ron J. Weiss, and Douglas Eck. - "Online and Linear-Time Attention by Enforcing Monotonic Alignments." - (Eq. 15). + The second is the scaled form inspired partly by the normalized form of + Bahdanau attention. To enable the second form, construct the object with parameter - `normalize=True`. + `scale=True`. """ - def __init__(self, num_units, memory, memory_sequence_length=None, - normalize=False, attention_r_initializer=None, + def __init__(self, + num_units, + memory, + memory_sequence_length=None, + scale=False, name="LuongAttention"): """Construct the AttentionMechanism mechanism. @@ -199,31 +198,21 @@ class LuongAttention(_BaseAttentionMechanism): memory_sequence_length (optional): Sequence lengths for the batch entries in memory. If provided, the memory tensor rows are masked with zeros for values past the respective sequence lengths. - normalize: Python boolean. Whether to normalize the energy term. - attention_r_initializer: Initial value of the post-normalization bias - when normalizing. Default is `0`. + scale: Python boolean. Whether to scale the energy term. name: Name to use when creating ops. """ # For LuongAttention, we only transform the memory layer; thus # num_units **must** match expected the query depth. super(LuongAttention, self).__init__( query_layer=None, - memory_layer=layers_core.Dense(num_units, name="memory_layer"), + memory_layer=layers_core.Dense( + num_units, name="memory_layer", use_bias=False), memory=memory, memory_sequence_length=memory_sequence_length, name=name) self._num_units = num_units - self._normalize = normalize + self._scale = scale self._name = name - if normalize and attention_r_initializer is None: - attention_r_initializer = 0 - if normalize: - with ops.name_scope(name, "LuongAttention", - [memory, attention_r_initializer]): - attention_r_initializer = ops.convert_to_tensor( - attention_r_initializer, dtype=self.values.dtype, - name="attention_r_initializer") - self._attention_r_initializer = attention_r_initializer def __call__(self, query): """Score the query based on the keys and values. @@ -249,7 +238,7 @@ class LuongAttention(_BaseAttentionMechanism): % (query, depth, self.keys, key_units, key_units)) dtype = query.dtype - with ops.name_scope(None, "LuongAttentionCall", [query]): + with variable_scope.variable_scope(None, "luong_attention", [query]): # Reshape from [batch_size, depth] to [batch_size, 1, depth] # for matmul. query = array_ops.expand_dims(query, 1) @@ -266,16 +255,11 @@ class LuongAttention(_BaseAttentionMechanism): score = math_ops.matmul(query, self.keys, transpose_b=True) score = array_ops.squeeze(score, [1]) - if self._normalize: - # Scalar used in weight normalization + if self._scale: + # Scalar used in weight scaling g = variable_scope.get_variable( - "attention_g", dtype=dtype, - initializer=math.sqrt((1. / self._num_units))) - # Scalar bias added to attention scores - r = variable_scope.get_variable( - "attention_r", dtype=dtype, - initializer=self._attention_r_initializer) - score = g * score + r + "attention_g", dtype=dtype, initializer=1.) + score = g * score return score @@ -290,18 +274,23 @@ class BahdanauAttention(_BaseAttentionMechanism): "Neural Machine Translation by Jointly Learning to Align and Translate." ICLR 2015. https://arxiv.org/abs/1409.0473 - The second is the normalized form, Raffel attention, as described in: + The second is the normalized form. This form is inspired by the + weight normalization article: - Colin Raffel, Thang Luong, Peter J. Liu, Ron J. Weiss, and Douglas Eck. - "Online and Linear-Time Attention by Enforcing Monotonic Alignments." - (Eq. 15). + Tim Salimans, Diederik P. Kingma. + "Weight Normalization: A Simple Reparameterization to Accelerate + Training of Deep Neural Networks." + https://arxiv.org/abs/1602.07868 To enable the second form, construct the object with parameter `normalize=True`. """ - def __init__(self, num_units, memory, memory_sequence_length=None, - normalize=False, attention_r_initializer=None, + def __init__(self, + num_units, + memory, + memory_sequence_length=None, + normalize=False, name="BahdanauAttention"): """Construct the Attention mechanism. @@ -313,28 +302,19 @@ class BahdanauAttention(_BaseAttentionMechanism): in memory. If provided, the memory tensor rows are masked with zeros for values past the respective sequence lengths. normalize: Python boolean. Whether to normalize the energy term. - attention_r_initializer: Initial value of the post-normalization bias - when normalizing. Default is `0`. name: Name to use when creating ops. """ super(BahdanauAttention, self).__init__( - query_layer=layers_core.Dense(num_units, name="query_layer"), - memory_layer=layers_core.Dense(num_units, name="memory_layer"), + query_layer=layers_core.Dense( + num_units, name="query_layer", use_bias=False), + memory_layer=layers_core.Dense( + num_units, name="memory_layer", use_bias=False), memory=memory, memory_sequence_length=memory_sequence_length, name=name) self._num_units = num_units self._normalize = normalize self._name = name - if normalize and attention_r_initializer is None: - attention_r_initializer = 0 - if normalize: - with ops.name_scope(name, "BahdanauAttention", - [memory, attention_r_initializer]): - attention_r_initializer = ops.convert_to_tensor( - attention_r_initializer, dtype=self.values.dtype, - name="attention_r_initializer") - self._attention_r_initializer = attention_r_initializer def __call__(self, query): """Score the query based on the keys and values. @@ -347,7 +327,7 @@ class BahdanauAttention(_BaseAttentionMechanism): score: Tensor of dtype matching `self.values` and shape `[batch_size, max_time]` (`max_time` is memory's `max_time`). """ - with ops.name_scope(None, "BahndahauAttentionCall", [query]): + with variable_scope.variable_scope(None, "bahdanau_attention", [query]): processed_query = self.query_layer(query) if self.query_layer else query dtype = processed_query.dtype # Reshape from [batch_size, ...] to [batch_size, 1, ...] for broadcasting. @@ -363,15 +343,11 @@ class BahdanauAttention(_BaseAttentionMechanism): b = variable_scope.get_variable( "attention_b", [self._num_units], dtype=dtype, initializer=init_ops.zeros_initializer()) - # Scalar bias added to attention scores - r = variable_scope.get_variable( - "attention_r", dtype=dtype, - initializer=self._attention_r_initializer) # normed_v = g * v / ||v|| normed_v = g * v * math_ops.rsqrt( math_ops.reduce_sum(math_ops.square(v))) score = math_ops.reduce_sum( - normed_v * math_ops.tanh(self.keys + processed_query + b), [2]) + r + normed_v * math_ops.tanh(self.keys + processed_query + b), [2]) else: score = math_ops.reduce_sum( v * math_ops.tanh(self.keys + processed_query), [2]) @@ -481,7 +457,7 @@ class AttentionWrapper(core_rnn_cell.RNNCell): self._attention_mechanism = attention_mechanism self._attention_size = attention_size self._attention_layer = layers_core.Dense( - attention_size, bias_initializer=None) + attention_size, name="attention_layer", use_bias=False) self._cell_input_fn = cell_input_fn self._probability_fn = probability_fn self._output_attention = output_attention @@ -550,44 +526,44 @@ class AttentionWrapper(core_rnn_cell.RNNCell): if scope is not None: raise NotImplementedError("scope not None is not supported") - # Step 1: Calculate the true inputs to the cell based on the - # previous attention value. - cell_inputs = self._cell_input_fn(inputs, state.attention) - cell_state = state.cell_state + with variable_scope.variable_scope("attention"): + # Step 1: Calculate the true inputs to the cell based on the + # previous attention value. + cell_inputs = self._cell_input_fn(inputs, state.attention) + cell_state = state.cell_state - cell_output, next_cell_state = self._cell(cell_inputs, cell_state) + cell_output, next_cell_state = self._cell(cell_inputs, cell_state) - score = self._attention_mechanism(cell_output) - alignments = self._probability_fn(score) + score = self._attention_mechanism(cell_output) + alignments = self._probability_fn(score) - # Reshape from [batch_size, memory_time] to [batch_size, 1, memory_time] - alignments = array_ops.expand_dims(alignments, 1) - # Context is the inner product of alignments and values along the - # memory time dimension. - # alignments shape is - # [batch_size, 1, memory_time] - # attention_mechanism.values shape is - # [batch_size, memory_time, attention_mechanism.num_units] - # the batched matmul is over memory_time, so the output shape is - # [batch_size, 1, attention_mechanism.num_units]. - # we then squeeze out the singleton dim. - context = math_ops.matmul(alignments, self._attention_mechanism.values) - context = array_ops.squeeze(context, [1]) + # Reshape from [batch_size, memory_time] to [batch_size, 1, memory_time] + alignments = array_ops.expand_dims(alignments, 1) + # Context is the inner product of alignments and values along the + # memory time dimension. + # alignments shape is + # [batch_size, 1, memory_time] + # attention_mechanism.values shape is + # [batch_size, memory_time, attention_mechanism.num_units] + # the batched matmul is over memory_time, so the output shape is + # [batch_size, 1, attention_mechanism.num_units]. + # we then squeeze out the singleton dim. + context = math_ops.matmul(alignments, self._attention_mechanism.values) + context = array_ops.squeeze(context, [1]) - attention = self._attention_layer( - array_ops.concat([cell_output, context], 1)) + attention = self._attention_layer( + array_ops.concat([cell_output, context], 1)) - if self._attention_history: - attention_history = state.attention_history.write( - state.time, attention) - else: - attention_history = () + if self._attention_history: + attention_history = state.attention_history.write(state.time, attention) + else: + attention_history = () - next_state = AttentionWrapperState( - time=state.time + 1, - cell_state=next_cell_state, - attention=attention, - attention_history=attention_history) + next_state = AttentionWrapperState( + time=state.time + 1, + cell_state=next_cell_state, + attention=attention, + attention_history=attention_history) if self._output_attention: return attention, next_state diff --git a/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py b/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py index 3c539728495..b8561a8458b 100644 --- a/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py +++ b/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py @@ -269,7 +269,12 @@ class SparseTensor(ItemHandler): class Image(ItemHandler): """An ItemHandler that decodes a parsed Tensor as an image.""" - def __init__(self, image_key=None, format_key=None, shape=None, channels=3): + def __init__(self, + image_key=None, + format_key=None, + shape=None, + channels=3, + dtype=dtypes.uint8): """Initializes the image. Args: @@ -282,6 +287,11 @@ class Image(ItemHandler): accordingly. If left as None, no reshaping is done. A shape should be supplied only if all the stored images have the same shape. channels: the number of channels in the image. + dtype: images will be decoded at this bit depth. Different formats + support different bit depths. + See tf.image.decode_png, + tf.decode_raw, + tf.image.decode_jpeg: only supports tf.uint8 """ if not image_key: image_key = 'image/encoded' @@ -293,6 +303,7 @@ class Image(ItemHandler): self._format_key = format_key self._shape = shape self._channels = channels + self._dtype = dtype def tensors_to_item(self, keys_to_tensors): """See base class.""" @@ -314,12 +325,17 @@ class Image(ItemHandler): """ def decode_png(): - return image_ops.decode_png(image_buffer, self._channels) + return image_ops.decode_png( + image_buffer, self._channels, dtype=self._dtype) def decode_raw(): - return parsing_ops.decode_raw(image_buffer, dtypes.uint8) + return parsing_ops.decode_raw(image_buffer, out_type=self._dtype) def decode_jpg(): + if self._dtype != dtypes.uint8: + raise ValueError( + 'jpeg decoder can only be used to decode to tf.uint8 but %s was ' + 'requested for a jpeg image.' % self._dtype) return image_ops.decode_jpeg(image_buffer, self._channels) # For RGBA images JPEG is not a valid decoder option. @@ -401,6 +417,7 @@ class TFExampleDecoder(data_decoder.DataDecoder): """ example = parsing_ops.parse_single_example(serialized_example, self._keys_to_features) + print(example.keys()) # Reshape non-sparse elements just once: for k in self._keys_to_features: diff --git a/tensorflow/contrib/slim/python/slim/data/tfexample_decoder_test.py b/tensorflow/contrib/slim/python/slim/data/tfexample_decoder_test.py index 179b6d23c63..dd3c6a39a24 100644 --- a/tensorflow/contrib/slim/python/slim/data/tfexample_decoder_test.py +++ b/tensorflow/contrib/slim/python/slim/data/tfexample_decoder_test.py @@ -224,6 +224,18 @@ class TFExampleDecoderTest(test.TestCase): self.assertAllClose(image, decoded_image, atol=0) + def testDecodeExampleWithJpegEncodingAt16BitCausesError(self): + image_shape = (2, 3, 3) + unused_image, serialized_example = self.GenerateImage( + image_format='jpeg', image_shape=image_shape) + expected_regex = ('jpeg decoder can only be used to decode to tf.uint8 but ' + '.* was requested for a jpeg image.') + with self.assertRaisesRegexp(ValueError, expected_regex): + unused_decoded_image = self.RunDecodeExample( + serialized_example, + tfexample_decoder.Image(dtype=dtypes.uint16), + image_format='jpeg') + def testDecodeExampleWithStringTensor(self): tensor_shape = (2, 3, 1) np_array = np.array([[['ab'], ['cd'], ['ef']], diff --git a/tensorflow/contrib/tensorboard/BUILD b/tensorflow/contrib/tensorboard/BUILD index 1e7dd79ae77..06f8c9e18f7 100644 --- a/tensorflow/contrib/tensorboard/BUILD +++ b/tensorflow/contrib/tensorboard/BUILD @@ -45,6 +45,7 @@ py_library( deps = [ ":protos_all_py", "//tensorflow/python:lib", + "//tensorflow/tensorboard/plugins/projector:projector_plugin", ], ) diff --git a/tensorflow/contrib/tensorboard/plugins/projector/__init__.py b/tensorflow/contrib/tensorboard/plugins/projector/__init__.py index 09a8b592f7f..c11f5d065c2 100644 --- a/tensorflow/contrib/tensorboard/plugins/projector/__init__.py +++ b/tensorflow/contrib/tensorboard/plugins/projector/__init__.py @@ -12,7 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Public API for the Embedding Projector.""" +"""Public API for the Embedding Projector. + +@@ProjectorPluginAsset +@@ProjectorConfig +@@EmbeddingInfo +@@EmbeddingMetadata +@@SpriteMetadata +""" from __future__ import absolute_import from __future__ import division @@ -24,8 +31,10 @@ from google.protobuf import text_format from tensorflow.contrib.tensorboard.plugins.projector.projector_config_pb2 import EmbeddingInfo from tensorflow.contrib.tensorboard.plugins.projector.projector_config_pb2 import ProjectorConfig from tensorflow.python.lib.io import file_io - -PROJECTOR_FILENAME = 'projector_config.pbtxt' +from tensorflow.tensorboard.plugins.projector import projector_plugin +# pylint: disable=wildcard-import +from tensorflow.tensorboard.plugins.projector.projector_plugin import * +# pylint: enable=wildcard-import def visualize_embeddings(summary_writer, config): @@ -51,4 +60,4 @@ def visualize_embeddings(summary_writer, config): # Saving the config file in the logdir. config_pbtxt = text_format.MessageToString(config) file_io.write_string_to_file( - os.path.join(logdir, PROJECTOR_FILENAME), config_pbtxt) + os.path.join(logdir, projector_plugin.PROJECTOR_FILENAME), config_pbtxt) diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 79d44c5a0c7..62f30b813b2 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -486,6 +486,7 @@ cc_library( tf_gen_op_libs( op_lib_names = [ "array_ops", + "audio_ops", "candidate_sampling_ops", "control_flow_ops", "ctc_ops", @@ -553,6 +554,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":array_ops_op_lib", + ":audio_ops_op_lib", ":candidate_sampling_ops_op_lib", ":control_flow_ops_op_lib", ":ctc_ops_op_lib", diff --git a/tensorflow/core/debug/BUILD b/tensorflow/core/debug/BUILD index 381e25f1f72..65cbd14f14b 100644 --- a/tensorflow/core/debug/BUILD +++ b/tensorflow/core/debug/BUILD @@ -15,6 +15,7 @@ package( licenses(["notice"]) # Apache 2.0 +# Google-internal rules omitted. load( "//tensorflow:tensorflow.bzl", "check_deps", diff --git a/tensorflow/core/distributed_runtime/master.cc b/tensorflow/core/distributed_runtime/master.cc index 8c6c36d8ccc..facbfbb2643 100644 --- a/tensorflow/core/distributed_runtime/master.cc +++ b/tensorflow/core/distributed_runtime/master.cc @@ -42,6 +42,7 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/notification.h" #include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/macros.h" @@ -288,41 +289,37 @@ class DeviceFinder { void Master::CreateSession(const CreateSessionRequest* req, CreateSessionResponse* resp, MyClosure done) { SchedClosure([this, req, resp, done]() { - Status status = ValidateExternalGraphDefSyntax(req->graph_def()); - if (status.ok()) { - // Ping all the workers and build the list of devices that the - // session will use. - // TODO(saeta): Convert to std::make_unique when available. - std::unique_ptr>> remote_devices( - new std::vector>()); - status = DeviceFinder::GetRemoteDevices(req->config().device_filters(), - env_, env_->worker_cache, - remote_devices.get()); - if (!status.ok()) { - done(status); - return; - } - SessionOptions options; - options.config = req->config(); - MasterSession* session = env_->master_session_factory( - options, env_, std::move(remote_devices)); - GraphDef* gdef = - const_cast(req)->mutable_graph_def(); - Status create_status = session->Create(gdef); - if (!create_status.ok()) { - session->Close().IgnoreError(); - session->Unref(); - done(create_status); - return; - } - resp->set_session_handle(session->handle()); - // Insert into the session map, which takes ownership of the session. - { - mutex_lock l(mu_); - CHECK(sessions_.insert({session->handle(), session}).second); - } + Status status; + auto call_done = gtl::MakeCleanup([&status, &done] { done(status); }); + status = ValidateExternalGraphDefSyntax(req->graph_def()); + if (!status.ok()) return; + // Ping all the workers and build the list of devices that the + // session will use. + // TODO(saeta): Convert to std::make_unique when available. + std::unique_ptr>> remote_devices( + new std::vector>()); + status = DeviceFinder::GetRemoteDevices(req->config().device_filters(), + env_, env_->worker_cache, + remote_devices.get()); + if (!status.ok()) return; + SessionOptions options; + options.config = req->config(); + MasterSession* session = + env_->master_session_factory(options, env_, std::move(remote_devices)); + GraphDef* gdef = + const_cast(req)->mutable_graph_def(); + status = session->Create(gdef); + if (!status.ok()) { + session->Close().IgnoreError(); + session->Unref(); + return; + } + resp->set_session_handle(session->handle()); + // Insert into the session map, which takes ownership of the session. + { + mutex_lock l(mu_); + CHECK(sessions_.insert({session->handle(), session}).second); } - done(status); }); } diff --git a/tensorflow/core/distributed_runtime/master_session.cc b/tensorflow/core/distributed_runtime/master_session.cc index ce039c93de9..870df353cb6 100644 --- a/tensorflow/core/distributed_runtime/master_session.cc +++ b/tensorflow/core/distributed_runtime/master_session.cc @@ -991,8 +991,7 @@ MasterSession::MasterSession( stats_publisher_factory_(std::move(stats_publisher_factory)), graph_version_(0), run_graphs_(5), - partial_run_graphs_(5), - cancellation_manager_(new CancellationManager) { + partial_run_graphs_(5) { UpdateLastAccessTime(); VLOG(1) << "Session " << handle_ << " #local " << env->local_devices.size() @@ -1015,7 +1014,6 @@ MasterSession::MasterSession( } MasterSession::~MasterSession() { - delete cancellation_manager_; for (const auto& iter : run_graphs_) iter.second->Unref(); for (const auto& iter : partial_run_graphs_) iter.second->Unref(); } @@ -1317,7 +1315,7 @@ Status MasterSession::DoPartialRun(CallOptions* opts, Status s = run_state->rcg->RunPartitions( env_, run_state->step_id, run_state->count, execution_state_.get(), - &run_state->pss, opts, req, resp, cancellation_manager_, + &run_state->pss, opts, req, resp, &cancellation_manager_, is_last_partial_run); // Delete the run state if there is an error or all fetches are done. @@ -1388,7 +1386,7 @@ Status MasterSession::DoRunWithLocalExecution( Status s = rcg->RunPartitions(env_, step_id, count, execution_state_.get(), &pss, - opts, req, resp, cancellation_manager_, false); + opts, req, resp, &cancellation_manager_, false); if (s.ok()) { pss.end_micros = Env::Default()->NowMicros(); @@ -1423,7 +1421,7 @@ Status MasterSession::Close() { mutex_lock l(mu_); closed_ = true; // All subsequent calls to Run() or Extend() will fail. } - cancellation_manager_->StartCancel(); + cancellation_manager_.StartCancel(); std::vector to_unref; { mutex_lock l(mu_); @@ -1443,7 +1441,7 @@ void MasterSession::GarbageCollect() { closed_ = true; garbage_collected_ = true; } - cancellation_manager_->StartCancel(); + cancellation_manager_.StartCancel(); Unref(); } diff --git a/tensorflow/core/distributed_runtime/master_session.h b/tensorflow/core/distributed_runtime/master_session.h index 39206c2eaa7..8e0460bd14b 100644 --- a/tensorflow/core/distributed_runtime/master_session.h +++ b/tensorflow/core/distributed_runtime/master_session.h @@ -173,7 +173,7 @@ class MasterSession : public core::RefCounted { int64 next_node_id_ GUARDED_BY(mu_) = 0; // Used to cancel running steps on Close(). - CancellationManager* cancellation_manager_; + CancellationManager cancellation_manager_; // Private dtor. The client must call Close(). virtual ~MasterSession(); diff --git a/tensorflow/core/framework/tensor.cc b/tensorflow/core/framework/tensor.cc index 62ced697195..ecb9810d83c 100644 --- a/tensorflow/core/framework/tensor.cc +++ b/tensorflow/core/framework/tensor.cc @@ -531,7 +531,7 @@ void Tensor::UnsafeCopyFromInternal(const Tensor& other, DataType dtype, // one both for the SubBuffer _and_ the underlying TensorBuffer. bool Tensor::RefCountIsOne() const { return buf_ != nullptr && buf_->RefCountIsOne() && - buf_->root_buffer()->RefCountIsOne(); + buf_->root_buffer()->RefCountIsOne() && buf_->OwnsMemory(); } // The macro CASES() expands to a switch statement conditioned on diff --git a/tensorflow/core/framework/tensor.h b/tensorflow/core/framework/tensor.h index 041661cee73..103da4c1b37 100644 --- a/tensorflow/core/framework/tensor.h +++ b/tensorflow/core/framework/tensor.h @@ -502,6 +502,9 @@ class TensorBuffer : public core::RefCounted { T* base() const { return reinterpret_cast(data()); } + + // Whether this TensorBuffer owns the underlying memory. + virtual bool OwnsMemory() const { return true; } }; template diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index b05096b9183..8d4e3ad1ac3 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -847,6 +847,51 @@ tf_cc_test( ], ) +tf_cc_test( + name = "decode_wav_op_test", + size = "small", + srcs = ["decode_wav_op_test.cc"], + deps = [ + ":decode_wav_op", + ":ops_testutil", + ":ops_util", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:client_session", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:tensorflow", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ], +) + +tf_cc_test( + name = "encode_wav_op_test", + size = "small", + srcs = ["encode_wav_op_test.cc"], + deps = [ + ":decode_wav_op", + ":encode_wav_op", + ":ops_testutil", + ":ops_util", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:client_session", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:tensorflow", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ], +) + tf_cc_test( name = "example_parsing_ops_test", size = "large", @@ -1656,6 +1701,31 @@ tf_kernel_library( deps = IMAGE_DEPS, ) +tf_kernel_library( + name = "encode_wav_op", + prefix = "encode_wav_op", + deps = [ + ":bounds_check", + "//tensorflow/core:audio_ops_op_lib", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + ], +) + +tf_kernel_library( + name = "decode_wav_op", + prefix = "decode_wav_op", + deps = [ + "//tensorflow/core:audio_ops_op_lib", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + ], +) + tf_cc_tests( name = "eigen_test", size = "small", diff --git a/tensorflow/core/kernels/decode_wav_op.cc b/tensorflow/core/kernels/decode_wav_op.cc new file mode 100644 index 00000000000..4bd5d7ac2a6 --- /dev/null +++ b/tensorflow/core/kernels/decode_wav_op.cc @@ -0,0 +1,110 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// See docs in ../ops/audio_ops.cc + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/wav/wav_io.h" + +namespace tensorflow { + +// Decode the contents of a WAV file +class DecodeWavOp : public OpKernel { + public: + explicit DecodeWavOp(OpKernelConstruction* context) : OpKernel(context) { + OP_REQUIRES_OK(context, + context->GetAttr("desired_channels", &desired_channels_)); + OP_REQUIRES_OK(context, + context->GetAttr("desired_samples", &desired_samples_)); + } + + void Compute(OpKernelContext* context) override { + const Tensor& contents = context->input(0); + OP_REQUIRES(context, TensorShapeUtils::IsScalar(contents.shape()), + errors::InvalidArgument("contents must be scalar, got shape ", + contents.shape().DebugString())); + const string wav_string = contents.scalar()(); + OP_REQUIRES(context, wav_string.size() <= std::numeric_limits::max(), + errors::InvalidArgument("WAV contents are too large for int: ", + wav_string.size())); + + std::vector decoded_samples; + uint32 decoded_sample_count; + uint16 decoded_channel_count; + uint32 decoded_sample_rate; + OP_REQUIRES_OK(context, + wav::DecodeLin16WaveAsFloatVector( + wav_string, &decoded_samples, &decoded_sample_count, + &decoded_channel_count, &decoded_sample_rate)); + + int32 output_sample_count; + if (desired_samples_ == -1) { + output_sample_count = decoded_sample_count; + } else { + output_sample_count = desired_samples_; + } + int32 output_channel_count; + if (desired_channels_ == -1) { + output_channel_count = decoded_channel_count; + } else { + output_channel_count = desired_channels_; + } + + Tensor* output = nullptr; + OP_REQUIRES_OK( + context, + context->allocate_output( + 0, TensorShape({output_sample_count, output_channel_count}), + &output)); + + auto output_matrix = output->matrix(); + for (int sample = 0; sample < output_sample_count; ++sample) { + for (int channel = 0; channel < output_channel_count; ++channel) { + float output_value; + if (sample >= decoded_sample_count) { + output_value = 0.0f; + } else { + int source_channel; + if (channel < decoded_channel_count) { + source_channel = channel; + } else { + source_channel = decoded_channel_count - 1; + } + const int decoded_index = + (sample * decoded_channel_count) + source_channel; + output_value = decoded_samples[decoded_index]; + } + output_matrix(sample, channel) = output_value; + } + } + + Tensor* sample_rate_output = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(1, TensorShape({}), + &sample_rate_output)); + sample_rate_output->flat()(0) = decoded_sample_rate; + } + + private: + int32 desired_channels_; + int32 desired_samples_; +}; +REGISTER_KERNEL_BUILDER(Name("DecodeWav").Device(DEVICE_CPU), DecodeWavOp); + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/decode_wav_op_test.cc b/tensorflow/core/kernels/decode_wav_op_test.cc new file mode 100644 index 00000000000..c282d53a5a1 --- /dev/null +++ b/tensorflow/core/kernels/decode_wav_op_test.cc @@ -0,0 +1,86 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#define EIGEN_USE_THREADS + +#include +#include +#include + +#include "tensorflow/cc/client/client_session.h" +#include "tensorflow/cc/ops/audio_ops.h" +#include "tensorflow/cc/ops/const_op.h" +#include "tensorflow/cc/ops/math_ops.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { + +using namespace ops; // NOLINT(build/namespaces) + +TEST(DecodeWavOpTest, DecodeWavTest) { + Scope root = Scope::NewRootScope(); + + std::vector wav_data = { + 'R', 'I', 'F', 'F', 44, 0, 0, 0, // size of whole file - 8 + 'W', 'A', 'V', 'E', 'f', 'm', 't', ' ', 16, 0, 0, + 0, // size of fmt block - 8: 24 - 8 + 1, 0, // format: PCM (1) + 1, 0, // channels: 1 + 0x13, 0x37, 0, 0, // sample rate: 14099 + 0x26, 0x6e, 0, 0, // byte rate: 2 * 14099 + 2, 0, // block align: NumChannels * BytesPerSample + 16, 0, // bits per sample: 2 * 8 + 'd', 'a', 't', 'a', 8, 0, 0, 0, // size of payload: 8 + 0, 0, // first sample: 0 + 0xff, 0x3f, // second sample: 16383 + 0xff, 0x7f, // third sample: 32767 (saturated) + 0x00, 0x80, // fourth sample: -32768 (saturated) + }; + Tensor content_tensor = + test::AsScalar(string(wav_data.begin(), wav_data.end())); + Output content_op = + Const(root.WithOpName("content_op"), Input::Initializer(content_tensor)); + + DecodeWav decode_wav_op = + DecodeWav(root.WithOpName("decode_wav_op"), content_op); + + TF_ASSERT_OK(root.status()); + + ClientSession session(root); + std::vector outputs; + + TF_EXPECT_OK(session.Run(ClientSession::FeedType(), + {decode_wav_op.audio, decode_wav_op.sample_rate}, + &outputs)); + + const Tensor& audio = outputs[0]; + const int sample_rate = outputs[1].flat()(0); + + EXPECT_EQ(2, audio.dims()); + EXPECT_EQ(1, audio.dim_size(1)); + EXPECT_EQ(4, audio.dim_size(0)); + EXPECT_NEAR(0.0f, audio.flat()(0), 1e-4f); + EXPECT_NEAR(0.5f, audio.flat()(1), 1e-4f); + EXPECT_NEAR(1.0f, audio.flat()(2), 1e-4f); + EXPECT_NEAR(-1.0f, audio.flat()(3), 1e-4f); + EXPECT_EQ(14099, sample_rate); +} + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/encode_wav_op.cc b/tensorflow/core/kernels/encode_wav_op.cc new file mode 100644 index 00000000000..ad5835aeb46 --- /dev/null +++ b/tensorflow/core/kernels/encode_wav_op.cc @@ -0,0 +1,66 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// See docs in ../ops/audio_ops.cc + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/bounds_check.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/wav/wav_io.h" + +namespace tensorflow { + +// Encode a tensor as audio samples into the contents of a WAV format file. +class EncodeWavOp : public OpKernel { + public: + explicit EncodeWavOp(OpKernelConstruction* context) : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + const Tensor& audio = context->input(0); + OP_REQUIRES(context, audio.dims() == 2, + errors::InvalidArgument("audio must be 2-dimensional", + audio.shape().DebugString())); + const Tensor& sample_rate_tensor = context->input(1); + OP_REQUIRES(context, TensorShapeUtils::IsScalar(sample_rate_tensor.shape()), + errors::InvalidArgument( + "Input sample_rate should be a scalar tensor, got ", + sample_rate_tensor.shape().DebugString(), " instead.")); + const int32 sample_rate = sample_rate_tensor.scalar()(); + OP_REQUIRES( + context, + FastBoundsCheck(audio.NumElements(), std::numeric_limits::max()), + errors::InvalidArgument( + "Cannot encode audio with >= max int32 elements")); + + const int32 channel_count = static_cast(audio.dim_size(1)); + const int32 sample_count = static_cast(audio.dim_size(0)); + + // Encode audio to wav string. + Tensor* output = NULL; + OP_REQUIRES_OK(context, + context->allocate_output(0, TensorShape({}), &output)); + OP_REQUIRES_OK(context, + wav::EncodeAudioAsS16LEWav( + audio.flat().data(), sample_rate, channel_count, + sample_count, &output->scalar()())); + } +}; +REGISTER_KERNEL_BUILDER(Name("EncodeWav").Device(DEVICE_CPU), EncodeWavOp); + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/encode_wav_op_test.cc b/tensorflow/core/kernels/encode_wav_op_test.cc new file mode 100644 index 00000000000..2f92c13268b --- /dev/null +++ b/tensorflow/core/kernels/encode_wav_op_test.cc @@ -0,0 +1,80 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#define EIGEN_USE_THREADS + +#include +#include +#include + +#include "tensorflow/cc/client/client_session.h" +#include "tensorflow/cc/ops/audio_ops.h" +#include "tensorflow/cc/ops/const_op.h" +#include "tensorflow/cc/ops/math_ops.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { + +using namespace ops; // NOLINT(build/namespaces) + +TEST(EncodeWavOpTest, EncodeWavTest) { + Scope root = Scope::NewRootScope(); + + Tensor audio_tensor(DT_FLOAT, {4, 2}); + test::FillValues( + &audio_tensor, {0.0f, 0.5f, 1.0f, -1.0f, 0.25f, 0.75f, 1.25f, -0.5f}); + Output audio_op = + Const(root.WithOpName("audio_op"), Input::Initializer(audio_tensor)); + + Output sample_rate_op = Const(root.WithOpName("sample_rate_op"), 44100); + + EncodeWav encode_wav_op = + EncodeWav(root.WithOpName("encode_wav_op"), audio_op, sample_rate_op); + + DecodeWav decode_wav_op = + DecodeWav(root.WithOpName("decode_wav_op"), encode_wav_op); + + TF_ASSERT_OK(root.status()); + + ClientSession session(root); + std::vector outputs; + + TF_EXPECT_OK(session.Run(ClientSession::FeedType(), + {decode_wav_op.audio, decode_wav_op.sample_rate}, + &outputs)); + + const Tensor& audio = outputs[0]; + const int sample_rate = outputs[1].flat()(0); + + EXPECT_EQ(2, audio.dims()); + EXPECT_EQ(2, audio.dim_size(1)); + EXPECT_EQ(4, audio.dim_size(0)); + EXPECT_NEAR(0.0f, audio.flat()(0), 1e-4f); + EXPECT_NEAR(0.5f, audio.flat()(1), 1e-4f); + EXPECT_NEAR(1.0f, audio.flat()(2), 1e-4f); + EXPECT_NEAR(-1.0f, audio.flat()(3), 1e-4f); + EXPECT_NEAR(0.25f, audio.flat()(4), 1e-4f); + EXPECT_NEAR(0.75f, audio.flat()(5), 1e-4f); + EXPECT_NEAR(1.0f, audio.flat()(6), 1e-4f); + EXPECT_NEAR(-0.5f, audio.flat()(7), 1e-4f); + EXPECT_EQ(44100, sample_rate); +} + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/hexagon/graph_transfer_utils.cc b/tensorflow/core/kernels/hexagon/graph_transfer_utils.cc index 92bb6f91ee2..2fe7d993ae7 100644 --- a/tensorflow/core/kernels/hexagon/graph_transfer_utils.cc +++ b/tensorflow/core/kernels/hexagon/graph_transfer_utils.cc @@ -21,6 +21,10 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" namespace tensorflow { +// function alias +constexpr auto AddOutputTensorShapeTypeByTensorShapeMap = + &RemoteFusedGraphExecuteUtils::AddOutputTensorShapeTypeByTensorShapeMap; + /* static */ std::priority_queue> GraphTransferUtils::GetTopNFloatResults(const float* const data, const string* const labels, @@ -86,15 +90,19 @@ GraphTransferUtils::BuildRemoteFusedGraphExecuteInfo( const IGraphTransferOpsDefinitions& ops_definitions, const string& remote_graph_execute_name, const std::vector>& inputs, - const std::vector& outputs, const GraphDef& def, + const std::vector& outputs, GraphDef* original_def, GraphTransferer* const gt) { CHECK(gt != nullptr); RemoteFusedGraphExecuteUtils::TensorShapeMap tensor_shape_map; Status status = RemoteFusedGraphExecuteUtils::DryRunInferenceForAllNode( - def, inputs, true /* initialize_by_zero */, &tensor_shape_map); + *original_def, inputs, true /* initialize_by_zero */, &tensor_shape_map); + for (NodeDef& node_def : *original_def->mutable_node()) { + TF_CHECK_OK( + AddOutputTensorShapeTypeByTensorShapeMap(tensor_shape_map, &node_def)); + } CHECK(status.ok()); - status = gt->LoadGraphFromProto(ops_definitions, def, inputs, outputs, false, - tensor_shape_map); + status = gt->LoadGraphFromProto(ops_definitions, *original_def, inputs, + outputs, false); const DataType input_data_type = inputs.empty() ? DT_FLOAT : inputs.at(0).second.dtype(); diff --git a/tensorflow/core/kernels/hexagon/graph_transfer_utils.h b/tensorflow/core/kernels/hexagon/graph_transfer_utils.h index 80db3bc77c6..6dd203b7b8e 100644 --- a/tensorflow/core/kernels/hexagon/graph_transfer_utils.h +++ b/tensorflow/core/kernels/hexagon/graph_transfer_utils.h @@ -45,7 +45,7 @@ class GraphTransferUtils { const IGraphTransferOpsDefinitions& ops_definitions, const string& remote_graph_execute_name, const std::vector>& inputs, - const std::vector& outputs, const GraphDef& def, + const std::vector& outputs, GraphDef* original_def, GraphTransferer* gt); private: diff --git a/tensorflow/core/kernels/hexagon/graph_transferer.cc b/tensorflow/core/kernels/hexagon/graph_transferer.cc index 035b0e6935d..e3fc228cc70 100644 --- a/tensorflow/core/kernels/hexagon/graph_transferer.cc +++ b/tensorflow/core/kernels/hexagon/graph_transferer.cc @@ -29,6 +29,10 @@ limitations under the License. namespace tensorflow { +// function alias +constexpr auto AddOutputTensorShapeTypeByTensorShapeMap = + &RemoteFusedGraphExecuteUtils::AddOutputTensorShapeTypeByTensorShapeMap; + constexpr bool DBG_DUMP_VERIFICATION_STRING = false; constexpr bool DBG_DUMP_PARAMS = false; @@ -66,12 +70,10 @@ Status GraphTransferer::LoadGraphFromProto( const GraphDef& graph_def, const std::vector>& input_node_info_list, const std::vector& output_node_names, - const bool shape_inference_for_unknown_shape, - const TensorShapeMap& output_tensor_map) { - ImportGraphDefOptions opts; + const bool shape_inference_for_unknown_shape) { Graph graph(OpRegistry::Global()); ShapeRefiner shape_refiner(graph.versions().producer(), graph.op_registry()); - Status status = ImportGraphDef(opts, graph_def, &graph, &shape_refiner); + Status status = ImportGraphDef({}, graph_def, &graph, &shape_refiner); if (!status.ok()) { return status; } @@ -102,7 +104,7 @@ Status GraphTransferer::LoadGraphFromProto( for (const Node* const node : graph.nodes()) { status = RegisterNodeIfAllInputsAreCached( ops_definitions, shape_refiner, *node, false, input_node_info_list, - output_node_names, output_tensor_map); + output_node_names); if (!status.ok()) { LOG(ERROR) << "Failed to transfer graph " << status; return status; @@ -123,15 +125,26 @@ Status GraphTransferer::LoadGraphFromProto( } for (const string& output_node_name : output_node_names) { + const TensorId tid = ParseTensorName(output_node_name); + const string node_name = tid.first.ToString(); + const int port = tid.second; + const int node_id = node_name_to_id_cache_map_.at(node_name); + const Node* node = node_name_cache_list_.at(node_id); + CHECK_NOTNULL(node); + GraphTransferInfo::GraphOutputNodeInfo& graph_output_node_info = *graph_transfer_info_.add_graph_output_node_info(); - graph_output_node_info.set_name(output_node_name); - if (!output_tensor_map.empty()) { - const DataType* dt; - const TensorShape* shape; - CHECK(FindShapeType(output_tensor_map, output_node_name, &dt, &shape)); - graph_output_node_info.set_dtype(*dt); - for (const int64 dim : ToTensorShapeArray(*shape)) { + graph_output_node_info.set_name(strings::StrCat(node_name, ":", port)); + + // Get output tensor shape type + std::vector data_types; + std::vector shapes; + status = RemoteFusedGraphExecuteUtils::GetOutputTensorShapeType( + node->def(), &data_types, &shapes); + if (status.ok()) { + CHECK(data_types.size() > port); + graph_output_node_info.set_dtype(data_types.at(port)); + for (const int64 dim : ToTensorShapeArray(shapes.at(port))) { graph_output_node_info.add_shape(dim); } } @@ -156,8 +169,7 @@ Status GraphTransferer::LoadGraphFromProtoFile( const std::vector>& input_node_info_list, const std::vector& output_node_names, const bool is_text_proto, const bool shape_inference_for_unknown_shape, - const bool dry_run_for_unknown_shape, - RemoteFusedGraphExecuteUtils::TensorShapeMap* tensor_shape_map) { + const bool dry_run_for_unknown_shape) { GraphDef graph_def; string output; Status status; @@ -176,16 +188,21 @@ Status GraphTransferer::LoadGraphFromProtoFile( } if (dry_run_for_unknown_shape) { VLOG(1) << "Dry run graph to obtain shape of nodes"; + RemoteFusedGraphExecuteUtils::TensorShapeMap tensor_shape_map; status = RemoteFusedGraphExecuteUtils::DryRunInferenceForAllNode( - graph_def, input_node_info_list, true, tensor_shape_map); + graph_def, input_node_info_list, true, &tensor_shape_map); if (!status.ok()) { return status; } + for (NodeDef& node_def : *graph_def.mutable_node()) { + TF_CHECK_OK(AddOutputTensorShapeTypeByTensorShapeMap(tensor_shape_map, + &node_def)); + } } VLOG(1) << "Load graph with output tensors"; - return LoadGraphFromProto( - ops_definitions, graph_def, input_node_info_list, output_node_names, - shape_inference_for_unknown_shape, *tensor_shape_map); + return LoadGraphFromProto(ops_definitions, graph_def, input_node_info_list, + output_node_names, + shape_inference_for_unknown_shape); } void GraphTransferer::SortParams(const std::vector& output_node_names) { @@ -276,8 +293,7 @@ bool GraphTransferer::AreAllInputsCached(const Node& node) const { Status GraphTransferer::RegisterNode( const IGraphTransferOpsDefinitions& ops_definitions, - const ShapeRefiner& shape_refiner, const TensorShapeMap& output_tensor_map, - const Node& node, + const ShapeRefiner& shape_refiner, const Node& node, const std::vector>& input_node_info_list, const std::vector& output_node_names) { VLOG(1) << "Register node: " << node.name(); @@ -286,19 +302,16 @@ Status GraphTransferer::RegisterNode( return Status(); } else if (RemoteFusedGraphExecuteUtils::IsInputNode(input_node_info_list, node.name())) { - RegisterInputNode(ops_definitions, shape_refiner, output_tensor_map, node); + RegisterInputNode(ops_definitions, shape_refiner, node); } else if (node.IsConstant()) { - RegisterConstantNode(shape_refiner, node, output_tensor_map); + RegisterConstantNode(shape_refiner, node); } else if (HasPaddingAndStrides(node)) { - RegisterNodeWithPaddingAndStrides(ops_definitions, shape_refiner, - output_tensor_map, node); - } else if (IsNodeFlattenReshape(node, output_tensor_map, shape_refiner)) { - RegisterFlattenNode(ops_definitions, shape_refiner, output_tensor_map, - node); + RegisterNodeWithPaddingAndStrides(ops_definitions, shape_refiner, node); + } else if (IsNodeFlattenReshape(node, shape_refiner)) { + RegisterFlattenNode(ops_definitions, shape_refiner, node); } else if (ops_definitions.GetOpIdFor(node.type_string()) != IGraphTransferOpsDefinitions::INVALID_OP_ID) { - RegisterGenericNode(ops_definitions, shape_refiner, output_tensor_map, - node); + RegisterGenericNode(ops_definitions, shape_refiner, node); } else { return errors::InvalidArgument(node.type_string() + " has not been implemented yet."); @@ -307,9 +320,8 @@ Status GraphTransferer::RegisterNode( return Status(); } -void GraphTransferer::RegisterConstantNode( - const ShapeRefiner& shape_refiner, const Node& node, - const TensorShapeMap& output_tensor_map) { +void GraphTransferer::RegisterConstantNode(const ShapeRefiner& shape_refiner, + const Node& node) { VLOG(1) << "Register constant node: " << node.name(); CHECK_EQ(node_name_to_id_cache_map_.count(node.name()), 1); const int id = node_name_to_id_cache_map_[node.name()]; @@ -329,19 +341,12 @@ void GraphTransferer::RegisterConstantNode( context->NumElements(shape_handle); std::array shape_array; int data_size; - if (context->ValueKnown(num_elements_dim)) { - const int64 num_output_elements = context->Value(num_elements_dim); - data_size = max_bytes_per_data * num_output_elements; - shape_array = BuildShapeArray(shape_handle, context); - CheckShape(output_tensor_map, node.name(), shape_array); - } else { - // Use output tensor for unknown shape - const TensorShape* shape; - CHECK(FindShapeType(output_tensor_map, node.name(), nullptr, &shape)); - shape_array = ToTensorShapeArray(*shape); - data_size = max_bytes_per_data * shape->num_elements(); - } + // Shape of constant node must be known CHECK(context->ValueKnown(num_elements_dim)); + const int64 num_output_elements = context->Value(num_elements_dim); + data_size = max_bytes_per_data * num_output_elements; + shape_array = BuildShapeArray(shape_handle, context); + GraphTransferInfo::ConstNodeInfo& const_node_info = *graph_transfer_info_.add_const_node_info(); const_node_info.set_name(node.name()); @@ -353,14 +358,12 @@ void GraphTransferer::RegisterConstantNode( const_node_info.add_shape(shape_array[2]); const_node_info.add_shape(shape_array[3]); const TensorProto* proto = nullptr; - // TODO(b/32704451): Don't just ignore this status! - GetNodeAttr(node.def(), "value", &proto).IgnoreError(); + TF_CHECK_OK(GetNodeAttr(node.def(), "value", &proto)); Tensor const_tensor; // TODO(b/32704451): Don't just ignore this status! MakeTensorFromProto(*proto, &const_tensor).IgnoreError(); const_node_info.set_dtype(const_tensor.dtype()); - // TODO(satok): Remove. Determine constant value without dryrun if (data_size > 0) { const_node_info.set_data(const_tensor.tensor_data().data(), data_size); } @@ -395,9 +398,8 @@ bool GraphTransferer::HasPaddingAndStrides(const Node& node) { node.def().attr().count(STRIDES_ATTR_NAME) > 0; } -bool GraphTransferer::IsNodeFlattenReshape( - const Node& node, const TensorShapeMap& output_tensor_map, - const ShapeRefiner& shape_refiner) { +bool GraphTransferer::IsNodeFlattenReshape(const Node& node, + const ShapeRefiner& shape_refiner) { // Check if node is reshape op if (node.type_string() != RESHAPE_NODE_TYPE_STRING) { return false; @@ -418,10 +420,13 @@ bool GraphTransferer::IsNodeFlattenReshape( if (context->ValueKnown(dim_handle)) { shape_array = BuildShapeArray(shape_handle, context); } else { - // Use output tensor for unknown shape - const TensorShape* shape; - CHECK(FindShapeType(output_tensor_map, node.name(), nullptr, &shape)); - shape_array = ToTensorShapeArray(*shape); + std::vector shapes; + TF_CHECK_OK(RemoteFusedGraphExecuteUtils::GetOutputTensorShapeType( + node.def(), nullptr, &shapes)); + + // Number of outputs should be 1 for reshape node. + CHECK_EQ(1, shapes.size()); + shape_array = ToTensorShapeArray(shapes.at(0)); } // check if reshape op just does flatten @@ -434,8 +439,7 @@ bool GraphTransferer::IsNodeFlattenReshape( void GraphTransferer::RegisterNodeWithPaddingAndStrides( const IGraphTransferOpsDefinitions& ops_definitions, - const ShapeRefiner& shape_refiner, const TensorShapeMap& output_tensor_map, - const Node& node) { + const ShapeRefiner& shape_refiner, const Node& node) { CHECK_EQ(node_name_to_id_cache_map_.count(node.name()), 1); const int id = node_name_to_id_cache_map_[node.name()]; shape_inference::InferenceContext* context = shape_refiner.GetContext(&node); @@ -461,16 +465,14 @@ void GraphTransferer::RegisterNodeWithPaddingAndStrides( // Safety check of padding id CHECK(padding == Padding::VALID ? 1 : 2); AppendNodeParamsWithIoParams( - shape_refiner, output_tensor_map, node, node.name(), id, - node.type_string(), op_type_id, static_cast(padding), - node.num_inputs(), extra_inputs, node.num_outputs(), - true /* append_input */, true /* append_output */); + shape_refiner, node, node.name(), id, node.type_string(), op_type_id, + static_cast(padding), node.num_inputs(), extra_inputs, + node.num_outputs(), true /* append_input */, true /* append_output */); } void GraphTransferer::RegisterInputNode( const IGraphTransferOpsDefinitions& ops_definitions, - const ShapeRefiner& shape_refiner, const TensorShapeMap& output_tensor_map, - const Node& node) { + const ShapeRefiner& shape_refiner, const Node& node) { VLOG(1) << "Register input node: " << node.name(); CHECK_EQ(node_name_to_id_cache_map_.count(node.name()), 1); const int id = node_name_to_id_cache_map_[node.name()]; @@ -480,15 +482,14 @@ void GraphTransferer::RegisterInputNode( << "Op" << node.name() << ", " << op_type << " is not supported," << op_type_id; AppendNodeParamsWithIoParams( - shape_refiner, output_tensor_map, node, node.name(), id, - node.type_string(), op_type_id, PADDING_NA_ID, node.num_inputs(), {}, - node.num_outputs(), true /* append_input */, true /* append_output */); + shape_refiner, node, node.name(), id, node.type_string(), op_type_id, + PADDING_NA_ID, node.num_inputs(), {}, node.num_outputs(), + true /* append_input */, true /* append_output */); } void GraphTransferer::RegisterFlattenNode( const IGraphTransferOpsDefinitions& ops_definitions, - const ShapeRefiner& shape_refiner, const TensorShapeMap& output_tensor_map, - const Node& node) { + const ShapeRefiner& shape_refiner, const Node& node) { VLOG(1) << "Register flatten node: " << node.name(); CHECK_EQ(node_name_to_id_cache_map_.count(node.name()), 1); const int id = node_name_to_id_cache_map_[node.name()]; @@ -497,15 +498,14 @@ void GraphTransferer::RegisterFlattenNode( CHECK(op_type_id >= 0 && op_type_id < ops_definitions.GetTotalOpsCount()); AppendNodeParamsWithIoParams( - shape_refiner, output_tensor_map, node, node.name(), id, - node.type_string(), op_type_id, PADDING_NA_ID, node.num_inputs(), {}, - node.num_outputs(), true /* append_input */, true /* append_output */); + shape_refiner, node, node.name(), id, node.type_string(), op_type_id, + PADDING_NA_ID, node.num_inputs(), {}, node.num_outputs(), + true /* append_input */, true /* append_output */); } void GraphTransferer::RegisterGenericNode( const IGraphTransferOpsDefinitions& ops_definitions, - const ShapeRefiner& shape_refiner, const TensorShapeMap& output_tensor_map, - const Node& node) { + const ShapeRefiner& shape_refiner, const Node& node) { VLOG(1) << "Register generic node: " << node.name(); CHECK_EQ(node_name_to_id_cache_map_.count(node.name()), 1); const int id = node_name_to_id_cache_map_[node.name()]; @@ -513,9 +513,9 @@ void GraphTransferer::RegisterGenericNode( CHECK(op_type_id >= 0 && op_type_id < ops_definitions.GetTotalOpsCount()); AppendNodeParamsWithIoParams( - shape_refiner, output_tensor_map, node, node.name(), id, - node.type_string(), op_type_id, PADDING_NA_ID, node.num_inputs(), {}, - node.num_outputs(), true /* append_input */, true /* append_output */); + shape_refiner, node, node.name(), id, node.type_string(), op_type_id, + PADDING_NA_ID, node.num_inputs(), {}, node.num_outputs(), + true /* append_input */, true /* append_output */); } // TODO(satok): Remove this function. @@ -525,13 +525,12 @@ Status GraphTransferer::RegisterNodeIfAllInputsAreCached( const ShapeRefiner& shape_refiner, const Node& node, const bool only_register_const_node, const std::vector>& input_node_info_list, - const std::vector& output_node_names, - const TensorShapeMap& output_tensor_map) { + const std::vector& output_node_names) { if (only_register_const_node && !node.IsConstant()) { return Status(); } CHECK(AreAllInputsCached(node)); - return RegisterNode(ops_definitions, shape_refiner, output_tensor_map, node, + return RegisterNode(ops_definitions, shape_refiner, node, input_node_info_list, output_node_names); } @@ -583,14 +582,18 @@ void GraphTransferer::AppendNodeInputParams( } } -void GraphTransferer::AppendNodeOutputParams( - const ShapeRefiner& shape_refiner, const TensorShapeMap& output_tensor_map, - const int id, const Node& node) { +void GraphTransferer::AppendNodeOutputParams(const ShapeRefiner& shape_refiner, + const int id, const Node& node) { VLOG(1) << "Append output params: " << node.name() << ", " << node.num_outputs(); GraphTransferInfo::NodeOutputInfo& node_output_info = *graph_transfer_info_.add_node_output_info(); node_output_info.set_node_id(id); + + std::vector shapes; + Status status = RemoteFusedGraphExecuteUtils::GetOutputTensorShapeType( + node.def(), nullptr, &shapes); + for (int i = 0; i < node.num_outputs(); ++i) { int data_size = -1; const int output_index = i; @@ -605,19 +608,11 @@ void GraphTransferer::AppendNodeOutputParams( if (context->ValueKnown(num_elements_dim)) { const int64 num_output_elements = context->Value(num_elements_dim); data_size = max_bytes_per_data * num_output_elements; - if (!output_tensor_map.empty() && strict_check_mode_) { - const TensorShape* shape; - CHECK(FindShapeType(output_tensor_map, node.name(), nullptr, &shape)); - CHECK_EQ(num_output_elements, shape->num_elements()) - << "num elements of node " << node.name() << " doesn't match " - << num_output_elements << " vs " << shape->num_elements() << ", " - << node.type_string(); - } } else { - // Use TensorShapeMap for unknown shapes - const TensorShape* shape; - CHECK(FindShapeType(output_tensor_map, node.name(), nullptr, &shape)); - data_size = max_bytes_per_data * shape->num_elements(); + TF_CHECK_OK(status); + // Use attribute attached to node + CHECK_EQ(node.num_outputs(), shapes.size()) << node.name(); + data_size = max_bytes_per_data * shapes.at(i).num_elements(); } CHECK_GE(data_size, 0); node_output_info.add_max_byte_size(data_size); @@ -625,17 +620,17 @@ void GraphTransferer::AppendNodeOutputParams( } void GraphTransferer::AppendNodeParamsWithIoParams( - const ShapeRefiner& shape_refiner, const TensorShapeMap& output_tensor_map, - const Node& node, const string& name, const int id, const string& type, - const int type_id, const int padding, const int inputs_size, - const std::vector& extra_inputs, const int outputs_size, - const bool append_input_params, const bool append_output_params) { + const ShapeRefiner& shape_refiner, const Node& node, const string& name, + const int id, const string& type, const int type_id, const int padding, + const int inputs_size, const std::vector& extra_inputs, + const int outputs_size, const bool append_input_params, + const bool append_output_params) { VLOG(1) << "Append node with io params: " << node.name(); if (append_input_params) { AppendNodeInputParams(id, node, extra_inputs); } if (append_output_params) { - AppendNodeOutputParams(shape_refiner, output_tensor_map, id, node); + AppendNodeOutputParams(shape_refiner, id, node); } AppendNodeParams(name, id, type, type_id, padding, inputs_size, extra_inputs, outputs_size); @@ -711,22 +706,6 @@ GraphTransferer::ToTensorShapeArray(const TensorShape& shape) { } } -/* static */ void GraphTransferer::CheckShape( - const TensorShapeMap& output_tensor_map, const string& node_name, - const std::array& expected) { - if (output_tensor_map.empty()) { - // As output_tensor_map is empty, skip checking tensor shape. - return; - } - const TensorShape* shape; - CHECK(FindShapeType(output_tensor_map, node_name, nullptr, &shape)); - VLOG(1) << "Check shape for " << node_name; - const std::array actual = ToTensorShapeArray(*shape); - for (int i = 0; i < SHAPE_ARRAY_SIZE; ++i) { - CHECK_EQ(expected[i], actual[i]) << node_name; - } -} - GraphTransferer::TransferParamsComparator::TransferParamsComparator( const std::unordered_map>& dep_map) : dependency_map_(dep_map) {} @@ -807,32 +786,6 @@ bool GraphTransferer::TransferParamsComparator::operator()( tensor_proto.DebugString()); } -/* static */ bool GraphTransferer::FindShapeType( - const TensorShapeMap& tensor_shape_map, const string& name, const int port, - const DataType** dt, const TensorShape** shape) { - const std::pair* tensor_shape_type = - RemoteFusedGraphExecuteUtils::GetTensorShapeType(tensor_shape_map, name, - port); - if (tensor_shape_type == nullptr) { - return false; - } - if (dt != nullptr) { - *dt = &tensor_shape_type->first; - } - if (shape != nullptr) { - *shape = &tensor_shape_type->second; - } - return true; -} - -/* static */ bool GraphTransferer::FindShapeType( - const TensorShapeMap& tensor_shape_map, const string& name, - const DataType** dt, const TensorShape** shape) { - const TensorId tid = ParseTensorName(name); - return FindShapeType(tensor_shape_map, tid.first.ToString(), tid.second, dt, - shape); -} - void GraphTransferer::ClearCache() { node_name_cache_list_.clear(); node_name_to_id_cache_map_.clear(); diff --git a/tensorflow/core/kernels/hexagon/graph_transferer.h b/tensorflow/core/kernels/hexagon/graph_transferer.h index 7289e38293c..60b58fd5006 100644 --- a/tensorflow/core/kernels/hexagon/graph_transferer.h +++ b/tensorflow/core/kernels/hexagon/graph_transferer.h @@ -57,8 +57,7 @@ class GraphTransferer { const GraphDef& graph_def, const std::vector>& input_node_info_list, const std::vector& output_node_names, - const bool shape_inference_for_unkown_shape, - const TensorShapeMap& output_tensor_map); + const bool shape_inference_for_unkown_shape); // Load graph structure into GraphTransferer from protobuf file // TODO(satok): Pass a pair of TensorShape and DataType instead of @@ -69,8 +68,7 @@ class GraphTransferer { const std::vector>& input_node_info_list, const std::vector& output_node_names, const bool is_text_proto, const bool shape_inference_for_unknown_shape, - const bool dry_run_for_unknown_shape, - RemoteFusedGraphExecuteUtils::TensorShapeMap* tensor_shape_map); + const bool dry_run_for_unknown_shape); // Sort params so that all input nodes appear before consumer nodes. // CAVEAT: This may be slow if the number of nodes are too large @@ -106,13 +104,12 @@ class GraphTransferer { Status RegisterNode( const IGraphTransferOpsDefinitions& ops_definitions, - const ShapeRefiner& shape_refiner, - const TensorShapeMap& output_tensor_map, const Node& node, + const ShapeRefiner& shape_refiner, const Node& node, const std::vector>& input_node_info_list, const std::vector& output_node_names); - void RegisterConstantNode(const ShapeRefiner& shape_refiner, const Node& node, - const TensorShapeMap& output_tensor_map); + void RegisterConstantNode(const ShapeRefiner& shape_refiner, + const Node& node); int RegisterConstantShape(const std::vector& shape); @@ -122,27 +119,22 @@ class GraphTransferer { // TODO(satok): Remove this method once generic reshape op is implemented in // SOC bool IsNodeFlattenReshape(const Node& node, - const TensorShapeMap& output_tensor_map, const ShapeRefiner& shape_refiner); void RegisterNodeWithPaddingAndStrides( const IGraphTransferOpsDefinitions& ops_definitions, - const ShapeRefiner& shape_refiner, - const TensorShapeMap& output_tensor_map, const Node& node); + const ShapeRefiner& shape_refiner, const Node& node); void RegisterInputNode(const IGraphTransferOpsDefinitions& ops_definitions, const ShapeRefiner& shape_refiner, - const TensorShapeMap& output_tensor_map, const Node& node); void RegisterFlattenNode(const IGraphTransferOpsDefinitions& ops_definitions, const ShapeRefiner& shape_refiner, - const TensorShapeMap& output_tensor_map, const Node& node); void RegisterGenericNode(const IGraphTransferOpsDefinitions& ops_definitions, const ShapeRefiner& shape_refiner, - const TensorShapeMap& output_tensor_map, const Node& node); Status RegisterNodeIfAllInputsAreCached( @@ -150,8 +142,7 @@ class GraphTransferer { const ShapeRefiner& shape_refiner, const Node& node, const bool only_register_const_node, const std::vector>& input_node_info_list, - const std::vector& output_node_names, - const TensorShapeMap& output_tensor_map); + const std::vector& output_node_names); void AppendNodeParams(const string& name, const int id, const string& type, const int type_id, const int padding, @@ -163,7 +154,6 @@ class GraphTransferer { const std::vector& extra_inputs); void AppendNodeOutputParams(const ShapeRefiner& shape_refiner, - const TensorShapeMap& output_tensor_map, const int id, const Node& node); static std::array BuildShapeArray( @@ -171,22 +161,17 @@ class GraphTransferer { shape_inference::InferenceContext* context); void AppendNodeParamsWithIoParams( - const ShapeRefiner& shape_refiner, - const TensorShapeMap& output_tensor_map, const Node& node, - const string& name, const int id, const string& type, const int type_id, - const int padding, const int inputs_size, - const std::vector& extra_inputs, const int outputs_size, - const bool append_input_params, const bool append_output_params); + const ShapeRefiner& shape_refiner, const Node& node, const string& name, + const int id, const string& type, const int type_id, const int padding, + const int inputs_size, const std::vector& extra_inputs, + const int outputs_size, const bool append_input_params, + const bool append_output_params); static std::array ToTensorShapeArray( const TensorShape& shape); static string ToPaddingDebugString(int padding); - static void CheckShape(const TensorShapeMap& output_tensor_map, - const string& node_name, - const std::array& actual); - // Create dependency map static void FillDependencyRec( int node_id, std::unordered_map>& dep_map, @@ -196,14 +181,6 @@ class GraphTransferer { static Status MakeTensorFromProto(const TensorProto& tensor_proto, Tensor* tensor); - static bool FindShapeType(const TensorShapeMap& tensor_shape_map, - const string& name, const int port, - const DataType** dt, const TensorShape** shape); - - static bool FindShapeType(const TensorShapeMap& tensor_shape_map, - const string& name, const DataType** dt, - const TensorShape** shape); - void ClearCache(); // Dump pretty print of parameters diff --git a/tensorflow/core/kernels/hexagon/graph_transferer_test.cc b/tensorflow/core/kernels/hexagon/graph_transferer_test.cc index ad407e02bac..18419fcc50f 100644 --- a/tensorflow/core/kernels/hexagon/graph_transferer_test.cc +++ b/tensorflow/core/kernels/hexagon/graph_transferer_test.cc @@ -267,7 +267,7 @@ TEST_F(GraphTransfererTest, LoadAddGraph) { GraphDef def = CreateAddGraphDef(); ASSERT_TRUE(gt_.LoadGraphFromProto(TEST_GRAPH_TRANSFER_OPS_DEFINITIONS, def, {}, std::vector{NAME_A_PLUS_B}, - false, EMPTY_OUTPUT_TENSOR_MAP) + false) .ok()); SanityCheckNodes(gt_); @@ -308,9 +308,8 @@ TEST_F(GraphTransfererTest, LoadAddGraphWithOutputTensorMap) { def, inputs, {}, &output_tensor_info); ASSERT_TRUE(status.ok()) << status; const std::vector output_node_names = {NAME_A_PLUS_B}; - status = - gt_.LoadGraphFromProto(TEST_GRAPH_TRANSFER_OPS_DEFINITIONS, def, inputs, - output_node_names, false, output_tensor_info); + status = gt_.LoadGraphFromProto(TEST_GRAPH_TRANSFER_OPS_DEFINITIONS, def, + inputs, output_node_names, false); ASSERT_TRUE(status.ok()); } @@ -322,7 +321,7 @@ TEST_F(GraphTransfererTest, LoadConvGraph) { const std::vector output_node_names = {"softmax"}; ASSERT_TRUE(gt_.LoadGraphFromProto(TEST_GRAPH_TRANSFER_OPS_DEFINITIONS, def, input_node_info_list, output_node_names, - false, EMPTY_OUTPUT_TENSOR_MAP) + false) .ok()); SanityCheckNodes(gt_); const int const_node_count = @@ -348,7 +347,7 @@ TEST_F(GraphTransfererTest, LoadMaxPoolGraph) { const std::vector output_node_names = {"softmax"}; ASSERT_TRUE(gt_.LoadGraphFromProto(TEST_GRAPH_TRANSFER_OPS_DEFINITIONS, def, input_node_info_list, output_node_names, - false, EMPTY_OUTPUT_TENSOR_MAP) + false) .ok()); SanityCheckNodes(gt_); const int const_node_count = @@ -392,12 +391,11 @@ TEST(GraphTransferer, LoadGraphFromProtoFile) { // is_text_proto = false; // ops_definitions = &HexagonOpsDefinitions::getInstance(); - RemoteFusedGraphExecuteUtils::TensorShapeMap output_tensor_info; GraphTransferer gt; gt.EnableStrictCheckMode(false); Status status = gt.LoadGraphFromProtoFile( *ops_definitions, filename, input_node_info_list, output_node_names, - is_text_proto, false, true, &output_tensor_info); + is_text_proto, false, true); } TEST_F(GraphTransfererTest, BuildRemoteFusedGraphDefAddGraph) { @@ -416,7 +414,7 @@ TEST_F(GraphTransfererTest, BuildRemoteFusedGraphDefAddGraph) { GraphDef fused_graph_def = GraphTransferUtils::BuildFusedGraphDef( TEST_GRAPH_TRANSFER_OPS_DEFINITIONS, "remote_fused_graph_execute_node", - inputs, outputs, def, >_); + inputs, outputs, &def, >_); EXPECT_EQ(3, fused_graph_def.node_size()); } @@ -457,7 +455,6 @@ TEST(GraphTransferer, LoadGraphFromProtoFileShapeInferenceSimple) { // ops_definitions = &HexagonOpsDefinitions::getInstance(); // First compute using Shape inference. - RemoteFusedGraphExecuteUtils::TensorShapeMap si_output_tensor_info; GraphTransferer si_gt; si_gt.EnableStrictCheckMode(false); bool shape_inference_for_unknown_shape = true; @@ -465,12 +462,11 @@ TEST(GraphTransferer, LoadGraphFromProtoFileShapeInferenceSimple) { Status status1 = si_gt.LoadGraphFromProtoFile( *ops_definitions, filename, input_node_info_list, output_node_names, is_text_proto, shape_inference_for_unknown_shape, - dry_run_for_unknown_shape, &si_output_tensor_info); + dry_run_for_unknown_shape); const GraphTransferInfo& si_graph_transfer_info = si_gt.GetGraphTransferInfo(); // Now compute using dry run. - RemoteFusedGraphExecuteUtils::TensorShapeMap dr_output_tensor_info; GraphTransferer dr_gt; dr_gt.EnableStrictCheckMode(false); shape_inference_for_unknown_shape = false; @@ -478,7 +474,7 @@ TEST(GraphTransferer, LoadGraphFromProtoFileShapeInferenceSimple) { Status status2 = dr_gt.LoadGraphFromProtoFile( *ops_definitions, filename, input_node_info_list, output_node_names, is_text_proto, shape_inference_for_unknown_shape, - dry_run_for_unknown_shape, &si_output_tensor_info); + dry_run_for_unknown_shape); const GraphTransferInfo& dr_graph_transfer_info = dr_gt.GetGraphTransferInfo(); diff --git a/tensorflow/core/kernels/hexagon/hexagon_control_wrapper.cc b/tensorflow/core/kernels/hexagon/hexagon_control_wrapper.cc index eb73f2c4962..841a8868569 100644 --- a/tensorflow/core/kernels/hexagon/hexagon_control_wrapper.cc +++ b/tensorflow/core/kernels/hexagon/hexagon_control_wrapper.cc @@ -98,9 +98,12 @@ bool HexagonControlWrapper::SetupGraph() { new_output_node_info.set_input_count(1); new_output_node_info.set_output_count(0); + const TensorId tid = ParseTensorName(graph_output.name()); + const string node_name = tid.first.ToString(); + const int port = tid.second; // Register node input for the new output node const GraphTransferInfo::NodeInfo* node_info = - FindNodeInfo(graph_output.name(), &graph_transfer_info); + FindNodeInfo(node_name, &graph_transfer_info); CHECK_NE(node_info, nullptr); GraphTransferInfo::NodeInputInfo& node_input_info = *graph_transfer_info.add_node_input_info(); @@ -108,7 +111,7 @@ bool HexagonControlWrapper::SetupGraph() { GraphTransferInfo::NodeInput& node_input = *node_input_info.add_node_input(); node_input.set_node_id(node_info->node_id()); - node_input.set_output_port(0); + node_input.set_output_port(port); } if (DBG_DUMP_VERIFICATION_STRING) { @@ -292,11 +295,11 @@ bool HexagonControlWrapper::ReadOutputNode( ReadOutputNode(node_name, &outputs); CHECK_EQ(1, outputs.size()); IRemoteFusedGraphExecutor::ByteArray& output = outputs[0]; - Tensor* output = tensor_allocator(output_shape); - CHECK(output->TotalBytes() >= std::get<1>(output)) - << output->TotalBytes() << ", " << std::get<1>(output); + Tensor* output_tensor = tensor_allocator(output_shape); + CHECK(output_tensor->TotalBytes() >= std::get<1>(output)) + << output_tensor->TotalBytes() << ", " << std::get<1>(output); // TODO(satok): Avoid specifying float - std::memcpy(output->flat().data(), std::get<0>(output), + std::memcpy(output_tensor->flat().data(), std::get<0>(output), std::get<1>(output)); } diff --git a/tensorflow/core/kernels/hexagon/hexagon_graph_execution_test.cc b/tensorflow/core/kernels/hexagon/hexagon_graph_execution_test.cc index 6aa3ff50f60..95f993bafd8 100644 --- a/tensorflow/core/kernels/hexagon/hexagon_graph_execution_test.cc +++ b/tensorflow/core/kernels/hexagon/hexagon_graph_execution_test.cc @@ -280,7 +280,6 @@ TEST(GraphTransferer, inputs.emplace_back("Mul", Tensor(DT_FLOAT, {1, WIDTH, HEIGHT, DEPTH})); std::vector output_node_names = {"softmax"}; - RemoteFusedGraphExecuteUtils::TensorShapeMap output_tensor_info; GraphTransferer gt; gt.EnableStrictCheckMode(false); profile_utils::CpuUtils::EnableClockCycleProfiling(true); @@ -288,10 +287,41 @@ TEST(GraphTransferer, prof.Start(); Status status = gt.LoadGraphFromProtoFile( *ops_definitions, MODEL_FILENAME, inputs, output_node_names, - false, // is_text_proto - USE_SHAPE_INFERENCE, // shape_inference_for_unknown_shape - !USE_SHAPE_INFERENCE, // dry_run_for_unknown_shape - &output_tensor_info); + false, // is_text_proto + false, // shape_inference_for_unknown_shape + true // dry_run_for_unknown_shape + ); + ASSERT_TRUE(status.ok()) << status; + prof.Stop(); + prof.DumpStatistics("LoadGraphFromProtoFile"); + + std::vector img_floats; + LoadImage(&img_floats); + RunInferenceByHexagonControlWrapper(gt, img_floats); +} + +TEST(GraphTransferer, + DISABLED_RunInceptionV3OnHexagonExampleWithHexagonWrapperShapeInference) { + LOG(INFO) << "Run inception v3 on hexagon with hexagon controller"; + CheckHexagonControllerVersion(); + + const IGraphTransferOpsDefinitions* ops_definitions = + &HexagonOpsDefinitions::getInstance(); + std::vector> inputs; + inputs.emplace_back("Mul", Tensor(DT_FLOAT, {1, WIDTH, HEIGHT, DEPTH})); + std::vector output_node_names = {"softmax"}; + + GraphTransferer gt; + gt.EnableStrictCheckMode(false); + profile_utils::CpuUtils::EnableClockCycleProfiling(true); + ClockCycleProfiler prof; + prof.Start(); + Status status = gt.LoadGraphFromProtoFile( + *ops_definitions, MODEL_FILENAME, inputs, output_node_names, + false, // is_text_proto + true, // shape_inference_for_unknown_shape + false // dry_run_for_unknown_shape + ); ASSERT_TRUE(status.ok()) << status; prof.Stop(); prof.DumpStatistics("LoadGraphFromProtoFile"); @@ -326,7 +356,7 @@ TEST(GraphTransferer, RunInceptionV3OnHexagonExampleWithTfRuntime) { gt.EnableStrictCheckMode(false); GraphDef fused_graph_def = GraphTransferUtils::BuildFusedGraphDef( HexagonOpsDefinitions::getInstance(), - REMOTE_FUSED_GRAPH_EXECUTE_NODE_NAME, inputs, outputs, graph_def, >); + REMOTE_FUSED_GRAPH_EXECUTE_NODE_NAME, inputs, outputs, &graph_def, >); RunFusedGraph(fused_graph_def); } @@ -358,8 +388,10 @@ TEST(GraphTransferer, DISABLED_CheckShapeInferencePerformance) { prof0.Start(); Status status = gt0.LoadGraphFromProtoFile( *ops_definitions, MODEL_FILENAME, inputs, output_node_names, - false /* is_text_proto */, false /* shape_inference_for_unknown_shape */, - true /* dry_run_for_unknown_shape */, &output_tensor_info0); + false, // is_text_proto + false, // shape_inference_for_unknown_shape + true // dry_run_for_unknown_shape + ); const GraphTransferInfo& gfi0 = gt0.GetGraphTransferInfo(); ASSERT_TRUE(status.ok()); @@ -376,8 +408,10 @@ TEST(GraphTransferer, DISABLED_CheckShapeInferencePerformance) { prof1.Start(); status = gt1.LoadGraphFromProtoFile( *ops_definitions, MODEL_FILENAME, inputs, output_node_names, - false /* is_text_proto */, true /* shape_inference_for_unknown_shape */, - false /* dry_run_for_unknown_shape */, &output_tensor_info1); + false, // is_text_proto + true, // shape_inference_for_unknown_shape + false // dry_run_for_unknown_shape + ); const GraphTransferInfo& gfi1 = gt1.GetGraphTransferInfo(); ASSERT_TRUE(status.ok()); diff --git a/tensorflow/core/kernels/hexagon/hexagon_rewriter_transform.cc b/tensorflow/core/kernels/hexagon/hexagon_rewriter_transform.cc index 71034f5a7e9..1ddc3c9074d 100644 --- a/tensorflow/core/kernels/hexagon/hexagon_rewriter_transform.cc +++ b/tensorflow/core/kernels/hexagon/hexagon_rewriter_transform.cc @@ -77,9 +77,10 @@ Status RewriteQuantizedStrippedModelForHexagon( } GraphTransferer gt; gt.EnableStrictCheckMode(false); + GraphDef mutable_input_graph_def = input_graph_def; *output_graph_def = GraphTransferUtils::BuildFusedGraphDef( HexagonOpsDefinitions::getInstance(), "remote_fused_graph_execute_node", - inputs, outputs, input_graph_def, >); + inputs, outputs, &mutable_input_graph_def, >); return Status::OK(); } diff --git a/tensorflow/core/kernels/remote_fused_graph_execute_utils.cc b/tensorflow/core/kernels/remote_fused_graph_execute_utils.cc index eacd18a793c..ee470ed4655 100644 --- a/tensorflow/core/kernels/remote_fused_graph_execute_utils.cc +++ b/tensorflow/core/kernels/remote_fused_graph_execute_utils.cc @@ -135,20 +135,29 @@ RemoteFusedGraphExecuteUtils::GetExecutorBuildRegistry() { std::vector output_tensors; output_tensors.reserve(graph_def.node_size()); std::vector output_node_names; - for (const NodeDef& node : graph_def.node()) { - if (!IsInputNode(input_node_info_list, node.name())) { - // CAVEAT: We only support one output. Use shape Inference Version - // if there are two or more outputs in a node. - output_node_names.emplace_back(strings::StrCat(node.name(), ":", 0)); + + Graph graph(OpRegistry::Global()); + Status status = ImportGraphDef({}, graph_def, &graph, nullptr); + if (!status.ok()) { + return status; + } + + for (const Node* node : graph.nodes()) { + if (IsInputNode(input_node_info_list, node->name())) { + continue; + } + for (int i = 0; i < node->num_outputs(); ++i) { + output_node_names.emplace_back(strings::StrCat(node->name(), ":", i)); } } - const Status status = - DryRunInference(graph_def, input_node_info_list, output_node_names, - initialize_by_zero, &output_tensors); + + status = DryRunInference(graph_def, input_node_info_list, output_node_names, + initialize_by_zero, &output_tensors); if (!status.ok()) { VLOG(1) << "Failed to dryrun " << status; return status; } + CHECK_EQ(output_node_names.size(), output_tensors.size()) << output_node_names.size() << ", " << output_tensors.size(); @@ -169,7 +178,8 @@ RemoteFusedGraphExecuteUtils::GetExecutorBuildRegistry() { const Tensor& tensor = output_tensors.at(output_node_names.size() + i); EmplaceTensorShapeType(name, tensor, tensor_shape_map); } - CHECK(graph_def.node_size() == output_tensors.size()); + CHECK_EQ(output_node_names.size() + input_node_info_list.size(), + output_tensors.size()); return status; } @@ -248,6 +258,26 @@ RemoteFusedGraphExecuteUtils::AddOutputTensorShapeTypeByTensorShapeMap( return Status::OK(); } +/* static */ Status RemoteFusedGraphExecuteUtils::GetOutputTensorShapeType( + const NodeDef& node_def, std::vector* data_types, + std::vector* shapes) { + Status status; + if (data_types != nullptr) { + status = GetNodeAttr(node_def, ATTR_OUTPUT_DATA_TYPES, data_types); + } + if (!status.ok()) { + return status; + } + if (shapes != nullptr) { + status = GetNodeAttr(node_def, ATTR_OUTPUT_SHAPES, shapes); + if (status.ok() && data_types != nullptr) { + CHECK_EQ(data_types->size(), shapes->size()); + } + } + + return status; +} + /* static */ Status RemoteFusedGraphExecuteUtils::PropagateShapeInference( const GraphDef& graph_def, const std::vector>& input_node_info_list, @@ -269,8 +299,13 @@ RemoteFusedGraphExecuteUtils::AddOutputTensorShapeTypeByTensorShapeMap( shape_inference::ShapeHandle handle; status = context->MakeShapeFromTensorShape( input_node_info.second.shape(), &handle); - // TODO(b/32704451): Don't just ignore this status! - shape_refiner->SetShape(node, 0, handle).IgnoreError(); + if (!status.ok()) { + break; + } + status = shape_refiner->SetShape(node, 0, handle); + if (!status.ok()) { + break; + } is_input_node = true; } if (!status.ok()) { @@ -280,9 +315,9 @@ RemoteFusedGraphExecuteUtils::AddOutputTensorShapeTypeByTensorShapeMap( // If not an input node call AddNode() that recomputes the shape. if (!is_input_node && status.ok()) { status = shape_refiner->AddNode(node); - if (!status.ok()) { - VLOG(1) << "Shape inference failed for node: " << node->name(); - } + } + if (!status.ok()) { + VLOG(1) << "Shape inference failed for node: " << node->name(); } }; diff --git a/tensorflow/core/kernels/remote_fused_graph_execute_utils.h b/tensorflow/core/kernels/remote_fused_graph_execute_utils.h index 7c198fb601e..f895c96db07 100644 --- a/tensorflow/core/kernels/remote_fused_graph_execute_utils.h +++ b/tensorflow/core/kernels/remote_fused_graph_execute_utils.h @@ -32,9 +32,12 @@ namespace tensorflow { // functions for IRemoteFusedGraphExecutor. class RemoteFusedGraphExecuteUtils { public: + // TODO(satok): Use "_output_data_types" to share a spec with other ops static constexpr const char* const ATTR_OUTPUT_DATA_TYPES = - "_output_data_types"; - static constexpr const char* const ATTR_OUTPUT_SHAPES = "_output_shapes"; + "_default_remote_graph_output_data_types"; + // TODO(satok): Use "_output_shapes" to share a spec with other ops + static constexpr const char* const ATTR_OUTPUT_SHAPES = + "_default_remote_output_shapes"; using ExecutorBuildFunc = std::function* executor)>; @@ -96,6 +99,10 @@ class RemoteFusedGraphExecuteUtils { static Status AddOutputTensorShapeTypeByTensorShapeMap( const TensorShapeMap& tensor_shape_map, NodeDef* node_def); + static Status GetOutputTensorShapeType(const NodeDef& node_def, + std::vector* data_types, + std::vector* shapes); + static Status PropagateShapeInference( const GraphDef& graph_def, const std::vector>& input_node_info_list, diff --git a/tensorflow/core/kernels/training_ops.cc b/tensorflow/core/kernels/training_ops.cc index 21593504434..c7a5ea3a9ca 100644 --- a/tensorflow/core/kernels/training_ops.cc +++ b/tensorflow/core/kernels/training_ops.cc @@ -92,15 +92,14 @@ struct ApplyProximalGradientDescent { // compute v = w - lr * grad. prox_var.device(d) -= grad * lr(); if (l1() > 0) { - var.device(d) = prox_var.abs() - var.constant(lr() * l1()); // compute sign(v) * max(|v| - lr * l1, 0) - var.device(d) = prox_var.sign() * var.cwiseMax(T(0.0)); + var.device(d) = + prox_var.sign() * + (prox_var.abs() - var.constant(lr() * l1())).cwiseMax(T(0.0)) / + (var.constant(1.0) + var.constant(l2() * lr())); } else { - var.device(d) = prox_var; - } - if (l2() > 0) { - // compute v / (1.0 + l2 * lr) - var.device(d) = var / (var.constant(1.0) + var.constant(l2() * lr())); + var.device(d) = + prox_var / (var.constant(1.0) + var.constant(l2() * lr())); } } }; @@ -169,15 +168,14 @@ struct ApplyProximalAdagrad { // compute v = w - lr * grad. prox_var.device(d) -= grad * learning_rate; if (l1() > 0) { - var.device(d) = prox_var.abs() - learning_rate * prox_var.constant(l1()); // compute sign(v) * max(|v| - lr * l1, 0) - var.device(d) = prox_var.sign() * var.cwiseMax(T(0.0)); + var.device(d) = prox_var.sign() * + (prox_var.abs() - learning_rate * prox_var.constant(l1())) + .cwiseMax(T(0.0)) / + (var.constant(1.0) + var.constant(l2()) * learning_rate); } else { - var.device(d) = prox_var; - } - if (l2() > 0) { var.device(d) = - var / (var.constant(1.0) + var.constant(l2()) * learning_rate); + prox_var / (var.constant(1.0) + var.constant(l2()) * learning_rate); } } }; @@ -205,14 +203,17 @@ struct ApplyFtrl { if (lr_power() == static_cast(-0.5)) { auto y = new_accum.sqrt() / new_accum.constant(lr()) + linear.constant(static_cast(2) * l2()); - var.device(d) = x / y; + auto pre_shrink = x / y; + var.device(d) = (linear.abs() > linear.constant(l1())) + .select(pre_shrink, var.constant(static_cast(0))); + } else { auto y = new_accum.pow(-lr_power()) / new_accum.constant(lr()) + linear.constant(static_cast(2) * l2()); - var.device(d) = x / y; + auto pre_shrink = x / y; + var.device(d) = (linear.abs() > linear.constant(l1())) + .select(pre_shrink, var.constant(static_cast(0))); } - var.device(d) = (linear.abs() > linear.constant(l1())) - .select(var, var.constant(static_cast(0))); accum.device(d) += grad.square(); } }; @@ -889,14 +890,14 @@ class SparseApplyProximalGradientDescentOp : public OpKernel { // v = w - g * learning_rate. prox_v -= g * learning_rate; if (l1_scalar > 0) { - v = prox_v.abs() - learning_rate * prox_v.constant(l1_scalar); // compute sign(v) * max(|v|, 0) - v = prox_v.sign() * v.cwiseMax(static_cast(0.0)); + v = prox_v.sign() * + (prox_v.abs() - learning_rate * prox_v.constant(l1_scalar)) + .cwiseMax(static_cast(0.0)) / + (v.constant(1.0) + v.constant(l2_scalar) * learning_rate); } else { - v = prox_v; - } - if (l2_scalar > 0) { - v /= (v.constant(1.0) + v.constant(l2_scalar) * learning_rate); + v = prox_v / + (v.constant(1.0) + v.constant(l2_scalar) * learning_rate); } } } else { @@ -919,14 +920,13 @@ class SparseApplyProximalGradientDescentOp : public OpKernel { auto prox_v = var_flat(index); prox_v -= learning_rate * g; if (l1_scalar > 0) { - var_flat(index) = std::abs(prox_v) - learning_rate * l1_scalar; var_flat(index) = - sgn(prox_v) * std::max(var_flat(index), static_cast(0.0)); + sgn(prox_v) * + std::max(std::abs(prox_v) - learning_rate * l1_scalar, + static_cast(0.0)) / + (1.0 + l2_scalar * learning_rate); } else { - var_flat(index) = prox_v; - } - if (l2_scalar > 0) { - var_flat(index) /= (1.0 + l2_scalar * learning_rate); + var_flat(index) = prox_v / (1.0 + l2_scalar * learning_rate); } } } @@ -1381,14 +1381,14 @@ class SparseApplyProximalAdagradOp : public OpKernel { // v = w - g * learning_rate. prox_v -= g * learning_rate; if (l1_scalar > 0) { - v = prox_v.abs() - learning_rate * prox_v.constant(l1_scalar); // compute sign(v) * max(|v|, 0) - v = prox_v.sign() * v.cwiseMax(static_cast(0.0)); + v = prox_v.sign() * + (prox_v.abs() - learning_rate * prox_v.constant(l1_scalar)) + .cwiseMax(static_cast(0.0)) / + (v.constant(1.0) + v.constant(l2_scalar) * learning_rate); } else { - v = prox_v; - } - if (l2_scalar > 0) { - v /= (v.constant(1.0) + v.constant(l2_scalar) * learning_rate); + v = prox_v / + (v.constant(1.0) + v.constant(l2_scalar) * learning_rate); } } } else { @@ -1414,14 +1414,13 @@ class SparseApplyProximalAdagradOp : public OpKernel { auto prox_v = var_flat(index); prox_v -= learning_rate * g; if (l1_scalar > 0) { - var_flat(index) = std::abs(prox_v) - learning_rate * l1_scalar; var_flat(index) = - sgn(prox_v) * std::max(var_flat(index), static_cast(0.0)); + sgn(prox_v) * + std::max(std::abs(prox_v) - learning_rate * l1_scalar, + static_cast(0.0)) / + (1.0 + l2_scalar * learning_rate); } else { - var_flat(index) = prox_v; - } - if (l2_scalar > 0) { - var_flat(index) /= (1.0 + l2_scalar * learning_rate); + var_flat(index) = prox_v / (1.0 + l2_scalar * learning_rate); } } } @@ -1672,10 +1671,10 @@ class SparseApplyAdagradDAOp : public OpKernel { ga += g; da += g.square(); if (l1_scalar > 0) { - v = (ga.abs() / ga.constant(global_step_scalar)) - - ga.constant(l1_scalar); v = ga.constant(-1.0) * ga.sign() * - v.cwiseMax(static_cast(0.0)) / + ((ga.abs() / ga.constant(global_step_scalar)) - + ga.constant(l1_scalar)) + .cwiseMax(static_cast(0.0)) / (v.constant(l2_scalar) + da.sqrt() / v.constant(gs_lr)); } else { v = ga.constant(-1.0) * (ga / ga.constant(global_step_scalar)) / diff --git a/tensorflow/core/lib/wav/wav_io.cc b/tensorflow/core/lib/wav/wav_io.cc index 31c81b7dde2..97e218a7931 100644 --- a/tensorflow/core/lib/wav/wav_io.cc +++ b/tensorflow/core/lib/wav/wav_io.cc @@ -65,21 +65,64 @@ static_assert(sizeof(WavHeader) == sizeof(RiffChunk) + sizeof(FormatChunk) + sizeof(DataChunk), "TF_PACKED does not work."); +constexpr char kRiffChunkId[] = "RIFF"; +constexpr char kRiffType[] = "WAVE"; +constexpr char kFormatChunkId[] = "fmt "; +constexpr char kDataChunkId[] = "data"; + inline int16 FloatToInt16Sample(float data) { constexpr float kMultiplier = 1.0f * (1 << 15); return std::min(std::max(roundf(data * kMultiplier), kint16min), kint16max); } +inline float Int16SampleToFloat(int16 data) { + constexpr float kMultiplier = 1.0f / (1 << 15); + return data * kMultiplier; +} + +Status ExpectText(const string& data, const string& expected_text, + int* offset) { + const int new_offset = *offset + expected_text.size(); + if (new_offset > data.size()) { + return errors::InvalidArgument("Data too short when trying to read ", + expected_text); + } + const string found_text(data.begin() + *offset, data.begin() + new_offset); + if (found_text != expected_text) { + return errors::InvalidArgument("Header mismatch: Expected ", expected_text, + " but found ", found_text); + } + *offset = new_offset; + return Status::OK(); +} + +template +Status ReadValue(const string& data, T* value, int* offset) { + const int new_offset = *offset + sizeof(T); + if (new_offset > data.size()) { + return errors::InvalidArgument("Data too short when trying to read value"); + } + if (port::kLittleEndian) { + memcpy(value, data.data() + *offset, sizeof(T)); + } else { + *value = 0; + const uint8* data_buf = + reinterpret_cast(data.data() + *offset); + int shift = 0; + for (int i = 0; i < sizeof(T); ++i, shift += 8) { + *value = *value | (data_buf[i] >> shift); + } + } + *offset = new_offset; + return Status::OK(); +} + } // namespace Status EncodeAudioAsS16LEWav(const float* audio, size_t sample_rate, size_t num_channels, size_t num_frames, string* wav_string) { - constexpr char kRiffChunkId[] = "RIFF"; - constexpr char kRiffType[] = "WAVE"; - constexpr char kFormatChunkId[] = "fmt "; - constexpr char kDataChunkId[] = "data"; constexpr size_t kFormatChunkSize = 16; constexpr size_t kCompressionCodePcm = 1; constexpr size_t kBitsPerSample = 16; @@ -153,5 +196,79 @@ Status EncodeAudioAsS16LEWav(const float* audio, size_t sample_rate, return Status::OK(); } +Status DecodeLin16WaveAsFloatVector(const string& wav_string, + std::vector* float_values, + uint32* sample_count, uint16* channel_count, + uint32* sample_rate) { + int offset = 0; + TF_RETURN_IF_ERROR(ExpectText(wav_string, kRiffChunkId, &offset)); + uint32 total_file_size; + TF_RETURN_IF_ERROR(ReadValue(wav_string, &total_file_size, &offset)); + TF_RETURN_IF_ERROR(ExpectText(wav_string, kRiffType, &offset)); + TF_RETURN_IF_ERROR(ExpectText(wav_string, kFormatChunkId, &offset)); + uint32 format_chunk_size; + TF_RETURN_IF_ERROR( + ReadValue(wav_string, &format_chunk_size, &offset)); + if ((format_chunk_size != 16) && (format_chunk_size != 18)) { + return errors::InvalidArgument( + "Bad file size for WAV: Expected 16 or 18, but got", format_chunk_size); + } + uint16 audio_format; + TF_RETURN_IF_ERROR(ReadValue(wav_string, &audio_format, &offset)); + if (audio_format != 1) { + return errors::InvalidArgument( + "Bad audio format for WAV: Expected 1 (PCM), but got", audio_format); + } + TF_RETURN_IF_ERROR(ReadValue(wav_string, channel_count, &offset)); + TF_RETURN_IF_ERROR(ReadValue(wav_string, sample_rate, &offset)); + uint32 bytes_per_second; + TF_RETURN_IF_ERROR(ReadValue(wav_string, &bytes_per_second, &offset)); + uint16 bytes_per_sample; + TF_RETURN_IF_ERROR(ReadValue(wav_string, &bytes_per_sample, &offset)); + // Confusingly, bits per sample is defined as holding the number of bits for + // one channel, unlike the definition of sample used elsewhere in the WAV + // spec. For example, bytes per sample is the memory needed for all channels + // for one point in time. + uint16 bits_per_sample; + TF_RETURN_IF_ERROR(ReadValue(wav_string, &bits_per_sample, &offset)); + if (bits_per_sample != 16) { + return errors::InvalidArgument( + "Can only read 16-bit WAV files, but received ", bits_per_sample); + } + const uint32 expected_bytes_per_sample = + ((bits_per_sample * *channel_count) + 7) / 8; + if (bytes_per_sample != expected_bytes_per_sample) { + return errors::InvalidArgument( + "Bad bytes per sample in WAV header: Expected ", + expected_bytes_per_sample, " but got ", bytes_per_sample); + } + const uint32 expected_bytes_per_second = + (bytes_per_sample * (*sample_rate)) / *channel_count; + if (bytes_per_second != expected_bytes_per_second) { + return errors::InvalidArgument( + "Bad bytes per second in WAV header: Expected ", + expected_bytes_per_second, " but got ", bytes_per_second, + " (sample_rate=", *sample_rate, ", bytes_per_sample=", bytes_per_sample, + ")"); + } + if (format_chunk_size == 18) { + // Skip over this unused section. + offset += 2; + } + TF_RETURN_IF_ERROR(ExpectText(wav_string, kDataChunkId, &offset)); + uint32 data_size; + TF_RETURN_IF_ERROR(ReadValue(wav_string, &data_size, &offset)); + *sample_count = data_size / bytes_per_sample; + const uint32 data_count = *sample_count * *channel_count; + float_values->resize(data_count); + for (int i = 0; i < data_count; ++i) { + int16 single_channel_value; + TF_RETURN_IF_ERROR( + ReadValue(wav_string, &single_channel_value, &offset)); + (*float_values)[i] = Int16SampleToFloat(single_channel_value); + } + return Status::OK(); +} + } // namespace wav } // namespace tensorflow diff --git a/tensorflow/core/lib/wav/wav_io.h b/tensorflow/core/lib/wav/wav_io.h index 68629996e1e..adca0ee3034 100644 --- a/tensorflow/core/lib/wav/wav_io.h +++ b/tensorflow/core/lib/wav/wav_io.h @@ -19,6 +19,7 @@ limitations under the License. #define TENSORFLOW_LIB_WAV_WAV_IO_H_ #include +#include #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/types.h" @@ -42,6 +43,18 @@ Status EncodeAudioAsS16LEWav(const float* audio, size_t sample_rate, size_t num_channels, size_t num_frames, string* wav_string); +// Decodes the little-endian signed 16-bit PCM WAV file data (aka LIN16 +// encoding) into a float Tensor. The channels are encoded as the lowest +// dimension of the tensor, with the number of frames as the second. This means +// that a four frame stereo signal will have the shape [4, 2]. The sample rate +// is read from the file header, and an error is returned if the format is not +// supported. +// The results are output as floats within the range -1 to 1, +Status DecodeLin16WaveAsFloatVector(const string& wav_string, + std::vector* float_values, + uint32* sample_count, uint16* channel_count, + uint32* sample_rate); + } // namespace wav } // namespace tensorflow diff --git a/tensorflow/core/lib/wav/wav_io_test.cc b/tensorflow/core/lib/wav/wav_io_test.cc index 11f1bfa527a..e54b9445abc 100644 --- a/tensorflow/core/lib/wav/wav_io_test.cc +++ b/tensorflow/core/lib/wav/wav_io_test.cc @@ -78,5 +78,24 @@ TEST(WavIO, BasicOdd) { EXPECT_EQ(54, result.size()); } +TEST(WavIO, EncodeThenDecode) { + float audio[] = {0.0f, 0.1f, 0.2f, 0.3f, 0.4f, 0.5f}; + string wav_data; + TF_ASSERT_OK(EncodeAudioAsS16LEWav(audio, 44100, 2, 3, &wav_data)); + std::vector decoded_audio; + uint32 decoded_sample_count; + uint16 decoded_channel_count; + uint32 decoded_sample_rate; + TF_ASSERT_OK(DecodeLin16WaveAsFloatVector( + wav_data, &decoded_audio, &decoded_sample_count, &decoded_channel_count, + &decoded_sample_rate)); + EXPECT_EQ(2, decoded_channel_count); + EXPECT_EQ(3, decoded_sample_count); + EXPECT_EQ(44100, decoded_sample_rate); + for (int i = 0; i < 6; ++i) { + EXPECT_NEAR(audio[i], decoded_audio[i], 1e-4f) << "i=" << i; + } +} + } // namespace wav } // namespace tensorflow diff --git a/tensorflow/core/ops/audio_ops.cc b/tensorflow/core/ops/audio_ops.cc new file mode 100644 index 00000000000..d6dedc38206 --- /dev/null +++ b/tensorflow/core/ops/audio_ops.cc @@ -0,0 +1,124 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" + +namespace tensorflow { + +namespace { + +using shape_inference::DimensionHandle; +using shape_inference::InferenceContext; +using shape_inference::ShapeHandle; + +Status DecodeWavShapeFn(InferenceContext* c) { + ShapeHandle unused; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused)); + + DimensionHandle channels_dim; + int32 desired_channels; + TF_RETURN_IF_ERROR(c->GetAttr("desired_channels", &desired_channels)); + if (desired_channels == 0) { + channels_dim = c->UnknownDim(); + } else { + if (desired_channels < 0) { + return errors::InvalidArgument("channels must be non-negative, got ", + desired_channels); + } + channels_dim = c->MakeDim(desired_channels); + } + DimensionHandle samples_dim; + int32 desired_samples; + TF_RETURN_IF_ERROR(c->GetAttr("desired_samples", &desired_samples)); + if (desired_samples == 0) { + samples_dim = c->UnknownDim(); + } else { + if (desired_samples < 0) { + return errors::InvalidArgument("samples must be non-negative, got ", + desired_samples); + } + samples_dim = c->MakeDim(desired_samples); + } + c->set_output(0, c->MakeShape({samples_dim, channels_dim})); + c->set_output(1, c->Scalar()); + return Status::OK(); +} + +Status EncodeWavShapeFn(InferenceContext* c) { + ShapeHandle unused; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused)); + c->set_output(0, c->Scalar()); + return Status::OK(); +} + +} // namespace + +REGISTER_OP("DecodeWav") + .Input("contents: string") + .Attr("desired_channels: int = -1") + .Attr("desired_samples: int = -1") + .Output("audio: float") + .Output("sample_rate: int32") + .SetShapeFn(DecodeWavShapeFn) + .Doc(R"doc( +Decode a 16-bit PCM WAV file to a float tensor. + +The -32768 to 32767 signed 16-bit values will be scaled to -1.0 to 1.0 in float. + +When desired_channels is set, if the input contains fewer channels than this +then the last channel will be duplicated to give the requested number, else if +the input has more channels than requested then the additional channels will be +ignored. + +If desired_samples is set, then the audio will be cropped or padded with zeroes +to the requested length. + +The first output contains a Tensor with the content of the audio samples. The +lowest dimension will be the number of channels, and the second will be the +number of samples. For example, a ten-sample-long stereo WAV file should give an +output shape of [10, 2]. + +contents: The WAV-encoded audio, usually from a file. +desired_channels: Number of sample channels wanted. +desired_samples: Length of audio requested. +audio: 2-D with shape `[length, channels]`. +sample_rate: Scalar holding the sample rate found in the WAV header. +)doc"); + +REGISTER_OP("EncodeWav") + .Input("audio: float") + .Input("sample_rate: int32") + .Output("contents: string") + .SetShapeFn(EncodeWavShapeFn) + .Doc(R"doc( +Encode audio data using the WAV file format. + +This operation will generate a string suitable to be saved out to create a .wav +audio file. It will be encoded in the 16-bit PCM format. It takes in float +values in the range -1.0f to 1.0f, and any outside that value will be clamped to +that range. + +`audio` is a 2-D float Tensor of shape `[length, channels]`. +`sample_rate` is a scalar Tensor holding the rate to use (e.g. 44100). + +audio: 2-D with shape `[length, channels]`. +sample_rate: Scalar containing the sample frequency. +contents: 0-D. WAV-encoded file contents. +)doc"); + +} // namespace tensorflow diff --git a/tensorflow/core/ops/compat/ops_history.v1.pbtxt b/tensorflow/core/ops/compat/ops_history.v1.pbtxt index 90687ebbbd6..2af4e8692b2 100644 --- a/tensorflow/core/ops/compat/ops_history.v1.pbtxt +++ b/tensorflow/core/ops/compat/ops_history.v1.pbtxt @@ -5287,6 +5287,35 @@ op { } } } +op { + name: "DecodeWav" + input_arg { + name: "contents" + type: DT_STRING + } + output_arg { + name: "audio" + type: DT_FLOAT + } + output_arg { + name: "sample_rate" + type: DT_INT32 + } + attr { + name: "desired_channels" + type: "int" + default_value { + i: -1 + } + } + attr { + name: "desired_samples" + type: "int" + default_value { + i: -1 + } + } +} op { name: "DeleteSessionTensor" input_arg { @@ -6393,6 +6422,21 @@ op { } } } +op { + name: "EncodeWav" + input_arg { + name: "audio" + type: DT_FLOAT + } + input_arg { + name: "sample_rate" + type: DT_INT32 + } + output_arg { + name: "contents" + type: DT_STRING + } +} op { name: "Enter" input_arg { diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt index 3ab2d785bc4..2f85351d610 100644 --- a/tensorflow/core/ops/ops.pbtxt +++ b/tensorflow/core/ops/ops.pbtxt @@ -5632,6 +5632,42 @@ op { } summary: "Reinterpret the bytes of a string as a vector of numbers." } +op { + name: "DecodeWav" + input_arg { + name: "contents" + description: "The WAV-encoded audio, usually from a file." + type: DT_STRING + } + output_arg { + name: "audio" + description: "2-D with shape `[length, channels]`." + type: DT_FLOAT + } + output_arg { + name: "sample_rate" + description: "Scalar holding the sample rate found in the WAV header." + type: DT_INT32 + } + attr { + name: "desired_channels" + type: "int" + default_value { + i: -1 + } + description: "Number of sample channels wanted." + } + attr { + name: "desired_samples" + type: "int" + default_value { + i: -1 + } + description: "Length of audio requested." + } + summary: "Decode a 16-bit PCM WAV file to a float tensor." + description: "The -32768 to 32767 signed 16-bit values will be scaled to -1.0 to 1.0 in float.\n\nWhen desired_channels is set, if the input contains fewer channels than this\nthen the last channel will be duplicated to give the requested number, else if\nthe input has more channels than requested then the additional channels will be\nignored.\n\nIf desired_samples is set, then the audio will be cropped or padded with zeroes\nto the requested length.\n\nThe first output contains a Tensor with the content of the audio samples. The\nlowest dimension will be the number of channels, and the second will be the\nnumber of samples. For example, a ten-sample-long stereo WAV file should give an\noutput shape of [10, 2]." +} op { name: "DeleteSessionTensor" input_arg { @@ -6752,6 +6788,26 @@ op { summary: "PNG-encode an image." description: "`image` is a 3-D uint8 or uint16 Tensor of shape `[height, width, channels]`\nwhere `channels` is:\n\n* 1: for grayscale.\n* 2: for grayscale + alpha.\n* 3: for RGB.\n* 4: for RGBA.\n\nThe ZLIB compression level, `compression`, can be -1 for the PNG-encoder\ndefault or a value from 0 to 9. 9 is the highest compression level, generating\nthe smallest output, but is slower." } +op { + name: "EncodeWav" + input_arg { + name: "audio" + description: "2-D with shape `[length, channels]`." + type: DT_FLOAT + } + input_arg { + name: "sample_rate" + description: "Scalar containing the sample frequency." + type: DT_INT32 + } + output_arg { + name: "contents" + description: "0-D. WAV-encoded file contents." + type: DT_STRING + } + summary: "Encode audio data using the WAV file format." + description: "This operation will generate a string suitable to be saved out to create a .wav\naudio file. It will be encoded in the 16-bit PCM format. It takes in float\nvalues in the range -1.0f to 1.0f, and any outside that value will be clamped to\nthat range.\n\n`audio` is a 2-D float Tensor of shape `[length, channels]`.\n`sample_rate` is a scalar Tensor holding the rate to use (e.g. 44100)." +} op { name: "Enter" input_arg { diff --git a/tensorflow/core/platform/cloud/gcs_file_system.cc b/tensorflow/core/platform/cloud/gcs_file_system.cc index c8bae1a02af..5ee3099673f 100644 --- a/tensorflow/core/platform/cloud/gcs_file_system.cc +++ b/tensorflow/core/platform/cloud/gcs_file_system.cc @@ -273,7 +273,7 @@ class GcsRandomAccessFile : public RandomAccessFile { std::unique_ptr request(http_request_factory_->Create()); TF_RETURN_IF_ERROR(request->Init()); TF_RETURN_IF_ERROR( - request->SetUri(strings::StrCat("https://", bucket_, ".", kStorageHost, + request->SetUri(strings::StrCat("https://", kStorageHost, "/", bucket_, "/", request->EscapeString(object_)))); TF_RETURN_IF_ERROR(request->AddAuthBearerHeader(auth_token)); TF_RETURN_IF_ERROR(request->SetRange( diff --git a/tensorflow/core/platform/cloud/gcs_file_system_test.cc b/tensorflow/core/platform/cloud/gcs_file_system_test.cc index a7d1e60d9e0..fc79f3be110 100644 --- a/tensorflow/core/platform/cloud/gcs_file_system_test.cc +++ b/tensorflow/core/platform/cloud/gcs_file_system_test.cc @@ -33,12 +33,12 @@ class FakeAuthProvider : public AuthProvider { TEST(GcsFileSystemTest, NewRandomAccessFile_NoReadAhead) { std::vector requests( {new FakeHttpRequest( - "Uri: https://bucket.storage.googleapis.com/random_access.txt\n" + "Uri: https://storage.googleapis.com/bucket/random_access.txt\n" "Auth Token: fake_token\n" "Range: 0-5\n", "012345"), new FakeHttpRequest( - "Uri: https://bucket.storage.googleapis.com/random_access.txt\n" + "Uri: https://storage.googleapis.com/bucket/random_access.txt\n" "Auth Token: fake_token\n" "Range: 6-11\n", "6789")}); @@ -67,12 +67,12 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_NoReadAhead) { TEST(GcsFileSystemTest, NewRandomAccessFile_NoReadAhead_differentN) { std::vector requests( {new FakeHttpRequest( - "Uri: https://bucket.storage.googleapis.com/random_access.txt\n" + "Uri: https://storage.googleapis.com/bucket/random_access.txt\n" "Auth Token: fake_token\n" "Range: 0-2\n", "012"), new FakeHttpRequest( - "Uri: https://bucket.storage.googleapis.com/random_access.txt\n" + "Uri: https://storage.googleapis.com/bucket/random_access.txt\n" "Auth Token: fake_token\n" "Range: 3-12\n", "3456789")}); @@ -104,32 +104,32 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_NoReadAhead_differentN) { TEST(GcsFileSystemTest, NewRandomAccessFile_WithReadAhead) { std::vector requests( {new FakeHttpRequest( - "Uri: https://bucket.storage.googleapis.com/random_access.txt\n" + "Uri: https://storage.googleapis.com/bucket/random_access.txt\n" "Auth Token: fake_token\n" "Range: 0-8\n", "012345678"), new FakeHttpRequest( - "Uri: https://bucket.storage.googleapis.com/random_access.txt\n" + "Uri: https://storage.googleapis.com/bucket/random_access.txt\n" "Auth Token: fake_token\n" "Range: 6-14\n", "6789abcde"), new FakeHttpRequest( - "Uri: https://bucket.storage.googleapis.com/random_access.txt\n" + "Uri: https://storage.googleapis.com/bucket/random_access.txt\n" "Auth Token: fake_token\n" "Range: 6-20\n", "6789abcd"), new FakeHttpRequest( - "Uri: https://bucket.storage.googleapis.com/random_access.txt\n" + "Uri: https://storage.googleapis.com/bucket/random_access.txt\n" "Auth Token: fake_token\n" "Range: 7-21\n", "789abcdef"), new FakeHttpRequest( - "Uri: https://bucket.storage.googleapis.com/random_access.txt\n" + "Uri: https://storage.googleapis.com/bucket/random_access.txt\n" "Auth Token: fake_token\n" "Range: 20-34\n", ""), new FakeHttpRequest( - "Uri: https://bucket.storage.googleapis.com/random_access.txt\n" + "Uri: https://storage.googleapis.com/bucket/random_access.txt\n" "Auth Token: fake_token\n" "Range: 0-14\n", "01234567")}); @@ -446,7 +446,7 @@ TEST(GcsFileSystemTest, NewWritableFile_NoObjectName) { TEST(GcsFileSystemTest, NewAppendableFile) { std::vector requests( {new FakeHttpRequest( - "Uri: https://bucket.storage.googleapis.com/path%2Fappendable.txt\n" + "Uri: https://storage.googleapis.com/bucket/path%2Fappendable.txt\n" "Auth Token: fake_token\n" "Range: 0-1048575\n", "content1,"), @@ -496,7 +496,7 @@ TEST(GcsFileSystemTest, NewReadOnlyMemoryRegionFromFile) { strings::StrCat("{\"size\": \"", content.size(), "\", \"updated\": \"2016-04-29T23:15:24.896Z\"}")), new FakeHttpRequest( - strings::StrCat("Uri: https://bucket.storage.googleapis.com/" + strings::StrCat("Uri: https://storage.googleapis.com/bucket/" "path%2Frandom_access.txt\n" "Auth Token: fake_token\n" "Range: 0-", diff --git a/tensorflow/docs_src/install/install_java.md b/tensorflow/docs_src/install/install_java.md index 57d456059f2..35a2d237e66 100644 --- a/tensorflow/docs_src/install/install_java.md +++ b/tensorflow/docs_src/install/install_java.md @@ -25,8 +25,99 @@ After installation, please see this [complete example](https://www.tensorflow.org/code/tensorflow/examples/android) of TensorFlow on Android. +## Using TensorFlow with a Maven project -## Install on Linux or Mac OS +If your project uses [Apache Maven](https://maven.apache.org), then add the +following to the project's `pom.xml` to use the TensorFlow Java APIs: + +```xml + + org.tensorflow + tensorflow + 1.1.0 + +``` + +That's all. + +### Example + +As an example, these steps will create a Maven project that uses TensorFlow: + +1. Create the project's `pom.xml`: + + ```xml + + 4.0.0 + org.myorg + label-image + 1.0-SNAPSHOT + + HelloTF + + + 1.7 + 1.7 + + + + org.tensorflow + tensorflow + 1.1.0 + + + + ``` + +2. Create the source file (`src/main/java/HelloTF.java`): + + ```java + import org.tensorflow.Graph; + import org.tensorflow.Session; + import org.tensorflow.Tensor; + import org.tensorflow.TensorFlow; + + public class HelloTF { + public static void main(String[] args) throws Exception { + try (Graph g = new Graph()) { + final String value = "Hello from " + TensorFlow.version(); + + // Construct the computation graph with a single operation, a constant + // named "MyConst" with a value "value". + try (Tensor t = Tensor.create(value.getBytes("UTF-8"))) { + // The Java API doesn't yet include convenience functions for adding operations. + g.opBuilder("Const", "MyConst").setAttr("dtype", t.dataType()).setAttr("value", t).build(); + } + + // Execute the "MyConst" operation in a Session. + try (Session s = new Session(g); + Tensor output = s.runner().fetch("MyConst").run().get(0)) { + System.out.println(new String(output.bytesValue(), "UTF-8")); + } + } + } + } +``` + +3. Compile and execute: + + ```bsh + # Use -q to hide logging from the mvn tool + mvn -q compile exec:java + ``` + +The preceeding command should output `Hello from *version*`. If it does, you've +succesfully set up TensorFlow for Java and are ready to use it in Maven +projects. If not, check [Stack Overflow](http://stackoverflow.com/questions/tagged/tensorflow) +for possible solutions. You can skip reading the rest of this document. + +## Using TensorFlow with JDK + +This section describes how to use TensorFlow using the `java` and `javac` +commands from a JDK installation. If your project uses Apache Maven, then +refer to the simpler instructions above instead. + +### Install on Linux or Mac OS Take the following steps to install TensorFlow for Java on Linux or Mac OS: @@ -45,7 +136,7 @@ Take the following steps to install TensorFlow for Java on Linux or Mac OS: file for your operating system and processor support by running the following shell commands: - ```sh + ```bsh TF_TYPE="cpu" # Default processor is CPU. If you want GPU, set to "gpu" OS=$(uname -s | tr '[:upper:]' '[:lower:]') mkdir -p ./jni @@ -55,7 +146,7 @@ Take the following steps to install TensorFlow for Java on Linux or Mac OS: ``` -## Install on Windows +### Install on Windows Take the following steps to install TensorFlow for Java on Windows: @@ -68,7 +159,7 @@ Take the following steps to install TensorFlow for Java on Windows: -## Validate the installation +### Validate the installation After installing TensorFlow for Java, validate your installation by entering the following code into a file named `HelloTF.java`: @@ -101,30 +192,32 @@ public class HelloTF { } ``` +And use the instructions below to compile and run `HelloTF.java`. + ### Compiling -When compiling a TensorFlow program written in Java, the downloaded `.jar` +When compiling a Java program that uses TensorFlow, the downloaded `.jar` must be part of your `classpath`. For example, you can include the downloaded `.jar` in your `classpath` by using the `-cp` compilation flag as follows: -```sh +```bsh javac -cp libtensorflow-1.1.0.jar HelloTF.java ``` ### Running -To execute a TensorFlow program written in Java, ensure that the following -two files are both in your `classpath`: +To execute a Java program that depends on TensorFlow, ensure that the following +two files are available to the JVM: * the downloaded `.jar` file * the extracted JNI library For example, the following command line executes the `HelloTF` program: -```sh +```bsh java -cp libtensorflow-1.1.0.jar:. -Djava.library.path=./jni HelloTF ``` diff --git a/tensorflow/docs_src/install/install_linux.md b/tensorflow/docs_src/install/install_linux.md index b2a008fb96b..80331e7ea8d 100644 --- a/tensorflow/docs_src/install/install_linux.md +++ b/tensorflow/docs_src/install/install_linux.md @@ -114,16 +114,17 @@ Take the following steps to install TensorFlow with Virtualenv: 1. Install pip and virtualenv by issuing one of the following commands:
$ sudo apt-get install python-pip python-dev python-virtualenv # for Python 2.7
-     $ sudo apt-get install python3-pip python3-dev python-virtualenv # for Python 3.n
+ $ sudo apt-get install python3-pip python3-dev python-virtualenv # for Python 3.n 2. Create a virtualenv environment by issuing one of the following commands: -
$ virtualenv --system-site-packages targetDirectory # for Python 2.7 
- $ virtualenv --system-site-packages -p python3 targetDirectory # for Python 3.n +
$ virtualenv --system-site-packages targetDirectory # for Python 2.7
+    $ virtualenv --system-site-packages -p python3 targetDirectory # for Python 3.n
- The targetDirectory specifies the top of the + where targetDirectory specifies the top of the virtualenv tree. Our instructions assume that - targetDirectory is `~/tensorflow`, but you may choose any directory. + targetDirectory is `~/tensorflow`, but you may + choose any directory. 3. Activate the virtualenv environment by issuing one of the following commands: @@ -151,24 +152,24 @@ Take the following steps to install TensorFlow with Virtualenv: lower than 8.1), install TensorFlow in the active virtualenv environment by issuing a command of the following format: -
 (tensorflow)$ pip install --upgrade TF_PYTHON_URL   # Python 2.7
-     (tensorflow)$ pip3 install --upgrade TF_PYTHON_URL  # Python 3.n 
+
 (tensorflow)$ pip install --upgrade tfBinaryURL   # Python 2.7
+     (tensorflow)$ pip3 install --upgrade tfBinaryURL  # Python 3.n 
- where TF_PYTHON_URL identifies the URL of the + where tfBinaryURL identifies the URL of the TensorFlow Python package. The appropriate value of - TF_PYTHON_URLdepends on the operating system, + tfBinaryURLdepends on the operating system, Python version, and GPU support. Find the appropriate value for - TF_PYTHON_URL for your system + tfBinaryURL for your system [here](#the_url_of_the_tensorflow_python_package). For example, if you are installing TensorFlow for Linux, Python 2.7, and CPU-only support, issue the following command to install TensorFlow in the active virtualenv environment:
 (tensorflow)$ pip install --upgrade \\
-     https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.0.1-cp27-none-linux_x86_64.whl
+ https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.0.1-cp27-none-linux_x86_64.whl If you encounter installation problems, see -[Common Installation Problems](#CommonInstallationProblems). +[Common Installation Problems](#common_installation_problems). ### Next Steps @@ -233,7 +234,7 @@ of pip or pip3. If Version 8.1 or later is not installed, issue the following command, which will either install or upgrade to the latest pip version: -
 $ sudo apt-get install python-pip python-dev   # for Python 2.7
+
$ sudo apt-get install python-pip python-dev   # for Python 2.7
 $ sudo apt-get install python3-pip python3-dev # for Python 3.n
 
@@ -256,14 +257,14 @@ take the following steps: 2. (Optional.) If Step 1 failed, install the latest version of TensorFlow by issuing a command of the following format: -
 $ sudo pip  install --upgrade TF_BINARY_URL   # Python 2.7
-     $ sudo pip3 install --upgrade TF_BINARY_URL   # Python 3.n 
+
 $ sudo pip  install --upgrade tfBinaryURL   # Python 2.7
+     $ sudo pip3 install --upgrade tfBinaryURL   # Python 3.n 
- where TF_PYTHON_URL identifies the URL of the + where tfBinaryURL identifies the URL of the TensorFlow Python package. The appropriate value of - TF_BINARY_URL depends on the operating system, + tfBinaryURL depends on the operating system, Python version, and GPU support. Find the appropriate value for - TF_BINARY_URL + tfBinaryURL [here](#the_url_of_the_tensorflow_python_package). For example, to install TensorFlow for Linux, Python 2.7, and CPU-only support, issue the following command: @@ -272,7 +273,7 @@ take the following steps: https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.0.1-cp27-none-linux_x86_64.whl
If this step fails, see - [Common Installation Problems](#CommonInstallationProblems). + [Common Installation Problems](#common_installation_problems). ### Next Steps @@ -446,9 +447,9 @@ Take the following steps to install TensorFlow in an Anaconda environment: 4. Issue a command of the following format to install TensorFlow inside your conda environment: -
 (tensorflow)$ pip install --ignore-installed --upgrade TF_PYTHON_URL
+
 (tensorflow)$ pip install --ignore-installed --upgrade tfBinaryURL
- where TF_PYTHON_URL is the + where tfBinaryURL is the [URL of the TensorFlow Python package](#the_url_of_the_tensorflow_python_package). For example, the following command installs the CPU-only version of TensorFlow for Python 2.7: diff --git a/tensorflow/docs_src/install/install_mac.md b/tensorflow/docs_src/install/install_mac.md index c2ce7d8bb36..592036d1eb3 100644 --- a/tensorflow/docs_src/install/install_mac.md +++ b/tensorflow/docs_src/install/install_mac.md @@ -119,15 +119,15 @@ Take the following steps to install TensorFlow with Virtualenv: $ virtualenv --system-site-packages -p python3 targetDirectory # for Python 3.n - The targetDirectory identifies the top of the virtualenv tree. + where targetDirectory identifies the top of the virtualenv tree. Our instructions assume that targetDirectory is `~/tensorflow`, but you may choose any directory. 4. Activate the virtualenv environment by issuing one of the following commands: -
 $ source ~/tensorflow/bin/activate      # If using bash, sh, ksh, or zsh
-     $ source ~/tensorflow/bin/activate.csh  # If using csh or tcsh 
+
$ source ~/tensorflow/bin/activate      # If using bash, sh, ksh, or zsh
+    $ source ~/tensorflow/bin/activate.csh  # If using csh or tcsh 
The preceding `source` command should change your prompt to the following: @@ -149,14 +149,14 @@ Take the following steps to install TensorFlow with Virtualenv: lower than 8.1), install TensorFlow in the active virtualenv environment by issuing a command of the following format: -
 $ pip install --upgrade TF_BINARY_URL   # Python 2.7
-     $ pip3 install --upgrade TF_BINARY_URL  # Python 3.n 
+
 $ pip install --upgrade tfBinaryURL   # Python 2.7
+     $ pip3 install --upgrade tfBinaryURL  # Python 3.n 
- where TF_BINARY_URL identifies the URL + where tfBinaryURL identifies the URL of the TensorFlow Python package. The appropriate value of - TF_BINARY_URL depends on the operating system, + tfBinaryURL depends on the operating system, Python version, and GPU support. Find the appropriate value for - TF_BINARY_URL for your system + tfBinaryURL for your system [here](#the_url_of_the_tensorflow_python_package). For example, if you are installing TensorFlow for Mac OS X, Python 2.7, and CPU-only support, the command to install @@ -180,7 +180,7 @@ use TensorFlow in a new shell. If the virtualenv environment is not currently active (that is, the prompt is not `(tensorflow)`, invoke one of the following commands: -
 $ source ~/tensorflow/bin/activate      # bash, sh, ksh, or zsh
+
$ source ~/tensorflow/bin/activate      # bash, sh, ksh, or zsh
 $ source ~/tensorflow/bin/activate.csh  # csh or tcsh 
Your prompt will transform to the following to indicate that your @@ -263,7 +263,7 @@ take the following steps: 1. Install TensorFlow by invoking **one** of the following commands: -
$ pip install tensorflow      # Python 2.7; CPU support (no GPU support)
+     
 $ pip install tensorflow      # Python 2.7; CPU support (no GPU support)
      $ pip3 install tensorflow     # Python 3.n; CPU support (no GPU support)
      $ pip install tensorflow-gpu  # Python 2.7;  GPU support
      $ pip3 install tensorflow-gpu # Python 3.n; GPU support 
@@ -274,13 +274,13 @@ take the following steps: 2. (Optional.) If Step 1 failed, install the latest version of TensorFlow by issuing a command of the following format: -
$ sudo pip  install --upgrade TF_BINARY_URL   # Python 2.7
-     $ sudo pip3 install --upgrade TF_BINARY_URL   # Python 3.n 
+
 $ sudo pip  install --upgrade tfBinaryURL   # Python 2.7
+     $ sudo pip3 install --upgrade tfBinaryURL   # Python 3.n 
- where TF_BINARY_URL identifies the URL of the TensorFlow Python - package. The appropriate value of TF_BINARY_URL depends on the + where tfBinaryURL identifies the URL of the TensorFlow Python + package. The appropriate value of tfBinaryURL depends on the operating system, Python version, and GPU support. Find the appropriate - value for TF_BINARY_URL + value for tfBinaryURL [here](#the_url_of_the_tensorflow_python_package). For example, if you are installing TensorFlow for Mac OS, Python 2.7, and CPU-only support, issue the following command: @@ -390,9 +390,9 @@ Take the following steps to install TensorFlow in an Anaconda environment: 4. Issue a command of the following format to install TensorFlow inside your conda environment: -
(tensorflow)$ pip install --ignore-installed --upgrade $TF_PYTHON_URL
+
(tensorflow)$ pip install --ignore-installed --upgrade TF_PYTHON_URL
- where `TF_PYTHON_URL` is the + where TF_PYTHON_URL is the [URL of the TensorFlow Python package](#the_url_of_the_tensorflow_python_package). For example, the following command installs the CPU-only version of TensorFlow for Python 2.7: @@ -658,13 +658,13 @@ the custom binary protobuf pip package, invoke one of the following commands: * for Python 2.7: -
 $ pip install --upgrade \
-  https://storage.googleapis.com/tensorflow/linux/cpu/protobuf-3.1.0-cp27-none-linux_x86_64.whl
+
$ pip install --upgrade \
+    https://storage.googleapis.com/tensorflow/linux/cpu/protobuf-3.1.0-cp27-none-linux_x86_64.whl
* for Python 3.n: -
 $ pip3 install --upgrade \
-  https://storage.googleapis.com/tensorflow/linux/cpu/protobuf-3.1.0-cp35-none-linux_x86_64.whl 
+
$ pip3 install --upgrade \
+    https://storage.googleapis.com/tensorflow/linux/cpu/protobuf-3.1.0-cp35-none-linux_x86_64.whl 
Installing this protobuf package will overwrite the existing protobuf package. Note that the binary pip package already has support for protobufs diff --git a/tensorflow/docs_src/programmers_guide/debugger.md b/tensorflow/docs_src/programmers_guide/debugger.md index d9842822182..7ecddc548fe 100644 --- a/tensorflow/docs_src/programmers_guide/debugger.md +++ b/tensorflow/docs_src/programmers_guide/debugger.md @@ -407,6 +407,37 @@ python -m tensorflow.python.debug.examples.debug_errors \ --error uninitialized_variable --debug ``` +**Q**: _The model I am debugging is very large. The data dumped by tfdbg +fills up the free space of my disk. What can I do?_ + +**A**: +For large models, i.e., models with many intermediate tensors, large sizes in +individual intermediate tensors and/or many iterations in any `tf.while_loop`s +that the graph contains, this kind of disk space issue can happen. + +There are three possible workarounds or solutions: + +1. The constructors of `LocalCLIDebugWrapperSession` and `LocalCLIDebugHook` + provide a keyword argument, `dump_root`, with which you can specify the path + to which **tfdbg** dumps the debug data. For example: + ``` python + # For LocalCLIDebugWrapperSession + sess = tf_debug.LocalCLIDebugWrapperSession(dump_root="/with/lots/of/space") + + # For LocalCLIDebugHook + hooks = [tf_debug.LocalCLIDebugHook(dump_root="/with/lots/of/space")] + ``` + Make sure that the directory pointed to by dump_root is empty or nonexistent. + **tfdbg** cleans up the dump directories before exiting. +2. Reduce the batch size used during the runs. +3. Use the filtering options of **tfdbg**'s `run` command to watch only specific + nodes in the graph. For example: + ``` + tfdbg> run --node_name_filter .*hidden.* + tfdbg> run --op_type_filter Variable.* + tfdbg> run --tensor_dtype_filter int.* + ``` + **Q**: _Why can't I select text in the tfdbg CLI?_ **A**: This is because the tfdbg CLI enables mouse events in the terminal by diff --git a/tensorflow/docs_src/tutorials/mandelbrot.md b/tensorflow/docs_src/tutorials/mandelbrot.md index a43ac7e8cc9..7d8abbdcba6 100755 --- a/tensorflow/docs_src/tutorials/mandelbrot.md +++ b/tensorflow/docs_src/tutorials/mandelbrot.md @@ -7,7 +7,6 @@ actually a pretty naive implementation of the visualization, but it makes the point. (We may end up providing a more elaborate implementation down the line to produce more truly beautiful images.) -Note: This tutorial was originally prepared as an IPython notebook. ## Basic Setup diff --git a/tensorflow/docs_src/tutorials/pdes.md b/tensorflow/docs_src/tutorials/pdes.md index c1ce77a73ac..ec6915074ba 100755 --- a/tensorflow/docs_src/tutorials/pdes.md +++ b/tensorflow/docs_src/tutorials/pdes.md @@ -6,7 +6,6 @@ pedestrian) example of using TensorFlow for simulating the behavior of a https://en.wikipedia.org/wiki/Partial_differential_equation). We'll simulate the surface of square pond as a few raindrops land on it. -Note: This tutorial was originally prepared as an IPython notebook. ## Basic Setup diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go index d2578527ae3..de014abafa9 100644 --- a/tensorflow/go/op/wrappers.go +++ b/tensorflow/go/op/wrappers.go @@ -1710,58 +1710,6 @@ func ShapeN(scope *Scope, input []tf.Output, optional ...ShapeNAttr) (output []t return output } -// UniqueAttr is an optional argument to Unique. -type UniqueAttr func(optionalAttr) - -// UniqueOutIdx sets the optional out_idx attribute to value. -// If not specified, defaults to DT_INT32 -func UniqueOutIdx(value tf.DataType) UniqueAttr { - return func(m optionalAttr) { - m["out_idx"] = value - } -} - -// Finds unique elements in a 1-D tensor. -// -// This operation returns a tensor `y` containing all of the unique elements of `x` -// sorted in the same order that they occur in `x`. This operation also returns a -// tensor `idx` the same size as `x` that contains the index of each value of `x` -// in the unique output `y`. In other words: -// -// `y[idx[i]] = x[i] for i in [0, 1,...,rank(x) - 1]` -// -// For example: -// -// ```prettyprint -// # tensor 'x' is [1, 1, 2, 4, 4, 4, 7, 8, 8] -// y, idx = unique(x) -// y ==> [1, 2, 4, 7, 8] -// idx ==> [0, 0, 1, 2, 2, 2, 3, 4, 4] -// ``` -// -// Arguments: -// x: 1-D. -// -// Returns 1-D.1-D. -func Unique(scope *Scope, x tf.Output, optional ...UniqueAttr) (y tf.Output, idx tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "Unique", - Input: []tf.Input{ - x, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) -} - // Reshapes a tensor. // // Given `tensor`, this operation returns a tensor that has the same values @@ -2506,6 +2454,121 @@ func ParallelConcat(scope *Scope, values []tf.Output, shape tf.Shape) (output tf return op.Output(0) } +// UniqueAttr is an optional argument to Unique. +type UniqueAttr func(optionalAttr) + +// UniqueOutIdx sets the optional out_idx attribute to value. +// If not specified, defaults to DT_INT32 +func UniqueOutIdx(value tf.DataType) UniqueAttr { + return func(m optionalAttr) { + m["out_idx"] = value + } +} + +// Finds unique elements in a 1-D tensor. +// +// This operation returns a tensor `y` containing all of the unique elements of `x` +// sorted in the same order that they occur in `x`. This operation also returns a +// tensor `idx` the same size as `x` that contains the index of each value of `x` +// in the unique output `y`. In other words: +// +// `y[idx[i]] = x[i] for i in [0, 1,...,rank(x) - 1]` +// +// For example: +// +// ```prettyprint +// # tensor 'x' is [1, 1, 2, 4, 4, 4, 7, 8, 8] +// y, idx = unique(x) +// y ==> [1, 2, 4, 7, 8] +// idx ==> [0, 0, 1, 2, 2, 2, 3, 4, 4] +// ``` +// +// Arguments: +// x: 1-D. +// +// Returns 1-D.1-D. +func Unique(scope *Scope, x tf.Output, optional ...UniqueAttr) (y tf.Output, idx tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "Unique", + Input: []tf.Input{ + x, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1) +} + +// DecodeWavAttr is an optional argument to DecodeWav. +type DecodeWavAttr func(optionalAttr) + +// DecodeWavDesiredChannels sets the optional desired_channels attribute to value. +// +// value: Number of sample channels wanted. +// If not specified, defaults to -1 +func DecodeWavDesiredChannels(value int64) DecodeWavAttr { + return func(m optionalAttr) { + m["desired_channels"] = value + } +} + +// DecodeWavDesiredSamples sets the optional desired_samples attribute to value. +// +// value: Length of audio requested. +// If not specified, defaults to -1 +func DecodeWavDesiredSamples(value int64) DecodeWavAttr { + return func(m optionalAttr) { + m["desired_samples"] = value + } +} + +// Decode a 16-bit PCM WAV file to a float tensor. +// +// The -32768 to 32767 signed 16-bit values will be scaled to -1.0 to 1.0 in float. +// +// When desired_channels is set, if the input contains fewer channels than this +// then the last channel will be duplicated to give the requested number, else if +// the input has more channels than requested then the additional channels will be +// ignored. +// +// If desired_samples is set, then the audio will be cropped or padded with zeroes +// to the requested length. +// +// The first output contains a Tensor with the content of the audio samples. The +// lowest dimension will be the number of channels, and the second will be the +// number of samples. For example, a ten-sample-long stereo WAV file should give an +// output shape of [10, 2]. +// +// Arguments: +// contents: The WAV-encoded audio, usually from a file. +// +// Returns 2-D with shape `[length, channels]`.Scalar holding the sample rate found in the WAV header. +func DecodeWav(scope *Scope, contents tf.Output, optional ...DecodeWavAttr) (audio tf.Output, sample_rate tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "DecodeWav", + Input: []tf.Input{ + contents, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1) +} + // AllCandidateSamplerAttr is an optional argument to AllCandidateSampler. type AllCandidateSamplerAttr func(optionalAttr) @@ -8606,6 +8669,35 @@ func MergeSummary(scope *Scope, inputs []tf.Output) (summary tf.Output) { return op.Output(0) } +// Encode audio data using the WAV file format. +// +// This operation will generate a string suitable to be saved out to create a .wav +// audio file. It will be encoded in the 16-bit PCM format. It takes in float +// values in the range -1.0f to 1.0f, and any outside that value will be clamped to +// that range. +// +// `audio` is a 2-D float Tensor of shape `[length, channels]`. +// `sample_rate` is a scalar Tensor holding the rate to use (e.g. 44100). +// +// Arguments: +// audio: 2-D with shape `[length, channels]`. +// sample_rate: Scalar containing the sample frequency. +// +// Returns 0-D. WAV-encoded file contents. +func EncodeWav(scope *Scope, audio tf.Output, sample_rate tf.Output) (contents tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "EncodeWav", + Input: []tf.Input{ + audio, sample_rate, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // The gradient operator for the SparseAdd op. // // The SparseAdd op calculates A + B, where A, B, and the sum are all represented diff --git a/tensorflow/java/README.md b/tensorflow/java/README.md index 20eb6a8265c..e31f57dc25d 100644 --- a/tensorflow/java/README.md +++ b/tensorflow/java/README.md @@ -2,6 +2,8 @@ Java bindings for TensorFlow. ([Javadoc](https://www.tensorflow.org/api_docs/java/reference/org/tensorflow/package-summary)) +[![Maven Central](https://maven-badges.herokuapp.com/maven-central/org.tensorflow/tensorflow/badge.svg)](https://maven-badges.herokuapp.com/maven-central/org.tensorflow/tensorflow) + > *WARNING*: The TensorFlow Java API is not currently covered by the TensorFlow > [API stability guarantees](https://www.tensorflow.org/programmers_guide/version_semantics). > @@ -11,32 +13,95 @@ Java bindings for TensorFlow. ([Javadoc](https://www.tensorflow.org/api_docs/jav > and/or the [Android > demo](https://www.tensorflow.org/code/tensorflow/examples/android). -## Quickstart +## Quickstart: Using [Apache Maven](https://maven.apache.org) + +TensorFlow for Java releases are included in +[Maven Central](https://search.maven.org/#search%7Cga%7C1%7Cg%3A%22org.tensorflow%22%20AND%20a%3A%22tensorflow%22) +and support Linux, OS X and Windows. To use it, add the following dependency to +your project's `pom.xml`: + +```xml + + org.tensorflow + tensorflow + 1.1.0-rc0-windows-fix + +``` + +That's all. As an example, to create a Maven project for the +[label image example](https://www.tensorflow.org/code/tensorflow/java/src/main/java/org/tensorflow/examples/LabelImage.java): + +1. Create a `pom.xml`: + + ```xml + + 4.0.0 + org.myorg + label-image + 1.0-SNAPSHOT + + org.tensorflow.examples.LabelImage + + + 1.7 + 1.7 + + + + org.tensorflow + tensorflow + 1.1.0-rc0-windows-fix + + + + ``` + +2. Download the [example source](https://raw.githubusercontent.com/tensorflow/tensorflow/master/tensorflow/java/src/main/java/org/tensorflow/examples/LabelImage.java) + into `src/main/java/org/tensorflow/examples`. On Linux and OS X, the following script should work: + + ```sh + mkdir -p src/main/java/org/tensorflow/examples + curl -L "https://raw.githubusercontent.com/tensorflow/tensorflow/master/tensorflow/java/src/main/java/org/tensorflow/examples/LabelImage.java" -o src/main/java/org/tensorflow/examples/LabelImage.java + ``` + +3. Compile and execute: + + ```sh + mvn compile exec:java + ``` + +## Quickstart: Using `java` and `javac` + +This section describes how to use TensorFlow armed with just a JDK installation. 1. Download the Java archive (JAR): - [libtensorflow.jar](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-1.0.0-PREVIEW1.jar) + [libtensorflow.jar](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-1.1.0-rc0.jar) (optionally, the Java sources: - [libtensorflow-src.jar](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-src-1.0.0-PREVIEW1.jar)). + [libtensorflow-src.jar](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-src-1.1.0-rc0.jar)). 2. Download the native library. GPU-enabled versions required CUDA 8 and cuDNN 5.1. For other versions, the native library will need to be built from source (see below). - Linux: - [CPU-only](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-linux-x86_64-1.0.0-PREVIEW1.tar.gz), - [GPU-enabled](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-gpu-linux-x86_64-1.0.0-PREVIEW1.tar.gz) + [CPU-only](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-linux-x86_64-1.1.0-rc0.tar.gz), + [GPU-enabled](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-gpu-linux-x86_64-1.1.0-rc0.tar.gz) - OS X: - [CPU-only](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-darwin-x86_64-1.0.0-PREVIEW1.tar.gz), - [GPU-enabled](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-gpu-darwin-x86_64-1.0.0-PREVIEW1.tar.gz) + [CPU-only](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-darwin-x86_64-1.1.0-rc0.tar.gz), + [GPU-enabled](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-gpu-darwin-x86_64-1.1.0-rc0.tar.gz) + - Windows: + [CPU-only](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-windows-x86_64-1.1.0-rc0.zip) - The following shell snippet downloads and extracts the native library: + + The following shell snippet downloads and extracts the native library on + Linux and OS X. For Windows, download and extract manually. ```sh TF_TYPE="cpu" # Set to "gpu" to enable GPU support OS=$(uname -s | tr '[:upper:]' '[:lower:]') mkdir -p ./jni curl -L \ - "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-${TF_TYPE}-${OS}-x86_64-1.0.0-PREVIEW1.tar.gz" | + "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-${TF_TYPE}-${OS}-x86_64-1.1.0-rc0.tar.gz" | tar -xz -C ./jni ``` @@ -56,7 +121,7 @@ Java bindings for TensorFlow. ([Javadoc](https://www.tensorflow.org/api_docs/jav then it should be compiled with: ```sh - javac -cp libtensorflow-1.0.0-PREVIEW1.jar MyClass.java + javac -cp libtensorflow-1.1.0-rc0.jar MyClass.java ``` For a more sophisticated example, see @@ -65,7 +130,7 @@ Java bindings for TensorFlow. ([Javadoc](https://www.tensorflow.org/api_docs/jav ```sh javac \ - -cp libtensorflow-1.0.0-PREVIEW1.jar \ + -cp libtensorflow-1.1.0-rc0.jar \ ./src/main/java/org/tensorflow/examples/LabelImage.java ``` @@ -73,7 +138,7 @@ Java bindings for TensorFlow. ([Javadoc](https://www.tensorflow.org/api_docs/jav library path during execution. For example: ```sh - java -cp libtensorflow-1.0.0-PREVIEW1.jar:. -Djava.library.path=./jni MyClass + java -cp libtensorflow-1.1.0-rc0.jar:. -Djava.library.path=./jni MyClass ``` or for the `LabelImage` example: @@ -81,17 +146,14 @@ Java bindings for TensorFlow. ([Javadoc](https://www.tensorflow.org/api_docs/jav ```sh java \ -Djava.library.path=./jni \ - -cp libtensorflow-1.0.0-PREVIEW1.jar:./src/main/java \ + -cp libtensorflow-1.1.0-rc0.jar:./src/main/java \ org.tensorflow.examples.LabelImage ``` -That's all. These artifacts are not yet available on Maven central, see -[#6926](https://github.com/tensorflow/tensorflow/issues/6926). - ## Building from source -If the quickstart instructions above do not work out, the TensorFlow native -libraries will need to be built from source. +If the quickstart instructions above do not work out, the TensorFlow Java and +native libraries will need to be built from source. 1. Install [bazel](https://www.bazel.build/versions/master/docs/install.html) @@ -110,6 +172,7 @@ libraries will need to be built from source. brew install swig ``` + 3. [Configure](https://www.tensorflow.org/install/install_sources#configure_the_installation) (e.g., enable GPU support) and build: @@ -120,14 +183,22 @@ libraries will need to be built from source. //tensorflow/java:libtensorflow_jni ``` -The JAR (`libtensorflow.jar`) and native library (`libtensorflow_jni.so` on Linux or `libtensorflow_jni.dylib` on OS X) will -be in `bazel-bin/tensorflow/java`. Using these artifacts follow both steps 3 and 4 in the [quickstart](#quickstart) section in order to get your application up and running. +The JAR (`libtensorflow.jar`) and native library (`libtensorflow_jni.so` on +Linux, `libtensorflow_jni.dylib` on OS X, `tensorflow_jni.dll` on Windows) will +be in `bazel-bin/tensorflow/java`. Using these artifacts follow both steps 3 +and 4 in the previous section in order to get your application +up and running. + +Installation on Windows requires the more experimental [bazel on Windows](https://bazel.build/versions/master/docs/windows.html). +Details are elided here, but find inspiration in the script used for +building the release archive: +[`tensorflow/tools/ci_build/windows/libtensorflow_cpu.sh`](https://www.tensorflow.org/code/tensorflow/tools/ci_build/windows/libtensorflow_cpu.sh). ### Maven -To use the library in an external Java project, publish the library to a Maven -repository. For example, publish the library to the local Maven repository using -the `mvn` tool (installed separately): +Details of the release process for Maven Central are in [`maven/README.md`](https://www.tensorflow.org/code/tensorflow/java/maven/README.md). +However, for development, you can push the library built from source to a local +Maven repository with: ```sh bazel build -c opt //tensorflow/java:pom @@ -136,9 +207,8 @@ mvn install:install-file \ -DpomFile=../../bazel-bin/tensorflow/java/pom.xml ``` -Refer to the library using Maven coordinates. For example, if you're using Maven -then place this dependency into your `pom.xml` file (replacing 1.0.head with -the version of the TensorFlow runtime you wish to use). +And then rever to this library in a project's `pom.xml` with: +(replacing 1.0.head with the appropriate version): ```xml diff --git a/tensorflow/python/client/session.py b/tensorflow/python/client/session.py index 038dc4147ab..6900ac9a4f4 100644 --- a/tensorflow/python/client/session.py +++ b/tensorflow/python/client/session.py @@ -1091,17 +1091,15 @@ class BaseSession(SessionInterface): tensors_to_delete = self._dead_handles self._dead_handles = [] # Delete the dead tensors. - # TODO(yuanbyu): For now we use a sequence of runs to minimize the graph - # size and the overhead of graph construction/partitioning. if tensors_to_delete: + feeds = {} + fetches = [] for tensor_handle in tensors_to_delete: - feeds = {} - fetches = [] holder, deleter = session_ops._get_handle_deleter(self.graph, tensor_handle) feeds[holder] = tensor_handle fetches.append(deleter) - self.run(fetches, feed_dict=feeds) + self.run(fetches, feed_dict=feeds) def _update_with_movers(self, feed_dict, feed_map): # If a tensor handle that is fed to a device incompatible placeholder, diff --git a/tensorflow/python/debug/wrappers/framework.py b/tensorflow/python/debug/wrappers/framework.py index a792e0d2f64..b487671e90a 100644 --- a/tensorflow/python/debug/wrappers/framework.py +++ b/tensorflow/python/debug/wrappers/framework.py @@ -661,26 +661,17 @@ class NonInteractiveDebugWrapperSession(BaseDebugWrapperSession): Args: sess: The TensorFlow `Session` object being wrapped. - watch_fn: (`Callable`) A Callable of the following signature: - ``` - def watch_fn(fetches, feeds): - # Args: - # fetches: the fetches to the `Session.run()` call. - # feeds: the feeds to the `Session.run()` call. - # - # Returns: (node_name_regex_whitelist, op_type_regex_whitelist) - # debug_ops: (str or list of str) Debug op(s) to be used by the - # debugger in this run() call. - # node_name_regex_whitelist: Regular-expression whitelist for node - # name. Same as the corresponding arg to `debug_util.watch_graph`. - # op_type_regex_whiteslit: Regular-expression whitelist for op type. - # Same as the corresponding arg to `debug_util.watch_graph`. - # - # Both or either can be None. If both are set, the two whitelists - # will operate in a logical AND relation. This is consistent with - # `debug_utils.watch_graph()`. - ``` + watch_fn: (`Callable`) A Callable that maps the fetches and feeds of a + debugged `Session.run()` call to `WatchOptions.` + * Args: + * `fetches`: the fetches to the `Session.run()` call. + * `feeds`: the feeds to the `Session.run()` call. + * Returns: + (`tf_debug.WatchOptions`) An object containing debug options including + the debug ops to use, the node names, op types and/or tensor data + types to watch, etc. See the documentation of `tf_debug.WatchOptions` + for more details. Raises: TypeError: If a non-None `watch_fn` is specified and it is not callable. """ diff --git a/tensorflow/python/debug/wrappers/hooks.py b/tensorflow/python/debug/wrappers/hooks.py index 150a7b081dc..c86cc413c11 100644 --- a/tensorflow/python/debug/wrappers/hooks.py +++ b/tensorflow/python/debug/wrappers/hooks.py @@ -110,8 +110,18 @@ class LocalCLIDebugHook(session_run_hook.SessionRunHook, run_args = session_run_hook.SessionRunArgs( None, feed_dict=None, options=config_pb2.RunOptions()) if self._performed_action == framework.OnRunStartAction.DEBUG_RUN: - self._decorate_options_for_debug(run_args.options, - run_context.session.graph) + self._decorate_options_for_debug( + run_args.options, + run_context.session.graph, + framework.WatchOptions( + node_name_regex_whitelist=( + on_run_start_response.node_name_regex_whitelist), + op_type_regex_whitelist=( + on_run_start_response.op_type_regex_whitelist), + tensor_dtype_regex_whitelist=( + on_run_start_response.tensor_dtype_regex_whitelist), + tolerate_debug_op_creation_failures=( + on_run_start_response.tolerate_debug_op_creation_failures))) elif self._performed_action == framework.OnRunStartAction.INVOKE_STEPPER: # The _finalized property must be set to False so that the NodeStepper # can insert ops for retrieving TensorHandles. @@ -136,16 +146,17 @@ class LocalCLIDebugHook(session_run_hook.SessionRunHook, run_values.run_metadata) self.on_run_end(on_run_end_request) - def _decorate_options_for_debug(self, options, graph): - """Modify RunOptions.debug_options.debug_tensor_watch_opts for debugging. - - Args: - options: (config_pb2.RunOptions) The RunOptions instance to be modified. - graph: A TensorFlow Graph object. - """ - + def _decorate_options_for_debug(self, options, graph, watch_options): + """Modify RunOptions.debug_options.debug_tensor_watch_opts for debugging.""" debug_utils.watch_graph( - options, graph, debug_urls=self._get_run_debug_urls()) + options, + graph, + debug_urls=self._get_run_debug_urls(), + node_name_regex_whitelist=watch_options.node_name_regex_whitelist, + op_type_regex_whitelist=watch_options.op_type_regex_whitelist, + tensor_dtype_regex_whitelist=watch_options.tensor_dtype_regex_whitelist, + tolerate_debug_op_creation_failures=( + watch_options.tolerate_debug_op_creation_failures)) options.output_partition_graphs = True diff --git a/tensorflow/python/debug/wrappers/local_cli_wrapper.py b/tensorflow/python/debug/wrappers/local_cli_wrapper.py index 1aab95152ad..b29259c901d 100644 --- a/tensorflow/python/debug/wrappers/local_cli_wrapper.py +++ b/tensorflow/python/debug/wrappers/local_cli_wrapper.py @@ -44,7 +44,11 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession): will launch the command-line interface (CLI) of tfdbg. """ - def __init__(self, sess, dump_root=None, log_usage=True, ui_type="curses"): + def __init__(self, + sess, + dump_root=None, + log_usage=True, + ui_type="curses"): """Constructor of LocalCLIDebugWrapperSession. Args: @@ -133,6 +137,27 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession): type=str, default="", help="Run until a tensor in the graph passes the specified filter.") + ap.add_argument( + "--node_name_filter", + dest="node_name_filter", + type=str, + default="", + help="Regular-expression filter for node names to be watched in the " + "run, e.g., loss, reshape.*") + ap.add_argument( + "--op_type_filter", + dest="op_type_filter", + type=str, + default="", + help="Regular-expression filter for op type to be watched in the run, " + "e.g., (MatMul|Add), Variable.*") + ap.add_argument( + "--tensor_dtype_filter", + dest="tensor_dtype_filter", + type=str, + default="", + help="Regular-expression filter for tensor dtype to be watched in the " + "run, e.g., (float32|float64), int.*") self._argparsers["run"] = ap ap = argparse.ArgumentParser( @@ -176,15 +201,11 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession): `run` / `invoke_stepper`. Args: - request: An instance of `OnSessionInitRequest`. + request: An instance of `OnRunStartRequest`. Returns: - An instance of `OnSessionInitResponse`. - - Raises: - RuntimeError: If user chooses to prematurely exit the debugger. + An instance of `OnRunStartResponse`. """ - self._is_run_start = True self._update_run_calls_state(request.run_call_count, request.fetches, request.feed_dict) @@ -195,6 +216,8 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession): return framework.OnRunStartResponse(framework.OnRunStartAction.DEBUG_RUN, self._get_run_debug_urls()) + self._exit_if_requested_by_user() + if self._run_call_count > 1 and not self._skip_debug: if self._run_through_times > 0: # Just run through without debugging. @@ -203,9 +226,10 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession): elif self._run_through_times == 0: # It is the run at which the run-end CLI will be launched: activate # debugging. - return framework.OnRunStartResponse( - framework.OnRunStartAction.DEBUG_RUN, - self._get_run_debug_urls()) + return (self._run_start_response or + framework.OnRunStartResponse( + framework.OnRunStartAction.DEBUG_RUN, + self._get_run_debug_urls())) if self._run_start_response is None: self._prep_cli_for_run_start() @@ -214,6 +238,10 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession): if self._run_through_times > 1: self._run_through_times -= 1 + self._exit_if_requested_by_user() + return self._run_start_response + + def _exit_if_requested_by_user(self): if self._run_start_response == debugger_cli_common.EXPLICIT_USER_EXIT: # Explicit user "exit" command leads to sys.exit(1). print( @@ -221,8 +249,6 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession): file=sys.stderr) sys.exit(1) - return self._run_start_response - def _prep_cli_for_run_start(self): """Prepare (but not launch) the CLI for run-start.""" @@ -398,6 +424,9 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession): _ = screen_info # Currently unused. parsed = self._argparsers["run"].parse_args(args) + parsed.node_name_filter = parsed.node_name_filter or None + parsed.op_type_filter = parsed.op_type_filter or None + parsed.tensor_dtype_filter = parsed.tensor_dtype_filter or None if parsed.till_filter_pass: # For the run-till-bad-numerical-value-appears mode, use the DEBUG_RUN @@ -425,7 +454,12 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession): # Raise CommandLineExit exception to cause the CLI to exit. raise debugger_cli_common.CommandLineExit( - exit_token=framework.OnRunStartResponse(action, debug_urls)) + exit_token=framework.OnRunStartResponse( + action, + debug_urls, + node_name_regex_whitelist=parsed.node_name_filter, + op_type_regex_whitelist=parsed.op_type_filter, + tensor_dtype_regex_whitelist=parsed.tensor_dtype_filter)) def _register_this_run_info(self, curses_cli): curses_cli.register_command_handler( diff --git a/tensorflow/python/debug/wrappers/local_cli_wrapper_test.py b/tensorflow/python/debug/wrappers/local_cli_wrapper_test.py index 01578dcb2ad..e22f6e783e8 100644 --- a/tensorflow/python/debug/wrappers/local_cli_wrapper_test.py +++ b/tensorflow/python/debug/wrappers/local_cli_wrapper_test.py @@ -31,6 +31,7 @@ from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import state_ops from tensorflow.python.ops import variables @@ -45,7 +46,10 @@ class LocalCLIDebuggerWrapperSessionForTest( Inserts observer variables for assertions. """ - def __init__(self, command_args_sequence, sess, dump_root=None): + def __init__(self, + command_args_sequence, + sess, + dump_root=None): """Constructor of the for-test subclass. Args: @@ -99,9 +103,15 @@ class LocalCLIDebugWrapperSessionTest(test_util.TensorFlowTestCase): self._tmp_dir = tempfile.mktemp() self.v = variables.Variable(10.0, name="v") + self.w = variables.Variable(21.0, name="w") self.delta = constant_op.constant(1.0, name="delta") self.inc_v = state_ops.assign_add(self.v, self.delta, name="inc_v") + self.w_int = control_flow_ops.with_dependencies( + [self.inc_v], + math_ops.cast(self.w, dtypes.int32, name="w_int_inner"), + name="w_int_outer") + self.ph = array_ops.placeholder(dtypes.float32, name="ph") self.xph = array_ops.transpose(self.ph, name="xph") self.m = constant_op.constant( @@ -111,7 +121,7 @@ class LocalCLIDebugWrapperSessionTest(test_util.TensorFlowTestCase): self.sess = session.Session() # Initialize variable. - self.sess.run(self.v.initializer) + self.sess.run(variables.global_variables_initializer()) def tearDown(self): ops.reset_default_graph() @@ -356,6 +366,108 @@ class LocalCLIDebugWrapperSessionTest(test_util.TensorFlowTestCase): self.assertEqual(2, len(wrapped_sess.observers["debug_dumps"])) self.assertEqual([None, None], wrapped_sess.observers["tf_errors"]) + def testRunsUnderDebugModeWithWatchFnFilteringNodeNames(self): + # Test command sequence: + # run --node_name_filter inc.* + # run --node_name_filter delta + # run + wrapped_sess = LocalCLIDebuggerWrapperSessionForTest( + [["--node_name_filter", "inc.*"], ["--node_name_filter", "delta"], []], + self.sess, dump_root=self._tmp_dir) + + # run under debug mode twice. + wrapped_sess.run(self.inc_v) + wrapped_sess.run(self.inc_v) + + # Verify that the assign_add op did take effect. + self.assertAllClose(12.0, self.sess.run(self.v)) + + # Verify that the dumps have been generated and picked up during run-end. + self.assertEqual(2, len(wrapped_sess.observers["debug_dumps"])) + + dumps = wrapped_sess.observers["debug_dumps"][0] + self.assertEqual(1, dumps.size) + self.assertEqual("inc_v", dumps.dumped_tensor_data[0].node_name) + + dumps = wrapped_sess.observers["debug_dumps"][1] + self.assertEqual(1, dumps.size) + self.assertEqual("delta", dumps.dumped_tensor_data[0].node_name) + + def testRunsUnderDebugModeWithWatchFnFilteringOpTypes(self): + # Test command sequence: + # run --node_name_filter delta + # run --op_type_filter AssignAdd + # run + wrapped_sess = LocalCLIDebuggerWrapperSessionForTest( + [["--node_name_filter", "delta"], + ["--op_type_filter", "AssignAdd"], + []], + self.sess, dump_root=self._tmp_dir) + + # run under debug mode twice. + wrapped_sess.run(self.inc_v) + wrapped_sess.run(self.inc_v) + + # Verify that the assign_add op did take effect. + self.assertAllClose(12.0, self.sess.run(self.v)) + + # Verify that the dumps have been generated and picked up during run-end. + self.assertEqual(2, len(wrapped_sess.observers["debug_dumps"])) + + dumps = wrapped_sess.observers["debug_dumps"][0] + self.assertEqual(1, dumps.size) + self.assertEqual("delta", dumps.dumped_tensor_data[0].node_name) + + dumps = wrapped_sess.observers["debug_dumps"][1] + self.assertEqual(1, dumps.size) + self.assertEqual("inc_v", dumps.dumped_tensor_data[0].node_name) + + def testRunsUnderDebugModeWithWatchFnFilteringTensorDTypes(self): + # Test command sequence: + # run --op_type_filter Variable.* + # run --dtype_filter int32 + # run + wrapped_sess = LocalCLIDebuggerWrapperSessionForTest( + [["--op_type_filter", "Variable.*"], + ["--tensor_dtype_filter", "int32"], []], + self.sess, dump_root=self._tmp_dir) + + # run under debug mode twice. + wrapped_sess.run(self.w_int) + wrapped_sess.run(self.w_int) + + # Verify that the dumps have been generated and picked up during run-end. + self.assertEqual(2, len(wrapped_sess.observers["debug_dumps"])) + + dumps = wrapped_sess.observers["debug_dumps"][0] + self.assertEqual(2, dumps.size) + self.assertItemsEqual( + ["v", "w"], [dumps.dumped_tensor_data[i].node_name for i in [0, 1]]) + + dumps = wrapped_sess.observers["debug_dumps"][1] + self.assertEqual(2, dumps.size) + self.assertEqual( + ["w_int_inner", "w_int_outer"], + [dumps.dumped_tensor_data[i].node_name for i in [0, 1]]) + + def testRunsUnderDebugModeWithWatchFnFilteringOpTypesAndTensorDTypes(self): + # Test command sequence: + # run --op_type_filter Cast --dtype_filter int32 + # run + wrapped_sess = LocalCLIDebuggerWrapperSessionForTest( + [["--op_type_filter", "Cast", "--tensor_dtype_filter", "int32"], []], + self.sess, dump_root=self._tmp_dir) + + # run under debug mode twice. + wrapped_sess.run(self.w_int) + + # Verify that the dumps have been generated and picked up during run-end. + self.assertEqual(1, len(wrapped_sess.observers["debug_dumps"])) + + dumps = wrapped_sess.observers["debug_dumps"][0] + self.assertEqual(1, dumps.size) + self.assertEqual("w_int_inner", dumps.dumped_tensor_data[0].node_name) + if __name__ == "__main__": googletest.main() diff --git a/tensorflow/python/estimator/BUILD b/tensorflow/python/estimator/BUILD index a2c610f51fe..616b7ae49b1 100644 --- a/tensorflow/python/estimator/BUILD +++ b/tensorflow/python/estimator/BUILD @@ -75,6 +75,7 @@ py_library( "//tensorflow/python:array_ops", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:training", + "//tensorflow/python:util", "@six_archive//:six", ], ) diff --git a/tensorflow/python/estimator/estimator_test.py b/tensorflow/python/estimator/estimator_test.py index b018f999ea9..cac0b55bd38 100644 --- a/tensorflow/python/estimator/estimator_test.py +++ b/tensorflow/python/estimator/estimator_test.py @@ -450,6 +450,27 @@ class EstimatorEvaluateTest(test.TestCase): self.assertIn('metric', scores) self.assertAlmostEqual(2., scores['metric']) + def test_tuple_metrics(self): + def _model_fn(features, labels, mode): + del features # unused + del labels + return model_fn_lib.EstimatorSpec( + mode, + train_op=control_flow_ops.no_op(), + loss=constant_op.constant(1.), + eval_metric_ops={ + 'nested_metric': ( + ((constant_op.constant(2.), constant_op.constant(1)), + constant_op.constant(3., dtype=dtypes.float64)), + control_flow_ops.no_op())}) + est = estimator.Estimator(model_fn=_model_fn) + est.train(dummy_input_fn, steps=1) + evaluation = est.evaluate(dummy_input_fn, steps=1) + ((two_float, one_integer), three_double) = evaluation['nested_metric'] + self.assertAlmostEqual(2., two_float) + self.assertEqual(1, one_integer) + self.assertAlmostEqual(3., three_double) + def test_steps0_raises_error(self): est = estimator.Estimator( model_fn=_model_fn_with_eval_metric_ops) diff --git a/tensorflow/python/estimator/model_fn.py b/tensorflow/python/estimator/model_fn.py index 26e8f514252..ee5999c78bc 100644 --- a/tensorflow/python/estimator/model_fn.py +++ b/tensorflow/python/estimator/model_fn.py @@ -30,6 +30,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.saved_model import signature_constants from tensorflow.python.training import monitored_session from tensorflow.python.training import session_run_hook +from tensorflow.python.util import nest class ModeKeys(object): @@ -195,15 +196,20 @@ class EstimatorSpec( if not isinstance(eval_metric_ops, dict): raise TypeError( 'eval_metric_ops must be a dict, given: {}'.format(eval_metric_ops)) - for key, metric_value in six.iteritems(eval_metric_ops): - if (not isinstance(metric_value, tuple) or - len(metric_value) != 2): + for key, metric_value_and_update in six.iteritems(eval_metric_ops): + if (not isinstance(metric_value_and_update, tuple) or + len(metric_value_and_update) != 2): raise TypeError( - 'Values of eval_metric_ops must be (metric_tensor, update_op) ' - 'tuples, given: {} for key: {}'.format(metric_value, key)) - _check_is_tensor_or_operation(metric_value[0], - 'eval_metric_ops[{}]'.format(key)) - _check_is_tensor_or_operation(metric_value[1], + 'Values of eval_metric_ops must be (metric_value, update_op) ' + 'tuples, given: {} for key: {}'.format( + metric_value_and_update, key)) + metric_value, metric_update = metric_value_and_update + for metric_value_member in nest.flatten(metric_value): + # Allow (possibly nested) tuples for metric values, but require that + # each of them be Tensors or Operations. + _check_is_tensor_or_operation(metric_value_member, + 'eval_metric_ops[{}]'.format(key)) + _check_is_tensor_or_operation(metric_update, 'eval_metric_ops[{}]'.format(key)) # Validate export_outputs. @@ -239,7 +245,7 @@ class EstimatorSpec( raise ValueError('loss must be from the default graph.') if train_op is not None and train_op.graph is not default_graph: raise ValueError('train_op must be from the default graph.') - for value in _eval_metric_ops_values(eval_metric_ops): + for value in nest.flatten(list(eval_metric_ops.values())): if value.graph is not default_graph: raise ValueError( 'eval_metric_ops values must be from the default graph.') @@ -292,14 +298,3 @@ def _prediction_values(predictions): if isinstance(predictions, dict): return list(six.itervalues(predictions)) return [predictions] - - -def _eval_metric_ops_values(eval_metric_ops): - """Returns the values of the given eval_metric_ops dict.""" - if eval_metric_ops is None: - return [] - result = [] - for value_tuple in six.itervalues(eval_metric_ops): - result.append(value_tuple[0]) - result.append(value_tuple[1]) - return result diff --git a/tensorflow/python/estimator/model_fn_test.py b/tensorflow/python/estimator/model_fn_test.py index 8ae6557f1b1..96c38a987b3 100644 --- a/tensorflow/python/estimator/model_fn_test.py +++ b/tensorflow/python/estimator/model_fn_test.py @@ -223,6 +223,17 @@ class EstimatorSpecEvalTest(test.TestCase): training_hooks=[_FakeHook()], scaffold=monitored_session.Scaffold()) + def testTupleMetric(self): + """Tests that no errors are raised when a metric is tuple-valued.""" + with ops.Graph().as_default(), self.test_session(): + loss = constant_op.constant(1.) + model_fn.EstimatorSpec( + mode=model_fn.ModeKeys.EVAL, + loss=loss, + eval_metric_ops={ + 'some_metric': ((loss, loss, (constant_op.constant(2), loss)), + control_flow_ops.no_op())}) + def testLoss1DTensor(self): """Tests that no errors are raised when loss is 1D tensor.""" with ops.Graph().as_default(), self.test_session(): @@ -345,7 +356,7 @@ class EstimatorSpecEvalTest(test.TestCase): loss = constant_op.constant(1.) with self.assertRaisesRegexp( TypeError, - (r'Values of eval_metric_ops must be \(metric_tensor, update_op\) ' + (r'Values of eval_metric_ops must be \(metric_value, update_op\) ' 'tuples')): model_fn.EstimatorSpec( mode=model_fn.ModeKeys.EVAL, @@ -363,6 +374,17 @@ class EstimatorSpecEvalTest(test.TestCase): loss=loss, eval_metric_ops={'loss': ('NonTensor', loss)}) + def testEvalMetricNestedNoTensorOrOperation(self): + with ops.Graph().as_default(), self.test_session(): + loss = constant_op.constant(1.) + with self.assertRaisesRegexp(TypeError, 'must be Operation or Tensor'): + model_fn.EstimatorSpec( + mode=model_fn.ModeKeys.EVAL, + predictions={'loss': loss}, + loss=loss, + eval_metric_ops={'loss': ((('NonTensor',),), + control_flow_ops.no_op())}) + def testEvalMetricOpsFromDifferentGraph(self): with ops.Graph().as_default(): eval_metric_ops = { diff --git a/tensorflow/python/kernel_tests/conv_ops_3d_test.py b/tensorflow/python/kernel_tests/conv_ops_3d_test.py index ffc5f19c25c..04c43ef5fa4 100644 --- a/tensorflow/python/kernel_tests/conv_ops_3d_test.py +++ b/tensorflow/python/kernel_tests/conv_ops_3d_test.py @@ -330,8 +330,11 @@ class Conv3DTest(test.TestCase): if test.is_gpu_available() and use_gpu: data_type = dtypes.float32 + # TOOD(mjanusz): Modify gradient_checker to also provide max relative + # error and synchronize the tolerance levels between the tests for forward + # and backward computations. if test.is_gpu_available(): - tolerance = 4e-3 + tolerance = 5e-3 else: # As of Aug 2016, higher tolerance is needed for some CPU architectures. # Runs on a single machine can also generate slightly different errors diff --git a/tensorflow/python/kernel_tests/py_func_test.py b/tensorflow/python/kernel_tests/py_func_test.py index 938e451cbcd..f1bb3bdc228 100644 --- a/tensorflow/python/kernel_tests/py_func_test.py +++ b/tensorflow/python/kernel_tests/py_func_test.py @@ -160,6 +160,14 @@ class PyOpTest(test.TestCase): _ = script_ops.py_func(lambda x: x + 1, [c], [dtypes.float32]) self.assertTrue(script_ops._py_funcs.size() < 100) + def testAlias(self): + with self.test_session(): + np_array = np.array([1.0, 2.0], dtype=np.float32) + tf_array = script_ops.py_func(lambda: np_array, [], [dtypes.float32]) + value = tf_array + constant_op.constant([2.0, 3.0], dtype=dtypes.float32) + value.op.run() + self.assertAllEqual(np_array, [1.0, 2.0]) + def testBadNumpyReturnType(self): with self.test_session(): diff --git a/tensorflow/python/layers/base.py b/tensorflow/python/layers/base.py index 3c67284f7f1..9b76585f9fb 100644 --- a/tensorflow/python/layers/base.py +++ b/tensorflow/python/layers/base.py @@ -23,6 +23,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import copy import functools import inspect import re @@ -86,10 +87,14 @@ class _Layer(object): self._updates = [] self._losses = [] self._reuse = kwargs.get('_reuse') + self._graph = ops.get_default_graph() self.dtype = dtype # Determine base name (non-unique). - base_name = name + if isinstance(name, vs.VariableScope): + base_name = name.name + else: + base_name = name if not name: base_name = _to_snake_case(self.__class__.__name__) self._base_name = base_name @@ -306,6 +311,13 @@ class _Layer(object): with vs.variable_scope(self._scope, reuse=True if self._built else self._reuse, custom_getter=variable_getter) as scope: + # Ensure the Layer, if being reused, is working with inputs from + # the same graph as where it was created. + try: + ops._get_graph_from_inputs(nest.flatten(inputs), graph=self.graph) # pylint: disable=protected-access + except ValueError as e: + raise ValueError("Inputs' and Layer's graphs are not the same: %s" % e) + with ops.name_scope(scope.original_name_scope): if not self.built: input_list = [ @@ -335,6 +347,25 @@ class _Layer(object): _add_elements_to_collection(self.updates, ops.GraphKeys.UPDATE_OPS) return outputs + @property + def graph(self): + return self._graph + + def __deepcopy__(self, memo): + no_copy = set(['_graph']) + shallow_copy = set(['_scope']) + cls = self.__class__ + result = cls.__new__(cls) + memo[id(self)] = result + for k, v in self.__dict__.items(): + if k in no_copy: + setattr(result, k, v) + elif k in shallow_copy: + setattr(result, k, copy.copy(v)) + else: + setattr(result, k, copy.deepcopy(v, memo)) + return result + def apply(self, inputs, **kwargs): """Apply the layer on a input. diff --git a/tensorflow/python/layers/base_test.py b/tensorflow/python/layers/base_test.py index c82ec47671a..83ae1b6e835 100644 --- a/tensorflow/python/layers/base_test.py +++ b/tensorflow/python/layers/base_test.py @@ -18,6 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import copy + from tensorflow.python.framework import ops from tensorflow.python.layers import base as base_layers from tensorflow.python.ops import init_ops @@ -145,6 +147,12 @@ class BaseLayerTest(test.TestCase): self.assertListEqual(lazy_layer.variables, []) self.assertEqual(lazy_layer.name, 'new_scope') + with ops.Graph().as_default(): + inputs_ng = random_ops.random_uniform((5,), seed=1) + with self.assertRaisesRegexp(ValueError, + r'graphs are not the same'): + layer.apply(inputs_ng) + def testCall(self): class MyLayer(base_layers._Layer): @@ -158,6 +166,24 @@ class BaseLayerTest(test.TestCase): self.assertEqual(layer.built, True) self.assertEqual(outputs.op.name, 'my_layer/Square') + def testDeepCopy(self): + + class MyLayer(base_layers._Layer): + + def call(self, inputs): + return math_ops.square(inputs) + + layer = MyLayer(name='my_layer') + inputs = random_ops.random_uniform((5,), seed=1) + outputs = layer.apply(inputs) + self.assertEqual(layer.built, True) + self.assertEqual(outputs.op.name, 'my_layer/Square') + + layer_copy = copy.deepcopy(layer) + self.assertEqual(layer_copy.name, layer.name) + self.assertEqual(layer_copy._scope.name, layer._scope.name) + self.assertEqual(layer_copy._graph, layer._graph) + def testNaming(self): class PrivateLayer(base_layers._Layer): @@ -178,11 +204,6 @@ class BaseLayerTest(test.TestCase): my_layer1 = PrivateLayer(name='my_layer') my_layer1.apply(inputs) self.assertEqual(my_layer1.name, 'my_layer_1') - # New graph has fully orthogonal names. - with ops.Graph().as_default(): - my_layer_other_graph = PrivateLayer(name='my_layer') - my_layer_other_graph.apply(inputs) - self.assertEqual(my_layer_other_graph.name, 'my_layer') my_layer2 = PrivateLayer(name='my_layer') my_layer2.apply(inputs) self.assertEqual(my_layer2.name, 'my_layer_2') diff --git a/tensorflow/python/layers/convolutional.py b/tensorflow/python/layers/convolutional.py index ecdabaf04a8..3b8959e2106 100644 --- a/tensorflow/python/layers/convolutional.py +++ b/tensorflow/python/layers/convolutional.py @@ -124,10 +124,10 @@ class _Conv(base._Layer): # pylint: disable=protected-access channel_axis = 1 else: channel_axis = -1 - if input_shape[channel_axis] is None: + if input_shape[channel_axis].value is None: raise ValueError('The channel dimension of the inputs ' 'should be defined. Found `None`.') - input_dim = input_shape[channel_axis] + input_dim = input_shape[channel_axis].value kernel_shape = self.kernel_size + (input_dim, self.filters) self.kernel = vs.get_variable('kernel', diff --git a/tensorflow/python/layers/convolutional_test.py b/tensorflow/python/layers/convolutional_test.py index 1a5fe5c9b7d..c3e133d08b2 100644 --- a/tensorflow/python/layers/convolutional_test.py +++ b/tensorflow/python/layers/convolutional_test.py @@ -21,6 +21,7 @@ from __future__ import print_function import numpy as np from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape from tensorflow.python.layers import convolutional as conv_layers from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops @@ -88,6 +89,23 @@ class ConvTest(test.TestCase): self.assertListEqual(layer.kernel.get_shape().as_list(), [3, 3, 4, 32]) self.assertListEqual(layer.bias.get_shape().as_list(), [32]) + def testUnknownInputChannels(self): + images = random_ops.random_uniform((5, 7, 9, 4)) + images._shape = tensor_shape.as_shape((5, 7, 9, None)) + layer = conv_layers.Conv2D(32, [3, 3], activation=nn_ops.relu) + with self.assertRaisesRegexp(ValueError, + 'The channel dimension of the inputs ' + 'should be defined. Found `None`.'): + _ = layer.apply(images) + + images = random_ops.random_uniform((5, 4, 7, 9)) + images._shape = tensor_shape.as_shape((5, None, 7, 9)) + layer = conv_layers.Conv2D(32, [3, 3], data_format='channels_first') + with self.assertRaisesRegexp(ValueError, + 'The channel dimension of the inputs ' + 'should be defined. Found `None`.'): + _ = layer.apply(images) + def testConv2DPaddingSame(self): height, width = 7, 9 images = random_ops.random_uniform((5, height, width, 32), seed=1) @@ -135,6 +153,23 @@ class ConvTest(test.TestCase): self.assertListEqual(layer.kernel.get_shape().as_list(), [3, 4, 32]) self.assertListEqual(layer.bias.get_shape().as_list(), [32]) + def testUnknownInputChannelsConv1D(self): + data = random_ops.random_uniform((5, 4, 7)) + data._shape = tensor_shape.as_shape((5, 4, None)) + layer = conv_layers.Conv1D(32, 3, activation=nn_ops.relu) + with self.assertRaisesRegexp(ValueError, + 'The channel dimension of the inputs ' + 'should be defined. Found `None`.'): + _ = layer.apply(data) + + data = random_ops.random_uniform((5, 7, 4)) + data._shape = tensor_shape.as_shape((5, None, 4)) + layer = conv_layers.Conv1D(32, 3, data_format='channels_first') + with self.assertRaisesRegexp(ValueError, + 'The channel dimension of the inputs ' + 'should be defined. Found `None`.'): + _ = layer.apply(data) + def testCreateConv3D(self): depth, height, width = 6, 7, 9 volumes = random_ops.random_uniform((5, depth, height, width, 4)) @@ -146,6 +181,15 @@ class ConvTest(test.TestCase): self.assertListEqual(layer.kernel.get_shape().as_list(), [3, 3, 3, 4, 32]) self.assertListEqual(layer.bias.get_shape().as_list(), [32]) + def testUnknownInputChannelsConv3D(self): + volumes = random_ops.random_uniform((5, 6, 7, 9, 9)) + volumes._shape = tensor_shape.as_shape((5, 6, 7, 9, None)) + layer = conv_layers.Conv3D(32, [3, 3, 3], activation=nn_ops.relu) + with self.assertRaisesRegexp(ValueError, + 'The channel dimension of the inputs ' + 'should be defined. Found `None`.'): + _ = layer.apply(volumes) + def testConv2DKernelRegularizer(self): height, width = 7, 9 images = random_ops.random_uniform((5, height, width, 4)) diff --git a/tensorflow/python/lib/core/py_func.cc b/tensorflow/python/lib/core/py_func.cc index fd9c64939db..040a4513caa 100644 --- a/tensorflow/python/lib/core/py_func.cc +++ b/tensorflow/python/lib/core/py_func.cc @@ -254,6 +254,9 @@ class NumpyTensorBuffer : public TensorBuffer { return Tensor(dtype, shape, this); } + // Prevents input forwarding from overwriting this buffer. + bool OwnsMemory() const override { return false; } + private: PyArrayObject* array_; size_t len_; diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index 8fdc6fba362..60057b9ab1e 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -2274,7 +2274,7 @@ def where(condition, x=None, y=None, name=None): If both non-None, `x` and `y` must have the same shape. The `condition` tensor must be a scalar if `x` and `y` are scalar. - If `x` and `y` are vectors or higher rank, then `condition` must be either a + If `x` and `y` are vectors of higher rank, then `condition` must be either a vector with size matching the first dimension of `x`, or must have the same shape as `x`. diff --git a/tensorflow/python/ops/gradients_test.py b/tensorflow/python/ops/gradients_test.py index 4f9196daa43..dd28f4e64ef 100644 --- a/tensorflow/python/ops/gradients_test.py +++ b/tensorflow/python/ops/gradients_test.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import sys import warnings import numpy as np @@ -559,13 +560,16 @@ class IndexedSlicesToTensorTest(test_util.TensorFlowTestCase): self.assertAllClose(np_val, c_dense.eval()) def testWarnings(self): - # Smaller than the threshold: no warning. - c_sparse = ops.IndexedSlices( - array_ops.placeholder(dtypes.float32), - array_ops.placeholder(dtypes.int32), constant([4, 4, 4, 4])) - with warnings.catch_warnings(record=True) as w: - math_ops.multiply(c_sparse, 1.0) - self.assertEqual(0, len(w)) + # TODO(gunan) Reenable after this issue is fixed: + # https://github.com/google/protobuf/issues/2812 + if sys.version_info < (3, 6): + # Smaller than the threshold: no warning. + c_sparse = ops.IndexedSlices( + array_ops.placeholder(dtypes.float32), + array_ops.placeholder(dtypes.int32), constant([4, 4, 4, 4])) + with warnings.catch_warnings(record=True) as w: + math_ops.multiply(c_sparse, 1.0) + self.assertEqual(0, len(w)) # Greater than or equal to the threshold: warning. c_sparse = ops.IndexedSlices( diff --git a/tensorflow/python/ops/rnn.py b/tensorflow/python/ops/rnn.py index 680da9b49b3..162b13ec212 100644 --- a/tensorflow/python/ops/rnn.py +++ b/tensorflow/python/ops/rnn.py @@ -663,7 +663,7 @@ def _dynamic_rnn_loop(cell, output_ta = tuple(_create_ta("output_%d" % i, _infer_state_dtype(dtype, state)) for i in range(len(flat_output_size))) - input_ta = tuple(_create_ta("input_%d" % i, flat_input[0].dtype) + input_ta = tuple(_create_ta("input_%d" % i, flat_input[i].dtype) for i in range(len(flat_input))) input_ta = tuple(ta.unstack(input_) diff --git a/tensorflow/python/ops/transpose_benchmark.py b/tensorflow/python/ops/transpose_benchmark.py index b0d8b1d1c11..6bd3fe5e5a0 100644 --- a/tensorflow/python/ops/transpose_benchmark.py +++ b/tensorflow/python/ops/transpose_benchmark.py @@ -31,7 +31,7 @@ from tensorflow.python.ops import variables from tensorflow.python.platform import test -def build_graph(device, input_shape, perm, datatype): +def build_graph(device, input_shape, perm, datatype, num_iters): """Build a graph containing a sequence of conv2d operations. Args: @@ -39,6 +39,7 @@ def build_graph(device, input_shape, perm, datatype): input_shape: Shape of the input tensor. perm: A list of ints with the same length as input tensor's dimension. datatype: numpy data type of the input tensor. + num_iters: number of iterations to run transpose. Returns: An array of tensors to run() @@ -46,11 +47,13 @@ def build_graph(device, input_shape, perm, datatype): with ops.device("/%s:0" % device): total_size = np.prod(input_shape) inp = np.arange(1, total_size + 1, dtype=datatype).reshape(input_shape) - t1 = constant_op.constant(inp, shape=input_shape) + t = constant_op.constant(inp, shape=input_shape) outputs = [] - for _ in range(10): - outputs.append(array_ops.transpose(t1, perm)) + outputs.append(array_ops.transpose(t, perm)) + for i in range(1, num_iters): + with ops.control_dependencies([outputs[i - 1]]): + outputs.append(array_ops.transpose(t, perm)) return control_flow_ops.group(*outputs) @@ -72,18 +75,17 @@ class TransposeBenchmark(test.Benchmark): """ graph = ops.Graph() with graph.as_default(): - outputs = build_graph(device, input_shape, perm, datatype) + outputs = build_graph(device, input_shape, perm, datatype, num_iters) with session_lib.Session(graph=graph) as session: variables.global_variables_initializer().run() - for _ in xrange(10): - session.run(outputs) + # warmup runs + session.run(outputs) start_time = time.time() - for _ in xrange(num_iters): - session.run(outputs) + session.run(outputs) duration = (time.time() - start_time) / num_iters - throughput = np.prod( - np.array(input_shape)) * 4 * 2 / duration / 1000000000 - print("%s %s inputshape:%s perm:%s %d %.4fsec, %.4fGB/s." % + throughput = np.prod(np.array( + input_shape)) * datatype().itemsize * 2 / duration / 1e9 + print("%s %s inputshape:%s perm:%s %d %.6fsec, %.4fGB/s." % (device, str(datatype), str(input_shape).replace(" ", ""), str(perm).replace(" ", ""), num_iters, duration, throughput)) @@ -125,21 +127,21 @@ class TransposeBenchmark(test.Benchmark): huge_perms = [[0, 4, 1, 2, 3], [0, 3, 1, 2], [0, 2, 1], [4, 1, 2, 3, 0], [3, 1, 2, 0], [2, 1, 0]] + num_iters = 40 for datatype in datatypes: for ishape, perm in zip(small_shapes, small_perms): - self._run_graph("gpu", ishape, perm, 20, datatype) + self._run_graph("gpu", ishape, perm, num_iters, datatype) if datatype is not np.complex128: if datatype is not np.float16: for ishape, perm in zip(large_shapes, large_perms): - self._run_graph("gpu", ishape, perm, 20, datatype) + self._run_graph("gpu", ishape, perm, num_iters, datatype) if datatype is not np.complex128: if datatype is not np.float64: if datatype is not np.float16: for ishape, perm in zip(huge_shapes, huge_perms): - self._run_graph("gpu", ishape, perm, 20, datatype) - + self._run_graph("gpu", ishape, perm, num_iters, datatype) if __name__ == "__main__": test.main() diff --git a/tensorflow/python/platform/googletest.py b/tensorflow/python/platform/googletest.py index 2cef056f104..1e74b1512b8 100644 --- a/tensorflow/python/platform/googletest.py +++ b/tensorflow/python/platform/googletest.py @@ -104,10 +104,7 @@ def GetTempDir(): first_frame = inspect.stack()[-1][0] temp_dir = os.path.join( tempfile.gettempdir(), os.path.basename(inspect.getfile(first_frame))) - temp_dir = temp_dir.rstrip('.py') - if not os.path.isdir(temp_dir): - os.mkdir(temp_dir, 0o755) - temp_dir = tempfile.mkdtemp(dir=temp_dir) + temp_dir = tempfile.mkdtemp(prefix=temp_dir.rstrip('.py')) def delete_temp_dir(dirname=temp_dir): try: diff --git a/tensorflow/python/summary/summary.py b/tensorflow/python/summary/summary.py index 44a6686cd61..d130588fa29 100644 --- a/tensorflow/python/summary/summary.py +++ b/tensorflow/python/summary/summary.py @@ -24,6 +24,7 @@ See the @{$python/summary} guide. @@histogram @@audio @@image +@@text @@merge @@merge_all @@get_summary_description @@ -56,8 +57,14 @@ from tensorflow.python.ops import gen_logging_ops as _gen_logging_ops # pylint: disable=unused-import from tensorflow.python.ops.summary_ops import tensor_summary # pylint: enable=unused-import + from tensorflow.python.platform import tf_logging as _logging +# exports text +# pylint: disable=unused-import +from tensorflow.python.summary.text_summary import text_summary as text +# pylint: enable=unused-import + # exports FileWriter, FileWriterCache # pylint: disable=unused-import from tensorflow.python.summary.writer.writer import FileWriter diff --git a/tensorflow/python/summary/text_summary.py b/tensorflow/python/summary/text_summary.py index 4b744fc3f7a..82dee45d267 100644 --- a/tensorflow/python/summary/text_summary.py +++ b/tensorflow/python/summary/text_summary.py @@ -59,7 +59,7 @@ def text_summary(name, tensor, collections=None): raise ValueError("Expected tensor %s to be scalar, has shape %s" % (tensor.name, tensor.shape)) - t_summary = tensor_summary(name, tensor, collections) + t_summary = tensor_summary(name, tensor, collections=collections) text_assets = plugin_asset.get_plugin_asset(TextSummaryPluginAsset) text_assets.register_tensor(t_summary.op.name) return t_summary diff --git a/tensorflow/python/summary/text_summary_test.py b/tensorflow/python/summary/text_summary_test.py index b4059778ed8..69739573c10 100644 --- a/tensorflow/python/summary/text_summary_test.py +++ b/tensorflow/python/summary/text_summary_test.py @@ -17,6 +17,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.framework import ops as framework_ops from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.platform import googletest @@ -46,6 +47,11 @@ class TextPluginTest(test_util.TensorFlowTestCase): summ = text_summary.text_summary("foo", array_ops.constant("one")) self.assertEqual(summ.op.type, "TensorSummary") + text_summary.text_summary("bar", array_ops.constant("2"), collections=[]) + summaries = framework_ops.get_collection( + framework_ops.GraphKeys.SUMMARIES) + self.assertEqual(len(summaries), 1) + if __name__ == "__main__": googletest.main() diff --git a/tensorflow/python/training/adadelta.py b/tensorflow/python/training/adadelta.py index 49bb10b5128..13c07cfd7bf 100644 --- a/tensorflow/python/training/adadelta.py +++ b/tensorflow/python/training/adadelta.py @@ -37,6 +37,7 @@ class AdadeltaOptimizer(optimizer.Optimizer): Args: learning_rate: A `Tensor` or a floating point value. The learning rate. + To match the exact form in the original paper use 1.0. rho: A `Tensor` or a floating point value. The decay rate. epsilon: A `Tensor` or a floating point value. A constant epsilon used to better conditioning the grad update. diff --git a/tensorflow/python/training/basic_session_run_hooks.py b/tensorflow/python/training/basic_session_run_hooks.py index 3d9a3fec3d1..f13b87dfed6 100644 --- a/tensorflow/python/training/basic_session_run_hooks.py +++ b/tensorflow/python/training/basic_session_run_hooks.py @@ -277,9 +277,8 @@ class CheckpointSaverListener(object): right thing, similar to what CheckpointSaverHook.end() does using self._timer.last_triggered_step(). - To use such listeners, pass them in the checkpoint_listeners argument to - graph_actions._monitored_train(). If using tf.Learn Estimators, create a - custom Estimator and override _get_checkpoint_listeners(). + To use such listeners, in your `model_fn` return a `CheckpointSaverHook` as + part of `training_chief_hooks`. """ def begin(self): diff --git a/tensorflow/python/training/monitored_session.py b/tensorflow/python/training/monitored_session.py index f5acd0deb9f..ae76a1ab580 100644 --- a/tensorflow/python/training/monitored_session.py +++ b/tensorflow/python/training/monitored_session.py @@ -256,7 +256,8 @@ def MonitoredTrainingSession(master='', # pylint: disable=invalid-name save_summaries_steps=100, save_summaries_secs=None, config=None, - stop_grace_period_secs=120): + stop_grace_period_secs=120, + log_step_count_steps=10000): """Creates a `MonitoredSession` for training. For a chief, this utility sets proper session initializer/restorer. It also @@ -292,6 +293,8 @@ def MonitoredTrainingSession(master='', # pylint: disable=invalid-name It's the `config` argument of constructor of `tf.Session`. stop_grace_period_secs: Number of seconds given to threads to stop after `close()` has been called. + log_step_count_steps: The frequency, in number of global steps, that the + global step/sec is logged. Returns: A `MonitoredSession` object. @@ -313,8 +316,8 @@ def MonitoredTrainingSession(master='', # pylint: disable=invalid-name config=config) if checkpoint_dir: - all_hooks.append( - basic_session_run_hooks.StepCounterHook(output_dir=checkpoint_dir)) + all_hooks.append(basic_session_run_hooks.StepCounterHook( + output_dir=checkpoint_dir, every_n_steps=log_step_count_steps)) if (save_summaries_steps and save_summaries_steps > 0) or ( save_summaries_secs and save_summaries_secs > 0): diff --git a/tensorflow/python/training/monitored_session_test.py b/tensorflow/python/training/monitored_session_test.py index e22972206d4..41f8fb34869 100644 --- a/tensorflow/python/training/monitored_session_test.py +++ b/tensorflow/python/training/monitored_session_test.py @@ -224,7 +224,8 @@ class MonitoredTrainingSessionTest(test.TestCase): with monitored_session.MonitoredTrainingSession( is_chief=True, checkpoint_dir=logdir, - save_summaries_steps=100) as session: + save_summaries_steps=100, + log_step_count_steps=10) as session: for _ in range(101): session.run(new_gstep) summaries = util_test.latest_summaries(logdir) @@ -242,7 +243,8 @@ class MonitoredTrainingSessionTest(test.TestCase): is_chief=True, checkpoint_dir=logdir, save_summaries_steps=None, - save_summaries_secs=0.1) as session: + save_summaries_secs=0.1, + log_step_count_steps=10) as session: session.run(new_gstep) time.sleep(0.2) for _ in range(101): diff --git a/tensorflow/python/training/proximal_adagrad_test.py b/tensorflow/python/training/proximal_adagrad_test.py index 28e28687f45..1da7f75531a 100644 --- a/tensorflow/python/training/proximal_adagrad_test.py +++ b/tensorflow/python/training/proximal_adagrad_test.py @@ -132,7 +132,7 @@ class ProximalAdagradOptimizerTest(test.TestCase): for _ in range(10): update.run() v0_val, v1_val = sess.run([var0, var1]) - self.assertAllClose(np.array([0.662907, 0.767398]), v0_val) + self.assertAllClose(np.array([-6.663634, -9.190331]), v0_val) self.assertAllClose(np.array([2.959304, 1.029232]), v1_val) def testProximalAdagradWithL1_L2(self): @@ -159,8 +159,8 @@ class ProximalAdagradOptimizerTest(test.TestCase): update.run() v0_val, v1_val = sess.run([var0, var1]) - self.assertAllClose(np.array([0.043069, 0.080461]), v0_val) - self.assertAllClose(np.array([0.004069, 0.008578]), v1_val) + self.assertAllClose(np.array([-0.0495, -0.0995]), v0_val) + self.assertAllClose(np.array([-0.0045, -0.0095]), v1_val) def applyOptimizer(self, opt, steps=5, is_sparse=False): if is_sparse: diff --git a/tensorflow/python/training/proximal_gradient_descent_test.py b/tensorflow/python/training/proximal_gradient_descent_test.py index 9c5ea670150..4e4812fe603 100644 --- a/tensorflow/python/training/proximal_gradient_descent_test.py +++ b/tensorflow/python/training/proximal_gradient_descent_test.py @@ -131,8 +131,8 @@ class ProximalGradientDescentOptimizerTest(test.TestCase): update.run() v0_val, v1_val = sess.run([var0, var1]) - self.assertAllClose(np.array([0.037125, 0.074625]), v0_val) - self.assertAllClose(np.array([0.003375, 0.007125]), v1_val) + self.assertAllClose(np.array([-0.0495, -0.0995]), v0_val) + self.assertAllClose(np.array([-0.0045, -0.0095]), v1_val) def applyOptimizer(self, opt, steps=5, is_sparse=False): if is_sparse: diff --git a/tensorflow/python/training/supervisor.py b/tensorflow/python/training/supervisor.py index f24b4830997..9435bdfa1cc 100644 --- a/tensorflow/python/training/supervisor.py +++ b/tensorflow/python/training/supervisor.py @@ -706,12 +706,14 @@ class Supervisor(object): init_feed_dict=self._init_feed_dict, init_fn=self._init_fn) self._write_graph() if start_standard_services: + logging.info("Starting standard services.") self.start_standard_services(sess) else: sess = self._session_manager.wait_for_session(master, config=config, max_wait_secs=max_wait_secs) if start_standard_services: + logging.info("Starting queue runners.") self.start_queue_runners(sess) return sess @@ -992,6 +994,7 @@ class SVSummaryThread(coordinator.LooperThread): summary_strs = self._sess.run(self._sv.summary_op) global_step = None if self._sv.summary_writer: + logging.info("Recording summary at step %d.", global_step) self._sv.summary_writer.add_summary(summary_strs, global_step) @@ -1051,6 +1054,7 @@ class SVTimerCheckpointThread(coordinator.LooperThread): self._sess = sess def run_loop(self): + logging.info("Saving checkpoint to path %s", self._sv.save_path) self._sv.saver.save(self._sess, self._sv.save_path, global_step=self._sv.global_step) if self._sv.summary_writer and self._sv.global_step is not None: diff --git a/tensorflow/tensorboard/DEVELOPMENT.md b/tensorflow/tensorboard/DEVELOPMENT.md index 1d98df0a837..0a35dec42fb 100644 --- a/tensorflow/tensorboard/DEVELOPMENT.md +++ b/tensorflow/tensorboard/DEVELOPMENT.md @@ -50,3 +50,77 @@ produce a new tf-tensorboard.html with your changes. Now, you can use `bazel` to launch TensorBoard: `bazel run //tensorflow/tensorboard:tensorboard -- --logdir=/path/to/logs`. + +## Updating the vulcanized HTML file (for linux) + +The vulcanized HTML file `dist/tf-tensorboard.html.OPENSOURCE` is the version of +Tensorboard started up by users who install TensorFlow via pip. Today, updating +that file involves using gulp. Future efforts will streamline this process. + +First, `cd` into the `tensorflow/tensorboard` directory within a git repository +(a piper client will not work). Run `npm run prepare`. + +Next, we build some third party JS dependencies via webfiles targets. Run + + bazel build \ + tensorflow/tensorboard/components/tf_imports:d3 \ + tensorflow/tensorboard/components/tf_imports:lodash \ + tensorflow/tensorboard/components/tf_imports:graphlib \ + tensorflow/tensorboard/components/tf_imports:dagre \ + tensorflow/tensorboard/components/tf_imports:plottable + +Users internal to Google should use the internal build tool instead. Move the +output JS binaries into the tf_imports directory. + +Run `gulp vulcanize`. If compilation errors arise (such as those related to +TypeScript), fix them and re-run. This step should update the contents of +`dist/tf-tensorboard.html.OPENSOURCE`. + +Next, we perform some manual find-and-replaces on script `src` paths within +`dist/tf-tensorboard.html.OPENSOURCE`. Manually replace: + +* `` with `` +* `` with `` +* `` with `` +* `` with `` +* `` with `` + +Also, remove duplicate instances of script includes. Each of those scripts +should only be included once (the first time) within the vulcanized output. + +### Try out the vulcanized Tensorboard HTML output + +To test the vulcanized output, prepare a pip package within a virtualized +environment, and run `tensorboard` after activating the environment. + +To do that, we first create and activate a virtual environment called say +`tf_foo` (Pick your own name.). + + virtualenv --system-site-packages ~/tf_foo + source ~/tf_foo/bin/activate + +Make sure that you have installed `pip` and `virtualenv` beforehand. If not, run + + sudo easy_install pip + sudo pip install --upgrade virtualenv + +Next, we run this command from the `tensorflow directory`. + + tools/google/make_tree.sh --pip_dir=/tmp/pip_dir + +to create a pip package. If you are running within Google, also provide the +`--pending_cl` flag. That script will generate a wheel file (.whl) within +`/tmp/pip_dir`. Lets say that it is +`tensorflow-1.0.0rc2-cp27-none-linux_x86_64.whl`. + +Run + + pip install --upgrade /tmp/pip_dir/tensorflow-1.0.0rc2-cp27-none-linux_x86_64.whl + +to update the pip installation of TensorFlow within the virtual environment. +Verify that the `tensorboard` command defers to the tensorboard instance +installed within your virtual environment (`tf_foo`) by running +`which tensorboard`. To run tensorboard, start it up as usual within the virtual +environment: + + tensorboard --logdir=/tmp/my/logdir diff --git a/tensorflow/tensorboard/README.md b/tensorflow/tensorboard/README.md index 63c1619af8d..c9e997044c7 100644 --- a/tensorflow/tensorboard/README.md +++ b/tensorflow/tensorboard/README.md @@ -1,8 +1,7 @@ # TensorBoard TensorBoard is a suite of web applications for inspecting and understanding your -TensorFlow runs and graphs. TensorBoard currently supports five visualizations: -scalars, images, audio, histograms, and the graph. +TensorFlow runs and graphs. This README gives an overview of key concepts in TensorBoard, as well as how to interpret the visualizations TensorBoard provides. For an in-depth example of @@ -10,6 +9,10 @@ using TensorBoard, see the tutorial: [TensorBoard: Visualizing Learning](https://www.tensorflow.org/get_started/summaries_and_tensorboard). For in-depth information on the Graph Visualizer, see this tutorial: [TensorBoard: Graph Visualization](https://www.tensorflow.org/get_started/graph_viz). +You may also want to watch +[this video tutorial](https://www.youtube.com/watch?v=eBbEDRsCmv4) that walks +through setting up and using TensorBoard. + # Usage Before running TensorBoard, make sure you have generated summary data in a log @@ -21,8 +24,7 @@ directory by creating a summary writer: file_writer = tf.summary.FileWriter('/path/to/logs', sess.graph) ``` -For more details, see [this -tutorial](http://www.tensorflow.org/how_tos/summaries_and_tensorboard/index.html#serializing-the-data). +For more details, see [the TensorBoard tutorial](https://www.tensorflow.org/get_started/summaries_and_tensorboard). Once you have event files, run TensorBoard and provide the log directory. If you're using a precompiled TensorFlow package (e.g. you installed via pip), run: @@ -37,7 +39,8 @@ bazel build tensorflow/tensorboard:tensorboard ./bazel-bin/tensorflow/tensorboard/tensorboard --logdir=path/to/logs ``` -This should print that TensorBoard has started. Next, connect to http://localhost:6006. +This should print that TensorBoard has started. Next, connect to +http://localhost:6006. TensorBoard requires a `logdir` to read logs from. For info on configuring TensorBoard, run `tensorboard --help`. @@ -50,19 +53,25 @@ work, but there may be bugs or performance issues. ### Summary Ops: How TensorBoard gets data from TensorFlow The first step in using TensorBoard is acquiring data from your TensorFlow run. -For this, you need [summary -ops](https://www.tensorflow.org/versions/r1.1/api_docs/python/train.html#summary-operations). +For this, you need [summary ops](https://www.tensorflow.org/api_docs/python/tf/summary). Summary ops are ops, like -[`tf.matmul`](https://www.tensorflow.org/versions/r1.1/api_docs/python/math_ops.html#matmul) +[`tf.matmul`](https://www.tensorflow.org/versions/r1.1/api_docs/python/tf/matmul) or -[`tf.nn.relu`](https://www.tensorflow.org/versions/r1.1/api_docs/python/nn.html#relu), +[`tf.nn.relu`](https://www.tensorflow.org/versions/master/api_docs/python/tf/nn/relu), which means they take in tensors, produce tensors, and are evaluated from within a TensorFlow graph. However, summary ops have a twist: the Tensors they produce contain serialized protobufs, which are written to disk and sent to TensorBoard. To visualize the summary data in TensorBoard, you should evaluate the summary op, retrieve the result, and then write that result to disk using a summary.FileWriter. A full explanation, with examples, is in [the -tutorial](https://www.tensorflow.org/versions/r1.1/how_tos/summaries_and_tensorboard/index.html). +tutorial](https://www.tensorflow.org/get_started/summaries_and_tensorboard). + +The supported summary ops include: +* tf.summary.scalar +* tf.summary.image +* tf.summary.audio +* tf.summary.text +* tf.summary.histogram ### Tags: Giving names to data @@ -121,9 +130,9 @@ tensorboard --logdir=name1:/path/to/logs/1,name2:/path/to/logs/2 # The Visualizations -### Events Dashboard +### Scalar Dashboard -TensorBoard's Events Dashboard visualizes scalar statistics that vary over time; +TensorBoard's Scalar Dashboard visualizes scalar statistics that vary over time; for example, you might want to track the model's loss or learning rate. As described in *Key Concepts*, you can compare multiple runs, and the data is organized by tag. The line charts have the following interactions: @@ -141,12 +150,20 @@ the run-selector on the left. Additionally, you can create new folders to organize tags by writing regular expressions in the box in the top-left of the dashboard. +### Histogram Dashboard + +The HistogramDashboard displays how the statistical distribution of a Tensor +has varied over time. It visualizes data recorded via `tf.summary.histogram`. +Each chart shows temporal "slices" of data, where each slice is a histogram of +the tensor at a given step. It's organized with the oldest timestep in the back, +and the most recent timestep in front. By changing the Histogram Mode from +"offset" to "overlay", the perspective will rotate so that every histogram slice +is rendered as a line and overlaid with one another. + ### Distribution Dashboard -The Distribution Dashboard is for visualizing how the statistical distribution -of a Tensor has varied over time. It visualizes data recorded via a -tf.summary.histogram. Right now, its name is a bit of a misnomer, as it doesn't -show histograms; instead, it shows some high-level statistics on a distribution. +The Distribution Dashboard is another way of visualizing histogram data from +`tf.summary.histogram`. It shows some high-level statistics on a distribution. Each line on the chart represents a percentile in the distribution over the data: for example, the bottom line shows how the minimum value has changed over time, and the line in the middle shows how the median has changed. Reading from @@ -158,22 +175,20 @@ normal distribution: `[maximum, μ+1.5σ, μ+σ, μ+0.5σ, μ, μ-0.5σ, μ-σ, minimum]` so that the colored regions, read from inside to outside, have widths `[σ, 2σ, 3σ]` respectively. -This histogram visualization is a bit weird, and cannot meaningfully represent -multimodal distributions. We are currently working on a true-histogram -replacement. ### Image Dashboard -The Image Dashboard can display pngs that were saved via a tf.summary.image. The -dashboard is set up so that each row corresponds to a different tag, and each -column corresponds to a run. Since the image dashboard supports arbitrary pngs, -you can use this to embed custom visualizations (e.g. matplotlib scatterplots) -into TensorBoard. This dashboard always shows you the latest image for each tag. +The Image Dashboard can display pngs that were saved via a `tf.summary.image`. +The dashboard is set up so that each row corresponds to a different tag, and +each column corresponds to a run. Since the image dashboard supports arbitrary +pngs, you can use this to embed custom visualizations (e.g. matplotlib +scatterplots) into TensorBoard. This dashboard always shows you the latest image +for each tag. ### Audio Dashboard The Audio Dashboard can embed playable audio widgets for audio saved via a -tf.summary.audio. The dashboard is set up so that each row corresponds to a +`tf.summary.audio`. The dashboard is set up so that each row corresponds to a different tag, and each column corresponds to a run. This dashboard always embeds the latest audio for each tag. @@ -183,8 +198,21 @@ The Graph Explorer can visualize a TensorBoard graph, enabling inspection of the TensorFlow model. To get best use of the graph visualizer, you should use name scopes to hierarchically group the ops in your graph - otherwise, the graph may be difficult to decipher. For more information, including examples, see [the -graph visualizer -tutorial](https://www.tensorflow.org/versions/r1.1/how_tos/graph_viz/index.html#tensorboard-graph-visualization). +graph visualizer tutorial](https://www.tensorflow.org/get_started/graph_viz). + +### Embedding Projector + +The Embedding Projector allows you to visualize high-dimensional data; for +example, you may view your input data after it has been embedded in a high- +dimensional space by your model. The embedding projector reads data from your +model checkpoint file, and may be configured with additional metadata, like +a vocabulary file or sprite images. For more details, see [the embedding +projector tutorial](https://www.tensorflow.org/get_started/embedding_viz). + +### Text Dashboard + +The Text Dashboard displays text snippets saved via `tf.summary.text`. Markdown +features including hyperlinks, lists, and tables are all supported. # Frequently Asked Questions diff --git a/tensorflow/tensorboard/TAG b/tensorflow/tensorboard/TAG index e373ee695f6..82cced27d7b 100644 --- a/tensorflow/tensorboard/TAG +++ b/tensorflow/tensorboard/TAG @@ -1 +1 @@ -50 +51 diff --git a/tensorflow/tensorboard/components/tf_backend/backend.ts b/tensorflow/tensorboard/components/tf_backend/backend.ts index b87ced2ec28..54d89a6bbb1 100644 --- a/tensorflow/tensorboard/components/tf_backend/backend.ts +++ b/tensorflow/tensorboard/components/tf_backend/backend.ts @@ -38,6 +38,9 @@ module TF.Backend { export type ScalarDatum = Datum & Scalar; export interface Scalar { scalar: number; } + export interface Text { text: string; } + export type TextDatum = Datum & Text; + export type HistogramDatum = Datum & Histogram; export interface Histogram { min: number; @@ -86,7 +89,7 @@ module TF.Backend { export var TYPES = [ 'scalar', 'histogram', 'compressedHistogram', 'graph', 'image', 'audio', - 'runMetadata' + 'runMetadata', 'text' ]; /** * The Backend class provides a convenient and typed interface to the backend. @@ -177,6 +180,27 @@ module TF.Backend { return this.runs().then((x) => _.mapValues(x, 'run_metadata')); } + + /** + * Returns a promise showing the Run-to-Tag mapping for text data. + */ + public textRuns(): Promise { + return this.requestManager.request(this.router.textRuns()); + } + + + /** + * Returns a promise containing TextDatums for given run and tag. + */ + public text(tag: string, run: string): Promise { + let url = this.router.text(tag, run); + // tslint:disable-next-line:no-any it's convenient and harmless here + return this.requestManager.request(url).then(map(function(x: any) { + x.wall_time = timeToDate(x.wall_time); + return x; + })); + } + /** * Return a promise of a graph string from the backend. */ diff --git a/tensorflow/tensorboard/components/tf_backend/router.ts b/tensorflow/tensorboard/components/tf_backend/router.ts index fef3f9327ca..d14216dcfc9 100644 --- a/tensorflow/tensorboard/components/tf_backend/router.ts +++ b/tensorflow/tensorboard/components/tf_backend/router.ts @@ -28,6 +28,8 @@ module TF.Backend { graph: (run: string, limit_attr_size?: number, large_attrs_key?: string) => string; runMetadata: RunTagUrlFn; + textRuns: () => string; + text: RunTagUrlFn; healthPills: () => string; }; @@ -100,6 +102,8 @@ module TF.Backend { audio: standardRoute('audio'), runMetadata: standardRoute('run_metadata', '.pbtxt'), healthPills: () => dataDir + '/plugin/debugger/health_pills', + textRuns: () => dataDir + '/plugin/text/runs' + (demoMode ? '.json' : ''), + text: standardRoute('plugin/text/text'), }; }; } diff --git a/tensorflow/tensorboard/components/tf_dashboard_common/tf-panes-helper.html b/tensorflow/tensorboard/components/tf_dashboard_common/tf-panes-helper.html index c434bd47282..155259d3294 100644 --- a/tensorflow/tensorboard/components/tf_dashboard_common/tf-panes-helper.html +++ b/tensorflow/tensorboard/components/tf_dashboard_common/tf-panes-helper.html @@ -113,7 +113,7 @@ downloadLinkUrlFunction property to an appropriate value. display: flex; flex-direction: column; margin: 5px; - padding: 0 30px 35px 0; + padding: var(--card-padding, 0 30px 35px 0); -webkit-user-select: none; -moz-user-select: none; position: relative; @@ -166,6 +166,7 @@ downloadLinkUrlFunction property to an appropriate value. padding: 4px; border-radius: 100%; pointer-events: auto; + display: var(--show-expand-button, block); } .card-expanded .expand-button { diff --git a/tensorflow/tensorboard/components/tf_distribution_dashboard/tf-distribution-dashboard.html b/tensorflow/tensorboard/components/tf_distribution_dashboard/tf-distribution-dashboard.html index 40e0b15413a..2da848bd99e 100644 --- a/tensorflow/tensorboard/components/tf_distribution_dashboard/tf-distribution-dashboard.html +++ b/tensorflow/tensorboard/components/tf_distribution_dashboard/tf-distribution-dashboard.html @@ -82,6 +82,7 @@ contains vz-distribution-charts embedded inside tf-panes-helper's. color-scale="[[_colorScale]]" data-type="[[dataType]]" data-provider="[[dataProvider]]" + data-not-found="[[dataNotFound]]" run2tag="[[run2tag]]" selected-runs="[[_selectedRuns]]" repeat-for-runs diff --git a/tensorflow/tensorboard/components/tf_globals/globals.ts b/tensorflow/tensorboard/components/tf_globals/globals.ts index 33feb26d238..0dcdde8cd38 100644 --- a/tensorflow/tensorboard/components/tf_globals/globals.ts +++ b/tensorflow/tensorboard/components/tf_globals/globals.ts @@ -19,7 +19,7 @@ module TF.Globals { // The names of TensorBoard tabs. export var TABS = [ 'scalars', 'images', 'audio', 'graphs', 'distributions', 'histograms', - 'embeddings' + 'embeddings', 'text' ]; // If true, TensorBoard stores its hash in the URI state. diff --git a/tensorflow/tensorboard/components/tf_histogram_dashboard/tf-histogram-dashboard.html b/tensorflow/tensorboard/components/tf_histogram_dashboard/tf-histogram-dashboard.html index c3105c3fd83..f8c9cc4537f 100644 --- a/tensorflow/tensorboard/components/tf_histogram_dashboard/tf-histogram-dashboard.html +++ b/tensorflow/tensorboard/components/tf_histogram_dashboard/tf-histogram-dashboard.html @@ -95,6 +95,7 @@ contains vz-histogram-timeseries embedded inside tf-panes-helper's. color-scale="[[_colorScale]]" data-type="[[dataType]]" data-provider="[[dataProvider]]" + data-not-found="[[dataNotFound]]" run2tag="[[run2tag]]" selected-runs="[[_selectedRuns]]" repeat-for-runs diff --git a/tensorflow/tensorboard/components/tf_image_dashboard/tf-image-dashboard.html b/tensorflow/tensorboard/components/tf_image_dashboard/tf-image-dashboard.html index 0274a1f3391..19c272a4683 100644 --- a/tensorflow/tensorboard/components/tf_image_dashboard/tf-image-dashboard.html +++ b/tensorflow/tensorboard/components/tf_image_dashboard/tf-image-dashboard.html @@ -57,6 +57,7 @@ tf-image-dashboard displays a dashboard that loads images from a TensorFlow run. color-scale="[[_colorScale]]" data-type="[[dataType]]" data-provider="[[dataProvider]]" + data-not-found="[[dataNotFound]]" run2tag="[[run2tag]]" selected-runs="[[_selectedRuns]]" repeat-for-runs diff --git a/tensorflow/tensorboard/components/tf_scalar_dashboard/tf-scalar-dashboard.html b/tensorflow/tensorboard/components/tf_scalar_dashboard/tf-scalar-dashboard.html index b18bc2e798c..641573366a6 100644 --- a/tensorflow/tensorboard/components/tf_scalar_dashboard/tf-scalar-dashboard.html +++ b/tensorflow/tensorboard/components/tf_scalar_dashboard/tf-scalar-dashboard.html @@ -121,6 +121,7 @@ contains vz-line-charts embedded inside tf-panes-helper's. color-scale="[[_colorScale]]" data-type="[[dataType]]" data-provider="[[dataProvider]]" + data-not-found="[[dataNotFound]]" run2tag="[[run2tag]]" selected-runs="[[_selectedRuns]]" show-download-links="[[_showDownloadLinks]]" diff --git a/tensorflow/tensorboard/components/tf_tensorboard/tf-tensorboard.html b/tensorflow/tensorboard/components/tf_tensorboard/tf-tensorboard.html index 2df282c0531..b5b2e2d5a86 100644 --- a/tensorflow/tensorboard/components/tf_tensorboard/tf-tensorboard.html +++ b/tensorflow/tensorboard/components/tf_tensorboard/tf-tensorboard.html @@ -31,6 +31,7 @@ limitations under the License. + @@ -134,6 +135,13 @@ allows the user to toggle between various dashboards. route-prefix="/data/plugin/projector"> + + @@ -294,6 +302,9 @@ allows the user to toggle between various dashboards. _modeIsHistograms: function(mode) { return mode === "histograms"; }, + _modeIsText: function(mode) { + return mode === "text"; + }, selectedDashboard: function() { var dashboard = this.$$("#" + this.mode); if (dashboard == null) { diff --git a/tensorflow/tensorboard/components/tf_text_dashboard/BUILD b/tensorflow/tensorboard/components/tf_text_dashboard/BUILD new file mode 100644 index 00000000000..9e4cd3614dc --- /dev/null +++ b/tensorflow/tensorboard/components/tf_text_dashboard/BUILD @@ -0,0 +1,58 @@ +package(default_visibility = ["//tensorflow:internal"]) + +load("@io_bazel_rules_closure//closure:defs.bzl", "webfiles") +load("//tensorflow/tensorboard:defs.bzl", "tensorboard_ts_library") +load("//tensorflow/tensorboard:defs.bzl", "tensorboard_webcomponent_library") + +licenses(["notice"]) # Apache 2.0 + +webfiles( + name = "tf_text_dashboard", + srcs = [ + "tf-text-dashboard.html", + "tf-text-loader.html", + ], + path = "/tf-text-dashboard", + deps = [ + "//tensorflow/tensorboard/components/tf_backend", + "//tensorflow/tensorboard/components/tf_color_scale", + "//tensorflow/tensorboard/components/tf_dashboard_common", + "//tensorflow/tensorboard/components/tf_imports:d3", + "//tensorflow/tensorboard/components/tf_imports:lodash", + "@org_polymer", + "@org_polymer_paper_dialog", + "@org_polymer_paper_icon_button", + "@org_polymer_paper_material", + "@org_polymer_paper_slider", + "@org_polymer_paper_spinner", + ], +) + +filegroup( + name = "all_files", + srcs = glob(["**"]), + tags = ["notsan"], +) + +################################################################################ +# MARKED FOR DELETION + +tensorboard_webcomponent_library( + name = "legacy", + srcs = [ + "tf-text-dashboard.html", + "tf-text-loader.html", + ], + destdir = "tf-text-dashboard", + deps = [ + "//tensorflow/tensorboard/components/tf_backend:legacy", + "//tensorflow/tensorboard/components/tf_dashboard_common:legacy", + "//third_party/javascript/polymer/v1/paper-material:lib", + ], +) + +# This is needed: components/BUILD seeks a legacy_ts rule in this package. +tensorboard_ts_library( + name = "legacy_ts", + srcs = [], +) diff --git a/tensorflow/tensorboard/components/tf_text_dashboard/demo/BUILD b/tensorflow/tensorboard/components/tf_text_dashboard/demo/BUILD new file mode 100644 index 00000000000..6cd6702e4b7 --- /dev/null +++ b/tensorflow/tensorboard/components/tf_text_dashboard/demo/BUILD @@ -0,0 +1,25 @@ +package(default_visibility = ["//tensorflow:internal"]) + +load("@io_bazel_rules_closure//closure:defs.bzl", "webfiles") + +licenses(["notice"]) # Apache 2.0 + +# bazel run //third_party/tensorflow/tensorboard/components/tf_text_dashboard/demo +webfiles( + name = "demo", + srcs = ["index.html"], + path = "/tf-text-dashboard/demo", + deps = [ + "//tensorflow/tensorboard/components/tf_text_dashboard", + "//tensorflow/tensorboard/components/tf_text_dashboard/demo/data", + "@org_polymer_iron_demo_helpers", + "@org_polymer_paper_styles", + "@org_polymer_webcomponentsjs", + ], +) + +filegroup( + name = "all_files", + srcs = glob(["**"]), + tags = ["notsan"], +) diff --git a/tensorflow/tensorboard/components/tf_text_dashboard/demo/data/BUILD b/tensorflow/tensorboard/components/tf_text_dashboard/demo/data/BUILD new file mode 100644 index 00000000000..8adf661396c --- /dev/null +++ b/tensorflow/tensorboard/components/tf_text_dashboard/demo/data/BUILD @@ -0,0 +1,17 @@ +package(default_visibility = ["//tensorflow:internal"]) + +load("@io_bazel_rules_closure//closure:defs.bzl", "webfiles") + +licenses(["notice"]) # Apache 2.0 + +webfiles( + name = "data", + srcs = glob(["*"]), + path = "/tf-text-dashboard/demo/data/plugin/text", +) + +filegroup( + name = "all_files", + srcs = glob(["**"]), + tags = ["notsan"], +) diff --git a/tensorflow/tensorboard/components/tf_text_dashboard/demo/data/logdir b/tensorflow/tensorboard/components/tf_text_dashboard/demo/data/logdir new file mode 100644 index 00000000000..c7d82022cc0 --- /dev/null +++ b/tensorflow/tensorboard/components/tf_text_dashboard/demo/data/logdir @@ -0,0 +1 @@ +{"logdir": "/some/fake/logdir"} \ No newline at end of file diff --git a/tensorflow/tensorboard/components/tf_text_dashboard/demo/data/runs.json b/tensorflow/tensorboard/components/tf_text_dashboard/demo/data/runs.json new file mode 100644 index 00000000000..aea7de5f917 --- /dev/null +++ b/tensorflow/tensorboard/components/tf_text_dashboard/demo/data/runs.json @@ -0,0 +1 @@ +{"fry": ["message", "markdown"], "leela": ["message"]} \ No newline at end of file diff --git a/tensorflow/tensorboard/components/tf_text_dashboard/demo/data/text_run_fry_tag_markdown.json b/tensorflow/tensorboard/components/tf_text_dashboard/demo/data/text_run_fry_tag_markdown.json new file mode 100644 index 00000000000..94183ae13d1 --- /dev/null +++ b/tensorflow/tensorboard/components/tf_text_dashboard/demo/data/text_run_fry_tag_markdown.json @@ -0,0 +1,32 @@ +[ + { + "wall_time": 1489715207.593146, + "step": 0, + "text": "

Italics1 Italics2 bold1 bold2

" + }, + { + "wall_time": 1489715207.593801, + "step": 1, + "text": "
    \n
  1. List item one.
  2. \n
  3. List item two.
  4. \n
  5. Sublist
  6. \n
  7. Sublist2
  8. \n
  9. List continues.
  10. \n
" + }, + { + "wall_time": 1489715207.594842, + "step": 2, + "text": "\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n
AnExampleTable
ABC
123
" + }, + { + "wall_time": 1489715207.595761, + "step": 3, + "text": "

hello you

" + }, + { + "wall_time": 1489715207.595761, + "step": 4, + "text": "

TensorFlow

" + }, + { + "wall_time": 1489715207.595761, + "step": 530234352, + "text": "<script>alert('xss')</script>" + } +] diff --git a/tensorflow/tensorboard/components/tf_text_dashboard/demo/data/text_run_fry_tag_message.json b/tensorflow/tensorboard/components/tf_text_dashboard/demo/data/text_run_fry_tag_message.json new file mode 100644 index 00000000000..e8cc006c0d0 --- /dev/null +++ b/tensorflow/tensorboard/components/tf_text_dashboard/demo/data/text_run_fry_tag_message.json @@ -0,0 +1,22 @@ +[ + { + "wall_time": 1489715207.593146, + "step": 0, + "text": "fry loves garnet" + }, + { + "wall_time": 1489715207.593801, + "step": 1, + "text": "fry loves amethyst" + }, + { + "wall_time": 1489715207.594842, + "step": 2, + "text": "fry loves pearl" + }, + { + "wall_time": 1489715207.595761, + "step": 3, + "text": "fry loves steven" + } +] \ No newline at end of file diff --git a/tensorflow/tensorboard/components/tf_text_dashboard/demo/data/text_run_leela_tag_message.json b/tensorflow/tensorboard/components/tf_text_dashboard/demo/data/text_run_leela_tag_message.json new file mode 100644 index 00000000000..5a6d2598937 --- /dev/null +++ b/tensorflow/tensorboard/components/tf_text_dashboard/demo/data/text_run_leela_tag_message.json @@ -0,0 +1,22 @@ +[ + { + "step": 0, + "wall_time": 1489715207.607792, + "text": "leela loves garnet and feels strongly about various issues of the day including the two-cent titanium tax and whether nixon's head contributes to greenhouse gas emissions" + }, + { + "step": 1, + "wall_time": 1489715207.609011, + "text": "leela loves amethyst" + }, + { + "step": 2, + "wall_time": 1489715207.610028, + "text": "leela loves pearl" + }, + { + "step": 3, + "wall_time": 1489715207.611142, + "text": "leela loves someverylongwordwithoutanybreaksorspacessowecanseehowthatishandledbythefrontend" + } +] \ No newline at end of file diff --git a/tensorflow/tensorboard/components/tf_text_dashboard/demo/index.html b/tensorflow/tensorboard/components/tf_text_dashboard/demo/index.html new file mode 100644 index 00000000000..3ab6e857387 --- /dev/null +++ b/tensorflow/tensorboard/components/tf_text_dashboard/demo/index.html @@ -0,0 +1,68 @@ + + + + + + + + + text Dashboard Demo + + + + + + + + diff --git a/tensorflow/tensorboard/components/tf_text_dashboard/tf-text-dashboard.html b/tensorflow/tensorboard/components/tf_text_dashboard/tf-text-dashboard.html new file mode 100644 index 00000000000..d39c890a7ae --- /dev/null +++ b/tensorflow/tensorboard/components/tf_text_dashboard/tf-text-dashboard.html @@ -0,0 +1,106 @@ + + + + + + + + + + + + + + + + + diff --git a/tensorflow/tensorboard/components/tf_text_dashboard/tf-text-loader.html b/tensorflow/tensorboard/components/tf_text_dashboard/tf-text-loader.html new file mode 100644 index 00000000000..374e0478dd1 --- /dev/null +++ b/tensorflow/tensorboard/components/tf_text_dashboard/tf-text-loader.html @@ -0,0 +1,143 @@ + + + + + + + + + + + + + + + + diff --git a/tensorflow/tensorboard/components/vz_projector/vz-projector-data-panel.html b/tensorflow/tensorboard/components/vz_projector/vz-projector-data-panel.html index 34734aa4389..f7ef0593023 100644 --- a/tensorflow/tensorboard/components/vz_projector/vz-projector-data-panel.html +++ b/tensorflow/tensorboard/components/vz_projector/vz-projector-data-panel.html @@ -339,7 +339,7 @@ paper-dropdown-menu paper-item {

If you'd like to share your visualization with the world, follow these simple steps. - See this tutorial for more. + See this tutorial for more.

Step 1: Make data public

diff --git a/tensorflow/tensorboard/gulp_tasks/vulcanize.js b/tensorflow/tensorboard/gulp_tasks/vulcanize.js index b8cdd80af02..89700e1d4cc 100644 --- a/tensorflow/tensorboard/gulp_tasks/vulcanize.js +++ b/tensorflow/tensorboard/gulp_tasks/vulcanize.js @@ -53,6 +53,16 @@ var nonTBComponents = util.getComponents(function(name) { return prefix !== 'tf_' && prefix !== 'vz_'; }); +// These manual additions are necessary. The task should not inline these +// third-party javascript files. However, vulcanization still needs the HTML +// files found within those directories. Upon adding new third-party javascript, +// consider updating this list. +nonTBComponents.push('/tf-imports/d3.js'); +nonTBComponents.push('/tf-imports/dagre.js'); +nonTBComponents.push('/tf-imports/graphlib.js'); +nonTBComponents.push('/tf-imports/lodash.js'); +nonTBComponents.push('/tf-imports/plottable.js'); + module.exports = function(overwrite) { return function() { var suffix = overwrite ? '' : '.OPENSOURCE'; diff --git a/tensorflow/tensorboard/plugins/projector/BUILD b/tensorflow/tensorboard/plugins/projector/BUILD index 70e83fd844e..7c0ab64fb8d 100644 --- a/tensorflow/tensorboard/plugins/projector/BUILD +++ b/tensorflow/tensorboard/plugins/projector/BUILD @@ -1,5 +1,4 @@ -# Description: -# TensorBoard plugin for the Embedding Projector +# Embedding Projector plugin. package(default_visibility = ["//tensorflow:internal"]) @@ -9,7 +8,6 @@ exports_files(["LICENSE"]) load("//tensorflow:tensorflow.bzl", "py_test") -## Embedding Projector Plugin ## py_library( name = "projector_plugin", srcs = ["projector_plugin.py"], @@ -19,9 +17,9 @@ py_library( "//tensorflow:internal", ], deps = [ - "//tensorflow/contrib/tensorboard:projector", "//tensorflow/contrib/tensorboard:protos_all_py", "//tensorflow/python:errors", + "//tensorflow/python:image_ops", "//tensorflow/python:lib", "//tensorflow/python:platform", "//tensorflow/python:training", @@ -45,6 +43,7 @@ py_test( "//tensorflow/python:client_testlib", "//tensorflow/python:init_ops", "//tensorflow/python:platform", + "//tensorflow/python:summary", "//tensorflow/python:training", "//tensorflow/python:variable_scope", "//tensorflow/python:variables", diff --git a/tensorflow/tensorboard/plugins/projector/projector_plugin.py b/tensorflow/tensorboard/plugins/projector/projector_plugin.py index 6ecd5972fc3..70450bf2a96 100644 --- a/tensorflow/tensorboard/plugins/projector/projector_plugin.py +++ b/tensorflow/tensorboard/plugins/projector/projector_plugin.py @@ -19,18 +19,23 @@ from __future__ import division from __future__ import print_function import imghdr +import math import os import numpy as np -from werkzeug import wrappers +from six import BytesIO +from werkzeug import wrappers from google.protobuf import json_format from google.protobuf import text_format -from tensorflow.contrib.tensorboard.plugins.projector import PROJECTOR_FILENAME -from tensorflow.contrib.tensorboard.plugins.projector.projector_config_pb2 import ProjectorConfig +from tensorflow.contrib.tensorboard.plugins.projector import projector_config_pb2 +from tensorflow.python.client import session from tensorflow.python.framework import errors +from tensorflow.python.framework import ops from tensorflow.python.lib.io import file_io +from tensorflow.python.ops import image_ops from tensorflow.python.platform import tf_logging as logging from tensorflow.python.pywrap_tensorflow import NewCheckpointReader +from tensorflow.python.summary import plugin_asset from tensorflow.python.training.saver import checkpoint_exists from tensorflow.python.training.saver import latest_checkpoint from tensorflow.tensorboard.backend.http_util import Respond @@ -39,6 +44,8 @@ from tensorflow.tensorboard.plugins.base_plugin import TBPlugin # The prefix of routes provided by this plugin. PLUGIN_PREFIX_ROUTE = 'projector' +PROJECTOR_FILENAME = 'projector_config.pbtxt' + # HTTP routes. CONFIG_ROUTE = '/info' TENSOR_ROUTE = '/tensor' @@ -56,6 +63,193 @@ _IMGHDR_TO_MIMETYPE = { _DEFAULT_IMAGE_MIMETYPE = 'application/octet-stream' +class EmbeddingMetadata(object): + """Metadata container for an embedding. + + The metadata holds different columns with values used for visualization + (color by, label by) in the "Embeddings" tab in TensorBoard. + """ + + def __init__(self, num_points): + """Constructs a metadata for an embedding of the specified size. + + Args: + num_points: Number of points in the embedding. + """ + self.num_points = num_points + self.column_names = [] + self.name_to_values = {} + + def add_column(self, column_name, column_values): + """Adds a named column of metadata values. + + Args: + column_name: Name of the column. + column_values: 1D array/list/iterable holding the column values. Must be + of length `num_points`. The i-th value corresponds to the i-th point. + + Raises: + ValueError: If `column_values` is not 1D array, or of length `num_points`, + or the `name` is already used. + """ + # Sanity checks. + if isinstance(column_values, list) and isinstance(column_values[0], list): + raise ValueError('"column_values" must be a flat list, but we detected ' + 'that its first entry is a list') + + if isinstance(column_values, np.ndarray) and column_values.ndim != 1: + raise ValueError('"column_values" should be of rank 1, ' + 'but is of rank %d' % column_values.ndim) + if len(column_values) != self.num_points: + raise ValueError('"column_values" should be of length %d, but is of ' + 'length %d' % (self.num_points, len(column_values))) + if column_name in self.name_to_values: + raise ValueError('The column name "%s" is already used' % column_name) + + self.column_names.append(column_name) + self.name_to_values[column_name] = column_values + + +class ProjectorPluginAsset(plugin_asset.PluginAsset): + """Provides a registry for assets needed by the Projector plugin.""" + plugin_name = 'org_tensorflow_tensorboard_projector' + + def __init__(self): + self._config = projector_config_pb2.ProjectorConfig() + self._assets = {} + self._used_names = set() + + def add_metadata_for_embedding_variable(self, + var_name, + metadata=None, + thumbnails=None, + thumbnail_dim=None): + """Adds metadata for an embedding variable stored in a checkpoint file. + + Args: + var_name: Name of the embedding variable. + metadata: Optional. A `Metadata` container mapping column header names to + the values of that column. + thumbnails: Optional. A 4D `ndarray` or a list of 3D `ndarray`s. Each + 3D array represents the pixels [height, width, channels] of a single + thumbnail. The i-th image corresponds to the i-th row (data point) of + the embedding variable. + thumbnail_dim: Required if `thumbnails` is provided. A tuple + (height, width) of a single thumbnail in the sprite. + + Raises: + ValueError: If the name of the variable was previously used in this + object, or both `metadata` and `thumbnails` are None. + """ + + if metadata is None and thumbnails is None: + raise ValueError('At least one of (`metadata`, `thumbnails`) must be ' + 'provided') + self._convert_embedding_to_assets(var_name, None, metadata, thumbnails, + thumbnail_dim) + + def add_embedding(self, + name, + values, + metadata=None, + thumbnails=None, + thumbnail_dim=None): + """Adds an embedding asset to be visualized by the Embedding Projector. + + Args: + name: Name of the embedding. + values: 2D `ndarray` of shape [numPoints, dimensionality] + containing the embedding values. The i-th row corresponds to the i-th + data point. + metadata: Optional. A `Metadata` container mapping column header names to + the values of that column. + thumbnails: Optional. A 4D `ndarray` or a list of 3D `ndarray`s. Each + 3D array represents the pixels [height, width, channels] of a single + thumbnail. The i-th image corresponds to the i-th row (data point) of + the `values` matrix. + thumbnail_dim: Required if `thumbnails` is provided. A tuple + (height, width) of a single thumbnail in the sprite. + + Raises: + ValueError: If the name of the embedding was previously used in this + object, or `values` is not a 2D array. + """ + + # Sanity checks. + if values.ndim != 2: + raise ValueError('`values` must be a 2D array, but is ' + '%d-D' % values.ndim) + self._convert_embedding_to_assets(name, values, metadata, thumbnails, + thumbnail_dim) + + def _convert_embedding_to_assets(self, + name, + values=None, + metadata=None, + thumbnails=None, + thumbnail_dim=None): + """Converts the data associated with embeddings into serializable assets.""" + + if name in self._used_names: + raise ValueError('The name "%s" was previously used' % name) + if thumbnails is not None and not thumbnail_dim: + raise ValueError('`thumbnail_dim` is required when `thumbnails` is ' + 'provided') + if thumbnail_dim is not None: + if not isinstance(thumbnail_dim, (list, tuple, np.ndarray)): + raise ValueError('`thumbnail_dim` must be either a list, tuple or ' + '`ndarray`') + if len(thumbnail_dim) != 2: + raise ValueError('`thumbnail_dim` must be of length 2, ' + 'but is of length %d' % len(thumbnail_dim)) + if metadata: + if values is not None and len(values) != metadata.num_points: + raise ValueError('First dimension of `values` "%d" must match ' + '`metadata.num_points` "%d"' % (len(values), + metadata.num_points)) + if not metadata.column_names: + raise ValueError('The provided metadata has no columns. Did you forget ' + 'to add a column?') + + self._used_names.add(name) + embedding_info = self._config.embeddings.add() + embedding_info.tensor_name = name + + if values is not None: + bytes_io = BytesIO() + np.savetxt(bytes_io, values, fmt='%.6g', delimiter='\t') + fname = '{}_values.tsv'.format(name) + embedding_info.tensor_path = fname + embedding_info.tensor_shape.extend(values.shape) + self._assets[fname] = bytes_io.getvalue() + + if metadata: + metadata_tsv_lines = [] + should_have_header = len(metadata.column_names) > 1 + if should_have_header: + metadata_tsv_lines.append('\t'.join(metadata.column_names)) + + for i in range(metadata.num_points): + row = [ + metadata.name_to_values[col_name][i] + for col_name in metadata.column_names + ] + metadata_tsv_lines.append('\t'.join(map(str, row))) + fname = '{}_metadata.tsv'.format(name) + embedding_info.metadata_path = fname + self._assets[fname] = '\n'.join(metadata_tsv_lines) + '\n' + + if thumbnails is not None: + fname = '{}_sprite.png'.format(name) + embedding_info.sprite.image_path = fname + embedding_info.sprite.single_image_dim.extend(thumbnail_dim) + self._assets[fname] = _make_sprite_image(thumbnails, thumbnail_dim) + + def assets(self): + self._assets[PROJECTOR_FILENAME] = text_format.MessageToString(self._config) + return self._assets + + def _read_tensor_file(fpath): with file_io.FileIO(fpath, 'r') as f: tensor = [] @@ -69,7 +263,7 @@ def _latest_checkpoints_changed(configs, run_path_pairs): """Returns true if the latest checkpoint has changed in any of the runs.""" for run_name, logdir in run_path_pairs: if run_name not in configs: - config = ProjectorConfig() + config = projector_config_pb2.ProjectorConfig() config_fpath = os.path.join(logdir, PROJECTOR_FILENAME) if file_io.file_exists(config_fpath): file_content = file_io.read_file_to_string(config_fpath) @@ -202,7 +396,7 @@ class ProjectorPlugin(TBPlugin): configs = {} config_fpaths = {} for run_name, logdir in run_path_pairs: - config = ProjectorConfig() + config = projector_config_pb2.ProjectorConfig() config_fpath = os.path.join(logdir, PROJECTOR_FILENAME) if file_io.file_exists(config_fpath): file_content = file_io.read_file_to_string(config_fpath) @@ -461,3 +655,39 @@ def _find_latest_checkpoint(dir_path): return ckpt_path except errors.NotFoundError: return None + + +def _make_sprite_image(thumbnails, thumbnail_dim): + """Constructs a sprite image from thumbnails and returns the png bytes.""" + if len(thumbnails) < 1: + raise ValueError('The length of "thumbnails" must be >= 1') + + if isinstance(thumbnails, np.ndarray) and thumbnails.ndim != 4: + raise ValueError('"thumbnails" should be of rank 4, ' + 'but is of rank %d' % thumbnails.ndim) + if isinstance(thumbnails, list): + if not isinstance(thumbnails[0], np.ndarray) or thumbnails[0].ndim != 3: + raise ValueError('Each element of "thumbnails" must be a 3D `ndarray`') + thumbnails = np.array(thumbnails) + + with ops.Graph().as_default(): + s = session.Session() + resized_images = image_ops.resize_images(thumbnails, thumbnail_dim).eval( + session=s) + images_per_row = int(math.ceil(math.sqrt(len(thumbnails)))) + thumb_height = thumbnail_dim[0] + thumb_width = thumbnail_dim[1] + master_height = images_per_row * thumb_height + master_width = images_per_row * thumb_width + num_channels = thumbnails.shape[3] + master = np.zeros([master_height, master_width, num_channels]) + for idx, image in enumerate(resized_images): + left_idx = idx % images_per_row + top_idx = int(math.floor(idx / images_per_row)) + left_start = left_idx * thumb_width + left_end = left_start + thumb_width + top_start = top_idx * thumb_height + top_end = top_start + thumb_height + master[top_start:top_end, left_start:left_end, :] = image + + return image_ops.encode_png(master).eval(session=s) diff --git a/tensorflow/tensorboard/plugins/projector/projector_plugin_test.py b/tensorflow/tensorboard/plugins/projector/projector_plugin_test.py index 790b3468a36..069e8be84ec 100644 --- a/tensorflow/tensorboard/plugins/projector/projector_plugin_test.py +++ b/tensorflow/tensorboard/plugins/projector/projector_plugin_test.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- # Copyright 2017 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -26,17 +27,19 @@ import numpy as np from werkzeug import test as werkzeug_test from werkzeug import wrappers - from google.protobuf import text_format -from tensorflow.contrib.tensorboard.plugins.projector.projector_config_pb2 import ProjectorConfig +from tensorflow.contrib.tensorboard.plugins.projector import projector_config_pb2 from tensorflow.core.protobuf import saver_pb2 from tensorflow.python.client import session from tensorflow.python.framework import ops +from tensorflow.python.ops import image_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.platform import gfile from tensorflow.python.platform import test +from tensorflow.python.summary import plugin_asset +from tensorflow.python.summary.writer import writer from tensorflow.python.training import saver as saver_lib from tensorflow.tensorboard.backend import application from tensorflow.tensorboard.backend.event_processing import event_multiplexer @@ -100,9 +103,9 @@ class ProjectorAppTest(test.TestCase): multiplexer = event_multiplexer.EventMultiplexer( size_guidance=application.DEFAULT_SIZE_GUIDANCE, purge_orphaned_data=True) - projector = projector_plugin.ProjectorPlugin() - projector.get_plugin_apps(multiplexer, self.log_dir) - plugins = {'projector': projector} + plugin = projector_plugin.ProjectorPlugin() + plugin.get_plugin_apps(multiplexer, self.log_dir) + plugins = {'projector': plugin} wsgi_app = application.TensorBoardWSGIApp( self.log_dir, plugins, multiplexer, reload_interval=0) self.server = werkzeug_test.Client(wsgi_app, wrappers.BaseResponse) @@ -119,7 +122,7 @@ class ProjectorAppTest(test.TestCase): def _GenerateProjectorTestData(self): config_path = os.path.join(self.log_dir, 'projector_config.pbtxt') - config = ProjectorConfig() + config = projector_config_pb2.ProjectorConfig() embedding = config.embeddings.add() # Add an embedding by its canonical tensor name. embedding.tensor_name = 'var1:0' @@ -140,5 +143,276 @@ class ProjectorAppTest(test.TestCase): saver.save(sess, checkpoint_path) +class MetadataColumnsTest(test.TestCase): + + def testLengthDoesNotMatch(self): + metadata = projector_plugin.EmbeddingMetadata(10) + + with self.assertRaises(ValueError): + metadata.add_column('Labels', [''] * 11) + + def testValuesNot1D(self): + metadata = projector_plugin.EmbeddingMetadata(3) + values = np.array([[1, 2, 3]]) + + with self.assertRaises(ValueError): + metadata.add_column('Labels', values) + + def testMultipleColumnsRetrieval(self): + metadata = projector_plugin.EmbeddingMetadata(3) + metadata.add_column('Sizes', [1, 2, 3]) + metadata.add_column('Labels', ['a', 'b', 'c']) + self.assertEqual(metadata.column_names, ['Sizes', 'Labels']) + self.assertEqual(metadata.name_to_values['Labels'], ['a', 'b', 'c']) + self.assertEqual(metadata.name_to_values['Sizes'], [1, 2, 3]) + + def testValuesAreListofLists(self): + metadata = projector_plugin.EmbeddingMetadata(3) + values = [[1, 2, 3], [4, 5, 6]] + with self.assertRaises(ValueError): + metadata.add_column('Labels', values) + + def testStringListRetrieval(self): + metadata = projector_plugin.EmbeddingMetadata(3) + metadata.add_column('Labels', ['a', 'B', 'c']) + self.assertEqual(metadata.name_to_values['Labels'], ['a', 'B', 'c']) + self.assertEqual(metadata.column_names, ['Labels']) + + def testNumericListRetrieval(self): + metadata = projector_plugin.EmbeddingMetadata(3) + metadata.add_column('Labels', [1, 2, 3]) + self.assertEqual(metadata.name_to_values['Labels'], [1, 2, 3]) + + def testNumericNdArrayRetrieval(self): + metadata = projector_plugin.EmbeddingMetadata(3) + metadata.add_column('Labels', np.array([1, 2, 3])) + self.assertEqual(metadata.name_to_values['Labels'].tolist(), [1, 2, 3]) + + def testStringNdArrayRetrieval(self): + metadata = projector_plugin.EmbeddingMetadata(2) + metadata.add_column('Labels', np.array(['a', 'b'])) + self.assertEqual(metadata.name_to_values['Labels'].tolist(), ['a', 'b']) + + def testDuplicateColumnName(self): + metadata = projector_plugin.EmbeddingMetadata(2) + metadata.add_column('Labels', np.array(['a', 'b'])) + with self.assertRaises(ValueError): + metadata.add_column('Labels', np.array(['a', 'b'])) + + +class ProjectorPluginAssetTest(test.TestCase): + + def testNoAssets(self): + manager = plugin_asset.get_plugin_asset( + projector_plugin.ProjectorPluginAsset) + self.assertEqual(manager.assets(), {'projector_config.pbtxt': ''}) + + def testAddEmbeddingNoMetadata(self): + manager = plugin_asset.get_plugin_asset( + projector_plugin.ProjectorPluginAsset) + manager.add_embedding('test', np.array([[1, 2, 3.1]])) + + config = projector_config_pb2.ProjectorConfig() + embedding = config.embeddings.add() + embedding.tensor_name = 'test' + embedding.tensor_shape.extend([1, 3]) + embedding.tensor_path = 'test_values.tsv' + expected_config_pbtxt = text_format.MessageToString(config) + + self.assertEqual(manager.assets(), { + 'projector_config.pbtxt': expected_config_pbtxt, + 'test_values.tsv': b'1\t2\t3.1\n' + }) + + def testAddEmbeddingIncorrectRank(self): + manager = plugin_asset.get_plugin_asset( + projector_plugin.ProjectorPluginAsset) + with self.assertRaises(ValueError): + manager.add_embedding('test', np.array([1, 2, 3.1])) + + def testAddEmbeddingWithTwoMetadataColumns(self): + manager = plugin_asset.get_plugin_asset( + projector_plugin.ProjectorPluginAsset) + + metadata = projector_plugin.EmbeddingMetadata(3) + metadata.add_column('labels', ['a', 'b', 'друг јазик']) + metadata.add_column('sizes', [10, 20, 30]) + manager.add_embedding('test', np.array([[1], [2], [3]]), metadata) + + config = projector_config_pb2.ProjectorConfig() + embedding = config.embeddings.add() + embedding.tensor_name = 'test' + embedding.tensor_shape.extend([3, 1]) + embedding.tensor_path = 'test_values.tsv' + embedding.metadata_path = 'test_metadata.tsv' + expected_config_pbtxt = text_format.MessageToString(config) + + self.assertEqual(manager.assets(), { + 'projector_config.pbtxt': expected_config_pbtxt, + 'test_values.tsv': b'1\n2\n3\n', + 'test_metadata.tsv': 'labels\tsizes\na\t10\nb\t20\nдруг јазик\t30\n' + }) + + def testAddEmbeddingWithOneMetadataColumn(self): + manager = plugin_asset.get_plugin_asset( + projector_plugin.ProjectorPluginAsset) + + metadata = projector_plugin.EmbeddingMetadata(3) + metadata.add_column('labels', ['a', 'b', 'c']) + manager.add_embedding('test', np.array([[1], [2], [3]]), metadata) + + config = projector_config_pb2.ProjectorConfig() + embedding = config.embeddings.add() + embedding.tensor_name = 'test' + embedding.tensor_shape.extend([3, 1]) + embedding.tensor_path = 'test_values.tsv' + embedding.metadata_path = 'test_metadata.tsv' + expected_config_pbtxt = text_format.MessageToString(config) + + self.assertEqual(manager.assets(), { + 'projector_config.pbtxt': expected_config_pbtxt, + 'test_values.tsv': b'1\n2\n3\n', + 'test_metadata.tsv': 'a\nb\nc\n' + }) + + def testAddEmbeddingWithThumbnails(self): + manager = plugin_asset.get_plugin_asset( + projector_plugin.ProjectorPluginAsset) + + image1 = np.array([[[1, 2, 3], [4, 5, 6]], + [[7, 8, 9], [10, 11, 12]]]) + image2 = np.array([[[10, 20, 30], [40, 50, 60]], + [[70, 80, 90], [100, 110, 120]]]) + manager.add_embedding( + 'test', + np.array([[1], [2], [3]]), + thumbnails=[image1, image2], + thumbnail_dim=[2, 2]) + + assets = manager.assets() + + config = projector_config_pb2.ProjectorConfig() + embedding = config.embeddings.add() + embedding.tensor_name = 'test' + embedding.tensor_shape.extend([3, 1]) + embedding.tensor_path = 'test_values.tsv' + embedding.sprite.image_path = 'test_sprite.png' + embedding.sprite.single_image_dim.extend([2, 2]) + expected_config_pbtxt = text_format.MessageToString(config) + + self.assertEqual(assets['projector_config.pbtxt'], expected_config_pbtxt) + self.assertEqual(assets['test_values.tsv'], b'1\n2\n3\n') + + png_bytes = assets['test_sprite.png'] + with ops.Graph().as_default(): + s = session.Session() + image_array = image_ops.decode_png(png_bytes).eval(session=s).tolist() + expected_master_image = [ + [[1, 2, 3], [4, 5, 6], [10, 20, 30], [40, 50, 60]], + [[7, 8, 9], [10, 11, 12], [70, 80, 90], [100, 110, 120]], + [[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]], + [[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]] + ] + self.assertEqual(image_array, expected_master_image) + + def testAddEmbeddingWithSpriteImageButNoThumbnailDim(self): + manager = plugin_asset.get_plugin_asset( + projector_plugin.ProjectorPluginAsset) + + thumbnails = np.array([[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]) + with self.assertRaises(ValueError): + manager.add_embedding( + 'test', np.array([[1], [2], [3]]), thumbnails=thumbnails) + + def testAddEmbeddingThumbnailDimNotAList(self): + manager = plugin_asset.get_plugin_asset( + projector_plugin.ProjectorPluginAsset) + + thumbnails = np.array([[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]) + with self.assertRaises(ValueError): + manager.add_embedding( + 'test', np.array([[1], [2], [3]]), thumbnails=thumbnails, + thumbnail_dim=4) + + def testAddEmbeddingThumbnailDimNotOfLength2(self): + manager = plugin_asset.get_plugin_asset( + projector_plugin.ProjectorPluginAsset) + + thumbnails = np.array([[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]) + with self.assertRaises(ValueError): + manager.add_embedding( + 'test', np.array([[1], [2], [3]]), thumbnails=thumbnails, + thumbnail_dim=[4]) + + def testAddEmbeddingWithMetadataOfIncorrectLength(self): + manager = plugin_asset.get_plugin_asset( + projector_plugin.ProjectorPluginAsset) + + metadata = projector_plugin.EmbeddingMetadata(3) + metadata.add_column('labels', ['a', 'b', 'c']) + # values has length 2, while metadata has length 3. + values = np.array([[1], [2]]) + + with self.assertRaises(ValueError): + manager.add_embedding('test', values, metadata) + + def testAddMetadataForVariableButNoColumns(self): + manager = plugin_asset.get_plugin_asset( + projector_plugin.ProjectorPluginAsset) + metadata = projector_plugin.EmbeddingMetadata(3) + with self.assertRaises(ValueError): + manager.add_metadata_for_embedding_variable('test', metadata) + + def testAddMetadataForVariable(self): + manager = plugin_asset.get_plugin_asset( + projector_plugin.ProjectorPluginAsset) + metadata = projector_plugin.EmbeddingMetadata(3) + metadata.add_column('Labels', ['a', 'b', 'c']) + manager.add_metadata_for_embedding_variable('test', metadata) + + config = projector_config_pb2.ProjectorConfig() + embedding = config.embeddings.add() + embedding.tensor_name = 'test' + embedding.metadata_path = 'test_metadata.tsv' + expected_config_pbtxt = text_format.MessageToString(config) + + self.assertEqual(manager.assets(), { + 'projector_config.pbtxt': expected_config_pbtxt, + 'test_metadata.tsv': 'a\nb\nc\n' + }) + + def testAddMetadataForVariableAtLeastOneParamIsRequired(self): + manager = plugin_asset.get_plugin_asset( + projector_plugin.ProjectorPluginAsset) + with self.assertRaises(ValueError): + manager.add_metadata_for_embedding_variable('test') + + def testNoAssetsProperSerializationOnDisk(self): + logdir = self.get_temp_dir() + plugin_dir = os.path.join(logdir, writer._PLUGINS_DIR, + projector_plugin.ProjectorPluginAsset.plugin_name) + + with ops.Graph().as_default() as g: + plugin_asset.get_plugin_asset(projector_plugin.ProjectorPluginAsset) + fw = writer.FileWriter(logdir) + fw.add_graph(g) + + with gfile.Open(os.path.join(plugin_dir, 'projector_config.pbtxt')) as f: + content = f.read() + self.assertEqual(content, '') + + def testNoReferenceToPluginNoSerializationOnDisk(self): + logdir = self.get_temp_dir() + plugin_dir = os.path.join(logdir, writer._PLUGINS_DIR, + projector_plugin.ProjectorPluginAsset.plugin_name) + + with ops.Graph().as_default() as g: + fw = writer.FileWriter(logdir) + fw.add_graph(g) + + self.assertFalse( + gfile.Exists(plugin_dir), + 'The projector plugin directory should not exist.') + if __name__ == '__main__': test.main() diff --git a/tensorflow/tensorboard/plugins/text/BUILD b/tensorflow/tensorboard/plugins/text/BUILD index 1f027511170..3e1455ea5a4 100644 --- a/tensorflow/tensorboard/plugins/text/BUILD +++ b/tensorflow/tensorboard/plugins/text/BUILD @@ -21,7 +21,9 @@ py_library( "//tensorflow/python:summary", "//tensorflow/tensorboard/backend:http_util", "//tensorflow/tensorboard/plugins:base_plugin", + "@org_mozilla_bleach", "@org_pocoo_werkzeug//:werkzeug", + "@org_pythonhosted_markdown", ], ) diff --git a/tensorflow/tensorboard/plugins/text/text_plugin.py b/tensorflow/tensorboard/plugins/text/text_plugin.py index fe786e15edf..a87949877bf 100644 --- a/tensorflow/tensorboard/plugins/text/text_plugin.py +++ b/tensorflow/tensorboard/plugins/text/text_plugin.py @@ -19,6 +19,12 @@ from __future__ import division from __future__ import print_function import json + +import bleach +# pylint: disable=g-bad-import-order +# Google-only: import markdown_freewisdom +import markdown +# pylint: enable=g-bad-import-order from werkzeug import wrappers from tensorflow.python.summary import text_summary @@ -29,15 +35,71 @@ from tensorflow.tensorboard.plugins import base_plugin PLUGIN_PREFIX_ROUTE = 'text' # HTTP routes -RUNS_ROUTE = '/index' +RUNS_ROUTE = '/runs' TEXT_ROUTE = '/text' +ALLOWED_TAGS = [ + 'ul', + 'ol', + 'li', + 'p', + 'pre', + 'code', + 'blockquote', + 'h1', + 'h2', + 'h3', + 'h4', + 'h5', + 'h6', + 'hr', + 'br', + 'strong', + 'em', + 'a', + 'img', + 'table', + 'thead', + 'tbody', + 'td', + 'tr', + 'th', +] + +ALLOWED_ATTRIBUTES = {'a': ['href', 'title'], 'img': ['src', 'title', 'alt']} + + +def markdown_and_sanitize(markdown_string): + """Takes a markdown string and converts it into sanitized html. + + It uses the table extension; while that's not a part of standard + markdown, it is sure to be useful for TensorBoard users. + + The sanitizer uses the allowed_tags and attributes specified above. Mostly, + we ensure that our standard use cases like tables and links are supported. + + Args: + markdown_string: Markdown string to sanitize + + Returns: + a string containing sanitized html for input markdown + """ + # Convert to utf-8 because we get a bytearray in python3 + if not isinstance(markdown_string, str): + markdown_string = markdown_string.decode('utf-8') + string_html = markdown.markdown( + markdown_string, extensions=['markdown.extensions.tables']) + string_sanitized = bleach.clean( + string_html, tags=ALLOWED_TAGS, attributes=ALLOWED_ATTRIBUTES) + return string_sanitized + def process_string_tensor_event(event): + """Convert a TensorEvent into a JSON-compatible response.""" return { 'wall_time': event.wall_time, 'step': event.step, - 'text': event.tensor_proto.string_val[0], + 'text': markdown_and_sanitize(event.tensor_proto.string_val[0]), } diff --git a/tensorflow/tensorboard/plugins/text/text_plugin_test.py b/tensorflow/tensorboard/plugins/text/text_plugin_test.py index aa8ce03727f..f1d06aa4f17 100644 --- a/tensorflow/tensorboard/plugins/text/text_plugin_test.py +++ b/tensorflow/tensorboard/plugins/text/text_plugin_test.py @@ -19,8 +19,7 @@ from __future__ import division from __future__ import print_function import os - -import six +import textwrap from tensorflow.python.client import session from tensorflow.python.framework import dtypes @@ -46,6 +45,10 @@ class TextPluginTest(test.TestCase): self.plugin = text_plugin.TextPlugin() self.apps = self.plugin.get_plugin_apps(multiplexer, None) + def assertConverted(self, actual, expected): + expected_html = text_plugin.markdown_and_sanitize(expected) + self.assertEqual(actual, expected_html) + def generate_testdata(self): ops.reset_default_graph() sess = session.Session() @@ -60,7 +63,7 @@ class TextPluginTest(test.TestCase): step = 0 for gem in GEMS: - message = run_name + " loves " + gem + message = run_name + " *loves* " + gem feed_dict = {placeholder: message} summ = sess.run(summary_tensor, feed_dict=feed_dict) writer.add_summary(summ, global_step=step) @@ -81,9 +84,85 @@ class TextPluginTest(test.TestCase): self.assertEqual(len(leela), 4) for i in range(4): self.assertEqual(fry[i]["step"], i) - self.assertEqual(fry[i]["text"], six.b("fry loves " + GEMS[i])) + self.assertConverted(fry[i]["text"], "fry *loves* " + GEMS[i]) self.assertEqual(leela[i]["step"], i) - self.assertEqual(leela[i]["text"], six.b("leela loves " + GEMS[i])) + self.assertConverted(leela[i]["text"], "leela *loves* " + GEMS[i]) + + def assertTextConverted(self, actual, expected): + self.assertEqual(text_plugin.markdown_and_sanitize(actual), expected) + + def testMarkdownConversion(self): + emphasis = "*Italics1* _Italics2_ **bold1** __bold2__" + emphasis_converted = ("

Italics1 Italics2 " + "bold1 bold2

") + + self.assertEqual( + text_plugin.markdown_and_sanitize(emphasis), emphasis_converted) + + md_list = textwrap.dedent("""\ + 1. List item one. + 2. List item two. + * Sublist + * Sublist2 + 1. List continues. + """) + md_list_converted = textwrap.dedent("""\ +
    +
  1. List item one.
  2. +
  3. List item two.
  4. +
  5. Sublist
  6. +
  7. Sublist2
  8. +
  9. List continues.
  10. +
""") + self.assertEqual( + text_plugin.markdown_and_sanitize(md_list), md_list_converted) + + link = "[TensorFlow](http://tensorflow.org)" + link_converted = '

TensorFlow

' + self.assertEqual(text_plugin.markdown_and_sanitize(link), link_converted) + + table = textwrap.dedent("""\ + An | Example | Table + --- | --- | --- + A | B | C + 1 | 2 | 3 + """) + + table_converted = textwrap.dedent("""\ + + + + + + + + + + + + + + + + + + + + +
AnExampleTable
ABC
123
""") + + self.assertEqual(text_plugin.markdown_and_sanitize(table), table_converted) + + def testSanitization(self): + dangerous = "" + sanitized = "<script>alert('xss')</script>" + self.assertEqual(text_plugin.markdown_and_sanitize(dangerous), sanitized) + + dangerous = textwrap.dedent("""\ + hello *you*""") + sanitized = "

hello you

" + self.assertEqual(text_plugin.markdown_and_sanitize(dangerous), sanitized) if __name__ == "__main__": diff --git a/tensorflow/tools/ci_build/install/install_pip_packages.sh b/tensorflow/tools/ci_build/install/install_pip_packages.sh index 8011f8de243..5ebd69bd3da 100755 --- a/tensorflow/tools/ci_build/install/install_pip_packages.sh +++ b/tensorflow/tools/ci_build/install/install_pip_packages.sh @@ -35,6 +35,14 @@ pip3 install --upgrade six==1.10.0 pip2 install --upgrade werkzeug==0.11.10 pip3 install --upgrade werkzeug==0.11.10 +# Install bleach. html5lib will be picked up as a dependency. +pip2 install --upgrade bleach==1.5.0 +pip3 install --upgrade bleach==1.5.0 + +# Install markdown. +pip2 install --upgrade markdown==2.6.8 +pip3 install --upgrade markdown==2.6.8 + # Install protobuf. pip2 install --upgrade protobuf==3.2.0 pip3 install --upgrade protobuf==3.2.0 diff --git a/tensorflow/tools/docs/generate_lib.py b/tensorflow/tools/docs/generate_lib.py index b173f869ddb..640ddb6df83 100644 --- a/tensorflow/tools/docs/generate_lib.py +++ b/tensorflow/tools/docs/generate_lib.py @@ -49,11 +49,11 @@ def _is_free_function(py_object, full_name, index): return True -def write_docs(output_dir, parser_config, duplicate_of, index, yaml_toc): +def write_docs(output_dir, parser_config, yaml_toc): """Write previously extracted docs to disk. - Write a docs page for each symbol in `index` to a tree of docs at - `output_dir`. + Write a docs page for each symbol included in the indices of parser_config to + a tree of docs at `output_dir`. Symbols with multiple aliases will have only one page written about them, which is referenced for all aliases. @@ -61,11 +61,8 @@ def write_docs(output_dir, parser_config, duplicate_of, index, yaml_toc): Args: output_dir: Directory to write documentation markdown files to. Will be created if it doesn't exist. - parser_config: A `parser.ParserConfig` object. - duplicate_of: A `dict` mapping fully qualified names to "master" names. - Used to determine which docs pages to write. - index: A `dict` mapping fully qualified names to the corresponding Python - objects. Used to produce docs for child objects. + parser_config: A `parser.ParserConfig` object, containing all the necessary + indices. yaml_toc: Set to `True` to generate a "_toc.yaml" file. """ # Make output_dir. @@ -84,15 +81,14 @@ def write_docs(output_dir, parser_config, duplicate_of, index, yaml_toc): symbol_to_file = {} # Parse and write Markdown pages, resolving cross-links (@{symbol}). - for full_name, py_object in six.iteritems(index): + for full_name, py_object in six.iteritems(parser_config.index): - if full_name in duplicate_of: + if full_name in parser_config.duplicate_of: continue # Methods and some routines are documented only as part of their class. - if not (inspect.ismodule(py_object) or - inspect.isclass(py_object) or - _is_free_function(py_object, full_name, index)): + if not (inspect.ismodule(py_object) or inspect.isclass(py_object) or + _is_free_function(py_object, full_name, parser_config.index)): continue sitepath = os.path.join('api_docs/python', @@ -113,7 +109,7 @@ def write_docs(output_dir, parser_config, duplicate_of, index, yaml_toc): subname = str(full_name) while True: subname = subname[:subname.rindex('.')] - if inspect.ismodule(index[subname]): + if inspect.ismodule(parser_config.index[subname]): module_children.setdefault(subname, []).append(full_name) break @@ -162,8 +158,8 @@ def write_docs(output_dir, parser_config, duplicate_of, index, yaml_toc): # Write a global index containing all full names with links. with open(os.path.join(output_dir, 'index.md'), 'w') as f: - f.write(parser.generate_global_index( - 'TensorFlow', index, parser_config.reference_resolver)) + f.write(parser.generate_global_index('TensorFlow', parser_config.index, + parser_config.reference_resolver)) def add_dict_to_dict(add_from, add_to): @@ -463,17 +459,17 @@ class DocGenerator(object): return [name for (name, _) in self._py_modules] def make_reference_resolver(self, visitor, doc_index): - return parser.ReferenceResolver( - duplicate_of=visitor.duplicate_of, - doc_index=doc_index, index=visitor.index, - py_module_names=self.py_module_names()) + return parser.ReferenceResolver.from_visitor( + visitor, doc_index, py_module_names=self.py_module_names()) def make_parser_config(self, visitor, reference_resolver, guide_index, base_dir): return parser.ParserConfig( reference_resolver=reference_resolver, duplicates=visitor.duplicates, + duplicate_of=visitor.duplicate_of, tree=visitor.tree, + index=visitor.index, reverse_index=visitor.reverse_index, guide_index=guide_index, base_dir=base_dir) @@ -486,14 +482,15 @@ class DocGenerator(object): doc_index = build_doc_index(flags.src_dir) visitor = self.run_extraction() reference_resolver = self.make_reference_resolver(visitor, doc_index) + guide_index = _build_guide_index( os.path.join(flags.src_dir, 'api_guides/python')) + parser_config = self.make_parser_config(visitor, reference_resolver, guide_index, flags.base_dir) output_dir = os.path.join(flags.output_dir, 'api_docs/python') - write_docs(output_dir, parser_config, visitor.duplicate_of, visitor.index, - yaml_toc=self.yaml_toc) + write_docs(output_dir, parser_config, yaml_toc=self.yaml_toc) _other_docs(flags.src_dir, flags.output_dir, reference_resolver) if parser.all_errors: diff --git a/tensorflow/tools/docs/generate_lib_test.py b/tensorflow/tools/docs/generate_lib_test.py index a28afa290a0..c8d4c3fe7e2 100644 --- a/tensorflow/tools/docs/generate_lib_test.py +++ b/tensorflow/tools/docs/generate_lib_test.py @@ -20,7 +20,6 @@ from __future__ import print_function import os import sys -import tempfile import tensorflow as tf @@ -46,6 +45,13 @@ class TestClass(object): pass +class DummyVisitor(object): + + def __init__(self, index, duplicate_of): + self.index = index + self.duplicate_of = duplicate_of + + class GenerateTest(googletest.TestCase): def test_extraction(self): @@ -85,43 +91,59 @@ class GenerateTest(googletest.TestCase): 'tf.TestModule.TestClass.ChildClass.GrandChildClass': [] } - duplicate_of = { - 'tf.TestModule.test_function': 'tf.test_function' - } + duplicate_of = {'tf.test_function': 'tf.TestModule.test_function'} duplicates = { - 'tf.test_function': ['tf.test_function', 'tf.TestModule.test_function'] + 'tf.TestModule.test_function': [ + 'tf.test_function', 'tf.TestModule.test_function' + ] } - output_dir = tempfile.mkdtemp() base_dir = os.path.dirname(__file__) - reference_resolver = parser.ReferenceResolver( - duplicate_of=duplicate_of, - doc_index={}, index=index, py_module_names=['tf']) + visitor = DummyVisitor(index, duplicate_of) + + reference_resolver = parser.ReferenceResolver.from_visitor( + visitor=visitor, doc_index={}, py_module_names=['tf']) + parser_config = parser.ParserConfig( - reference_resolver=reference_resolver, duplicates=duplicates, tree=tree, - reverse_index={}, guide_index={}, base_dir=base_dir) - generate_lib.write_docs(output_dir, parser_config, duplicate_of, index, - yaml_toc=True) + reference_resolver=reference_resolver, + duplicates=duplicates, + duplicate_of=duplicate_of, + tree=tree, + index=index, + reverse_index={}, + guide_index={}, + base_dir=base_dir) + + output_dir = googletest.GetTempDir() + + generate_lib.write_docs(output_dir, parser_config, yaml_toc=True) # Make sure that the right files are written to disk. self.assertTrue(os.path.exists(os.path.join(output_dir, 'index.md'))) self.assertTrue(os.path.exists(os.path.join(output_dir, 'tf.md'))) self.assertTrue(os.path.exists(os.path.join(output_dir, '_toc.yaml'))) - self.assertTrue(os.path.exists(os.path.join( - output_dir, 'tf/TestModule.md'))) - self.assertTrue(os.path.exists(os.path.join( - output_dir, 'tf/test_function.md'))) - self.assertTrue(os.path.exists(os.path.join( - output_dir, 'tf/TestModule/TestClass.md'))) - self.assertTrue(os.path.exists(os.path.join( - output_dir, 'tf/TestModule/TestClass/ChildClass.md'))) - self.assertTrue(os.path.exists(os.path.join( - output_dir, 'tf/TestModule/TestClass/ChildClass/GrandChildClass.md'))) + self.assertTrue( + os.path.exists(os.path.join(output_dir, 'tf/TestModule.md'))) + self.assertFalse( + os.path.exists(os.path.join(output_dir, 'tf/test_function.md'))) + self.assertTrue( + os.path.exists( + os.path.join(output_dir, 'tf/TestModule/TestClass.md'))) + self.assertTrue( + os.path.exists( + os.path.join(output_dir, + 'tf/TestModule/TestClass/ChildClass.md'))) + self.assertTrue( + os.path.exists( + os.path.join( + output_dir, + 'tf/TestModule/TestClass/ChildClass/GrandChildClass.md'))) # Make sure that duplicates are not written - self.assertFalse(os.path.exists(os.path.join( - output_dir, 'tf/TestModule/test_function.md'))) + self.assertTrue( + os.path.exists( + os.path.join(output_dir, 'tf/TestModule/test_function.md'))) if __name__ == '__main__': diff --git a/tensorflow/tools/docs/parser.py b/tensorflow/tools/docs/parser.py index 84b944afc9d..3da58d2b3c7 100644 --- a/tensorflow/tools/docs/parser.py +++ b/tensorflow/tools/docs/parser.py @@ -22,6 +22,7 @@ import ast import collections import functools import inspect +import json import os import re @@ -96,12 +97,71 @@ class ReferenceResolver(object): py_module_names: A list of string names of Python modules. """ - def __init__(self, duplicate_of, doc_index, index, py_module_names): + def __init__(self, duplicate_of, doc_index, is_class, is_module, + py_module_names): self._duplicate_of = duplicate_of self._doc_index = doc_index - self._index = index + self._is_class = is_class + self._is_module = is_module + self._all_names = set(is_class.keys()) self._py_module_names = py_module_names + @classmethod + def from_visitor(cls, visitor, doc_index, **kwargs): + """A factory function for building a ReferenceResolver from a visitor. + + Args: + visitor: an instance of `DocGeneratorVisitor` + doc_index: a dictionary mapping document names to references objects with + "title" and "url" fields + **kwargs: all remaining args are passed to the constructor + Returns: + an instance of `ReferenceResolver` () + """ + is_class = { + name: inspect.isclass(visitor.index[name]) + for name, obj in visitor.index.items() + } + + is_module = { + name: inspect.ismodule(visitor.index[name]) + for name, obj in visitor.index.items() + } + + return cls( + duplicate_of=visitor.duplicate_of, + doc_index=doc_index, + is_class=is_class, + is_module=is_module, + **kwargs) + + @classmethod + def from_json_file(cls, filepath, doc_index): + with open(filepath) as f: + json_dict = json.load(f) + + return cls(doc_index=doc_index, **json_dict) + + def to_json_file(self, filepath): + """Converts the RefenceResolver to json and writes it to the specified file. + + Args: + filepath: The file path to write the json to. + """ + json_dict = {} + for key, value in self.__dict__.items(): + # Drop these two fields. `_doc_index` is not serializable. `_all_names` is + # generated by the constructor. + if key in ('_doc_index', '_all_names'): + continue + + # Strip off any leading underscores on field names as these are not + # recognized by the constructor. + json_dict[key.lstrip('_')] = value + + with open(filepath, 'w') as f: + json.dump(json_dict, f) + def replace_references(self, string, relative_path_to_root): """Replace "@{symbol}" references with links to symbol's documentation page. @@ -161,10 +221,6 @@ class ReferenceResolver(object): """Return the master name for a Python symbol name.""" return self._duplicate_of.get(full_name, full_name) - def py_name_to_object(self, full_name): - """Return the Python object for a Python symbol name.""" - return self._index[full_name] - def reference_to_url(self, ref_full_name, relative_path_to_root): """Resolve a "@{python symbol}" reference to a relative path. @@ -187,12 +243,12 @@ class ReferenceResolver(object): to the documentation page of `ref_full_name`. Raises: - RuntimeError: If `ref_full_name` is not in `self._index`. + RuntimeError: If `ref_full_name` is not documented. """ master_name = self._duplicate_of.get(ref_full_name, ref_full_name) # Check whether this link exists - if master_name not in self._index: + if master_name not in self._all_names: # TODO(josh11b): Make error reporting more uniform. print('ERROR: Cannot make link to %s (original: %s): Not in index.' % (master_name, ref_full_name)) @@ -200,13 +256,12 @@ class ReferenceResolver(object): # If this is a member of a class, link to the class page with an anchor. ref_path = None - py_object = self._index[master_name] - if not (inspect.isclass(py_object) or inspect.ismodule(py_object)): + if not (self._is_class[master_name] or self._is_module[master_name]): idents = master_name.split('.') if len(idents) > 1: class_name = '.'.join(idents[:-1]) - assert class_name in self._index - if inspect.isclass(self._index[class_name]): + assert class_name in self._all_names + if self._is_class[class_name]: ref_path = documentation_path(class_name) + '#%s' % idents[-1] if not ref_path: @@ -841,24 +896,20 @@ class _ClassPageInfo(object): other_member_info = _OtherMemberInfo(short_name, full_name, obj, doc) self._other_members.append(other_member_info) - def collect_docs_for_class(self, py_class, - reference_resolver, tree, reverse_index): + def collect_docs_for_class(self, py_class, parser_config): """Collect information necessary specifically for a class's doc page. Mainly, this is details about information about the class's members. Args: - py_class: The class object to collect docs for. - reference_resolver: An instance of ReferenceResolver. - tree: A map from full names to the names of all documentable child - objects. - reverse_index: A map from object ids in the index to full names. + py_class: the class object being documented + parser_config: An instance of ParserConfig. """ doc_path = documentation_path(self.full_name) relative_path = os.path.relpath( path='.', start=os.path.dirname(doc_path) or '.') - for short_name in tree[self.full_name]: + for short_name in parser_config.tree[self.full_name]: # Remove builtin members that we never want to document. if short_name in ['__class__', '__base__', '__weakref__', '__doc__', '__module__', '__dict__', '__abstractmethods__', @@ -866,7 +917,7 @@ class _ClassPageInfo(object): continue child_name = '.'.join([self.full_name, short_name]) - child = reference_resolver.py_name_to_object(child_name) + child = parser_config.py_name_to_object(child_name) # Don't document anything that is defined in object or by protobuf. defining_class = _get_defining_class(py_class, short_name) @@ -879,7 +930,8 @@ class _ClassPageInfo(object): continue # TODO(markdaoust): Add a note in child docs showing the defining class. - child_doc = _parse_md_docstring(child, relative_path, reference_resolver) + child_doc = _parse_md_docstring(child, relative_path, + parser_config.reference_resolver) if isinstance(child, property): self._add_property(short_name, child_name, child, child_doc) @@ -887,7 +939,7 @@ class _ClassPageInfo(object): elif inspect.isclass(child): if defining_class is None: continue - url = reference_resolver.reference_to_url( + url = parser_config.reference_resolver.reference_to_url( child_name, relative_path) self._add_class(short_name, child_name, child, child_doc, url) @@ -912,7 +964,8 @@ class _ClassPageInfo(object): continue try: - child_signature = _generate_signature(child, reverse_index) + child_signature = _generate_signature(child, + parser_config.reverse_index) except TypeError: # If this is a (dynamically created) slot wrapper, inspect will # raise typeerror when trying to get to the code. Ignore such @@ -1025,34 +1078,32 @@ class _ModulePageInfo(object): self._other_members.append( _OtherMemberInfo(short_name, full_name, obj, doc)) - def collect_docs_for_module(self, reference_resolver, tree): + def collect_docs_for_module(self, parser_config): """Collect information necessary specifically for a module's doc page. Mainly this is information about the members of the module. Args: - reference_resolver: An instance of ReferenceResolver. - tree: A map from full names to the names of all documentable child - objects. + parser_config: An instance of ParserConfig. """ relative_path = os.path.relpath( path='.', start=os.path.dirname(documentation_path(self.full_name)) or '.') - member_names = tree.get(self.full_name, []) + member_names = parser_config.tree.get(self.full_name, []) for name in member_names: - if name in ['__builtins__', '__doc__', '__file__', '__name__', '__path__', - '__package__']: + if name in ['__builtins__', '__doc__', '__file__', + '__name__', '__path__', '__package__']: continue member_full_name = self.full_name + '.' + name if self.full_name else name - member = reference_resolver.py_name_to_object(member_full_name) + member = parser_config.py_name_to_object(member_full_name) member_doc = _parse_md_docstring(member, relative_path, - reference_resolver) + parser_config.reference_resolver) - url = reference_resolver.reference_to_url( + url = parser_config.reference_resolver.reference_to_url( member_full_name, relative_path) if inspect.ismodule(member): @@ -1069,9 +1120,10 @@ class _ModulePageInfo(object): class ParserConfig(object): + """Stores all indexes required to parse the docs.""" - def __init__(self, reference_resolver, duplicates, tree, reverse_index, - guide_index, base_dir): + def __init__(self, reference_resolver, duplicates, duplicate_of, tree, index, + reverse_index, guide_index, base_dir): """Object with the common config for docs_for_object() calls. Args: @@ -1079,24 +1131,35 @@ class ParserConfig(object): duplicates: A `dict` mapping fully qualified names to a set of all aliases of this name. This is used to automatically generate a list of all aliases for each name. + duplicate_of: A map from duplicate names to preferred names of API + symbols. tree: A `dict` mapping a fully qualified name to the names of all its members. Used to populate the members section of a class or module page. - reverse_index: A `dict` mapping objects in the index to full names. + index: A `dict` mapping full names to objects. + reverse_index: A `dict` mapping object ids to full names. + guide_index: A `dict` mapping symbol name strings to objects with a `make_md_link()` method. + base_dir: A base path that is stripped from file locations written to the docs. """ self.reference_resolver = reference_resolver self.duplicates = duplicates + self.duplicate_of = duplicate_of self.tree = tree self.reverse_index = reverse_index + self.index = index self.guide_index = guide_index self.base_dir = base_dir self.defined_in_prefix = 'tensorflow/' self.code_url_prefix = ( 'https://www.tensorflow.org/code/tensorflow/') # pylint: disable=line-too-long + def py_name_to_object(self, full_name): + """Return the Python object for a Python symbol name.""" + return self.index[full_name] + def docs_for_object(full_name, py_object, parser_config): """Return a PageInfo object describing a given object from the TF API. @@ -1141,15 +1204,11 @@ def docs_for_object(full_name, py_object, parser_config): elif inspect.isclass(py_object): page_info = _ClassPageInfo(master_name) - page_info.collect_docs_for_class(py_object, - parser_config.reference_resolver, - parser_config.tree, - parser_config.reverse_index) + page_info.collect_docs_for_class(py_object, parser_config) elif inspect.ismodule(py_object): page_info = _ModulePageInfo(master_name) - page_info.collect_docs_for_module(parser_config.reference_resolver, - parser_config.tree) + page_info.collect_docs_for_module(parser_config) else: raise RuntimeError('Cannot make docs for object %s: %r' % (full_name, diff --git a/tensorflow/tools/docs/parser_test.py b/tensorflow/tools/docs/parser_test.py index 11c5e92d84c..2bab6b3de4b 100644 --- a/tensorflow/tools/docs/parser_test.py +++ b/tensorflow/tools/docs/parser_test.py @@ -56,6 +56,13 @@ class TestClass(object): CLASS_MEMBER = 'a class member' +class DummyVisitor(object): + + def __init__(self, index, duplicate_of): + self.index = index + self.duplicate_of = duplicate_of + + class ParserTest(googletest.TestCase): def test_documentation_path(self): @@ -75,9 +82,12 @@ class ParserTest(googletest.TestCase): 'tf.reference.foo': HasOneMember.foo, 'tf.third': HasOneMember, 'tf.fourth': HasOneMember} - reference_resolver = parser.ReferenceResolver( - duplicate_of=duplicate_of, doc_index={}, index=index, - py_module_names=['tf']) + + visitor = DummyVisitor(index, duplicate_of) + + reference_resolver = parser.ReferenceResolver.from_visitor( + visitor=visitor, doc_index={}, py_module_names=['tf']) + result = reference_resolver.replace_references(string, '../..') self.assertEqual( 'A [`tf.reference`](../../tf/reference.md), another ' @@ -98,8 +108,11 @@ class ParserTest(googletest.TestCase): doc2.title = 'Two words' doc2.url = 'somewhere/else' doc_index = {'doc1': doc1, 'do/c2': doc2} - reference_resolver = parser.ReferenceResolver( - duplicate_of={}, doc_index=doc_index, index={}, py_module_names=['tf']) + + visitor = DummyVisitor(index={}, duplicate_of={}) + + reference_resolver = parser.ReferenceResolver.from_visitor( + visitor=visitor, doc_index=doc_index, py_module_names=['tf']) result = reference_resolver.replace_references(string, 'python') self.assertEqual( '[Title1](../URL1) [Title1](../URL1#abc) [link](../URL1) ' @@ -115,15 +128,24 @@ class ParserTest(googletest.TestCase): 'TestClass.ChildClass': TestClass.ChildClass, 'TestClass.CLASS_MEMBER': TestClass.CLASS_MEMBER } - reference_resolver = parser.ReferenceResolver( - duplicate_of={}, doc_index={}, index=index, py_module_names=['tf']) + + visitor = DummyVisitor(index=index, duplicate_of={}) + + reference_resolver = parser.ReferenceResolver.from_visitor( + visitor=visitor, doc_index={}, py_module_names=['tf']) tree = { 'TestClass': ['a_method', 'a_property', 'ChildClass', 'CLASS_MEMBER'] } parser_config = parser.ParserConfig( - reference_resolver=reference_resolver, duplicates={}, tree=tree, - reverse_index={}, guide_index={}, base_dir='/') + reference_resolver=reference_resolver, + duplicates={}, + duplicate_of={}, + tree=tree, + index=index, + reverse_index={}, + guide_index={}, + base_dir='/') page_info = parser.docs_for_object( full_name='TestClass', py_object=TestClass, parser_config=parser_config) @@ -158,16 +180,25 @@ class ParserTest(googletest.TestCase): test_function_with_args_kwargs, 'TestModule.TestClass': TestClass, } - reference_resolver = parser.ReferenceResolver( - duplicate_of={}, doc_index={}, index=index, py_module_names=['tf']) + + visitor = DummyVisitor(index=index, duplicate_of={}) + + reference_resolver = parser.ReferenceResolver.from_visitor( + visitor=visitor, doc_index={}, py_module_names=['tf']) tree = { 'TestModule': ['TestClass', 'test_function', 'test_function_with_args_kwargs'] } parser_config = parser.ParserConfig( - reference_resolver=reference_resolver, duplicates={}, tree=tree, - reverse_index={}, guide_index={}, base_dir='/') + reference_resolver=reference_resolver, + duplicates={}, + duplicate_of={}, + tree=tree, + index=index, + reverse_index={}, + guide_index={}, + base_dir='/') page_info = parser.docs_for_object( full_name='TestModule', py_object=module, parser_config=parser_config) @@ -189,15 +220,24 @@ class ParserTest(googletest.TestCase): index = { 'test_function': test_function } - reference_resolver = parser.ReferenceResolver( - duplicate_of={}, doc_index={}, index=index, py_module_names=['tf']) + + visitor = DummyVisitor(index=index, duplicate_of={}) + + reference_resolver = parser.ReferenceResolver.from_visitor( + visitor=visitor, doc_index={}, py_module_names=['tf']) tree = { '': ['test_function'] } parser_config = parser.ParserConfig( - reference_resolver=reference_resolver, duplicates={}, tree=tree, - reverse_index={}, guide_index={}, base_dir='/') + reference_resolver=reference_resolver, + duplicates={}, + duplicate_of={}, + tree=tree, + index=index, + reverse_index={}, + guide_index={}, + base_dir='/') page_info = parser.docs_for_object( full_name='test_function', @@ -219,15 +259,24 @@ class ParserTest(googletest.TestCase): index = { 'test_function_with_args_kwargs': test_function_with_args_kwargs } - reference_resolver = parser.ReferenceResolver( - duplicate_of={}, doc_index={}, index=index, py_module_names=['tf']) + + visitor = DummyVisitor(index=index, duplicate_of={}) + + reference_resolver = parser.ReferenceResolver.from_visitor( + visitor=visitor, doc_index={}, py_module_names=['tf']) tree = { '': ['test_function_with_args_kwargs'] } parser_config = parser.ParserConfig( - reference_resolver=reference_resolver, duplicates={}, tree=tree, - reverse_index={}, guide_index={}, base_dir='/') + reference_resolver=reference_resolver, + duplicates={}, + duplicate_of={}, + tree=tree, + index=index, + reverse_index={}, + guide_index={}, + base_dir='/') page_info = parser.docs_for_object( full_name='test_function_with_args_kwargs', @@ -287,9 +336,11 @@ class ParserTest(googletest.TestCase): 'tf.third': HasOneMember, 'tf.fourth': HasOneMember } - reference_resolver = parser.ReferenceResolver( - duplicate_of=duplicate_of, doc_index={}, index=index, - py_module_names=['tf']) + + visitor = DummyVisitor(index=index, duplicate_of=duplicate_of) + + reference_resolver = parser.ReferenceResolver.from_visitor( + visitor=visitor, doc_index={}, py_module_names=['tf']) doc_info = parser._parse_md_docstring(test_function_with_fancy_docstring, '../..', reference_resolver) @@ -319,9 +370,11 @@ class ParserTest(googletest.TestCase): duplicate_of = { 'TestModule.test_function': 'test_function' } - reference_resolver = parser.ReferenceResolver( - duplicate_of=duplicate_of, doc_index={}, index=index, - py_module_names=['tf']) + + visitor = DummyVisitor(index=index, duplicate_of=duplicate_of) + + reference_resolver = parser.ReferenceResolver.from_visitor( + visitor=visitor, doc_index={}, py_module_names=['tf']) docs = parser.generate_global_index('TestLibrary', index=index, reference_resolver=reference_resolver) @@ -389,6 +442,37 @@ class ParserTest(googletest.TestCase): # pylint: enable=protected-access + def testSaveReferenceResolver(self): + you_cant_serialize_this = object() + + duplicate_of = {'AClass': ['AClass2']} + doc_index = {'doc': you_cant_serialize_this} + is_class = { + 'tf': False, + 'tf.AClass': True, + 'tf.AClass2': True, + 'tf.function': False + } + is_module = { + 'tf': True, + 'tf.AClass': False, + 'tf.AClass2': False, + 'tf.function': False + } + py_module_names = ['tf', 'tfdbg'] + + resolver = parser.ReferenceResolver(duplicate_of, doc_index, is_class, + is_module, py_module_names) + + outdir = googletest.GetTempDir() + + filepath = os.path.join(outdir, 'resolver.json') + + resolver.to_json_file(filepath) + resolver2 = parser.ReferenceResolver.from_json_file(filepath, doc_index) + + # There are no __slots__, so all fields are visible in __dict__. + self.assertEqual(resolver.__dict__, resolver2.__dict__) RELU_DOC = """Computes rectified linear: `max(features, 0)` diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD index d9c67862e74..c3a839030d8 100644 --- a/tensorflow/tools/pip_package/BUILD +++ b/tensorflow/tools/pip_package/BUILD @@ -109,7 +109,10 @@ filegroup( "@png_archive//:LICENSE", "@protobuf//:LICENSE", "@six_archive//:LICENSE", + "@org_html5lib//:LICENSE", + "@org_mozilla_bleach//:LICENSE", "@org_pocoo_werkzeug//:LICENSE", + "@org_pythonhosted_markdown//:LICENSE.md", "@zlib_archive//:zlib.h", ] + if_not_windows([ "@nccl_archive//:LICENSE.txt", diff --git a/tensorflow/tools/pip_package/setup.py b/tensorflow/tools/pip_package/setup.py index 6baaf2f99b4..6f1c539608b 100644 --- a/tensorflow/tools/pip_package/setup.py +++ b/tensorflow/tools/pip_package/setup.py @@ -36,6 +36,9 @@ REQUIRED_PACKAGES = [ 'six >= 1.10.0', 'protobuf >= 3.2.0', 'werkzeug >= 0.11.10', + 'html5lib == 1.0b8', + 'markdown == 2.2.0', + 'bleach == 1.5.0', ] project_name = 'tensorflow' diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index dda0b0b7e32..dac04440d03 100644 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -217,6 +217,39 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""): build_file = str(Label("//third_party:six.BUILD")), ) + native.new_http_archive( + name = "org_pythonhosted_markdown", + urls = [ + "http://bazel-mirror.storage.googleapis.com/pypi.python.org/packages/1d/25/3f6d2cb31ec42ca5bd3bfbea99b63892b735d76e26f20dd2dcc34ffe4f0d/Markdown-2.6.8.tar.gz", + "https://pypi.python.org/packages/1d/25/3f6d2cb31ec42ca5bd3bfbea99b63892b735d76e26f20dd2dcc34ffe4f0d/Markdown-2.6.8.tar.gz", + ], + strip_prefix = "Markdown-2.6.8", + sha256 = "0ac8a81e658167da95d063a9279c9c1b2699f37c7c4153256a458b3a43860e33", + build_file = str(Label("//third_party:markdown.BUILD")), + ) + + native.new_http_archive( + name = "org_html5lib", + urls = [ + "http://bazel-mirror.storage.googleapis.com/github.com/html5lib/html5lib-python/archive/1.0b8.tar.gz", + "https://github.com/html5lib/html5lib-python/archive/1.0b8.tar.gz", + ], + sha256 = "adb36c879264e8880b92589c4c4fe0814cd9d157b73328b14d728f48a6bab0a4", + strip_prefix = "html5lib-python-1.0b8", + build_file = str(Label("//third_party:html5lib.BUILD")), + ) + + native.new_http_archive( + name = "org_mozilla_bleach", + urls = [ + "http://bazel-mirror.storage.googleapis.com/github.com/mozilla/bleach/archive/v1.5.tar.gz", + "https://github.com/mozilla/bleach/archive/v1.5.tar.gz", + ], + strip_prefix = "bleach-1.5", + sha256 = "0d68713d02ba4148c417ab1637dd819333d96929a34401d0233947bec0881ad8", + build_file = str(Label("//third_party:bleach.BUILD")), + ) + native.new_http_archive( name = "org_pocoo_werkzeug", urls = [ diff --git a/third_party/bleach.BUILD b/third_party/bleach.BUILD new file mode 100644 index 00000000000..1bf75b84a76 --- /dev/null +++ b/third_party/bleach.BUILD @@ -0,0 +1,20 @@ +# Description: +# Build file for Bleach. +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +py_library( + name = "org_mozilla_bleach", + srcs = [ + "bleach/__init__.py", + "bleach/callbacks.py", + "bleach/encoding.py", + "bleach/sanitizer.py", + "bleach/version.py", + ], + srcs_version = "PY2AND3", + deps = ["@org_html5lib"], +) diff --git a/third_party/html5lib.BUILD b/third_party/html5lib.BUILD new file mode 100644 index 00000000000..63aac14f155 --- /dev/null +++ b/third_party/html5lib.BUILD @@ -0,0 +1,17 @@ +# Description: +# Import of html5lib library. + +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # BSD-like notice-style license, see LICENSE file + +exports_files(["LICENSE"]) + +py_library( + name = "org_html5lib", + srcs = glob(["html5lib/**/*.py"]), + srcs_version = "PY2AND3", + deps = [ + "@six_archive//:six", + ], +) diff --git a/third_party/markdown.BUILD b/third_party/markdown.BUILD new file mode 100644 index 00000000000..fa3e85d5304 --- /dev/null +++ b/third_party/markdown.BUILD @@ -0,0 +1,15 @@ +# Description: +# Markdown processor + +package(default_visibility = ["//visibility:public"]) + +# This software says they use a BSD license. +licenses(["notice"]) + +exports_files(["LICENSE.md"]) + +py_library( + name = "org_pythonhosted_markdown", + srcs = glob(["markdown/**/*.py"]), + srcs_version = "PY2AND3", +)