diff --git a/.bazelrc b/.bazelrc index 8fd166c10a5..f2eae57d4a1 100644 --- a/.bazelrc +++ b/.bazelrc @@ -19,10 +19,10 @@ # Compiler options: # cuda_clang: Use clang when building CUDA code. # c++17: Build with C++17 options -# C++1z: Build with C++17 options +# c++1z: Build with C++17 options # avx_linux: Build with avx instruction set on linux. # avx2_linux: Build with avx2 instruction set on linux. -# arch_native_linux: Build with instruction sets available to the host machine on linux +# native_arch_linux: Build with instruction sets available to the host machine on linux # avx_win: Build with avx instruction set on windows # avx2_win: Build with avx2 instruction set on windows # diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index b4dc0e73975..ccc03cc046d 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -88,6 +88,9 @@ TensorFlow coding style. submitting PRs to fix one typo, one warning,etc. We recommend fixing the same issue at the file level at least (e.g.: fix all typos in a file, fix all compiler warning in a file, etc.) +* Tests should follow the + [testing best practices](https://www.tensorflow.org/community/contribute/tests) + guide. #### License diff --git a/README.md b/README.md index 5b4dd28c446..05ddb90fabc 100644 --- a/README.md +++ b/README.md @@ -135,6 +135,7 @@ Build Type | Status * [TensorFlow Examples](https://github.com/tensorflow/examples) * [TensorFlow in Practice from Coursera](https://www.coursera.org/specializations/tensorflow-in-practice) * [TensorFlow: Data and Deployment from Coursera](https://www.coursera.org/specializations/tensorflow-data-and-deployment) +* [Getting Started with TensorFlow 2 from Coursera](https://www.coursera.org/learn/getting-started-with-tensor-flow2) * [Intro to TensorFlow for Deep Learning from Udacity](https://www.udacity.com/course/intro-to-tensorflow-for-deep-learning--ud187) * [Introduction to TensorFlow Lite from Udacity](https://www.udacity.com/course/intro-to-tensorflow-lite--ud190) * [TensorFlow Blog](https://blog.tensorflow.org) diff --git a/tensorflow/api_template.__init__.py b/tensorflow/api_template.__init__.py index 632cb682348..d22eafada16 100644 --- a/tensorflow/api_template.__init__.py +++ b/tensorflow/api_template.__init__.py @@ -13,16 +13,16 @@ # limitations under the License. # ============================================================================== """ -Top-level module of TensorFlow. By convention, we refer to this module as -`tf` instead of `tensorflow`, following the common practice of importing +Top-level module of TensorFlow. By convention, we refer to this module as +`tf` instead of `tensorflow`, following the common practice of importing TensorFlow via the command `import tensorflow as tf`. -The primary function of this module is to import all of the public TensorFlow -interfaces into a single place. The interfaces themselves are located in +The primary function of this module is to import all of the public TensorFlow +interfaces into a single place. The interfaces themselves are located in sub-modules, as described below. -Note that the file `__init__.py` in the TensorFlow source code tree is actually -only a placeholder to enable test cases to run. The TensorFlow build replaces +Note that the file `__init__.py` in the TensorFlow source code tree is actually +only a placeholder to enable test cases to run. The TensorFlow build replaces this file with a file generated from [`api_template.__init__.py`](https://www.github.com/tensorflow/tensorflow/blob/master/tensorflow/api_template.__init__.py) """ @@ -41,6 +41,11 @@ import sys as _sys from tensorflow.python.tools import module_util as _module_util from tensorflow.python.util.lazy_loader import LazyLoader as _LazyLoader +# Make sure code inside the TensorFlow codebase can use tf2.enabled() at import. +_os.environ['TF2_BEHAVIOR'] = '1' +from tensorflow.python import tf2 as _tf2 +_tf2.enable() + # API IMPORTS PLACEHOLDER # WRAPPER_PLACEHOLDER diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc index 4e7ba3943ae..1e9498777c4 100644 --- a/tensorflow/c/c_api_experimental.cc +++ b/tensorflow/c/c_api_experimental.cc @@ -683,7 +683,11 @@ TFE_TensorHandle* TFE_NewTensorHandleFromScalar(TF_DataType data_type, tensorflow::Tensor tensor(dtype, tensorflow::TensorShape({})); std::memcpy(tensorflow::TensorCApi::Buffer(tensor)->data(), data, len); - return TFE_TensorHandle::CreateLocalHandle(tensor, status); + + status->status = tensorflow::Status::OK(); + return new TFE_TensorHandle{ + std::make_unique( + tensorflow::TensorHandle::CreateLocalHandle(tensor))}; } namespace { @@ -703,7 +707,8 @@ tensorflow::Status EnableCollectiveOps(const tensorflow::ServerDef& server_def, } while (0); // New server created for new server_def. Unused if updating server_def. - tensorflow::EagerContext* context = ctx->context; + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(ctx->context); tensorflow::GrpcServer* grpc_server = dynamic_cast(context->GetServer()); if (grpc_server == nullptr) { @@ -822,8 +827,7 @@ void TFE_InferShapes(TFE_Op* tfe_op, TF_ShapeAndTypeList* input_shapes, for (int i = 0; i < num_inputs; ++i) { node_def.add_input("dummy_input"); } - tensorflow::down_cast( - tfe_op->operation.get()) + OperationFromInterface(tfe_op->operation) ->Attrs() .FillAttrValueMap(node_def.mutable_attr()); diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index c25cb264ce7..01375f115e9 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -28,6 +28,9 @@ tf_cuda_library( "c_api_debug.cc", "c_api_experimental.h", "c_api_internal.h", + "c_api_unified_experimental.h", + "context_interface.cc", + "context_interface.h", "operation_interface.cc", "operation_interface.h", "tensor_handle_interface.h", @@ -62,6 +65,7 @@ tf_cuda_library( "//tensorflow/core/platform:errors", "//tensorflow/core:protos_all_cc", "//tensorflow/core/profiler/lib:traceme", + "@com_google_absl//absl/types:variant", ], }) + select({ "//tensorflow:with_xla_support": [ @@ -95,6 +99,8 @@ filegroup( srcs = [ "c_api_experimental.h", "c_api_internal.h", + "c_api_unified_experimental.h", + "context_interface.h", "dlpack.h", "operation_interface.h", "tensor_handle_interface.h", @@ -109,6 +115,8 @@ tf_cuda_library( name = "c_api_internal", srcs = [ "c_api_experimental.h", + "c_api_unified_experimental.h", + "context_interface.h", "operation_interface.h", "tensor_handle_interface.h", ], @@ -206,7 +214,6 @@ tf_cuda_cc_test( "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib", - "//tensorflow/core/platform:casts", "@com_google_absl//absl/strings", ], ) @@ -215,8 +222,12 @@ tf_cuda_library( name = "c_api_experimental", srcs = [ "c_api_experimental.cc", + "c_api_unified_experimental.cc", + ], + hdrs = [ + "c_api_experimental.h", + "c_api_unified_experimental.h", ], - hdrs = ["c_api_experimental.h"], copts = tf_copts() + tfe_xla_copts(), visibility = ["//visibility:public"], deps = select({ @@ -242,6 +253,7 @@ tf_cuda_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/types:variant", ], }) + select({ "//tensorflow:with_xla_support": [ @@ -293,6 +305,30 @@ tf_cuda_cc_test( ], ) +tf_cuda_cc_test( + name = "c_api_unified_experimental_test", + size = "small", + srcs = [ + "c_api_unified_experimental_test.cc", + ], + args = ["--heap_check=local"], + extra_copts = tfe_xla_copts(), + linkstatic = tf_kernel_tests_linkstatic(), + tags = tf_cuda_tests_tags() + ["nomac"], + deps = [ + ":c_api", + ":c_api_experimental", + ":c_api_test_util", + "//tensorflow/c:c_test_util", + "//tensorflow/cc/profiler", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "@com_google_absl//absl/strings", + ], +) + tf_cc_test( name = "custom_device_test", size = "small", diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 94a0a76ada1..a38bdc6cbb0 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -305,7 +305,9 @@ tensorflow::Status CreateRemoteContexts( server_def.default_session_config()); std::vector filtered_device_mask; - ctx->context->FilterDevicesForRemoteWorkers( + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(ctx->context); + context->FilterDevicesForRemoteWorkers( remote_worker, base_request.cluster_device_attributes(), &filtered_device_mask); DCHECK_EQ(filtered_device_mask.size(), @@ -388,7 +390,9 @@ tensorflow::Status UpdateRemoteContexts( } std::vector filtered_device_mask; - ctx->context->FilterDevicesForRemoteWorkers( + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(ctx->context); + context->FilterDevicesForRemoteWorkers( remote_worker, base_request.cluster_device_attributes(), &filtered_device_mask); DCHECK_EQ(filtered_device_mask.size(), cluster_device_count); @@ -467,7 +471,8 @@ tensorflow::Status UpdateTFE_ContextWithServerDef( // New server created for new server_def. Unused if updating server_def. std::unique_ptr new_server; - tensorflow::EagerContext* context = ctx->context; + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(ctx->context); tensorflow::GrpcServer* grpc_server; if (reset_context) { LOG_AND_RETURN_IF_ERROR(tensorflow::NewServer(server_def, &new_server)); @@ -696,14 +701,16 @@ TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) { tensorflow::Rendezvous* r = new tensorflow::IntraProcessRendezvous(device_mgr.get()); - return new TFE_Context{new tensorflow::EagerContext( - opts->session_options.options, - static_cast( - opts->device_placement_policy), - static_cast(opts->mirroring_policy), - opts->async, opts->lazy_remote_inputs_copy, device_mgr.release(), - /*device_mgr_owned*/ true, r, - tensorflow::GetDefaultCustomKernelCreator())}; + return new TFE_Context{std::make_unique( + new tensorflow::EagerContext( + opts->session_options.options, + static_cast( + opts->device_placement_policy), + static_cast( + opts->mirroring_policy), + opts->async, opts->lazy_remote_inputs_copy, device_mgr.release(), + /*device_mgr_owned*/ true, r, + tensorflow::GetDefaultCustomKernelCreator()))}; } TFE_Context* TFE_NewContextFromSession(const TFE_ContextOptions* opts, @@ -714,20 +721,24 @@ TFE_Context* TFE_NewContextFromSession(const TFE_ContextOptions* opts, tensorflow::Rendezvous* r = new tensorflow::IntraProcessRendezvous(device_mgr); - return new TFE_Context{new tensorflow::EagerContext( - opts->session_options.options, - static_cast( - opts->device_placement_policy), - static_cast(opts->mirroring_policy), - opts->async, opts->lazy_remote_inputs_copy, device_mgr, - /*device_mgr_owned*/ false, r, - tensorflow::GetDefaultCustomKernelCreator())}; + return new TFE_Context{std::make_unique( + new tensorflow::EagerContext( + opts->session_options.options, + static_cast( + opts->device_placement_policy), + static_cast( + opts->mirroring_policy), + opts->async, opts->lazy_remote_inputs_copy, device_mgr, + /*device_mgr_owned*/ false, r, + tensorflow::GetDefaultCustomKernelCreator()))}; } void TFE_DeleteContext(TFE_Context* ctx) { // context->RefCountIsOne() should be true here. // TODO(iga): Remove EagerContext refcounting. - ctx->context->Unref(); + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(ctx->context); + context->Unref(); delete ctx; } @@ -739,7 +750,9 @@ TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) { } void TFE_ContextClearCaches(TFE_Context* ctx) { - ctx->context->ClearCachesAndThreadExecutors(); + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(ctx->context); + context->ClearCachesAndThreadExecutors(); } // Set server_def on the context, possibly updating it. @@ -769,8 +782,10 @@ TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx, device_filters[i] = tdf.second.device_filters(i); } const string remote_worker = remote_prefix + std::to_string(task_index); + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(ctx->context); status->status = - ctx->context->SetRemoteDeviceFilters(remote_worker, device_filters); + context->SetRemoteDeviceFilters(remote_worker, device_filters); } } } @@ -789,11 +804,13 @@ TF_CAPI_EXPORT extern void TFE_ContextUpdateServerDef(TFE_Context* ctx, "TFE_ContextSetServerDef not supported on mobile"); #else // !defined(IS_MOBILE_PLATFORM) tensorflow::ServerDef server_def; + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(ctx->context); if (!server_def.ParseFromArray(proto, proto_len)) { status->status = tensorflow::errors::InvalidArgument( "Invalid tensorflow.ServerDef protocol buffer"); return; - } else if (ctx->context->GetContextId() == + } else if (context->GetContextId() == tensorflow::EagerContext::kInvalidContextId) { status->status = tensorflow::errors::InvalidArgument( "Trying to update a context with invalid context id."); @@ -817,7 +834,8 @@ TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx, "TFE_ContextSetServerDef not supported on mobile"); return false; #else // !defined(IS_MOBILE_PLATFORM) - tensorflow::EagerContext* context = ctx->context; + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(ctx->context); tensorflow::GrpcServer* grpc_server = static_cast(context->GetServer()); @@ -872,13 +890,17 @@ TF_CAPI_EXPORT extern void TFE_ContextAsyncWait(TFE_Context* ctx, #if defined(IS_MOBILE_PLATFORM) status->status = tensorflow::Status::OK(); #else // !defined(IS_MOBILE_PLATFORM) - status->status = ctx->context->SyncExecutors(); + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(ctx->context); + status->status = context->SyncExecutors(); #endif // !IS_MOBILE_PLATFORM } void TFE_ContextSetThreadLocalDevicePlacementPolicy( TFE_Context* ctx, TFE_ContextDevicePlacementPolicy policy) { - ctx->context->SetThreadLocalDevicePlacementPolicy( + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(ctx->context); + context->SetThreadLocalDevicePlacementPolicy( static_cast(policy)); } @@ -887,15 +909,20 @@ void TFE_ContextSetThreadLocalDevicePlacementPolicy( // safe to call this function from the async EagerExecutor threads. extern TFE_ContextDevicePlacementPolicy TFE_ContextGetDevicePlacementPolicy( TFE_Context* ctx) { + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(ctx->context); return static_cast( - ctx->context->GetDevicePlacementPolicy()); + context->GetDevicePlacementPolicy()); } TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status) { tensorflow::Tensor tensor; status->status = tensorflow::TF_TensorToTensor(t, &tensor); if (!status->status.ok()) return nullptr; - return TFE_TensorHandle::CreateLocalHandle(tensor, status); + + return new TFE_TensorHandle{ + std::make_unique( + tensorflow::TensorHandle::CreateLocalHandle(tensor))}; } void TFE_DeleteTensorHandle(TFE_TensorHandle* h) { @@ -1050,10 +1077,12 @@ TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_TensorHandleCopySharingTensor( } return new TFE_TensorHandle{ - std::unique_ptr(h->handle->Copy())}; + std::unique_ptr( + h->handle->Copy())}; } -AbstractTensorHandleInterface* tensorflow::TensorHandleInterface::Copy() { +tensorflow::AbstractTensorHandleInterface* +tensorflow::TensorHandleInterface::Copy() { handle_->Ref(); return new TensorHandleInterface(handle_); } @@ -1069,13 +1098,22 @@ TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) { return nullptr; } - return h->handle->Resolve(&status->status); + std::unique_ptr t = + h->handle->Resolve(&status->status); + if (t == nullptr) { + return nullptr; + } + + tensorflow::Tensor tensor = tensorflow::TensorFromInterface(t); + return tensorflow::TF_TensorFromTensor(tensor, &status->status); } -TF_Tensor* tensorflow::TensorHandleInterface::Resolve(Status* status) { +std::unique_ptr +tensorflow::TensorHandleInterface::Resolve(Status* status) { if (!IsValid(status)) { return nullptr; } + if (VariantDeviceIsCustom(handle_->device())) { tensorflow::CustomDevice* custom_device = absl::get(handle_->device()); @@ -1104,7 +1142,7 @@ TF_Tensor* tensorflow::TensorHandleInterface::Resolve(Status* status) { h_cpu->Unref(); return nullptr; } - TF_Tensor* retval = tensorflow::TF_TensorFromTensor(*t, status); + auto retval = std::make_unique(*t); h_cpu->Unref(); return retval; } else { @@ -1131,7 +1169,7 @@ TF_Tensor* tensorflow::TensorHandleInterface::Resolve(Status* status) { if (!status->ok()) return nullptr; } } - return tensorflow::TF_TensorFromTensor(tensor, status); + return std::make_unique(std::move(tensor)); } } @@ -1142,8 +1180,7 @@ void* TFE_TensorHandleDevicePointer(TFE_TensorHandle* h, TF_Status* status) { return nullptr; } tensorflow::TensorHandle* handle = - tensorflow::down_cast(h->handle.get()) - ->Handle(); + tensorflow::TensorHandleFromInterface(h->handle); if (VariantDeviceIsCustom(handle->device())) { const tensorflow::Tensor* t; status->status = handle->Tensor(&t); @@ -1178,7 +1215,8 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory( void (*deallocator)(void* data, size_t len, void* arg), void* deallocator_arg, TF_Status* status) { tensorflow::Device* device = nullptr; - tensorflow::EagerContext* context = ctx->context; + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(ctx->context); status->status = context->FindDeviceFromName(device_name, &device); tensorflow::CustomDevice* custom_device = nullptr; if (!status->status.ok()) { @@ -1203,19 +1241,17 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory( tensorflow::Tensor t(static_cast(dtype), tensorflow::TensorShape(dimvec), buf); buf->Unref(); - tensorflow::TensorHandle* ret_handle; if (custom_device == nullptr) { - status->status = tensorflow::TensorHandle::CreateLocalHandle( - std::move(t), device, device, context, &ret_handle); + return new TFE_TensorHandle{ + std::make_unique( + tensorflow::TensorHandle::CreateLocalHandle(std::move(t), device, + device, context))}; } else { - status->status = tensorflow::TensorHandle::CreateLocalHandle( - std::move(t), custom_device, context, &ret_handle); + return new TFE_TensorHandle{ + std::make_unique( + tensorflow::TensorHandle::CreateLocalHandle( + std::move(t), custom_device, context))}; } - if (!status->status.ok()) { - return nullptr; - } - return new TFE_TensorHandle{ - std::make_unique(ret_handle)}; } // This function will block till the operation that produces `h` has @@ -1229,9 +1265,7 @@ size_t TFE_TensorHandleDeviceMemorySize(TFE_TensorHandle* h, return 0; } tensorflow::TensorHandle* handle = - tensorflow::down_cast(h->handle.get()) - ->Handle(); - + tensorflow::TensorHandleFromInterface(h->handle); if (handle->IsRemote()) { status->status = tensorflow::errors::InvalidArgument( "TFE_TensorHandleDeviceMemorySize may not be called on a remote tensor " @@ -1248,8 +1282,7 @@ size_t TFE_TensorHandleDeviceMemorySize(TFE_TensorHandle* h, TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name, TF_Status* status) { - std::unique_ptr new_op( - new TFE_Op{std::make_unique(ctx)}); + std::unique_ptr new_op(new TFE_Op{ctx->context->CreateOperation()}); status->status = new_op->operation->Reset(op_or_function_name, nullptr); if (!status->status.ok()) { new_op.reset(); @@ -1285,8 +1318,8 @@ void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* input, TF_Status* status) { void TFE_OpAddInputList(TFE_Op* op, TFE_TensorHandle** inputs, int num_inputs, TF_Status* status) { - absl::FixedArray> handles( - num_inputs); + absl::FixedArray> + handles(num_inputs); for (int i = 0; i < num_inputs; ++i) { handles[i].reset(inputs[i]->handle->Copy()); } @@ -1383,7 +1416,10 @@ void TFE_OpSetAttrFunctionName(TFE_Op* op, const char* attr_name, void TFE_OpSetAttrTensor(TFE_Op* op, const char* attr_name, TF_Tensor* tensor, TF_Status* status) { - status->status = op->operation->SetAttrTensor(attr_name, tensor); + tensorflow::Tensor t; + status->status = TF_TensorToTensor(tensor, &t); + status->status = op->operation->SetAttrTensor( + attr_name, std::make_unique(t)); } void TFE_OpSetAttrStringList(TFE_Op* op, const char* attr_name, @@ -1480,8 +1516,8 @@ TF_CAPI_EXPORT extern int TFE_OpGetOutputLength(TFE_Op* op, void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals, TF_Status* status) { - absl::FixedArray> handles( - *num_retvals); + absl::FixedArray> + handles(*num_retvals); status->status = op->operation->Execute(&handles, num_retvals); if (!status->status.ok()) { return; @@ -1497,17 +1533,15 @@ TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h, TF_Status* status) { tensorflow::TensorHandle* handle = nullptr; tensorflow::Device* device; - tensorflow::EagerContext* context = ctx->context; + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(ctx->context); status->status = context->FindDeviceFromName(device_name, &device); if (!status->status.ok()) { tensorflow::CustomDevice* dev; status->status = context->FindCustomDeviceFromName(device_name, &dev); if (status->status.ok()) { status->status = dev->CopyTensorToDevice( - tensorflow::down_cast( - h->handle.get()) - ->Handle(), - &handle); + tensorflow::TensorHandleFromInterface(h->handle), &handle); if (status->status.ok()) { return new TFE_TensorHandle{ std::make_unique(handle)}; @@ -1524,10 +1558,7 @@ TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h, status->status = context->FindCustomDeviceFromName(handle_device_name, &dev); if (status->status.ok()) { status->status = dev->CopyTensorFromDevice( - tensorflow::down_cast( - h->handle.get()) - ->Handle(), - device_name, &handle); + tensorflow::TensorHandleFromInterface(h->handle), device_name, &handle); if (status->status.ok()) { return new TFE_TensorHandle{ std::make_unique(handle)}; @@ -1537,9 +1568,8 @@ TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h, // Handle regular case. status->status = tensorflow::EagerCopyToDevice( - tensorflow::down_cast(h->handle.get()) - ->Handle(), - context, &context->Executor(), device, false, &handle); + tensorflow::TensorHandleFromInterface(h->handle), context, + &context->Executor(), device, false, &handle); if (status->status.ok()) { return new TFE_TensorHandle{ std::make_unique(handle)}; @@ -1556,41 +1586,56 @@ void TFE_ContextAddFunctionDef(TFE_Context* ctx, tensorflow::errors::InvalidArgument("Invalid FunctionDef proto"); return; } - status->status = ctx->context->AddFunctionDef(function_def); + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(ctx->context); + status->status = context->AddFunctionDef(function_def); } void TFE_ContextAddFunction(TFE_Context* ctx, TF_Function* function, TF_Status* status) { - status->status = ctx->context->AddFunctionDef(function->fdef); + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(ctx->context); + status->status = context->AddFunctionDef(function->fdef); } void TFE_ContextRemoveFunction(TFE_Context* ctx, const char* name, TF_Status* status) { - status->status = ctx->context->RemoveFunction(name); + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(ctx->context); + status->status = context->RemoveFunction(name); } unsigned char TFE_ContextHasFunction(TFE_Context* ctx, const char* name) { - return ctx->context->FindFunctionDef(name) != nullptr; + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(ctx->context); + return context->FindFunctionDef(name) != nullptr; } void TFE_ContextEnableRunMetadata(TFE_Context* ctx) { - ctx->context->SetShouldStoreGraphs(true); + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(ctx->context); + context->SetShouldStoreGraphs(true); } void TFE_ContextDisableRunMetadata(TFE_Context* ctx) { - ctx->context->SetShouldStoreGraphs(false); + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(ctx->context); + context->SetShouldStoreGraphs(false); } } // extern "C" TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t, TF_Status* status) { - return TFE_TensorHandle::CreateLocalHandle(t, status); + return new TFE_TensorHandle{ + std::make_unique( + tensorflow::TensorHandle::CreateLocalHandle(t))}; } void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf, TF_Status* status) { - tensorflow::EagerContext* context = ctx->context; + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(ctx->context); status->status = context->Executor().WaitForAllPendingNodes(); if (!status->status.ok()) return; tensorflow::mutex_lock ml(*context->MetadataMu()); @@ -1611,21 +1656,27 @@ TFE_Op* GetFunc(TFE_Context* ctx, const tensorflow::NameAttrList& func, } } // namespace -void TFE_ContextStartStep(TFE_Context* ctx) { ctx->context->StartStep(); } +void TFE_ContextStartStep(TFE_Context* ctx) { + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(ctx->context); + context->StartStep(); +} -void TFE_ContextEndStep(TFE_Context* ctx) { ctx->context->EndStep(); } +void TFE_ContextEndStep(TFE_Context* ctx) { + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(ctx->context); + context->EndStep(); +} void TFE_OpGetAttrs(TFE_Op* op, TFE_OpAttrs* attrs) { - auto operation = tensorflow::down_cast( - op->operation.get()); - *attrs = TFE_OpAttrs(&operation->Attrs(), op->operation->Name().c_str()); + tensorflow::EagerOperation* operation = OperationFromInterface(op->operation); + *attrs = TFE_OpAttrs(&operation->Attrs(), operation->Name().c_str()); } void TFE_OpAddAttrs(TFE_Op* op, const TFE_OpAttrs* attrs) { tensorflow::AttrValueMap m; attrs->attributes->FillAttrValueMap(&m); - auto operation = tensorflow::down_cast( - op->operation.get()); + tensorflow::EagerOperation* operation = OperationFromInterface(op->operation); tensorflow::AttrBuilder* destination = operation->MutableAttrs(); for (auto attribute : m) { destination->Set(attribute.first, attribute.second); @@ -1721,9 +1772,7 @@ class CustomDeviceAPI : public tensorflow::CustomDevice { TFE_TensorHandle* result_handle = device_.copy_tensor_to_device(context_, &tensor_handle, &status, info_); if (!status.status.ok()) return status.status; - *result = tensorflow::down_cast( - result_handle->handle.get()) - ->Handle(); + *result = tensorflow::TensorHandleFromInterface(result_handle->handle); (*result)->Ref(); delete result_handle; return status.status; @@ -1740,9 +1789,7 @@ class CustomDeviceAPI : public tensorflow::CustomDevice { TFE_TensorHandle* result_handle = device_.copy_tensor_from_device( context_, &tensor_handle, target_device_name.c_str(), &status, info_); if (!status.status.ok()) return status.status; - *result = tensorflow::down_cast( - result_handle->handle.get()) - ->Handle(); + *result = tensorflow::TensorHandleFromInterface(result_handle->handle); (*result)->Ref(); delete result_handle; return status.status; @@ -1766,9 +1813,7 @@ class CustomDeviceAPI : public tensorflow::CustomDevice { &attributes, num_retvals, outputs.data(), &status, info_); if (status.status.ok()) { for (int i = 0; i < *num_retvals; ++i) { - retvals[i] = tensorflow::down_cast( - outputs[i]->handle.get()) - ->Handle(); + retvals[i] = tensorflow::TensorHandleFromInterface(outputs[i]->handle); retvals[i]->Ref(); delete outputs[i]; } @@ -1793,6 +1838,8 @@ void TFE_RegisterCustomDevice(TFE_Context* ctx, TFE_CustomDevice device, TF_Status* status) { auto custom_device = std::make_unique(ctx, device, device_info, device_name); + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(ctx->context); status->status = - ctx->context->RegisterCustomDevice(device_name, std::move(custom_device)); + context->RegisterCustomDevice(device_name, std::move(custom_device)); } diff --git a/tensorflow/c/eager/c_api_debug.cc b/tensorflow/c/eager/c_api_debug.cc index 2d6dd21e12b..50f31fae3f2 100644 --- a/tensorflow/c/eager/c_api_debug.cc +++ b/tensorflow/c/eager/c_api_debug.cc @@ -54,36 +54,32 @@ extern "C" { TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo( TFE_TensorHandle* h, TF_Status* status) { - return h->handle->TensorDebugInfo(&status->status); -} - -TFE_TensorDebugInfo* tensorflow::TensorHandleInterface::TensorDebugInfo( - Status* status) { + tensorflow::TensorHandle* handle = TensorHandleFromInterface(h->handle); const tensorflow::Tensor* tensor; - *status = handle_->Tensor(&tensor); - if (!status->ok()) { + status->status = handle->Tensor(&tensor); + if (!status->status.ok()) { return nullptr; } #ifdef TENSORFLOW_EAGER_USE_XLA - tensorflow::Device* device = absl::get(handle_->device()); + auto* device = absl::get(handle->device()); // If tensor resides on an XLA device, use XLA device's PaddedShapeFn. - tensorflow::XlaDevice* xla_device = - dynamic_cast(device); + auto* xla_device = dynamic_cast(device); if (xla_device != nullptr) { tensorflow::XlaDevice::PaddedShapeFn shape_fn = xla_device->metadata().padded_shape_fn(); xla::Shape padded_shape; - *status = shape_fn(*tensor, &padded_shape); - if (!status->ok()) { + status->status = shape_fn(*tensor, &padded_shape); + if (!status->status.ok()) { return nullptr; } if (VLOG_IS_ON(3)) { - std::vector shape_to_log = TensorShapeAsVector(*handle_, status); - if (!status->ok()) { + std::vector shape_to_log = + TensorShapeAsVector(*handle, &status->status); + if (!status->status.ok()) { // Ignore the status here as we are simply logging. - *status = tensorflow::Status::OK(); + status->status = tensorflow::Status::OK(); } else { VLOG(3) << "Fully padded shape of [" << absl::StrJoin(shape_to_log, ", ") << "] is " @@ -96,7 +92,7 @@ TFE_TensorDebugInfo* tensorflow::TensorHandleInterface::TensorDebugInfo( // Currently, the only case of XlaTensor containing a tuple shape is to // represent 64 bit ints, doubles, and complex numbers (we don't support // 64bit complex numbers). - *status = tensorflow::errors::InvalidArgument( + status->status = tensorflow::errors::InvalidArgument( "XlaTensors should only contain tuples of size 2. Shape: ", padded_shape.DebugString()); return nullptr; @@ -108,13 +104,13 @@ TFE_TensorDebugInfo* tensorflow::TensorHandleInterface::TensorDebugInfo( const xla::Shape& shape1 = xla::ShapeUtil::GetTupleElementShape(padded_shape, 1); if (shape0.IsTuple() || shape1.IsTuple()) { - *status = tensorflow::errors::InvalidArgument( + status->status = tensorflow::errors::InvalidArgument( "XlaTensors should not contain nested tuples. Shape: ", padded_shape.DebugString()); return nullptr; } if (!xla::ShapeUtil::Equal(shape0, shape1)) { - *status = tensorflow::errors::InvalidArgument( + status->status = tensorflow::errors::InvalidArgument( "Subshapes of XlaTensors should be the same. Shape: ", padded_shape.DebugString()); return nullptr; @@ -139,15 +135,15 @@ TFE_TensorDebugInfo* tensorflow::TensorHandleInterface::TensorDebugInfo( dev_dims.push_back(padded_shape.dimensions(dim_index)); } } - *status = tensorflow::Status::OK(); + status->status = tensorflow::Status::OK(); return new TFE_TensorDebugInfo(dev_dims); } #endif // TENSORFLOW_EAGER_USE_XLA // If the tensor is not an XLA tensor, the device shape is // the same as regular tensor shape. - std::vector dev_dims = TensorShapeAsVector(*handle_, status); - if (!status->ok()) { + std::vector dev_dims = TensorShapeAsVector(*handle, &status->status); + if (!status->status.ok()) { return nullptr; } return new TFE_TensorDebugInfo(dev_dims); diff --git a/tensorflow/c/eager/c_api_experimental.cc b/tensorflow/c/eager/c_api_experimental.cc index afa36fe1210..4d01a066642 100644 --- a/tensorflow/c/eager/c_api_experimental.cc +++ b/tensorflow/c/eager/c_api_experimental.cc @@ -31,6 +31,7 @@ using tensorflow::string; void TFE_OpReset(TFE_Op* op_to_reset, const char* op_or_function_name, const char* raw_device_name, TF_Status* status) { if (op_to_reset) { + op_to_reset->operation->Clear(); status->status = op_to_reset->operation->Reset(op_or_function_name, raw_device_name); } else { @@ -40,11 +41,15 @@ void TFE_OpReset(TFE_Op* op_to_reset, const char* op_or_function_name, } void TFE_ContextEnableGraphCollection(TFE_Context* ctx) { - ctx->context->SetShouldStoreGraphs(true); + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(ctx->context); + context->SetShouldStoreGraphs(true); } void TFE_ContextDisableGraphCollection(TFE_Context* ctx) { - ctx->context->SetShouldStoreGraphs(false); + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(ctx->context); + context->SetShouldStoreGraphs(false); } void TFE_MonitoringCounterCellIncrementBy(TFE_MonitoringCounterCell* cell, @@ -474,7 +479,9 @@ void TFE_ContextOptionsSetMirroringPolicy(TFE_ContextOptions* options, void TFE_ContextSetThreadLocalMirroringPolicy( TFE_Context* ctx, TFE_ContextMirroringPolicy policy) { - ctx->context->SetThreadLocalMirroringPolicy( + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(ctx->context); + context->SetThreadLocalMirroringPolicy( static_cast(policy)); } @@ -483,8 +490,9 @@ void TFE_ContextSetThreadLocalMirroringPolicy( // safe to call this function from the async EagerExecutor threads. extern TFE_ContextMirroringPolicy TFE_ContextGetMirroringPolicy( TFE_Context* ctx) { - return static_cast( - ctx->context->GetMirroringPolicy()); + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(ctx->context); + return static_cast(context->GetMirroringPolicy()); } void TFE_ContextOptionsSetLazyRemoteInputsCopy(TFE_ContextOptions* options, @@ -492,6 +500,10 @@ void TFE_ContextOptionsSetLazyRemoteInputsCopy(TFE_ContextOptions* options, options->lazy_remote_inputs_copy = lazy_copy; } +void TFE_ContextOptionsSetTfrt(TFE_ContextOptions* options, bool use_tfrt) { + options->use_tfrt = use_tfrt; +} + TFE_CancellationManager* TFE_NewCancellationManager() { return new TFE_CancellationManager; } @@ -514,7 +526,11 @@ void TFE_DeleteCancellationManager( void TFE_OpSetCancellationManager(TFE_Op* op, TFE_CancellationManager* cancellation_manager, TF_Status* status) { - status->status = op->operation->SetCancellationManager(cancellation_manager); + tensorflow::EagerOperation* operation = + tensorflow::OperationFromInterface(op->operation); + operation->SetCancellationManager( + &cancellation_manager->cancellation_manager); + status->status = tensorflow::Status::OK(); } TFE_Executor* TFE_NewExecutor(bool is_async) { @@ -537,16 +553,22 @@ void TFE_ExecutorClearError(TFE_Executor* executor) { } void TFE_ContextSetExecutorForThread(TFE_Context* ctx, TFE_Executor* executor) { - ctx->context->SetExecutorForThread(executor->executor()); + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(ctx->context); + context->SetExecutorForThread(executor->executor()); } TFE_Executor* TFE_ContextGetExecutorForThread(TFE_Context* ctx) { - return new TFE_Executor(&ctx->context->Executor()); + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(ctx->context); + return new TFE_Executor(&context->Executor()); } void TFE_HostAddressSpace(TFE_Context* ctx, TF_Buffer* buf) { + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(ctx->context); auto address_space = tensorflow::DeviceNameUtils::AddressSpace( - ctx->context->HostCPU()->parsed_name()); + context->HostCPU()->parsed_name()); auto str = tensorflow::DeviceNameUtils::ParsedNameToString(address_space); void* data = tensorflow::port::Malloc(str.length()); str.copy(static_cast(data), str.length(), 0); @@ -565,7 +587,9 @@ void TFE_TensorHandleEnableImplicitMirroring(TFE_TensorHandle* h, void TFE_ContextGetFunctionDef(TFE_Context* ctx, const char* function_name, TF_Buffer* buf, TF_Status* status) { - auto* function_def = ctx->context->FindFunctionDef(function_name); + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(ctx->context); + auto* function_def = context->FindFunctionDef(function_name); if (function_def == nullptr) { status->status = tensorflow::errors::NotFound( "Unable to find FunctionDef with name: ", function_name); diff --git a/tensorflow/c/eager/c_api_experimental.h b/tensorflow/c/eager/c_api_experimental.h index c24735963d6..5f9190af79a 100644 --- a/tensorflow/c/eager/c_api_experimental.h +++ b/tensorflow/c/eager/c_api_experimental.h @@ -296,6 +296,10 @@ TF_CAPI_EXPORT extern TFE_ContextMirroringPolicy TFE_ContextGetMirroringPolicy( TF_CAPI_EXPORT extern void TFE_ContextOptionsSetLazyRemoteInputsCopy( TFE_ContextOptions*, bool lazy_copy); +// Sets whether to use TFRT +TF_CAPI_EXPORT extern void TFE_ContextOptionsSetTfrt(TFE_ContextOptions*, + bool use_tfrt); + // ----------------------------------------------------------------------------- // Cancellation APIs. diff --git a/tensorflow/c/eager/c_api_experimental_test.cc b/tensorflow/c/eager/c_api_experimental_test.cc index 8f333ab3aef..cf71863e124 100644 --- a/tensorflow/c/eager/c_api_experimental_test.cc +++ b/tensorflow/c/eager/c_api_experimental_test.cc @@ -455,6 +455,7 @@ TEST(CAPI, TensorHandleOnDeviceMemory) { TFE_DeleteTensorHandle(copy_aliased); // Note that this will delete copy. TFE_DeleteTensorHandle(on_host); } + TF_DeleteDeviceList(devices); TF_DeleteTensor(m_data); TFE_DeleteTensorHandle(m); TFE_DeleteContext(ctx); diff --git a/tensorflow/c/eager/c_api_internal.h b/tensorflow/c/eager/c_api_internal.h index 05b0a143025..754fea1aad5 100644 --- a/tensorflow/c/eager/c_api_internal.h +++ b/tensorflow/c/eager/c_api_internal.h @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/c/c_api_internal.h" #include "tensorflow/c/eager/c_api.h" #include "tensorflow/c/eager/c_api_experimental.h" +#include "tensorflow/c/eager/context_interface.h" #include "tensorflow/c/eager/operation_interface.h" #include "tensorflow/c/eager/tensor_handle_interface.h" #include "tensorflow/core/common_runtime/device_factory.h" @@ -59,25 +60,16 @@ struct TFE_ContextOptions { TFE_ContextMirroringPolicy mirroring_policy{TFE_MIRRORING_NONE}; // If true, lazily copy the remote inputs of a function to the target devices. bool lazy_remote_inputs_copy = true; + // If true, use TFRT backend + bool use_tfrt = false; }; struct TFE_Context { - tensorflow::EagerContext* context; + std::unique_ptr context; }; struct TFE_TensorHandle { - static TFE_TensorHandle* CreateLocalHandle(const class tensorflow::Tensor& t, - TF_Status* s) { - tensorflow::TensorHandle* handle; - s->status = tensorflow::TensorHandle::CreateLocalHandle(t, &handle); - if (!s->status.ok()) { - return nullptr; - } - return new TFE_TensorHandle{ - std::make_unique(handle)}; - } - - std::unique_ptr handle; + std::unique_ptr handle; }; struct TFE_TensorDebugInfo { @@ -89,7 +81,7 @@ struct TFE_TensorDebugInfo { }; struct TFE_Op { - std::unique_ptr operation; + std::unique_ptr operation; }; struct TFE_MonitoringCounterCell { diff --git a/tensorflow/c/eager/c_api_remote_test.cc b/tensorflow/c/eager/c_api_remote_test.cc index eb6b234e3df..6a3f54f56dd 100644 --- a/tensorflow/c/eager/c_api_remote_test.cc +++ b/tensorflow/c/eager/c_api_remote_test.cc @@ -184,9 +184,7 @@ void TestRemoteExecuteSilentCopies(bool async, bool remote) { // TODO(gjn): Add support for waiting on async local mirrors if (!async) { - auto remote_arg = tensorflow::down_cast( - h1_task2->handle.get()) - ->Handle(); + auto remote_arg = tensorflow::TensorHandleFromInterface(h1_task2->handle); auto op = tensorflow::down_cast( matmul->operation.get()); // The input handles should never change since they have been mirrored. diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc index 29dba253fee..4664af87fa3 100644 --- a/tensorflow/c/eager/c_api_test.cc +++ b/tensorflow/c/eager/c_api_test.cc @@ -409,13 +409,8 @@ void TensorHandleSilentCopy(bool async, ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get()); // Validate if the input was replaced with a different TensorHandle - auto arg0 = tensorflow::down_cast( - hcpu->handle.get()) - ->Handle(); - auto arg1 = tensorflow::down_cast( - hgpu->handle.get()) - ->Handle(); - + auto arg0 = tensorflow::TensorHandleFromInterface(hcpu->handle); + auto arg1 = tensorflow::TensorHandleFromInterface(hgpu->handle); auto op = tensorflow::down_cast( matmul->operation.get()); diff --git a/tensorflow/c/eager/c_api_unified_experimental.cc b/tensorflow/c/eager/c_api_unified_experimental.cc new file mode 100644 index 00000000000..e501f70a0f2 --- /dev/null +++ b/tensorflow/c/eager/c_api_unified_experimental.cc @@ -0,0 +1,261 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/eager/c_api_unified_experimental.h" + +#include "absl/types/variant.h" +#include "tensorflow/c/c_api.h" +#include "tensorflow/c/eager/c_api.h" +#include "tensorflow/c/eager/c_api_internal.h" +#include "tensorflow/c/tf_status_helper.h" +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/lib/monitoring/counter.h" +#include "tensorflow/core/lib/monitoring/gauge.h" +#include "tensorflow/core/lib/monitoring/sampler.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/casts.h" +#include "tensorflow/core/platform/mutex.h" + +using tensorflow::string; + +// ============================================================================= +// Unified Execution APIs for Eager and tracing backends. +// ============================================================================= + +typedef void (*ExecuteOperation)(TF_AbstractOp* op, int num_inputs, + TF_AbstractTensor* const* inputs, + TF_OutputList* o, TF_ExecutionContext* ctx, + TF_Status* s); +struct TF_ExecutionContext { + explicit TF_ExecutionContext() {} + absl::variant ctx; + ExecuteOperation execution_callback; +}; + +struct TF_AbstractTensor { + absl::variant t; +}; + +struct TF_AbstractOp { + string op_type; + string op_name; +}; + +TF_ExecutionContext* TF_NewExecutionContext() { + return new TF_ExecutionContext(); +} + +void TF_DeleteExecutionContext(TF_ExecutionContext* c) { delete c; } + +TF_AbstractOp* TF_NewAbstractOp() { + TF_AbstractOp* op = new TF_AbstractOp; + return op; +} + +void TF_DeleteAbstractOp(TF_AbstractOp* op) { delete op; } + +TF_AbstractTensor* TF_NewAbstractTensor() { + TF_AbstractTensor* t = new TF_AbstractTensor; + return t; +} + +void TF_DeleteAbstractTensor(TF_AbstractTensor* t) { delete t; } + +struct TF_GraphContext { + TF_Graph* graph; + // TODO(srbs): Handle captures. +}; + +TF_GraphContext* TF_NewGraphContext(TF_Graph* g) { + auto ctx = new TF_GraphContext; + ctx->graph = g; + return ctx; +} + +void TF_DeleteGraphContext(TF_GraphContext* ctx) { delete ctx; } + +struct TF_GraphTensor { + TF_Output output; + TF_GraphContext* ctx; +}; +TF_GraphTensor* TF_NewGraphTensor(TF_GraphContext* ctx, TF_Output output, + TF_Status* s) { + TF_GraphTensor* t = new TF_GraphTensor; + t->output = output; + t->ctx = ctx; + return t; +} +TF_Output TF_GraphTensorToOutput(const TF_GraphTensor* const t, TF_Status* s) { + return t->output; +} +void TF_DeleteGraphTensor(TF_GraphTensor* t) { delete t; } +void TF_AbstractTensorSetEagerTensor(TF_AbstractTensor* at, TFE_TensorHandle* t, + TF_Status* s) { + at->t = t; +} +TFE_TensorHandle* TF_AbstractTensorGetEagerTensor(TF_AbstractTensor* at, + TF_Status* s) { + if (!absl::holds_alternative(at->t)) { + string msg = absl::StrCat("Not an eager tensor handle.", + reinterpret_cast(at)); + TF_SetStatus(s, TF_INVALID_ARGUMENT, msg.c_str()); + return nullptr; + } + return absl::get(at->t); +} +void TF_AbstractTensorSetGraphTensor(TF_AbstractTensor* at, TF_GraphTensor* t, + TF_Status* s) { + at->t = t; +} +TF_GraphTensor* TF_AbstractTensorGetGraphTensor(TF_AbstractTensor* at, + TF_Status* s) { + if (!absl::holds_alternative(at->t)) { + string msg = absl::StrCat("Not an graph tensor handle."); + TF_SetStatus(s, TF_INVALID_ARGUMENT, msg.c_str()); + return nullptr; + } + return absl::get(at->t); +} + +bool IsEagerTensor(const TF_AbstractTensor* const t) { + return absl::holds_alternative(t->t); +} + +struct TF_OutputList { + std::vector outputs; + int expected_num_outputs = -1; +}; + +TF_OutputList* TF_NewOutputList() { return new TF_OutputList; } +void TF_DeleteOutputList(TF_OutputList* o) { delete o; } +void TF_OutputListSetNumOutputs(TF_OutputList* o, int num_outputs, + TF_Status* s) { + o->expected_num_outputs = num_outputs; +} +int TF_OutputListNumOutputs(TF_OutputList* o) { return o->outputs.size(); } +TF_AbstractTensor* TF_OutputListGet(TF_OutputList* o, int i) { + return o->outputs[i]; +} + +void ExecuteOperationEager(TF_AbstractOp* op, int num_inputs, + TF_AbstractTensor* const* inputs, TF_OutputList* o, + TF_ExecutionContext* ctx, TF_Status* s) { + auto* tfe_op = + TFE_NewOp(absl::get(ctx->ctx), op->op_type.c_str(), s); + if (TF_GetCode(s) != TF_OK) return; + for (int i = 0; i < num_inputs; ++i) { + if (!IsEagerTensor(inputs[i])) { + TF_SetStatus(s, TF_INVALID_ARGUMENT, "Not an eager tensor."); + return; + } + TFE_OpAddInput(tfe_op, absl::get(inputs[i]->t), s); + if (TF_GetCode(s) != TF_OK) return; + } + if (o->expected_num_outputs == -1) { + string msg = + "The number of outputs must be provided in eager mode. Use " + "TF_OutputListSetNumOutputs."; + TF_SetStatus(s, TF_INVALID_ARGUMENT, msg.c_str()); + return; + } + tensorflow::gtl::InlinedVector retvals; + int num_retvals = o->expected_num_outputs; + retvals.resize(num_retvals); + TFE_Execute(tfe_op, retvals.data(), &num_retvals, s); + TFE_DeleteOp(tfe_op); + if (TF_GetCode(s) != TF_OK) { + return; + } + o->outputs.clear(); + o->outputs.reserve(num_retvals); + for (int i = 0; i < num_retvals; ++i) { + auto* t = TF_NewAbstractTensor(); + t->t = retvals[i]; + o->outputs.push_back(t); + } +} + +TF_GraphContext* GetGraphContext(TF_AbstractTensor const* t) { + return absl::get(t->t)->ctx; +} + +void ExecuteOperationGraph(TF_AbstractOp* op, int num_inputs, + TF_AbstractTensor* const* inputs, TF_OutputList* o, + TF_ExecutionContext* ctx, TF_Status* s) { + TF_GraphContext* graph_ctx = absl::get(ctx->ctx); + TF_Graph* g = graph_ctx->graph; + auto* tf_opdesc = + TF_NewOperation(g, op->op_type.c_str(), op->op_name.c_str()); + for (int i = 0; i < num_inputs; ++i) { + auto* input = inputs[i]; + if (IsEagerTensor(input)) { + TF_SetStatus(s, TF_INVALID_ARGUMENT, + "Capturing eager tensors is not supported yet."); + return; + } else { + if (GetGraphContext(input) != graph_ctx) { + TF_SetStatus( + s, TF_INVALID_ARGUMENT, + "Capturing tensors from other graphs is not supported yet."); + return; + } + TF_AddInput(tf_opdesc, absl::get(input->t)->output); + } + } + auto* operation = TF_FinishOperation(tf_opdesc, s); + if (TF_GetCode(s) != TF_OK) return; + int num_outputs = TF_OperationNumOutputs(operation); + o->outputs.clear(); + o->outputs.reserve(num_outputs); + for (int i = 0; i < num_outputs; ++i) { + auto* t = TF_NewAbstractTensor(); + TF_GraphTensor* output_t = TF_NewGraphTensor(graph_ctx, {operation, i}, s); + if (TF_GetCode(s) != TF_OK) { + return; + } + t->t = output_t; + o->outputs.push_back(t); + } +} + +void TF_ExecutionContextSetEagerContext(TF_ExecutionContext* context, + TFE_Context* eager_context, + TF_Status* s) { + context->ctx = eager_context; + context->execution_callback = &ExecuteOperationEager; +} + +void TF_ExecutionContextSetGraphContext(TF_ExecutionContext* context, + TF_GraphContext* graph_context, + TF_Status* s) { + context->ctx = graph_context; + context->execution_callback = &ExecuteOperationGraph; +} + +void TF_AbstractOpSetOpType(TF_AbstractOp* op, const char* const op_type, + TF_Status* s) { + op->op_type = op_type; +} + +void TF_AbstractOpSetOpName(TF_AbstractOp* op, const char* const op_name, + TF_Status* s) { + op->op_name = op_name; +} + +void TF_ExecuteOperation(TF_AbstractOp* op, int num_inputs, + TF_AbstractTensor* const* inputs, TF_OutputList* o, + TF_ExecutionContext* ctx, TF_Status* s) { + ctx->execution_callback(op, num_inputs, inputs, o, ctx, s); +} diff --git a/tensorflow/c/eager/c_api_unified_experimental.h b/tensorflow/c/eager/c_api_unified_experimental.h new file mode 100644 index 00000000000..6346ceaf26e --- /dev/null +++ b/tensorflow/c/eager/c_api_unified_experimental.h @@ -0,0 +1,119 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EAGER_C_API_UNIFIED_EXPERIMENTAL_H_ +#define TENSORFLOW_C_EAGER_C_API_UNIFIED_EXPERIMENTAL_H_ + +#include "tensorflow/c/c_api.h" +#include "tensorflow/c/eager/c_api.h" + +#ifdef __cplusplus +extern "C" { +#endif + +// ============================================================================= +// Unified Execution APIs for Eager and tracing backends. +// ============================================================================= + +// ----------------------------------------------------------------------------- +// Core APIs +// ----------------------------------------------------------------------------- + +// A TF_ExecutionContext stores knowledge about how to execute an operation. +// E.g. it could know whether we're in eager mode or in graph mode, keeps track +// of gradient tapes, etc. +typedef struct TF_ExecutionContext TF_ExecutionContext; +// A TF_AbstractTensor is an input to an operation. E.g. it could be a union +// type of eager and graph tensors. +typedef struct TF_AbstractTensor TF_AbstractTensor; +// A TF_AbstractOp is the metadata we need to execute an operation. E.g. this +// could contain the op type and other attributes. +typedef struct TF_AbstractOp TF_AbstractOp; + +TF_ExecutionContext* TF_NewExecutionContext(); +void TF_DeleteExecutionContext(TF_ExecutionContext*); + +TF_AbstractOp* TF_NewAbstractOp(); +void TF_DeleteAbstractOp(TF_AbstractOp*); + +TF_AbstractTensor* TF_NewAbstractTensor(); +void TF_DeleteAbstractTensor(TF_AbstractTensor*); + +// ----------------------------------------------------------------------------- +// APIs for Eager and graph modes +// ----------------------------------------------------------------------------- + +// Keeps track of the current graph and other state e.g. captures etc. +typedef struct TF_GraphContext TF_GraphContext; +TF_GraphContext* TF_NewGraphContext(TF_Graph*); +void TF_DeleteGraphContext(TF_GraphContext*); + +// `eager_context` must outlive `context`. +void TF_ExecutionContextSetEagerContext(TF_ExecutionContext* context, + TFE_Context* eager_context, TF_Status*); +// `graph_context` must outlive `context`. +void TF_ExecutionContextSetGraphContext(TF_ExecutionContext* context, + TF_GraphContext* graph_context, + TF_Status*); + +// TODO(srbs): Add APIs for specifying attrs etc. +// `op_type` must outlive `op`. +void TF_AbstractOpSetOpType(TF_AbstractOp* op, const char* const op_type, + TF_Status* s); +// `op_name` must outlive `op`. +void TF_AbstractOpSetOpName(TF_AbstractOp* op, const char* const op_name, + TF_Status* s); + +// Wrapper for TF_Output but contains a pointer to TF_GraphContext as well. +typedef struct TF_GraphTensor TF_GraphTensor; +TF_GraphTensor* TF_NewGraphTensor(TF_GraphContext* c, TF_Output t, + TF_Status* s); +TF_Output TF_GraphTensorToOutput(const TF_GraphTensor* const t, TF_Status* s); +void TF_DeleteGraphTensor(TF_GraphTensor* t); + +// `t` must outlive `at`. +void TF_AbstractTensorSetEagerTensor(TF_AbstractTensor* at, TFE_TensorHandle* t, + TF_Status* s); +TFE_TensorHandle* TF_AbstractTensorGetEagerTensor(TF_AbstractTensor* at, + TF_Status* s); + +// `t` must outlive `at`. +void TF_AbstractTensorSetGraphTensor(TF_AbstractTensor* at, TF_GraphTensor* t, + TF_Status* s); +TF_GraphTensor* TF_AbstractTensorGetGraphTensor(TF_AbstractTensor* at, + TF_Status* s); + +// TF_OutputList just lets us not specify the number of outputs of an operation +// beforehand. This forces a memory allocation in the runtime, which is bad, but +// it allows for generic code. +typedef struct TF_OutputList TF_OutputList; +TF_OutputList* TF_NewOutputList(); +void TF_DeleteOutputList(TF_OutputList* o); +void TF_OutputListSetNumOutputs(TF_OutputList* o, int, TF_Status*); +int TF_OutputListNumOutputs(TF_OutputList* o); +TF_AbstractTensor* TF_OutputListGet(TF_OutputList* o, int i); + +// TF_ExecuteOperation will, if in eager mode, execute, if in graph mode, maybe +// capture some inputs and then add a node in the graph, and after +// execution/node creation it'll go and record things that happened in any tape +// which happens to be active. +void TF_ExecuteOperation(TF_AbstractOp* op, int num_inputs, + TF_AbstractTensor* const* inputs, TF_OutputList* o, + TF_ExecutionContext* ctx, TF_Status* s); + +#ifdef __cplusplus +} /* end extern "C" */ +#endif + +#endif // TENSORFLOW_C_EAGER_C_API_UNIFIED_EXPERIMENTAL_H_ diff --git a/tensorflow/c/eager/c_api_unified_experimental_test.cc b/tensorflow/c/eager/c_api_unified_experimental_test.cc new file mode 100644 index 00000000000..58b4237e119 --- /dev/null +++ b/tensorflow/c/eager/c_api_unified_experimental_test.cc @@ -0,0 +1,204 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/eager/c_api_unified_experimental.h" + +#include + +#include "tensorflow/c/eager/c_api.h" +#include "tensorflow/c/eager/c_api_test_util.h" +#include "tensorflow/cc/profiler/profiler.h" +#include "tensorflow/core/lib/monitoring/collection_registry.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" + +using tensorflow::string; + +namespace tensorflow { +namespace { + +TEST(UnifedCAPI, TestBasicEager) { + TF_ExecutionContext* ctx = TF_NewExecutionContext(); + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_Context* eager_ctx = TFE_NewContext(opts, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + TFE_DeleteContextOptions(opts); + + // Enter the eager context. + TF_ExecutionContextSetEagerContext(ctx, eager_ctx, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + + // Build an abstract input tensor. + TFE_TensorHandle* t = TestScalarTensorHandle(2.0f); + TF_AbstractTensor* at = TF_NewAbstractTensor(); + TF_AbstractTensorSetEagerTensor(at, t, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + + // Build an abstract operation. + auto* op = TF_NewAbstractOp(); + TF_AbstractOpSetOpType(op, "Add", status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + + // Build inputs and outputs. + TF_AbstractTensor* inputs[2] = {at, at}; + TF_OutputList* o = TF_NewOutputList(); + TF_OutputListSetNumOutputs(o, 1, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + + // Execute. + TF_ExecuteOperation(op, 2, inputs, o, ctx, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + + // Clean up operation and inputs. + TF_DeleteAbstractOp(op); + TF_DeleteAbstractTensor(at); + TFE_DeleteTensorHandle(t); + + // Verify the results. + ASSERT_EQ(1, TF_OutputListNumOutputs(o)); + TF_AbstractTensor* result = TF_OutputListGet(o, 0); + TFE_TensorHandle* result_t = + TF_AbstractTensorGetEagerTensor(result, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + TF_Tensor* result_tensor = TFE_TensorHandleResolve(result_t, status.get()); + float* result_value = static_cast(TF_TensorData(result_tensor)); + EXPECT_EQ(*result_value, 4.0); + + TF_DeleteTensor(result_tensor); + TF_DeleteAbstractTensor(result); + TFE_DeleteTensorHandle(result_t); + TF_DeleteOutputList(o); + TFE_DeleteContext(eager_ctx); + TF_DeleteExecutionContext(ctx); +} + +TEST(UnifedCAPI, TestBasicGraph) { + TF_ExecutionContext* ctx = TF_NewExecutionContext(); + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + + // Enter a graph context. + TF_Graph* g = TF_NewGraph(); + TF_GraphContext* graph_context = TF_NewGraphContext(g); + TF_ExecutionContextSetGraphContext(ctx, graph_context, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + + // Add a placeholder to the graph. + auto* placeholder_op = TF_NewOperation(g, "Placeholder", "Placeholder"); + TF_SetAttrType(placeholder_op, "dtype", TF_FLOAT); + auto* operation = TF_FinishOperation(placeholder_op, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + TF_Output placeholder_t = {operation, 0}; + TF_GraphTensor* graph_t = + TF_NewGraphTensor(graph_context, placeholder_t, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + TF_AbstractTensor* t = TF_NewAbstractTensor(); + TF_AbstractTensorSetGraphTensor(t, graph_t, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + + // Build an abstract operation. + auto* op = TF_NewAbstractOp(); + TF_AbstractOpSetOpType(op, "Add", status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + TF_AbstractOpSetOpName(op, "my_add", status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + + // Build inputs and outputs. + TF_AbstractTensor* inputs[2] = {t, t}; + TF_OutputList* o = TF_NewOutputList(); + + // Execute. + TF_ExecuteOperation(op, 2, inputs, o, ctx, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + + // Clean up operation and inputs. + TF_DeleteAbstractOp(op); + TF_DeleteAbstractTensor(t); + TF_DeleteGraphTensor(graph_t); + + TF_AbstractTensor* result = TF_OutputListGet(o, 0); + TF_GraphTensor* result_graph_tensor = + TF_AbstractTensorGetGraphTensor(result, status.get()); + TF_DeleteAbstractTensor(result); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + TF_Output result_output = + TF_GraphTensorToOutput(result_graph_tensor, status.get()); + TF_DeleteGraphTensor(result_graph_tensor); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + string fn_name = "double"; + TF_Function* f = TF_GraphToFunction( + g, fn_name.c_str(), 0, -1, nullptr, 1, &placeholder_t, 1, &result_output, + nullptr, nullptr, fn_name.c_str(), status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + + // Build an eager context to run the function. + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_Context* eager_ctx = TFE_NewContext(opts, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + TFE_DeleteContextOptions(opts); + + // Build the abstract op to run the function. + TFE_ContextAddFunction(eager_ctx, f, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + TF_AbstractOp* fn_op = TF_NewAbstractOp(); + TF_AbstractOpSetOpType(fn_op, fn_name.c_str(), status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + + // Build an abstract input tensor. + TFE_TensorHandle* input_eager = TestScalarTensorHandle(2.0f); + TF_AbstractTensor* input_t = TF_NewAbstractTensor(); + TF_AbstractTensorSetEagerTensor(input_t, input_eager, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + + // Enter the eager context. + TF_ExecutionContextSetEagerContext(ctx, eager_ctx, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + TF_OutputListSetNumOutputs(o, 1, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + TF_ExecuteOperation(fn_op, 1, &input_t, o, ctx, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + + ASSERT_EQ(1, TF_OutputListNumOutputs(o)); + TF_AbstractTensor* final_result = TF_OutputListGet(o, 0); + TFE_TensorHandle* final = + TF_AbstractTensorGetEagerTensor(final_result, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + TF_Tensor* f_t = TFE_TensorHandleResolve(final, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + float* f_value = static_cast(TF_TensorData(f_t)); + ASSERT_EQ(*f_value, 4.0); + + TF_DeleteOutputList(o); + TF_DeleteAbstractOp(fn_op); + TF_DeleteAbstractTensor(input_t); + TFE_DeleteTensorHandle(input_eager); + TF_DeleteAbstractTensor(final_result); + TFE_DeleteTensorHandle(final); + TF_DeleteTensor(f_t); + TF_DeleteFunction(f); + + TF_DeleteGraphContext(graph_context); + TF_DeleteGraph(g); + TFE_DeleteContext(eager_ctx); + TF_DeleteExecutionContext(ctx); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/c/eager/context_interface.cc b/tensorflow/c/eager/context_interface.cc new file mode 100644 index 00000000000..f190d7aeaf3 --- /dev/null +++ b/tensorflow/c/eager/context_interface.cc @@ -0,0 +1,143 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/eager/context_interface.h" + +#include "tensorflow/c/eager/operation_interface.h" +#include "tensorflow/c/eager/tensor_handle_interface.h" +#include "tensorflow/core/framework/tensor_interface.h" +#include "tensorflow/core/platform/casts.h" + +namespace tensorflow { + +std::unique_ptr ContextInterface::CreateInt64Scalar( + int64 value) { + return std::make_unique(Tensor(value)); +} + +std::unique_ptr ContextInterface::CreateUint64Scalar( + uint64 value) { + return std::make_unique(Tensor(value)); +} + +std::unique_ptr ContextInterface::CreateInt32Scalar( + int32 value) { + return std::make_unique(Tensor(value)); +} + +std::unique_ptr ContextInterface::CreateFloatScalar( + float value) { + return std::make_unique(Tensor(value)); +} + +std::unique_ptr ContextInterface::CreateDoubleScalar( + double value) { + return std::make_unique(Tensor(value)); +} + +std::unique_ptr ContextInterface::CreateHalfScalar( + Eigen::half value) { + return std::make_unique(Tensor(value)); +} + +std::unique_ptr ContextInterface::CreateStringScalar( + tstring value) { + return std::make_unique(Tensor(value)); +} + +std::unique_ptr +ContextInterface::CreateComplex128Scalar(complex128 value) { + return std::make_unique(Tensor(value)); +} + +std::unique_ptr ContextInterface::CreateBoolScalar( + bool value) { + return std::make_unique(Tensor(value)); +} + +std::unique_ptr ContextInterface::CreateInt64Tensor( + absl::Span dim_sizes) { + return std::make_unique( + Tensor(DT_INT64, TensorShape(dim_sizes))); +} + +std::unique_ptr ContextInterface::CreateUint64Tensor( + absl::Span dim_sizes) { + return std::make_unique( + Tensor(DT_UINT64, TensorShape(dim_sizes))); +} + +std::unique_ptr ContextInterface::CreateInt32Tensor( + absl::Span dim_sizes) { + return std::make_unique( + Tensor(DT_INT32, TensorShape(dim_sizes))); +} + +std::unique_ptr ContextInterface::CreateFloatTensor( + absl::Span dim_sizes) { + return std::make_unique( + Tensor(DT_FLOAT, TensorShape(dim_sizes))); +} + +std::unique_ptr ContextInterface::CreateDoubleTensor( + absl::Span dim_sizes) { + return std::make_unique( + Tensor(DT_DOUBLE, TensorShape(dim_sizes))); +} + +std::unique_ptr ContextInterface::CreateHalfTensor( + absl::Span dim_sizes) { + return std::make_unique( + Tensor(DT_HALF, TensorShape(dim_sizes))); +} + +std::unique_ptr ContextInterface::CreateStringTensor( + absl::Span dim_sizes) { + return std::make_unique( + Tensor(DT_STRING, TensorShape(dim_sizes))); +} + +std::unique_ptr +ContextInterface::CreateComplex128Tensor(absl::Span dim_sizes) { + return std::make_unique( + Tensor(DT_COMPLEX128, TensorShape(dim_sizes))); +} + +std::unique_ptr ContextInterface::CreateBoolTensor( + absl::Span dim_sizes) { + return std::make_unique( + Tensor(DT_BOOL, TensorShape(dim_sizes))); +} + +std::unique_ptr +ContextInterface::CreateLocalHandle( + const std::unique_ptr t) { + Tensor tensor = tensorflow::down_cast(t.get())->Tensor(); + return std::make_unique( + TensorHandle::CreateLocalHandle(std::move(tensor), /*d=*/ctx_->HostCPU(), + /*op_device=*/nullptr, ctx_)); +} + +std::unique_ptr +ContextInterface::CreateOperation() { + return std::make_unique(ctx_); +} + +void ContextInterface::ListDevices( + std::vector* devices) { + ctx_->ListDevices(devices); +} + +} // namespace tensorflow diff --git a/tensorflow/c/eager/context_interface.h b/tensorflow/c/eager/context_interface.h new file mode 100644 index 00000000000..665651cc873 --- /dev/null +++ b/tensorflow/c/eager/context_interface.h @@ -0,0 +1,155 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EAGER_CONTEXT_INTERFACE_H_ +#define TENSORFLOW_C_EAGER_CONTEXT_INTERFACE_H_ + +#include + +#include "tensorflow/c/eager/operation_interface.h" +#include "tensorflow/c/eager/tensor_handle_interface.h" +#include "tensorflow/core/framework/numeric_types.h" +#include "tensorflow/core/framework/tensor_interface.h" +#include "tensorflow/core/platform/casts.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/tstring.h" + +namespace tensorflow { + +// Abstract interface to a context. +// +// A context is responsible for creating key objects such as Tensors, +// TensorHandles & Operations. +class AbstractContextInterface { + public: + virtual ~AbstractContextInterface() {} + + // Scalar creation functions + virtual std::unique_ptr CreateInt64Scalar( + int64 value) = 0; + virtual std::unique_ptr CreateUint64Scalar( + uint64 value) = 0; + virtual std::unique_ptr CreateInt32Scalar( + int32 value) = 0; + virtual std::unique_ptr CreateFloatScalar( + float value) = 0; + virtual std::unique_ptr CreateDoubleScalar( + double value) = 0; + virtual std::unique_ptr CreateHalfScalar( + Eigen::half value) = 0; + virtual std::unique_ptr CreateStringScalar( + tstring value) = 0; + virtual std::unique_ptr CreateComplex128Scalar( + complex128 value) = 0; + virtual std::unique_ptr CreateBoolScalar( + bool value) = 0; + + // Tensor creation functions + virtual std::unique_ptr CreateInt64Tensor( + absl::Span dim_sizes) = 0; + virtual std::unique_ptr CreateUint64Tensor( + absl::Span dim_sizes) = 0; + virtual std::unique_ptr CreateInt32Tensor( + absl::Span dim_sizes) = 0; + virtual std::unique_ptr CreateFloatTensor( + absl::Span dim_sizes) = 0; + virtual std::unique_ptr CreateDoubleTensor( + absl::Span dim_sizes) = 0; + virtual std::unique_ptr CreateHalfTensor( + absl::Span dim_sizes) = 0; + virtual std::unique_ptr CreateStringTensor( + absl::Span dim_sizes) = 0; + virtual std::unique_ptr CreateComplex128Tensor( + absl::Span dim_sizes) = 0; + virtual std::unique_ptr CreateBoolTensor( + absl::Span dim_sizes) = 0; + + // Create a handle to wrap and manage a Tensor + virtual std::unique_ptr CreateLocalHandle( + const std::unique_ptr t) = 0; + + // Create an operation to perform op execution + virtual std::unique_ptr CreateOperation() = 0; + + // List attributes of available devices + virtual void ListDevices(std::vector* devices) = 0; +}; + +// TODO(gjn): Try to move these all to EagerContext and make it implement +// AbstractContextInterface. Currently, this is not so straightforward because +// of various BUILD file dependencies. +class ContextInterface : public AbstractContextInterface { + public: + explicit ContextInterface(EagerContext* ctx) : ctx_(ctx) {} + ~ContextInterface() override {} + + std::unique_ptr CreateInt64Scalar( + int64 value) override; + std::unique_ptr CreateUint64Scalar( + uint64 value) override; + std::unique_ptr CreateInt32Scalar( + int32 value) override; + std::unique_ptr CreateFloatScalar( + float value) override; + std::unique_ptr CreateDoubleScalar( + double value) override; + std::unique_ptr CreateHalfScalar( + Eigen::half value) override; + std::unique_ptr CreateStringScalar( + tensorflow::tstring value) override; + std::unique_ptr CreateComplex128Scalar( + tensorflow::complex128 value) override; + std::unique_ptr CreateBoolScalar( + bool value) override; + + std::unique_ptr CreateInt64Tensor( + absl::Span dim_sizes) override; + std::unique_ptr CreateUint64Tensor( + absl::Span dim_sizes) override; + std::unique_ptr CreateInt32Tensor( + absl::Span dim_sizes) override; + std::unique_ptr CreateFloatTensor( + absl::Span dim_sizes) override; + std::unique_ptr CreateDoubleTensor( + absl::Span dim_sizes) override; + std::unique_ptr CreateHalfTensor( + absl::Span dim_sizes) override; + std::unique_ptr CreateStringTensor( + absl::Span dim_sizes) override; + std::unique_ptr CreateComplex128Tensor( + absl::Span dim_sizes) override; + std::unique_ptr CreateBoolTensor( + absl::Span dim_sizes) override; + + std::unique_ptr CreateLocalHandle( + const std::unique_ptr t) override; + std::unique_ptr CreateOperation() override; + + void ListDevices(std::vector* devices) override; + + // For runtime specific APIs, provide ability to get the underlying context. + EagerContext* Context() const { return ctx_; } + + private: + EagerContext* ctx_; +}; + +inline EagerContext* ContextFromInterface( + const std::unique_ptr& context) { + return down_cast(context.get())->Context(); +} + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EAGER_CONTEXT_INTERFACE_H_ diff --git a/tensorflow/c/eager/dlpack.cc b/tensorflow/c/eager/dlpack.cc index fee2154c8dc..1dc3a08afe0 100644 --- a/tensorflow/c/eager/dlpack.cc +++ b/tensorflow/c/eager/dlpack.cc @@ -20,7 +20,6 @@ limitations under the License. #include "tensorflow/c/tf_status_helper.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_reference.h" -#include "tensorflow/core/platform/casts.h" #include "tensorflow/core/platform/logging.h" namespace tensorflow { @@ -47,9 +46,7 @@ const Tensor* GetTensorFromHandle(TFE_TensorHandle* h, TF_Status* status) { return nullptr; } tensorflow::TensorHandle* handle = - tensorflow::down_cast(h->handle.get()) - ->Handle(); - + tensorflow::TensorHandleFromInterface(h->handle); if (handle->IsRemote()) { status->status = tensorflow::errors::InvalidArgument( "DLPack doesn't support remote tensor"); @@ -289,9 +286,8 @@ void* TFE_HandleToDLPack(TFE_TensorHandle* h, TF_Status* status) { return static_cast(dlm_tensor); } -TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm, TF_Status* status) { - TFE_ContextOptions* opts = TFE_NewContextOptions(); - TFE_Context* ctx = TFE_NewContext(opts, status); +TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm, TF_Status* status, + TFE_Context* ctx) { DLManagedTensor* dlmt = static_cast(dlm); DLTensor* dl_tensor = &dlmt->dl_tensor; absl::optional device_name = diff --git a/tensorflow/c/eager/dlpack.h b/tensorflow/c/eager/dlpack.h index 4177af1a6e7..8c85dee62f7 100644 --- a/tensorflow/c/eager/dlpack.h +++ b/tensorflow/c/eager/dlpack.h @@ -30,7 +30,8 @@ TF_CAPI_EXPORT extern void* TFE_HandleToDLPack(TFE_TensorHandle* h, // Converts DLPack (DLManagedTensor*) to eager tensor handle. TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm, - TF_Status* status); + TF_Status* status, + TFE_Context* ctx); // Calls the destructor of DLManagedTensor, used in the destructor of PyCapsule. TF_CAPI_EXPORT extern void TFE_CallDLManagedTensorDeleter(void* dlm_ptr); diff --git a/tensorflow/c/eager/operation_interface.cc b/tensorflow/c/eager/operation_interface.cc index 5703d3231bd..136fdef2de5 100644 --- a/tensorflow/c/eager/operation_interface.cc +++ b/tensorflow/c/eager/operation_interface.cc @@ -26,8 +26,7 @@ limitations under the License. namespace tensorflow { -OperationInterface::OperationInterface(TFE_Context* ctx) - : operation_(ctx->context) {} +OperationInterface::OperationInterface(EagerContext* ctx) : operation_(ctx) {} const string& OperationInterface::DeviceName() const { absl::variant variant_device = @@ -99,9 +98,8 @@ Status OperationInterface::SetAttrFunction( AttrValue attr_value; NameAttrList* func = attr_value.mutable_func(); func->set_name(value->Name()); - OperationInterface* value_operation = - tensorflow::down_cast(value.get()); - value_operation->operation_.Attrs().FillAttrValueMap(func->mutable_attr()); + EagerOperation* value_operation = OperationFromInterface(value); + value_operation->Attrs().FillAttrValueMap(func->mutable_attr()); operation_.MutableAttrs()->Set(attr_name, attr_value); return Status::OK(); } @@ -116,10 +114,9 @@ Status OperationInterface::SetAttrFunctionName(const char* attr_name, return Status::OK(); } -Status OperationInterface::SetAttrTensor(const char* attr_name, - TF_Tensor* tensor) { - Tensor t; - TF_RETURN_IF_ERROR(TF_TensorToTensor(tensor, &t)); +Status OperationInterface::SetAttrTensor( + const char* attr_name, std::unique_ptr tensor) { + Tensor t = TensorFromInterface(tensor); operation_.MutableAttrs()->Set(attr_name, t); return Status::OK(); } @@ -209,11 +206,10 @@ Status OperationInterface::SetAttrFunctionList(const char* attr_name, int num_values) { std::unique_ptr funcs(new NameAttrList[num_values]); for (int i = 0; i < num_values; i++) { - auto value_operation = - tensorflow::down_cast(value[i]->operation.get()); - funcs[i].set_name(value_operation->operation_.Name()); - value_operation->operation_.Attrs().FillAttrValueMap( - funcs[i].mutable_attr()); + EagerOperation* value_operation = + OperationFromInterface(value[i]->operation); + funcs[i].set_name(value_operation->Name()); + value_operation->Attrs().FillAttrValueMap(funcs[i].mutable_attr()); } operation_.MutableAttrs()->Set( attr_name, gtl::ArraySlice(funcs.get(), num_values)); @@ -267,8 +263,7 @@ Status OperationInterface::OutputLength(const char* output_name, int* length) { Status OperationInterface::AddInput( const std::unique_ptr& input) { - TensorHandle* h = - tensorflow::down_cast(input.get())->Handle(); + TensorHandle* h = TensorHandleFromInterface(input); operation_.AddInput(h); return operation_.MaybeInferSingleInputAttrs(h); } @@ -277,8 +272,7 @@ Status OperationInterface::AddInputList( const absl::FixedArray>& inputs) { for (auto& input : inputs) { - TensorHandle* h = - tensorflow::down_cast(input.get())->Handle(); + TensorHandle* h = TensorHandleFromInterface(input); operation_.AddInput(h); } return operation_.InferInputListAttrs(inputs.size()); @@ -297,13 +291,6 @@ Status OperationInterface::Execute( return Status::OK(); } -Status OperationInterface::SetCancellationManager( - TFE_CancellationManager* cancellation_manager) { - operation_.SetCancellationManager( - &cancellation_manager->cancellation_manager); - return Status::OK(); -} - Status OperationInterface::SetUseXla(bool enable) { operation_.SetUseXla(enable); return Status::OK(); diff --git a/tensorflow/c/eager/operation_interface.h b/tensorflow/c/eager/operation_interface.h index 900c5112c08..eb818827a19 100644 --- a/tensorflow/c/eager/operation_interface.h +++ b/tensorflow/c/eager/operation_interface.h @@ -19,9 +19,14 @@ limitations under the License. #include "absl/container/fixed_array.h" #include "tensorflow/c/eager/c_api.h" -#include "tensorflow/c/eager/c_api_experimental.h" #include "tensorflow/c/eager/tensor_handle_interface.h" +#include "tensorflow/c/tf_datatype.h" #include "tensorflow/core/common_runtime/eager/eager_operation.h" +#include "tensorflow/core/framework/tensor_interface.h" +#include "tensorflow/core/platform/casts.h" +#include "tensorflow/core/platform/status.h" + +namespace tensorflow { // Abstract interface to an operation. class AbstractOperationInterface { @@ -29,90 +34,67 @@ class AbstractOperationInterface { virtual ~AbstractOperationInterface() {} virtual void Clear() = 0; - virtual tensorflow::Status Reset(const char* op, - const char* raw_device_name) = 0; + virtual Status Reset(const char* op, const char* raw_device_name) = 0; - virtual const tensorflow::string& Name() const = 0; - virtual const tensorflow::string& DeviceName() const = 0; - virtual tensorflow::Status SetDeviceName(const char* name) = 0; + virtual const string& Name() const = 0; + virtual const string& DeviceName() const = 0; + virtual Status SetDeviceName(const char* name) = 0; - virtual tensorflow::Status AddInput( + virtual Status AddInput( const std::unique_ptr& input) = 0; - virtual tensorflow::Status AddInputList( + virtual Status AddInputList( const absl::FixedArray>& inputs) = 0; - virtual tensorflow::Status Execute( + virtual Status Execute( absl::FixedArray>* retvals, int* num_retvals) = 0; virtual const tensorflow::OpDef* OpDef() const = 0; - virtual tensorflow::Status SetAttrString(const char* attr_name, - const char* data, size_t length) = 0; - virtual tensorflow::Status SetAttrInt(const char* attr_name, - int64_t value) = 0; - virtual tensorflow::Status SetAttrFloat(const char* attr_name, - float value) = 0; - virtual tensorflow::Status SetAttrBool(const char* attr_name, bool value) = 0; - virtual tensorflow::Status SetAttrType(const char* attr_name, - TF_DataType value) = 0; - virtual tensorflow::Status SetAttrShape(const char* attr_name, - const int64_t* dims, - const int num_dims) = 0; - virtual tensorflow::Status SetAttrFunction( + virtual Status SetAttrString(const char* attr_name, const char* data, + size_t length) = 0; + virtual Status SetAttrInt(const char* attr_name, int64_t value) = 0; + virtual Status SetAttrFloat(const char* attr_name, float value) = 0; + virtual Status SetAttrBool(const char* attr_name, bool value) = 0; + virtual Status SetAttrType(const char* attr_name, TF_DataType value) = 0; + virtual Status SetAttrShape(const char* attr_name, const int64_t* dims, + const int num_dims) = 0; + virtual Status SetAttrFunction( const char* attr_name, const std::unique_ptr& value) = 0; - virtual tensorflow::Status SetAttrFunctionName(const char* attr_name, - const char* value, - size_t length) = 0; - virtual tensorflow::Status SetAttrTensor(const char* attr_name, - TF_Tensor* tensor) = 0; - virtual tensorflow::Status SetAttrStringList(const char* attr_name, - const void* const* values, - const size_t* lengths, - int num_values) = 0; - virtual tensorflow::Status SetAttrFloatList(const char* attr_name, - const float* values, - int num_values) = 0; - virtual tensorflow::Status SetAttrIntList(const char* attr_name, - const int64_t* values, - int num_values) = 0; - virtual tensorflow::Status SetAttrTypeList(const char* attr_name, - const TF_DataType* values, - int num_values) = 0; - virtual tensorflow::Status SetAttrBoolList(const char* attr_name, - const unsigned char* values, - int num_values) = 0; - virtual tensorflow::Status SetAttrShapeList(const char* attr_name, - const int64_t** dims, - const int* num_dims, - int num_values) = 0; - virtual tensorflow::Status SetAttrFunctionList(const char* attr_name, - const TFE_Op** value, - int num_values) = 0; + virtual Status SetAttrFunctionName(const char* attr_name, const char* value, + size_t length) = 0; + virtual Status SetAttrTensor( + const char* attr_name, + std::unique_ptr tensor) = 0; + virtual Status SetAttrStringList(const char* attr_name, + const void* const* values, + const size_t* lengths, int num_values) = 0; + virtual Status SetAttrFloatList(const char* attr_name, const float* values, + int num_values) = 0; + virtual Status SetAttrIntList(const char* attr_name, const int64_t* values, + int num_values) = 0; + virtual Status SetAttrTypeList(const char* attr_name, + const TF_DataType* values, int num_values) = 0; + virtual Status SetAttrBoolList(const char* attr_name, + const unsigned char* values, + int num_values) = 0; + virtual Status SetAttrShapeList(const char* attr_name, const int64_t** dims, + const int* num_dims, int num_values) = 0; + virtual Status SetAttrFunctionList(const char* attr_name, + const TFE_Op** value, int num_values) = 0; - virtual tensorflow::Status InputLength(const char* input_name, - int* length) = 0; - virtual tensorflow::Status OutputLength(const char* output_name, - int* length) = 0; + virtual Status InputLength(const char* input_name, int* length) = 0; + virtual Status OutputLength(const char* output_name, int* length) = 0; // Experimental - virtual tensorflow::Status SetUseXla(bool enable) { - return tensorflow::errors::Unimplemented("SetUseXla not implemented"); - } - virtual tensorflow::Status SetCancellationManager( - TFE_CancellationManager* cancellation_manager) { - return tensorflow::errors::Unimplemented( - "SetCancellationManager not implemented"); - } + virtual Status SetUseXla(bool enable) = 0; }; -namespace tensorflow { - class OpDef; class OperationInterface : public AbstractOperationInterface { public: - explicit OperationInterface(TFE_Context* ctx); + explicit OperationInterface(EagerContext* ctx); ~OperationInterface() override{}; void Clear() override { operation_.Clear(); } @@ -149,7 +131,9 @@ class OperationInterface : public AbstractOperationInterface { const std::unique_ptr& value) override; Status SetAttrFunctionName(const char* attr_name, const char* data, size_t length) override; - Status SetAttrTensor(const char* attr_name, TF_Tensor* tensor) override; + Status SetAttrTensor( + const char* attr_name, + std::unique_ptr tensor) override; Status SetAttrStringList(const char* attr_name, const void* const* values, const size_t* lengths, int num_values) override; Status SetAttrFloatList(const char* attr_name, const float* values, @@ -169,20 +153,25 @@ class OperationInterface : public AbstractOperationInterface { Status OutputLength(const char* output_name, int* length) override; Status SetUseXla(bool enable) override; - Status SetCancellationManager( - TFE_CancellationManager* cancellation_manager) override; // TODO(gjn): Remove once TFE_InferShapes is removed - const tensorflow::AttrBuilder& Attrs() const { return operation_.Attrs(); } - tensorflow::AttrBuilder* MutableAttrs() { return operation_.MutableAttrs(); } + const AttrBuilder& Attrs() const { return operation_.Attrs(); } + AttrBuilder* MutableAttrs() { return operation_.MutableAttrs(); } const TensorHandle* GetInput(int i) const { return operation_.Inputs()[i]; } + EagerOperation* Operation() { return &operation_; } + private: const tensorflow::OpDef* GetOpDef(Status* status); EagerOperation operation_; }; +inline EagerOperation* OperationFromInterface( + const std::unique_ptr& operation) { + return down_cast(operation.get())->Operation(); +} + } // namespace tensorflow #endif // TENSORFLOW_C_EAGER_OPERATION_INTERFACE_H_ diff --git a/tensorflow/c/eager/tensor_handle_interface.h b/tensorflow/c/eager/tensor_handle_interface.h index 9008550b2c6..6d73ff33b8b 100644 --- a/tensorflow/c/eager/tensor_handle_interface.h +++ b/tensorflow/c/eager/tensor_handle_interface.h @@ -15,10 +15,13 @@ limitations under the License. #ifndef TENSORFLOW_C_EAGER_TENSOR_HANDLE_INTERFACE_H_ #define TENSORFLOW_C_EAGER_TENSOR_HANDLE_INTERFACE_H_ -#include "tensorflow/c/c_api.h" -#include "tensorflow/c/eager/c_api.h" #include "tensorflow/c/tf_datatype.h" #include "tensorflow/core/common_runtime/eager/tensor_handle.h" +#include "tensorflow/core/framework/tensor_interface.h" +#include "tensorflow/core/platform/casts.h" +#include "tensorflow/core/platform/status.h" + +namespace tensorflow { // Abstract interface to a TensorHandle. // @@ -34,24 +37,22 @@ class AbstractTensorHandleInterface { virtual ~AbstractTensorHandleInterface() {} // Check if the handle is in a valid initialized state. - virtual bool IsValid(tensorflow::Status* status) const = 0; + virtual bool IsValid(Status* status) const = 0; // Returns tensor dtype. virtual TF_DataType DataType() const = 0; // Returns number of dimensions. - virtual int NumDims(tensorflow::Status* status) const = 0; + virtual int NumDims(Status* status) const = 0; // Returns number of elements across all dimensions. - virtual int64_t NumElements(tensorflow::Status* status) const = 0; + virtual int64_t NumElements(Status* status) const = 0; // Returns size of specified dimension - virtual int64_t Dim(int dim_index, tensorflow::Status* status) const = 0; + virtual int64_t Dim(int dim_index, Status* status) const = 0; // Returns the device which created the handle. - virtual const char* DeviceName(tensorflow::Status* status) const = 0; + virtual const char* DeviceName(Status* status) const = 0; // Returns the device where the tensor was placed. - virtual const char* BackingDeviceName(tensorflow::Status* status) const = 0; + virtual const char* BackingDeviceName(Status* status) const = 0; // Returns a tensor for the handle. If tensor is remote, it will be copied. - virtual TF_Tensor* Resolve(tensorflow::Status* status) = 0; - // Returns debug information about the tensor. - virtual TFE_TensorDebugInfo* TensorDebugInfo(tensorflow::Status* status) = 0; + virtual std::unique_ptr Resolve(Status* status) = 0; // Return a copy of the handle. virtual AbstractTensorHandleInterface* Copy() = 0; @@ -65,8 +66,9 @@ class AbstractTensorHandleInterface { virtual void EnableImplicitMirroring() = 0; }; -namespace tensorflow { - +// TODO(gjn): Try to move these all to TensorHandle and make it implement +// AbstractTensorHandleInterface. Currently, this is not so straightforward +// because of various BUILD file dependencies. class TensorHandleInterface : public AbstractTensorHandleInterface { public: explicit TensorHandleInterface(TensorHandle* h) : handle_(h) {} @@ -80,21 +82,24 @@ class TensorHandleInterface : public AbstractTensorHandleInterface { const char* DeviceName(Status* status) const override; const char* BackingDeviceName(Status* status) const override; - TF_Tensor* Resolve(Status* status) override; - TFE_TensorDebugInfo* TensorDebugInfo(Status* status) override; + std::unique_ptr Resolve(Status* status) override; AbstractTensorHandleInterface* Copy() override; void EnableImplicitMirroring() override; - // TODO(gjn): This is not a very generic interface, but is needed for specific - // use cases. + // For runtime specific APIs, provide ability to get the underlying handle. TensorHandle* Handle() { return handle_; } private: TensorHandle* handle_; }; +inline TensorHandle* TensorHandleFromInterface( + const std::unique_ptr& handle) { + return down_cast(handle.get())->Handle(); +} + } // namespace tensorflow #endif // TENSORFLOW_C_EAGER_TENSOR_HANDLE_INTERFACE_H_ diff --git a/tensorflow/c/tf_status.cc b/tensorflow/c/tf_status.cc index 3144f2c1900..8db826627b4 100644 --- a/tensorflow/c/tf_status.cc +++ b/tensorflow/c/tf_status.cc @@ -16,8 +16,8 @@ limitations under the License. #include "tensorflow/c/tf_status.h" #include "tensorflow/c/tf_status_internal.h" -#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/error.h" +#include "tensorflow/core/platform/status.h" using ::tensorflow::IOError; using ::tensorflow::Status; diff --git a/tensorflow/c/tf_status_helper.h b/tensorflow/c/tf_status_helper.h index 7e52049445b..ff8085f1229 100644 --- a/tensorflow/c/tf_status_helper.h +++ b/tensorflow/c/tf_status_helper.h @@ -17,7 +17,7 @@ limitations under the License. #define TENSORFLOW_C_TF_STATUS_HELPER_H_ #include "tensorflow/c/tf_status.h" -#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/status.h" namespace tensorflow { diff --git a/tensorflow/c/tf_status_helper_test.cc b/tensorflow/c/tf_status_helper_test.cc index ab945ccfd00..60780d74b21 100644 --- a/tensorflow/c/tf_status_helper_test.cc +++ b/tensorflow/c/tf_status_helper_test.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/c/tf_status_helper.h" -#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { diff --git a/tensorflow/c/tf_status_internal.h b/tensorflow/c/tf_status_internal.h index 66ca9938f0c..1e0f99819ff 100644 --- a/tensorflow/c/tf_status_internal.h +++ b/tensorflow/c/tf_status_internal.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_C_TF_STATUS_INTERNAL_H_ #define TENSORFLOW_C_TF_STATUS_INTERNAL_H_ -#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/status.h" // Internal structures used by the status C API. These are likely to change // and should not be depended on. diff --git a/tensorflow/c/tf_tensor.cc b/tensorflow/c/tf_tensor.cc index 4e75beceb3e..03833368102 100644 --- a/tensorflow/c/tf_tensor.cc +++ b/tensorflow/c/tf_tensor.cc @@ -381,7 +381,7 @@ Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst) { ->ToTensor(dst); } -Status TensorInterface::ToTensor(Tensor* dst) const { +Status TensorInterface::ToTensor(tensorflow::Tensor* dst) const { if (tensor_.dtype() == DT_RESOURCE) { if (tensor_.dims() != 0) { return InvalidArgument( @@ -389,7 +389,7 @@ Status TensorInterface::ToTensor(Tensor* dst) const { "shape ", tensor_.shape().DebugString()); } - *dst = Tensor(tensorflow::DT_RESOURCE, tensor_.shape()); + *dst = tensorflow::Tensor(tensorflow::DT_RESOURCE, tensor_.shape()); if (!dst->scalar()().ParseFromString( string(static_cast(Data()), ByteSize()))) { return InvalidArgument( @@ -414,7 +414,7 @@ Status TensorInterface::ToTensor(Tensor* dst) const { const char* data_start = input + sizeof(tensorflow::uint64) * num_elements; const char* limit = input + src_size; - *dst = Tensor(tensor_.dtype(), tensor_.shape()); + *dst = tensorflow::Tensor(tensor_.dtype(), tensor_.shape()); auto dstarray = dst->flat(); for (tensorflow::int64 i = 0; i < num_elements; ++i) { tensorflow::uint64 offset = diff --git a/tensorflow/c/tf_tensor_internal.h b/tensorflow/c/tf_tensor_internal.h index 08a55f26a83..2d31418fd18 100644 --- a/tensorflow/c/tf_tensor_internal.h +++ b/tensorflow/c/tf_tensor_internal.h @@ -31,7 +31,7 @@ limitations under the License. // passed to or returned from C functions *by pointer*. Otherwise, changes to // its internal structure will break the C API's binary interface. typedef struct TF_Tensor { - std::unique_ptr tensor; + std::unique_ptr tensor; } TF_Tensor; class TF_ManagedBuffer : public tensorflow::TensorBuffer { diff --git a/tensorflow/cc/saved_model/loader.h b/tensorflow/cc/saved_model/loader.h index 5ebf5c9d47f..2b2e44bc619 100644 --- a/tensorflow/cc/saved_model/loader.h +++ b/tensorflow/cc/saved_model/loader.h @@ -124,7 +124,7 @@ Status LoadSavedModel(const SessionOptions& session_options, /// the export directory definitely does not contain a SavedModel. If the method /// returns `true`, the export directory may contain a SavedModel but provides /// no guarantee that it can be loaded. -bool MaybeSavedModelDirectory(const string& export_dir); +bool MaybeSavedModelDirectory(const std::string& export_dir); } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/BUILD b/tensorflow/compiler/mlir/BUILD index b776ee77493..4d07b8d26c1 100644 --- a/tensorflow/compiler/mlir/BUILD +++ b/tensorflow/compiler/mlir/BUILD @@ -115,11 +115,10 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow:dump_mlir_util", "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags", "//tensorflow/core:core_cpu", - "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/strings", "@llvm-project//llvm:support", "@llvm-project//mlir:IR", + "@llvm-project//mlir:StandardOps", ], alwayslink = 1, ) diff --git a/tensorflow/compiler/mlir/glob_lit_test.bzl b/tensorflow/compiler/mlir/glob_lit_test.bzl index 422fca3308e..d69560220f2 100644 --- a/tensorflow/compiler/mlir/glob_lit_test.bzl +++ b/tensorflow/compiler/mlir/glob_lit_test.bzl @@ -12,7 +12,7 @@ load("@bazel_skylib//lib:paths.bzl", "paths") _default_test_file_exts = ["mlir", ".pbtxt", ".td"] _default_driver = "@llvm-project//mlir:run_lit.sh" _default_size = "small" -_default_tags = ["no_rocm"] +_default_tags = [] # These are patterns which we should never match, for tests, subdirectories, or # test input data files. diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD index 61919204f9a..32a977416ae 100644 --- a/tensorflow/compiler/mlir/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/BUILD @@ -17,8 +17,7 @@ package_group( name = "friends", includes = ["//third_party/mlir:subpackages"], packages = [ - "//learning/brain/experimental/mlir/...", - "//learning/brain/google/xla/...", + "//learning/brain/mlir/...", "//tensorflow/compiler/mlir/...", ], ) @@ -205,8 +204,6 @@ cc_library( cc_library( name = "tensorflow_lite", srcs = [ - "experimental/estimators/estimator.h", - "experimental/estimators/gpu_estimator.h.inc", "ir/tfl_ops.cc", "ir/tfl_ops.cc.inc", "ir/tfl_ops.h.inc", @@ -216,7 +213,6 @@ cc_library( "utils/attribute_utils.cc", ], hdrs = [ - "experimental/estimators/hardware.h", "ir/tfl_ops.h", "transforms/passes.h", "utils/attribute_utils.h", @@ -226,6 +222,7 @@ cc_library( deps = [ ":tensorflow_lite_ops_inc_gen", ":validators", + "//tensorflow/compiler/mlir/lite/experimental/estimators:cost_estimators", "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", "//tensorflow/lite/schema:schema_fbs", "@llvm-project//llvm:support", diff --git a/tensorflow/compiler/mlir/lite/experimental/estimators/BUILD b/tensorflow/compiler/mlir/lite/experimental/estimators/BUILD new file mode 100644 index 00000000000..79ee35f83fc --- /dev/null +++ b/tensorflow/compiler/mlir/lite/experimental/estimators/BUILD @@ -0,0 +1,15 @@ +package( + default_visibility = [ + "//visibility:public", + ], + licenses = ["notice"], # Apache 2.0 +) + +cc_library( + name = "cost_estimators", + textual_hdrs = [ + "estimator.h", + "gpu_estimators.h", + "hardware.h", + ], +) diff --git a/tensorflow/compiler/mlir/lite/experimental/estimators/gpu_estimator.h.inc b/tensorflow/compiler/mlir/lite/experimental/estimators/gpu_estimator.h.inc deleted file mode 100644 index 819056dcc91..00000000000 --- a/tensorflow/compiler/mlir/lite/experimental/estimators/gpu_estimator.h.inc +++ /dev/null @@ -1,72 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_GPU_ESTIMATOR_H_ -#define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_GPU_ESTIMATOR_H_ - -// tfl.average_pool_2d -template <> -class TFLiteCostEstimator { - public: - static double GetCost(mlir::Operation* op) { - llvm::errs() << "No defined cost function for op: " - << op->getName().getStringRef().str(); - return 0.0; - } - - static bool IsSupported(mlir::Operation* op) { return true; } -}; - -// tfl.conv_2d -template <> -class TFLiteCostEstimator { - public: - static double GetCost(mlir::Operation* op) { - llvm::errs() << "No defined cost function for op: " - << op->getName().getStringRef().str(); - return 0.0; - } - - // TODO(renjieliu): We probably need to check for dynamic weights. - static bool IsSupported(mlir::Operation* op) { return true; } -}; - -// tfl.depthwise_conv_2d -template <> -class TFLiteCostEstimator { - public: - static double GetCost(mlir::Operation* op) { - llvm::errs() << "No defined cost function for op: " - << op->getName().getStringRef().str(); - return 0.0; - } - - static bool IsSupported(mlir::Operation* op) { return true; } -}; - -// tfl.max_pool_2d -template <> -class TFLiteCostEstimator { - public: - static double GetCost(mlir::Operation* op) { - llvm::errs() << "No defined cost function for op: " - << op->getName().getStringRef().str(); - return 0.0; - } - - static bool IsSupported(mlir::Operation* op) { return true; } -}; - -#endif // TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_GPU_ESTIMATOR_H_ diff --git a/tensorflow/compiler/mlir/lite/experimental/estimators/gpu_estimators.h b/tensorflow/compiler/mlir/lite/experimental/estimators/gpu_estimators.h new file mode 100644 index 00000000000..8581187be70 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/experimental/estimators/gpu_estimators.h @@ -0,0 +1,231 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_GPU_ESTIMATORS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_GPU_ESTIMATORS_H_ + +// tfl.add +template <> +class TFLiteCostEstimator { + public: + static double GetCost(mlir::Operation* op) { + llvm::errs() << "No defined cost function for op: " + << op->getName().getStringRef().str(); + return 0.0; + } + + static bool IsSupported(mlir::Operation* op) { return true; } +}; + +// tfl.average_pool_2d +template <> +class TFLiteCostEstimator { + public: + static double GetCost(mlir::Operation* op) { + llvm::errs() << "No defined cost function for op: " + << op->getName().getStringRef().str(); + return 0.0; + } + + static bool IsSupported(mlir::Operation* op) { return true; } +}; + +// tfl.concatenation +template <> +class TFLiteCostEstimator { + public: + static double GetCost(mlir::Operation* op) { + llvm::errs() << "No defined cost function for op: " + << op->getName().getStringRef().str(); + return 0.0; + } + + // TODO(renjieliu): We probably need to check for dynamic weights. + static bool IsSupported(mlir::Operation* op) { return true; } +}; + +// tfl.conv_2d +template <> +class TFLiteCostEstimator { + public: + static double GetCost(mlir::Operation* op) { + llvm::errs() << "No defined cost function for op: " + << op->getName().getStringRef().str(); + return 0.0; + } + + // TODO(renjieliu): We probably need to check for dynamic weights. + static bool IsSupported(mlir::Operation* op) { return true; } +}; + +// tfl.depthwise_conv_2d +template <> +class TFLiteCostEstimator { + public: + static double GetCost(mlir::Operation* op) { + llvm::errs() << "No defined cost function for op: " + << op->getName().getStringRef().str(); + return 0.0; + } + + static bool IsSupported(mlir::Operation* op) { return true; } +}; + +// tfl.fully_connected +template <> +class TFLiteCostEstimator { + public: + static double GetCost(mlir::Operation* op) { + llvm::errs() << "No defined cost function for op: " + << op->getName().getStringRef().str(); + return 0.0; + } + + // TODO(renjieliu): we need to check for dynamic weights. + static bool IsSupported(mlir::Operation* op) { return true; } +}; + +// tfl.logistic +template <> +class TFLiteCostEstimator { + public: + static double GetCost(mlir::Operation* op) { + llvm::errs() << "No defined cost function for op: " + << op->getName().getStringRef().str(); + return 0.0; + } + + static bool IsSupported(mlir::Operation* op) { return true; } +}; + +// tfl.max_pool_2d +template <> +class TFLiteCostEstimator { + public: + static double GetCost(mlir::Operation* op) { + llvm::errs() << "No defined cost function for op: " + << op->getName().getStringRef().str(); + return 0.0; + } + + static bool IsSupported(mlir::Operation* op) { return true; } +}; + +// tfl.mirror_pad +template <> +class TFLiteCostEstimator { + public: + static double GetCost(mlir::Operation* op) { + llvm::errs() << "No defined cost function for op: " + << op->getName().getStringRef().str(); + return 0.0; + } + + static bool IsSupported(mlir::Operation* op) { return true; } +}; + +// tfl.maximum +template <> +class TFLiteCostEstimator { + public: + static double GetCost(mlir::Operation* op) { + llvm::errs() << "No defined cost function for op: " + << op->getName().getStringRef().str(); + return 0.0; + } + + static bool IsSupported(mlir::Operation* op) { return true; } +}; + +// tfl.minimum +template <> +class TFLiteCostEstimator { + public: + static double GetCost(mlir::Operation* op) { + llvm::errs() << "No defined cost function for op: " + << op->getName().getStringRef().str(); + return 0.0; + } + + static bool IsSupported(mlir::Operation* op) { return true; } +}; + +// tfl.mul +template <> +class TFLiteCostEstimator { + public: + static double GetCost(mlir::Operation* op) { + llvm::errs() << "No defined cost function for op: " + << op->getName().getStringRef().str(); + return 0.0; + } + + static bool IsSupported(mlir::Operation* op) { return true; } +}; + +// tfl.relu +template <> +class TFLiteCostEstimator { + public: + static double GetCost(mlir::Operation* op) { + llvm::errs() << "No defined cost function for op: " + << op->getName().getStringRef().str(); + return 0.0; + } + + static bool IsSupported(mlir::Operation* op) { return true; } +}; + +// tfl.relu6 +template <> +class TFLiteCostEstimator { + public: + static double GetCost(mlir::Operation* op) { + llvm::errs() << "No defined cost function for op: " + << op->getName().getStringRef().str(); + return 0.0; + } + + static bool IsSupported(mlir::Operation* op) { return true; } +}; + +// tfl.reshape +template <> +class TFLiteCostEstimator { + public: + static double GetCost(mlir::Operation* op) { + llvm::errs() << "No defined cost function for op: " + << op->getName().getStringRef().str(); + return 0.0; + } + + static bool IsSupported(mlir::Operation* op) { return true; } +}; + +// tfl.softmax +template <> +class TFLiteCostEstimator { + public: + static double GetCost(mlir::Operation* op) { + llvm::errs() << "No defined cost function for op: " + << op->getName().getStringRef().str(); + return 0.0; + } + + static bool IsSupported(mlir::Operation* op) { return true; } +}; + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_GPU_ESTIMATORS_H_ + diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.h b/tensorflow/compiler/mlir/lite/ir/tfl_ops.h index baef9a41e3a..42ac0af48d0 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.h +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.h @@ -54,7 +54,7 @@ class TensorFlowLiteDialect : public Dialect { #define GET_OP_CLASSES #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h.inc" // Include all specializes estimators below this line -#include "tensorflow/compiler/mlir/lite/experimental/estimators/gpu_estimator.h.inc" +#include "tensorflow/compiler/mlir/lite/experimental/estimators/gpu_estimators.h" } // end namespace TFL } // end namespace mlir diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td index d12a1e28908..7226f68cc90 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td @@ -139,8 +139,7 @@ def TFL_Uint8 : UI<8>; def TFL_Int32Or64 : SignlessIntOfWidths<[32, 64]>; def TFL_BoolTensor : TFL_TensorOf<[I1]>; -def TFL_FpOrI32OrI64Tensor : TFL_TensorOf<[AnyFloat, TFL_Int32Or64]>; -def TFL_FpTensor : TFL_TensorOf<[AnyFloat]>; +def TFL_FpTensor : TFL_TensorOf<[F32]>; def TFL_I32OrI64Tensor : TFL_TensorOf<[TFL_Int32Or64]>; def TFL_I32Tensor : TFL_TensorOf<[I32]>; def TFL_I64Tensor : TFL_TensorOf<[I64]>; @@ -324,9 +323,9 @@ class TFL_ConvOp : }]; let arguments = ( - ins AnyTensor:$input, - AnyTensor:$filter, - TFL_TensorOfOrNone<[AnyType]>:$bias, + ins TFL_TensorOf<[F32, QI8, QUI8]>:$input, + TFL_TensorOf<[F32, QI8, QUI8]>:$filter, + TFL_TensorOfOrNone<[F32, I32]>:$bias, I32Attr:$dilation_h_factor, I32Attr:$dilation_w_factor, TFL_AFAttr:$fused_activation_function, @@ -335,7 +334,7 @@ class TFL_ConvOp : I32Attr:$stride_w ); - let results = (outs AnyTensor:$output); + let results = (outs TFL_TensorOf<[F32, QI8, QUI8]>:$output); let hasOptions = 0b1; } @@ -361,7 +360,10 @@ an output element, this operation computes \\(y = |x|\\). let hasFolder = 1; } -def TFL_AddOp : TFL_Op<"add", [ResultsBroadcastableShape, NoSideEffect, Commutative]> { +def TFL_AddOp : TFL_Op<"add", [ResultsBroadcastableShape, + NoSideEffect, + Commutative, + TFL_GpuTargetOp]> { let summary = "Addition operator"; let description = [{ @@ -394,11 +396,11 @@ def TFL_AddNOp : TFL_Op<"add_n", [Commutative, NoSideEffect, SameOperandsAndResu }]; let arguments = (ins - TFL_VariadicTensorOf<[F32, I32, QI16, QUI16]>:$inputs + TFL_VariadicTensorOf<[F32, I32]>:$inputs ); let results = (outs - TFL_TensorOf<[F32, I32, QI16, QUI16]>:$sum + TFL_TensorOf<[F32, I32]>:$sum ); } @@ -492,7 +494,7 @@ def TFL_AveragePool2DOp: }]; let arguments = ( - ins AnyTensor:$input, + ins TFL_TensorOf<[F32, QI8, QUI8]>:$input, I32Attr:$filter_height, I32Attr:$filter_width, TFL_PaddingAttr:$padding, @@ -501,7 +503,7 @@ def TFL_AveragePool2DOp: TFL_AFAttr:$fused_activation_function ); - let results = (outs AnyTensor:$output); + let results = (outs TFL_TensorOf<[F32, QI8, QUI8]>:$output); let hasOptions = 1; let customOption = "Pool2DOptions"; @@ -577,7 +579,8 @@ def TFL_ConcatenationOp : TFL_Op<"concatenation", NoSideEffect, PredOpTrait<"values and output must have same element type", TFL_TCresVTEtIsSameAsOp<0, 0>>, - SameOperandsAndResultsScale + SameOperandsAndResultsScale, + TFL_GpuTargetOp ]> { let summary = "Concatenation operator"; @@ -719,7 +722,18 @@ def TFL_CosOp: TFL_Op<"cos", [ def TFL_DepthwiseConv2DOp : TFL_ConvOp<"depthwise_conv_2d", "Depthwise-separable convolution", 3> { - let arguments = !con(TFL_Conv2DOp.arguments, (ins I32Attr:$depth_multiplier)); + let arguments = ( + ins TFL_TensorOf<[F32, QI8, QUI8, QI16]>:$input, + TFL_TensorOf<[F32, QI8, QUI8]>:$filter, + TFL_TensorOfOrNone<[F32, I32, I64]>:$bias, + I32Attr:$dilation_h_factor, + I32Attr:$dilation_w_factor, + TFL_AFAttr:$fused_activation_function, + TFL_PaddingAttr:$padding, + I32Attr:$stride_h, + I32Attr:$stride_w, + I32Attr:$depth_multiplier + ); let extraClassDeclaration = [{ // ChannelDimIndexInterface: @@ -741,7 +755,8 @@ def TFL_FullyConnectedOp : TFL_Op<"fully_connected", [ NoSideEffect, AccumulatorUniformScale<2, 0, 1>, TFL_ChannelDimIndexInterface, AffineOpCoefficient<-1, 1>, - TFL_SparseOp]> { + TFL_SparseOp, + TFL_GpuTargetOp]> { let summary = "Fully connected op"; let arguments = (ins @@ -1091,6 +1106,8 @@ def TFL_NotEqualOp : TFL_Op<"not_equal", [ } def TFL_DivOp : TFL_Op<"div", [ + // TODO(fengliuai): NoQuantizableResult is only correct for int8 + // quantization. update to handle Uint8 quantization. ResultsBroadcastableShape, NoSideEffect, NoQuantizableResult]> { let summary = "Division operator"; @@ -1099,11 +1116,11 @@ def TFL_DivOp : TFL_Op<"div", [ }]; let arguments = ( - ins AnyTensor:$lhs, - AnyTensor:$rhs, + ins TFL_TensorOf<[F32, I32, QUI8]>:$lhs, + TFL_TensorOf<[F32, I32, TFL_Uint8]>:$rhs, TFL_AFAttr:$fused_activation_function); - let results = (outs AnyTensor:$output); + let results = (outs TFL_TensorOf<[F32, I32, TFL_Uint8]>:$output); let builders = [TFL_FusedBroadcastableBinaryBuilder]; @@ -1126,7 +1143,7 @@ def TFL_EluOp: TFL_Op<"elu", [NoSideEffect, SameOperandsAndResultType]> { let arguments = (ins TFL_FpTensor:$x); - let results = (outs AnyTensor:$y); + let results = (outs TFL_FpTensor:$y); let hasOptions = 0; } @@ -1487,16 +1504,17 @@ def TFL_LogisticOp: TFL_Op<"logistic", [ // zero_point = 0 // scale = 1. / (max_value + 1) FixedResultScale>, - FixedResultScale>]> { + FixedResultScale>, + TFL_GpuTargetOp]> { let summary = "Logistic operator"; let description = [{ Computes element-wise Sigmoid of input }]; - let arguments = (ins TFL_TensorOf<[AnyFloat, QI8, QUI8, QI16, QUI16]>:$x); + let arguments = (ins TFL_TensorOf<[F32, QI8, QUI8, QI16, QUI16]>:$x); - let results = (outs TFL_TensorOf<[AnyFloat, QI8, QUI8, QI16, QUI16]>:$y); + let results = (outs TFL_TensorOf<[F32, QI8, QUI8, QI16, QUI16]>:$y); } def TFL_LogOp: TFL_Op<"log", [ @@ -1639,19 +1657,23 @@ def TFL_MaxUnpooling2DOp : } def TFL_MaximumOp : TFL_Op<"maximum", [ - ResultsBroadcastableShape, NoSideEffect, Commutative, SameOperandsAndResultsScale]> { + ResultsBroadcastableShape, + NoSideEffect, + Commutative, + SameOperandsAndResultsScale, + TFL_GpuTargetOp]> { let summary = "Max operator"; let description = [{ Element-wise max operation. }]; let arguments = ( - ins TFL_TensorOf<[AnyFloat, TFL_Int32Or64, QI8, QUI8]>:$lhs, - TFL_TensorOf<[AnyFloat, TFL_Int32Or64, QI8, QUI8]>:$rhs + ins TFL_TensorOf<[F32, TFL_Int32Or64, QI8, QUI8]>:$lhs, + TFL_TensorOf<[F32, TFL_Int32Or64, QI8, QUI8]>:$rhs ); let results = (outs - TFL_TensorOf<[AnyFloat, TFL_Int32Or64, QI8, QUI8]>:$max + TFL_TensorOf<[F32, TFL_Int32Or64, QI8, QUI8]>:$max ); let builders = [TFL_BroadcastableBinaryBuilder]; @@ -1837,19 +1859,23 @@ def TFL_ReduceProdOp: TFL_Op<"reduce_prod", [NoSideEffect]> { } def TFL_MinimumOp : TFL_Op<"minimum", [ - ResultsBroadcastableShape, NoSideEffect, Commutative, SameOperandsAndResultsScale]> { + ResultsBroadcastableShape, + NoSideEffect, + Commutative, + SameOperandsAndResultsScale, + TFL_GpuTargetOp]> { let summary = "Min operator"; let description = [{ Element-wise min operation. }]; let arguments = ( - ins TFL_TensorOf<[AnyFloat, TFL_Int32Or64, QI8, QUI8]>:$lhs, - TFL_TensorOf<[AnyFloat, TFL_Int32Or64, QI8, QUI8]>:$rhs + ins TFL_TensorOf<[F32, TFL_Int32Or64, QI8, QUI8]>:$lhs, + TFL_TensorOf<[F32, TFL_Int32Or64, QI8, QUI8]>:$rhs ); let results = (outs - TFL_TensorOf<[AnyFloat, TFL_Int32Or64, QI8, QUI8]>:$min + TFL_TensorOf<[F32, TFL_Int32Or64, QI8, QUI8]>:$min ); let builders = [TFL_BroadcastableBinaryBuilder]; @@ -1857,7 +1883,10 @@ def TFL_MinimumOp : TFL_Op<"minimum", [ let hasOptions = 0; } -def TFL_MulOp : TFL_Op<"mul", [ResultsBroadcastableShape, NoSideEffect, Commutative]> { +def TFL_MulOp : TFL_Op<"mul", [ResultsBroadcastableShape, + NoSideEffect, + Commutative, + TFL_GpuTargetOp]> { let summary = "Multiplication operator"; let description = [{ @@ -2090,7 +2119,8 @@ def TFL_RankOp: TFL_Op<"rank", [NoSideEffect]> { def TFL_ReluOp: TFL_Op<"relu", [NoSideEffect, SameOperandsAndResultShape, - SameOperandsAndResultsScale]> { + SameOperandsAndResultsScale, + TFL_GpuTargetOp]> { let summary = "Relu operator"; let description = [{ @@ -2105,7 +2135,8 @@ def TFL_ReluOp: TFL_Op<"relu", [NoSideEffect, def TFL_Relu6Op: TFL_Op<"relu6", [NoSideEffect, SameOperandsAndResultShape, - SameOperandsAndResultsScale]> { + SameOperandsAndResultsScale, + TFL_GpuTargetOp]> { let summary = "Relu6 operator"; let description = [{ @@ -2134,7 +2165,7 @@ def TFL_Relu1Op: TFL_Op<"relu_n1_to_1", [NoSideEffect, } def TFL_ReshapeOp: TFL_Op<"reshape", [ - NoSideEffect, SameOperandsAndResultsScale]> { + NoSideEffect, SameOperandsAndResultsScale, TFL_GpuTargetOp]> { let summary = "Reshape operator"; let description = [{ @@ -2351,7 +2382,8 @@ def TFL_SoftmaxOp : TFL_Op<"softmax", [ // zero_point = 0 // scale = 1. / (max_value + 1) FixedResultScale>, - FixedResultScale>]> { + FixedResultScale>, + TFL_GpuTargetOp]> { let summary = "Softmax operator"; let description = [{ @@ -2882,7 +2914,7 @@ def TFL_CastOp : TFL_Op<"cast", [ TFL_TensorOf<[F32, I1, I32, I64, TFL_Quint8, TFL_Uint8, Complex>]>:$input ); - let results = (outs TFL_TensorOf<[F32, I1, I32, I64, Complex>]>:$output); + let results = (outs TFL_TensorOf<[F32, I1, I32, I64, TFL_Quint8, TFL_Uint8, Complex>]>:$output); // TFLite's cast op does not utilize CastOptions, instead derives types // from the TfLiteTensors. @@ -2891,7 +2923,7 @@ def TFL_CastOp : TFL_Op<"cast", [ def TFL_MirrorPadOp: TFL_Op<"mirror_pad", [ - NoSideEffect, TFL_OperandHasRank<1, 2>]> { + NoSideEffect, TFL_OperandHasRank<1, 2>, TFL_GpuTargetOp]> { let summary = "MirrorPad Operator. Pads a tensor with mirrored values."; let description = [{ @@ -3400,7 +3432,7 @@ def TFL_BidirectionalSequenceLSTMOp : let summary = "Bidirectional sequence lstm operator"; let description = [{ - Bidirectional lstm is essentiallay two lstms, one running forward & the + Bidirectional lstm is essentially two lstms, one running forward & the other running backward. And the output is the concatenation of the two lstms. }]; diff --git a/tensorflow/compiler/mlir/lite/quantization/BUILD b/tensorflow/compiler/mlir/lite/quantization/BUILD index 7d5e6e43e82..a75135cf3b5 100644 --- a/tensorflow/compiler/mlir/lite/quantization/BUILD +++ b/tensorflow/compiler/mlir/lite/quantization/BUILD @@ -111,3 +111,31 @@ tf_native_cc_binary( "@llvm-project//mlir:TableGen", ], ) + +cc_library( + name = "device_target", + srcs = ["device_target.cc"], + hdrs = ["device_target.h"], + deps = [ + "@llvm-project//llvm:support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:QuantOps", + "@llvm-project//mlir:Support", + ], +) + +cc_library( + name = "quantization_context", + srcs = ["quantization_context.cc"], + hdrs = ["quantization_context.h"], + deps = [ + ":device_target", + ":quantization_lib", + "@com_google_absl//absl/memory", + "@llvm-project//llvm:support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:QuantOps", + "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:Support", + ], +) diff --git a/tensorflow/compiler/mlir/lite/quantization/device_target.cc b/tensorflow/compiler/mlir/lite/quantization/device_target.cc new file mode 100644 index 00000000000..b1d72017657 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/quantization/device_target.cc @@ -0,0 +1,82 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/lite/quantization/device_target.h" + +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project +#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project + +namespace mlir { +namespace quant { + +constexpr int k8Bits = 8; +constexpr unsigned kSigned = quant::QuantizationFlags::Signed; + +DeviceTarget::DeviceTarget(MLIRContext* ctx) : ctx_(ctx) { + f32_ = FloatType::getF32(ctx_); + i8_ = IntegerType::get(k8Bits, ctx_); + i8_min_ = QuantizedType::getDefaultMinimumForInteger(kSigned, k8Bits); + i8_max_ = QuantizedType::getDefaultMaximumForInteger(kSigned, k8Bits); + any_ = AnyQuantizedType(); + qi8_ = AnyQuantizedType::get(kSigned, i8_, f32_, i8_min_, i8_max_); + qi8n_ = AnyQuantizedType::get(kSigned, i8_, f32_, i8_min_ + 1, i8_max_); + assert(qi8n_ == qi8n_); +} + +Optional DeviceTarget::Get(QuantizeRegionOp op) const { + auto kernel_specs_it = specs_.find(op.logical_kernel()); + if (kernel_specs_it == specs_.end()) return llvm::None; + + KernelSpecs::Signature signature; + signature.reserve(op.input_specs().size() + op.output_specs().size()); + AppendToSignature(op.input_specs(), &signature); + AppendToSignature(op.output_specs(), &signature); + return kernel_specs_it->getValue().Find(signature); +} + +LogicalResult DeviceTarget::RegisterKernel( + llvm::StringRef kernel, const KernelSpecs::Signature& signature, + const ScaleFn& fn) { + return specs_[kernel].Add(signature, {ScaleConstraintType::CustomScale, fn}); +} + +LogicalResult DeviceTarget::RegisterKernel( + llvm::StringRef kernel, const KernelSpecs::Signature& signature, + const ScaleConstraintType constraint) { + return specs_[kernel].Add(signature, {constraint, {}}); +} + +void DeviceTarget::AppendToSignature(ArrayAttr specs_attr, + KernelSpecs::Signature* signature) const { + for (auto attr : specs_attr) { + Type spec = attr.cast().getValue(); + if (auto quant = spec.dyn_cast()) { + signature->push_back(AnyQuantizedType::get( + quant.getFlags(), quant.getStorageType(), quant.getExpressedType(), + quant.getStorageTypeMin(), quant.getStorageTypeMax())); + } else if (auto any = spec.dyn_cast()) { + signature->push_back(any); + } else { // float + signature->push_back({}); + } + } +} + +} // namespace quant +} // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/quantization/device_target.h b/tensorflow/compiler/mlir/lite/quantization/device_target.h new file mode 100644 index 00000000000..ee5f1fe7a4c --- /dev/null +++ b/tensorflow/compiler/mlir/lite/quantization/device_target.h @@ -0,0 +1,147 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_DEVICE_TARGET_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_DEVICE_TARGET_H_ + +#include +#include + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/Hashing.h" +#include "llvm/ADT/MapVector.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringMap.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project +#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project + +namespace mlir { +namespace quant { + +class QuantizeContext; + +using AdjacentOperations = llvm::SmallVectorImpl; +using ScaleFn = std::function; + +enum class ScaleConstraintType { + OutputInputSameScale, + OutputInputFreeScale, + CustomScale, +}; + +// Each kernel signature has its own specification for scales. +struct KernelSpec { + // Scale constraint + ScaleConstraintType type; + + // Custom function to derive the scales. Only available when the scale + // constraint is `CustomScale`. + ScaleFn scale_fn; +}; + +class KernelSpecs { + public: + using Signature = llvm::SmallVector; + + // Returns the kernel specification for the kernel signature. + Optional Find(const Signature& signature) const { + auto spec_it = all_signatures_.find(signature); + if (spec_it != all_signatures_.end()) { + return spec_it->second; + } else { + return llvm::None; + } + } + + // Adds the kernel signature with the kernel specification. + LogicalResult Add(const Signature& signature, const KernelSpec& spec) { + if (all_signatures_.insert({signature, spec}).second) return success(); + return failure(); + } + + private: + // The signature is pattern match based. + struct SignatureInfo : public llvm::DenseMapInfo { + static inline Signature getEmptyKey() { return {}; } + static inline Signature getTombstoneKey() { return {nullptr}; } + static unsigned getHashValue(Signature val) { + return llvm::hash_combine_range(val.begin(), val.end()); + } + static bool isEqual(Signature LHS, Signature RHS) { + if (RHS == getEmptyKey()) return LHS == getEmptyKey(); + if (RHS == getTombstoneKey()) return LHS == getTombstoneKey(); + if (LHS.size() != RHS.size()) return false; + for (auto arg : llvm::zip(LHS, RHS)) { + if (std::get<0>(arg) != std::get<1>(arg)) return false; + } + return true; + } + }; + + // Maps the signature to the kernel spec. Note that the matching is + // pattern match based. + llvm::DenseMap all_signatures_; +}; + +class DeviceTarget { + public: + explicit DeviceTarget(MLIRContext* ctx); + + // Retrieves the kernel spec for the quant region op. + Optional Get(quant::QuantizeRegionOp op) const; + + protected: + // Adds the kernel spec with the custom scale function for the kernel. + LogicalResult RegisterKernel(llvm::StringRef kernel, + const KernelSpecs::Signature& signature, + const ScaleFn& fn); + + // Adds the kernel spec with the scale constraint type for the kernel. + LogicalResult RegisterKernel(llvm::StringRef kernel, + const KernelSpecs::Signature& signature, + const ScaleConstraintType constraint); + + // converts specification to signature: + // - UniformedQuantizedType -> AnyQuantizedType + // - AnyQuantizedType (int) -> AnyQuantizedType + // - Float -> {} + void AppendToSignature(ArrayAttr specs_attr, + KernelSpecs::Signature* signature) const; + + // A set of parameters are required to build the signatures. + FloatType f32_; + IntegerType i8_; + int64_t i8_min_, i8_max_; + AnyQuantizedType any_, qi8_, qi8n_; + + private: + // Maps the kernel names to all the available kernels. + llvm::StringMap specs_; + + // Points to the global MLIRContext. + MLIRContext* ctx_; +}; + +} // namespace quant +} // namespace mlir +#endif // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_DEVICE_TARGET_H_ diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc index 6eb72dab2fc..9b49757fd3f 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc +++ b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc @@ -34,11 +34,11 @@ limitations under the License. namespace mlir { namespace lite { -// TODO(fengliuai): check the result for `allow_float` flag. +// TODO(fengliuai): check the result for `fully_quantize` flag. TfLiteStatus QuantizeModel( const tflite::ModelT& input_model, const tflite::TensorType& input_type, const tflite::TensorType& output_type, - const std::unordered_set& operator_names, bool allow_float, + const std::unordered_set& operator_names, bool fully_quantize, flatbuffers::FlatBufferBuilder* builder, tflite::ErrorReporter* error_reporter) { // TODO(b/142502494): remove this restriction by improving the `emit_adaptor` diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.h b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.h index 0e040570ee6..473e97e07df 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.h +++ b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.h @@ -27,11 +27,11 @@ namespace lite { // Quantize the `input_model` and write the result to a flatbuffer `builder`. // The `input_type` and `output_type` can be float32/qint8/int8. -// Return partially quantized model if `allow_float` is true. +// Return partially quantized model if `fully_quantize` is false. TfLiteStatus QuantizeModel( const tflite::ModelT& input_model, const tflite::TensorType& input_type, const tflite::TensorType& output_type, - const std::unordered_set& operator_names, bool allow_float, + const std::unordered_set& operator_names, bool fully_quantize, flatbuffers::FlatBufferBuilder* builder, tflite::ErrorReporter* error_reporter); } // namespace lite diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/tfl_quantizer.cc b/tensorflow/compiler/mlir/lite/quantization/lite/tfl_quantizer.cc index 0138e0e8276..7530cdf008f 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/tfl_quantizer.cc +++ b/tensorflow/compiler/mlir/lite/quantization/lite/tfl_quantizer.cc @@ -47,7 +47,7 @@ TfLiteStatus QuantizeAnnotatedModel(llvm::StringRef buffer, tflite::StderrReporter error_reporter; return mlir::lite::QuantizeModel( *model, tflite::TensorType_INT8, tflite::TensorType_INT8, {}, - /*allow_float=*/false, builder, &error_reporter); + /*fully_quantize=*/true, builder, &error_reporter); } } // namespace diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_context.cc b/tensorflow/compiler/mlir/lite/quantization/quantization_context.cc new file mode 100644 index 00000000000..85a988a9bde --- /dev/null +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_context.cc @@ -0,0 +1,239 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/lite/quantization/quantization_context.h" + +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project +#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Matchers.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/quantization/device_target.h" +#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" + +#define DEBUG_TYPE "quantization-context" + +namespace mlir { +namespace quant { + +QuantizeContext::QuantizeContext(FuncOp func, const DeviceTarget &spec) + : func_(func), target_spec_(spec) { + llvm::DenseMap value_to_state; + func.walk([&](quant::QuantizeRegionOp op) { + for (int i = 0, e = op.getNumOperands(); i != e; ++i) { + states_manager_.InitializeOperandState(op, i, &value_to_state); + } + + for (int res = 0, e = op.getNumResults(); res != e; ++res) { + states_manager_.InitializeResultState(op, res, &value_to_state); + } + }); +} + +llvm::ArrayRef QuantizeContext::GetAllOps() { + llvm::SmallVector all_ops; + func_.walk([&](quant::QuantizeRegionOp op) { all_ops.push_back(op); }); + return all_ops; +} + +LogicalResult QuantizeContext::Handle( + quant::QuantizeRegionOp op, llvm::SmallVectorImpl *new_items, + bool *changed) { + auto spec = target_spec_.Get(op); + if (!spec.hasValue()) { + op.emitWarning( + "Couldn't find kernel from the registeration for quantization."); + return success(); + } + switch (spec->type) { + case ScaleConstraintType::OutputInputFreeScale: { + // no propagation. + *changed = false; + break; + } + case ScaleConstraintType::CustomScale: { + if (failed(spec->scale_fn(this, op, new_items, changed))) { + return failure(); + } + break; + } + default: { + llvm_unreachable("no implementation."); + return failure(); + } + } + return success(); +} + +LogicalResult QuantizeContext::Finalize() { + MLIRContext *context = func_.getContext(); + func_.walk([&](quant::QuantizeRegionOp op) { + llvm::SmallVector input_specs; + auto original_input_specs = op.input_specs().getValue(); + for (int i = 0, e = op.getNumOperands(); i != e; ++i) { + auto &state = states_manager_.GetOperandQuantState(op, i); + auto &requantize = states_manager_.GetOperandRequantizeState(op, i); + if (state.IsEmpty() && requantize.pos == RequantizeState::NO_REQUANTIZE) { + input_specs.push_back(original_input_specs[i]); + } else if (requantize.pos == RequantizeState::ON_OUTPUT) { + input_specs.push_back(TypeAttr::get(requantize.params)); + } else { + input_specs.push_back(TypeAttr::get(state.params)); + } + } + op.setAttr("input_specs", ArrayAttr::get(input_specs, context)); + + llvm::SmallVector output_specs; + auto original_output_specs = op.output_specs().getValue(); + for (int res = 0, e = op.getNumResults(); res != e; ++res) { + auto &state = states_manager_.GetResultQuantState(op, res); + auto &requantize = states_manager_.GetResultRequantizeState(op, res); + if (state.IsEmpty() && requantize.pos == RequantizeState::NO_REQUANTIZE) { + output_specs.push_back(original_output_specs[res]); + } else if (requantize.pos == RequantizeState::ON_INPUT) { + output_specs.push_back(TypeAttr::get(requantize.params)); + } else { + output_specs.push_back(TypeAttr::get(state.params)); + } + } + op.setAttr("output_specs", ArrayAttr::get(output_specs, context)); + }); + return success(); +} + +void QuantizeContext::DumpStates(QuantizeRegionOp current_op) { + if (current_op) { + llvm::errs() << "\n\n\n" << current_op.logical_kernel() << "\n"; + } + func_.walk([&](QuantizeRegionOp op) { + if (current_op == op) llvm::errs() << "===>>>"; + llvm::errs() << op.logical_kernel() << " : ("; + for (auto i = 0; i < op.getNumOperands(); ++i) { + if (auto params = GetOperandParams(op, i)) + params.print(llvm::errs()); + else + llvm::errs() << "_"; + llvm::errs() << ","; + } + llvm::errs() << ") -> ("; + for (auto i = 0; i < op.getNumResults(); ++i) { + if (auto params = GetResultParams(op, i)) + params.print(llvm::errs()); + else + llvm::errs() << "_"; + llvm::errs() << ","; + } + llvm::errs() << ")\n"; + }); +} + +int QuantizeContext::StatesManager::InitializeState(quant::QuantizeRegionOp op, + int index, bool as_result) { + Attribute params_attr; + if (as_result) { + params_attr = op.output_specs()[index]; + } else { + params_attr = op.input_specs()[index]; + } + QuantParams params = + params_attr.cast().getValue().dyn_cast(); + bool immutable = !EmptyParams(params); + int next_state_index = states_.size(); + states_.push_back({params, immutable}); + if (as_result) { + result_states_.insert({{op, index}, next_state_index}); + } else { + operand_states_.insert({{op, index}, next_state_index}); + } + return next_state_index; +} + +void QuantizeContext::StatesManager::InitializeOperandState( + quant::QuantizeRegionOp op, int index, llvm::DenseMap *cache) { + Value in = op.getOperand(index); + auto cached = cache->insert({in, 0}); + if (!cached.second) { + operand_states_.insert({{op, index}, cached.first->second}); + return; + } + cached.first->second = InitializeState(op, index, /*as_result=*/false); +} + +void QuantizeContext::StatesManager::InitializeResultState( + quant::QuantizeRegionOp op, int index, llvm::DenseMap *cache) { + auto res = op.getResult(index); + auto cached = cache->insert({res, 0}); + if (!cached.second) { + result_states_.insert({{op, index}, cached.first->second}); + return; + } + cached.first->second = InitializeState(op, index, /*as_result=*/true); +} + +bool QuantizeContext::StatesManager::SetConstantResultParams(Operation *op) { + llvm_unreachable("no implementation."); + return false; +} + +bool QuantizeContext::StatesManager::SetResultParams(Operation *op, + int res_index, + QuantParams params) { + auto &state = GetResultQuantState(op, res_index); + if (state.params == params) { + return false; + } + if (!state.IsEmpty()) { + auto &rescale = GetResultRequantizeState(op, res_index); + rescale.params = params; + rescale.pos = RequantizeState::ON_INPUT; + return false; + } + state.params = params; + return true; +} + +bool QuantizeContext::StatesManager::SetOperandParams(Operation *op, int index, + QuantParams params) { + auto &state = GetOperandQuantState(op, index); + if (state.params == params) { + return false; + } + + if (!state.IsEmpty()) { + auto &rescale = GetOperandRequantizeState(op, index); + rescale.params = params; + rescale.pos = RequantizeState::ON_OUTPUT; + return false; + } + state.params = params; + return true; +} +} // namespace quant +} // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_context.h b/tensorflow/compiler/mlir/lite/quantization/quantization_context.h new file mode 100644 index 00000000000..35ed1feaaab --- /dev/null +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_context.h @@ -0,0 +1,217 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_CONTEXT_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_CONTEXT_H_ + +#include "llvm/ADT/DenseMap.h" +#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project +#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/quantization/device_target.h" +#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" + +namespace mlir { +namespace quant { + +static bool EmptyParams(QuantParams p) { return p == quant::QuantizedType(); } + +// The state for each op result during the quantization parameters propagation. +struct QuantState { + // Quantization parameters propagated to an op result. + QuantParams params; + // A flag indicates this state (the params) shouldn't be changed after it is + // initialized. This flag will be set to true if the quantization parameters + // are from the quantization-aware training. + const bool immutable; + + bool IsEmpty() { return EmptyParams(params); } +}; + +// The state for rescaling the propagated quantization parameters. This can be +// on the input side to satisfy the constraint of previous operation, or on the +// output side to satisfy the constraint of the next operation. +struct RequantizeState { + // Sometimes, we have to "requantize" the quantization result to satisfy all + // the constraints. The "requantize" can happen either on the input or output + // of the quantization result. + enum RequantizePosition { + NO_REQUANTIZE, + ON_INPUT, + ON_OUTPUT + } pos = NO_REQUANTIZE; + + // Quantization parameters will be used to add the requantize ops. + QuantParams params; +}; + +// This class manages all the intermedaite quantization states. +class QuantizeContext { + public: + QuantizeContext(FuncOp func, const DeviceTarget &spec); + + // Returns all the quant region ops. + ArrayRef GetAllOps(); + + // For each quant region op, propagates its quantization parameters according + // to the kernel specification and also returns the adjcent quant region ops + // which get the new quantization parameters propagated. + LogicalResult Handle(quant::QuantizeRegionOp op, + llvm::SmallVectorImpl *new_items, + bool *changed); + + // Updates the port quantization specifications of all the quant region ops + // with the propagation results. + LogicalResult Finalize(); + + // Dumps the states stores in the state manager. + void DumpStates(QuantizeRegionOp current_op = {}); + + // Update the quantization parameter for certain result of the op. By this + // method, the quantization parameter is propagated to all the users of the + // result as well. + bool SetResultParams(Operation *op, int index, QuantParams params) { + return states_manager_.SetResultParams(op, index, params); + } + + // Update the quantization parameter for certain operand of the op. By this + // method, the quantization parameter is propagated to the defining op of + // operand as well. + bool SetOperandParams(Operation *op, int index, QuantParams params) { + return states_manager_.SetOperandParams(op, index, params); + } + + // Return the quantization parameter of certain result of the op. + QuantParams GetResultParams(Operation *op, int index) { + return states_manager_.GetResultParams(op, index); + } + + // Return the quantization parameter of certain operand of the op. + QuantParams GetOperandParams(Operation *op, int index) { + return states_manager_.GetOperandParams(op, index); + } + + private: + class StatesManager { + public: + // Sets the quantization parameters of the constant result according to its + // content. + // + // Always returns true. + bool SetConstantResultParams(Operation *op); + + // Sets the quantization parameters of the result to a fixed value. If any + // quantization parameters have been propagated, a `requantize` will happen + // on the input of propagated quantization. + // + // Returns true, if the users of the result needs to be added to the + // worklist. + bool SetResultParams(Operation *op, int index, QuantParams params); + + // Sets the quantization parameters of the operand to a fixed value. If any + // quantization parameters have been propagated, a `requantize` will happen + // on the output of propagated quantization. + // + // Returns true, if the defining op of the operand needs to be added to the + // worklist. + bool SetOperandParams(Operation *op, int index, QuantParams params); + + // Returns the quantization parameters of the index-th result of the op. + QuantParams GetResultParams(Operation *op, int index) { + return states_[result_states_[{op, index}]].params; + } + + // Returns the quantization parameters of the index-th operand of the op. + QuantParams GetOperandParams(Operation *op, int index) { + return states_[operand_states_[{op, index}]].params; + } + + private: + friend class QuantizeContext; + + // Uses the type of `val` to set the initial state of the index-th result if + // `as_result` is true or index-th operand if `as_result` is false. The + // state is immutable if the type is a quantized type. Returns the index of + // this new state in the state vector. + int InitializeState(quant::QuantizeRegionOp op, int index, bool as_result); + + // Sets the state of the index-th operand of the op. If this operand is + // cached, uses the cached result without creating new entry in the state + // vector. Otherwise, allocate a new entry in the state vector. + void InitializeOperandState(quant::QuantizeRegionOp op, int index, + llvm::DenseMap *cache); + + // Sets the state of the index-th result of the op. If this result is + // cached, uses the cached result without creating new entry in the state + // vector. Otherwise, allocate a new entry in the state vector. + void InitializeResultState(quant::QuantizeRegionOp op, int index, + llvm::DenseMap *cache); + + // Returns the state of the index-th operand of the op. + QuantState &GetOperandQuantState(Operation *op, int index) { + return states_[operand_states_[{op, index}]]; + } + + // Returns the state of the index-th result of the op. + QuantState &GetResultQuantState(Operation *op, int index) { + return states_[result_states_[{op, index}]]; + } + + // Returns the state of the index-th operand of the op. + RequantizeState &GetOperandRequantizeState(Operation *op, int index) { + return rescale_states_[operand_states_[{op, index}]]; + } + + // Returns the state of the index-th result of the op. + RequantizeState &GetResultRequantizeState(Operation *op, int index) { + return rescale_states_[result_states_[{op, index}]]; + } + + private: + // This is used to identify an operand or result of an op. The second + // element of this pair is the index of the operand or result. + using OpValue = std::pair; + + // The vector contains all the quantization parameters propagated from the + // defining operations of the value, or from the quantization aware + // training. + std::vector states_; + + // The map contains all the quantization parameters which are required to + // satisfy the same operands and results constraint. The keys of this map + // are the values from `operand_states_` and `result_state_`. + std::unordered_map rescale_states_; + + // Maps of indexes to the propagation state vector from the ops operands, + // results and arguments. + llvm::DenseMap operand_states_; + llvm::DenseMap result_states_; + }; + + FuncOp func_; + + DeviceTarget target_spec_; + + StatesManager states_manager_; +}; + +} // namespace quant +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_CONTEXT_H_ diff --git a/tensorflow/compiler/mlir/lite/quantization/xla/BUILD b/tensorflow/compiler/mlir/lite/quantization/xla/BUILD index 2ce36709e9c..2bc1568eb17 100644 --- a/tensorflow/compiler/mlir/lite/quantization/xla/BUILD +++ b/tensorflow/compiler/mlir/lite/quantization/xla/BUILD @@ -33,7 +33,9 @@ cc_library( "passes.h", ], deps = [ + ":cpu_device_target", "//tensorflow/compiler/mlir/lite/quantization:quantization_config", + "//tensorflow/compiler/mlir/lite/quantization:quantization_context", "//tensorflow/compiler/mlir/lite/quantization:quantization_lib", "//tensorflow/compiler/mlir/xla:hlo", "//tensorflow/compiler/xla/client/lib:quantize", @@ -49,6 +51,23 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "cpu_device_target", + srcs = [ + "cpu_device_target.cc", + ], + hdrs = [ + "cpu_device_target.h", + ], + deps = [ + "//tensorflow/compiler/mlir/lite/quantization:device_target", + "//tensorflow/compiler/mlir/lite/quantization:quantization_context", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:QuantOps", + "@llvm-project//mlir:Support", + ], +) + cc_library( name = "quantize", srcs = [ diff --git a/tensorflow/compiler/mlir/lite/quantization/xla/cpu_device_target.cc b/tensorflow/compiler/mlir/lite/quantization/xla/cpu_device_target.cc new file mode 100644 index 00000000000..e4bdafa89ff --- /dev/null +++ b/tensorflow/compiler/mlir/lite/quantization/xla/cpu_device_target.cc @@ -0,0 +1,59 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/lite/quantization/xla/cpu_device_target.h" + +#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/quantization/quantization_context.h" + +namespace mlir { +namespace xla_hlo { + +namespace ph = std::placeholders; + +CpuDeviceTarget::CpuDeviceTarget(MLIRContext* ctx) : DeviceTarget(ctx) { + RegisterKernel("generic.concat", {qi8_, qi8_, qi8_}, + quant::ScaleConstraintType::OutputInputSameScale); + RegisterKernel("generic.mul", {qi8_, qi8_, qi8_}, + quant::ScaleConstraintType::OutputInputFreeScale); + RegisterKernel("generic.mul_add", {qi8_, qi8n_, any_, qi8_}, + std::bind(&CpuDeviceTarget::HandleMultiplyAccumulateScale, + this, ph::_1, ph::_2, ph::_3, ph::_4)); + RegisterKernel("generic.matmul_add", {qi8_, qi8n_, any_, qi8_}, + std::bind(&CpuDeviceTarget::HandleMultiplyAccumulateScale, + this, ph::_1, ph::_2, ph::_3, ph::_4)); +} + +LogicalResult CpuDeviceTarget::HandleMultiplyAccumulateScale( + quant::QuantizeContext* ctx, Operation* op, + quant::AdjacentOperations* new_items, bool* changed) { + auto bias_params = ctx->GetOperandParams(op, 2); + if (!EmptyParams(bias_params)) { + return success(); + } + std::vector op_types{ctx->GetOperandParams(op, 0), + ctx->GetOperandParams(op, 1)}; + auto bias_scale = GetUniformQuantizedTypeForBias(op_types); + if (bias_scale && ctx->SetOperandParams(op, 2, bias_scale)) { + *changed = true; + new_items->push_back(op->getOperand(2).getDefiningOp()); + } + return success(); +} + +} // namespace xla_hlo +} // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/quantization/xla/cpu_device_target.h b/tensorflow/compiler/mlir/lite/quantization/xla/cpu_device_target.h new file mode 100644 index 00000000000..a2b05fb6a00 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/quantization/xla/cpu_device_target.h @@ -0,0 +1,40 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_XLA_CPU_DEVICE_TARGET_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_XLA_CPU_DEVICE_TARGET_H_ + +#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/quantization/device_target.h" + +namespace mlir { +namespace xla_hlo { + +// Target specs for cpu kernels +class CpuDeviceTarget : public quant::DeviceTarget { + public: + explicit CpuDeviceTarget(MLIRContext* ctx); + + private: + LogicalResult HandleMultiplyAccumulateScale( + quant::QuantizeContext* ctx, Operation* op, + quant::AdjacentOperations* new_items, bool* changed); +}; + +} // namespace xla_hlo +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_XLA_CPU_DEVICE_TARGET_H_ diff --git a/tensorflow/compiler/mlir/lite/quantization/xla/propagate.cc b/tensorflow/compiler/mlir/lite/quantization/xla/propagate.cc index 4087eeb3c09..c4c5904209c 100644 --- a/tensorflow/compiler/mlir/lite/quantization/xla/propagate.cc +++ b/tensorflow/compiler/mlir/lite/quantization/xla/propagate.cc @@ -26,7 +26,9 @@ limitations under the License. #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h" +#include "tensorflow/compiler/mlir/lite/quantization/quantization_context.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" +#include "tensorflow/compiler/mlir/lite/quantization/xla/cpu_device_target.h" // NOLINTNEXTLINE static llvm::cl::opt disable_per_channel( @@ -59,9 +61,36 @@ struct PropagateQuantPass : public FunctionPass { void PropagateQuantPass::runOnFunction() { FuncOp func = getFunction(); + // TODO(fengliuai): deprecate this old code generation path. // XLA only support uint8/uint16 quantization for now. ApplyQuantizationParamsPropagation(func, /*is_signed*/ false, disable_per_channel, GetOpQuantSpec); + + CpuDeviceTarget spec(&getContext()); + quant::QuantizeContext ctx(func, spec); + + std::vector work_list(ctx.GetAllOps()); + bool changed = false; + while (!work_list.empty()) { + quant::QuantizeRegionOp op = work_list.back(); + work_list.pop_back(); + + llvm::SmallVector new_items; + if (failed(ctx.Handle(op, &new_items, &changed))) { + // The IR is still valid, thus we shouldn't fail. + signalPassFailure(); + } + for (auto item : new_items) { + if (auto reg = llvm::dyn_cast_or_null(item)) + work_list.push_back(reg); + } + } + + if (!changed) return; + + if (failed(ctx.Finalize())) { + signalPassFailure(); + } } } // namespace diff --git a/tensorflow/compiler/mlir/lite/quantization/xla/tests/region_propagation.mlir b/tensorflow/compiler/mlir/lite/quantization/xla/tests/region_propagation.mlir new file mode 100644 index 00000000000..05ac48c9f39 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/quantization/xla/tests/region_propagation.mlir @@ -0,0 +1,54 @@ +// RUN: tf-opt -xla-hlo-propagate-quant %s | FileCheck %s --dump-input-on-failure + +// ----- + +// CHECK-LABEL: @mul_add_source_no_params +func @mul_add_source_no_params(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>) { + %region = "quant.region"(%arg0, %arg1, %arg2) ( { + ^bb0(%arg3: tensor<4xf32>, %arg4: tensor<4xf32>, %arg5: tensor<4xf32>): // no predecessors + %mul = xla_hlo.multiply %arg3, %arg4 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<4xf32> + %add = xla_hlo.add %mul, %arg5 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<4xf32> + "quant.return"(%add) : (tensor<4xf32>) -> () + }) {input_specs = [f32, f32, f32], logical_kernel = "generic.mul_add", output_specs = [f32]} : + (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + return %region : tensor<4xf32> + +// CHECK: input_specs = [f32, f32, f32] +// CHECK-SAME: output_specs = [f32] +} + +// ----- + +// CHECK-LABEL: @mul_add_annotated_no_narrow_range +func @mul_add_annotated_no_narrow_range(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>) { + %region = "quant.region"(%arg0, %arg1, %arg2) ( { + ^bb0(%arg3: tensor<4xf32>, %arg4: tensor<4xf32>, %arg5: tensor<4xf32>): // no predecessors + %mul = xla_hlo.multiply %arg3, %arg4 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<4xf32> + %add = xla_hlo.add %mul, %arg5 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<4xf32> + "quant.return"(%add) : (tensor<4xf32>) -> () + }) {input_specs = [!quant.uniform, !quant.uniform, f32], + logical_kernel = "generic.mul_add", output_specs = [!quant.uniform]} : + (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + return %region : tensor<4xf32> + +// CHECK: input_specs = [!quant.uniform, !quant.uniform, f32] +// CHECK-SAME: output_specs = [!quant.uniform] +} + +// ----- + +// CHECK-LABEL: @mul_add_annotated +func @mul_add_annotated(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>) { + %region = "quant.region"(%arg0, %arg1, %arg2) ( { + ^bb0(%arg3: tensor<4xf32>, %arg4: tensor<4xf32>, %arg5: tensor<4xf32>): // no predecessors + %mul = xla_hlo.multiply %arg3, %arg4 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<4xf32> + %add = xla_hlo.add %mul, %arg5 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<4xf32> + "quant.return"(%add) : (tensor<4xf32>) -> () + }) {input_specs = [!quant.uniform, !quant.uniform:f32, 1.0:-128>, f32], + logical_kernel = "generic.mul_add", output_specs = [!quant.uniform]} : + (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + return %region : tensor<4xf32> + +// CHECK: input_specs = [!quant.uniform, !quant.uniform:f32, 1.000000e+00:-128>, !quant.uniform] +// CHECK-SAME: output_specs = [!quant.uniform] +} diff --git a/tensorflow/compiler/mlir/lite/tests/BUILD b/tensorflow/compiler/mlir/lite/tests/BUILD index a4ebc997991..0d612cec961 100644 --- a/tensorflow/compiler/mlir/lite/tests/BUILD +++ b/tensorflow/compiler/mlir/lite/tests/BUILD @@ -5,6 +5,11 @@ package(licenses = ["notice"]) glob_lit_tests( data = [":test_utilities"], driver = "@llvm-project//mlir:run_lit.sh", + tags_override = { + "legalize-tf.mlir": ["no_rocm"], + "optimize.mlir": ["no_rocm"], + "prepare-tf.mlir": ["no_rocm"], + }, test_file_exts = ["mlir"], ) diff --git a/tensorflow/compiler/mlir/lite/tests/debuginfo/v1_1.0_224_frozen.wrong_attr.line.part.pbtxt b/tensorflow/compiler/mlir/lite/tests/debuginfo/v1_1.0_224_frozen.wrong_attr.line.part.pbtxt index b16ea3fa584..6b01e99da2a 100644 --- a/tensorflow/compiler/mlir/lite/tests/debuginfo/v1_1.0_224_frozen.wrong_attr.line.part.pbtxt +++ b/tensorflow/compiler/mlir/lite/tests/debuginfo/v1_1.0_224_frozen.wrong_attr.line.part.pbtxt @@ -1,6 +1,6 @@ # RUN: not tf_tfl_translate -mlir-pretty-debuginfo -tf-input-arrays=input -tf-input-data-types=DT_FLOAT -tf-input-shapes=1,224,224,3 -tf-output-arrays=MobilenetV1/MobilenetV1/Conv2d_0/BatchNorm/FusedBatchNorm -tf-debug-info=%s.debug %s -o - 2>&1 | FileCheck %s -# CHECK: fake/user/code/file_C.py: error: 'tf.Conv2D' op attribute 'data_format' failed to satisfy constraint: 'NHWC' or 'NCHW' convnet data format +# CHECK: fake/user/code/file_C.py:27:0: error: 'tf.Conv2D' op attribute 'data_format' failed to satisfy constraint: 'NHWC' or 'NCHW' convnet data format node { name: "input" diff --git a/tensorflow/compiler/mlir/lite/tests/debuginfo/v1_1.0_224_frozen.wrong_attr.stack.part.pbtxt b/tensorflow/compiler/mlir/lite/tests/debuginfo/v1_1.0_224_frozen.wrong_attr.stack.part.pbtxt index f2d8f6762cd..9a676c196ce 100644 --- a/tensorflow/compiler/mlir/lite/tests/debuginfo/v1_1.0_224_frozen.wrong_attr.stack.part.pbtxt +++ b/tensorflow/compiler/mlir/lite/tests/debuginfo/v1_1.0_224_frozen.wrong_attr.stack.part.pbtxt @@ -1,9 +1,9 @@ # RUN: not tf_tfl_translate -mlir-pretty-debuginfo -tf-input-arrays=input -tf-input-data-types=DT_FLOAT -tf-input-shapes=1,224,224,3 -tf-output-arrays=MobilenetV1/MobilenetV1/Conv2d_0/BatchNorm/FusedBatchNorm -tf-debug-info=%s.debug %s -o - 2>&1 | FileCheck %s -# CHECK: fake/user/code/file_C.py: error: 'tf.Conv2D' op attribute 'data_format' failed to satisfy constraint: 'NHWC' or 'NCHW' convnet data format -# CHECK: fake/user/code/file_D.py: note: called from -# CHECK: fake/user/code/file_E.py: note: called from -# CHECK: fake/user/code/file_F.py: note: called from +# CHECK: fake/user/code/file_C.py:27:0: error: 'tf.Conv2D' op attribute 'data_format' failed to satisfy constraint: 'NHWC' or 'NCHW' convnet data format +# CHECK: fake/user/code/file_D.py:28:0: note: called from +# CHECK: fake/user/code/file_E.py:29:0: note: called from +# CHECK: fake/user/code/file_F.py:30:0: note: called from node { name: "input" diff --git a/tensorflow/compiler/mlir/lite/tests/end2end/BUILD b/tensorflow/compiler/mlir/lite/tests/end2end/BUILD index 732fd784bbc..9d768fec0ab 100644 --- a/tensorflow/compiler/mlir/lite/tests/end2end/BUILD +++ b/tensorflow/compiler/mlir/lite/tests/end2end/BUILD @@ -8,6 +8,12 @@ glob_lit_tests( ":test_utilities", ], driver = "@llvm-project//mlir:run_lit.sh", + tags_override = { + "add.pbtxt": ["no_rocm"], + "conv_2d.pbtxt": ["no_rocm"], + "fake_quant_per_channel.pbtxt": ["no_rocm"], + "ophint_lstm.pbtxt": ["no_rocm"], + }, test_file_exts = [ "pbtxt", ], diff --git a/tensorflow/compiler/mlir/lite/tests/legalize-tf-while.mlir b/tensorflow/compiler/mlir/lite/tests/legalize-tf-while.mlir index 8f30aef8287..50ea5c1da41 100644 --- a/tensorflow/compiler/mlir/lite/tests/legalize-tf-while.mlir +++ b/tensorflow/compiler/mlir/lite/tests/legalize-tf-while.mlir @@ -50,8 +50,8 @@ func @while_cond_10_frozen0(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>, %arg2: t // INLINE: ^bb0([[ARGS]]): // INLINE: %cst_2 = constant // INLINE: yield -// INLINE: while_body -// INLINE: while_cond +// INLINE-NOT: while_body +// INLINE-NOT: while_cond // CANON-LABEL: func @while_main // CANON-SAME: ([[VAL_0:%.*]]: tensor) diff --git a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir index d236c8169b8..7db46f778fa 100644 --- a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir @@ -192,11 +192,11 @@ func @argmin(%arg0: tensor<3xi32>, %arg1: tensor) -> tensor { // CHECK: "tfl.arg_min"(%arg0, %arg1) : (tensor<3xi32>, tensor) -> tensor } -func @sigmoid(%arg0: tensor) -> tensor { - %0 = "tf.Sigmoid"(%arg0) : (tensor) -> tensor - return %0 : tensor +func @sigmoid(%arg0: tensor) -> tensor { + %0 = "tf.Sigmoid"(%arg0) : (tensor) -> tensor + return %0 : tensor // CHECK-LABEL: sigmoid -// CHECK: "tfl.logistic"(%arg0) : (tensor) -> tensor +// CHECK: "tfl.logistic"(%arg0) : (tensor) -> tensor } func @sqrt(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> { @@ -1316,16 +1316,6 @@ func @assert_remove(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>) -> tensor<1xi1> // CHECK: return } -func @reciprocal_f16(%arg0: tensor<8xf16>) -> tensor<8xf16> { - %0 = "tf.Reciprocal"(%arg0) : (tensor<8xf16>) -> tensor<8xf16> - return %0: tensor<8xf16> - -// CHECK-LABEL: reciprocal_f16 -// CHECK: %cst = constant dense<1.000000e+00> : tensor -// CHECK: "tfl.div"(%cst, %arg0) {fused_activation_function = "NONE"} : (tensor, tensor<8xf16>) -> tensor<8xf16> -// CHECK: return -} - func @reciprocal_f32(%arg0: tensor<8xf32>) -> tensor<8xf32> { %0 = "tf.Reciprocal"(%arg0) : (tensor<8xf32>) -> tensor<8xf32> return %0: tensor<8xf32> @@ -1336,16 +1326,6 @@ func @reciprocal_f32(%arg0: tensor<8xf32>) -> tensor<8xf32> { // CHECK: return } -func @reciprocal_complex_f32(%arg0: tensor<8xcomplex>) -> tensor<8xcomplex> { - %0 = "tf.Reciprocal"(%arg0) : (tensor<8xcomplex>) -> tensor<8xcomplex> - return %0: tensor<8xcomplex> - -// CHECK-LABEL: reciprocal_complex_f32 -// CHECK: %cst = constant opaque<"tf", "0x746674656E736F722464747970653A2044545F434F4D504C455836342074656E736F725F7368617065207B2064696D207B2073697A653A2031207D207D2074656E736F725F636F6E74656E743A20225C3030305C3030305C3230303F5C3030305C3030305C3030305C30303022"> : tensor> -// CHECK: "tfl.div"(%cst, %arg0) {fused_activation_function = "NONE"} : (tensor>, tensor<8xcomplex>) -> tensor<8xcomplex> -// CHECK: return -} - func @reciprocal_i32(%arg0: tensor<8xi32>) -> tensor<8xi32> { %0 = "tf.Reciprocal"(%arg0) : (tensor<8xi32>) -> tensor<8xi32> return %0: tensor<8xi32> @@ -1356,16 +1336,6 @@ func @reciprocal_i32(%arg0: tensor<8xi32>) -> tensor<8xi32> { // CHECK: return } -func @reciprocal_i64(%arg0: tensor<8xi64>) -> tensor<8xi64> { - %0 = "tf.Reciprocal"(%arg0) : (tensor<8xi64>) -> tensor<8xi64> - return %0: tensor<8xi64> - -// CHECK-LABEL: reciprocal_i64 -// CHECK: %cst = constant dense<1> : tensor -// CHECK: "tfl.div"(%cst, %arg0) {fused_activation_function = "NONE"} : (tensor, tensor<8xi64>) -> tensor<8xi64> -// CHECK: return -} - func @random_uniform() -> tensor<2x5xf32> { %0 = "tf.Const"() { value = dense<[2, 5]> : tensor<2xi32> } : () -> tensor<2xi32> %1 = "tf.RandomUniform"(%0) { seed = 1, seed2 = 0} : (tensor<2xi32>) -> tensor<2x5xf32> diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2exec/tfl_while_op.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2exec/tfl_while_op.mlir index 3addd8a9248..53785e728d1 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2exec/tfl_while_op.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2exec/tfl_while_op.mlir @@ -14,7 +14,6 @@ // CHECK-NEXT: Tensor 2 val kTfLiteFloat32 kTfLiteMmapRo 4 bytes // CHECK-NEXT: Tensor 3 tfl.while kTfLiteInt32 kTfLiteArenaRw 4 bytes // CHECK-NEXT: Tensor 4 result kTfLiteFloat32 kTfLiteArenaRw 4 bytes -// CHECK-NEXT: Tensor 5 tfl.while:2 kTfLiteInt32 kTfLiteArenaRw 4 bytes // Verify while was not folded away: // ------------------------------------ diff --git a/tensorflow/compiler/mlir/lite/tests/ops.mlir b/tensorflow/compiler/mlir/lite/tests/ops.mlir index 995f20c4a07..007fb9d4dae 100644 --- a/tensorflow/compiler/mlir/lite/tests/ops.mlir +++ b/tensorflow/compiler/mlir/lite/tests/ops.mlir @@ -16,7 +16,7 @@ func @testCos(tensor) -> tensor { // test invalid Cos input func @testCosWithWrongInputType(tensor) -> tensor { ^bb0(%arg0: tensor): - // expected-error @+1 {{tfl.cos' op operand #0 must be tensor of floating-point values}} + // expected-error @+1 {{tfl.cos' op operand #0 must be tensor of 32-bit float values}} %0 = "tfl.cos"(%arg0): (tensor) -> tensor return %0#0 : tensor } @@ -103,7 +103,7 @@ func @testAddN(tensor, tensor, tensor) -> tensor, tensor, tensor) -> tensor { ^bb0(%arg0: tensor, %arg1: tensor, %arg2: tensor): - // expected-error @+1 {{'tfl.add_n' op operand #0 must be tensor of 32-bit float or 32-bit signless integer or QI16 type or QUI16 type values}} + // expected-error @+1 {{'tfl.add_n' op operand #0 must be tensor of 32-bit float or 32-bit signless integer}} %0 = "tfl.add_n"(%arg0, %arg1, %arg2): (tensor, tensor, tensor) -> tensor return %0 : tensor } @@ -147,7 +147,7 @@ func @testSin(tensor) -> tensor { // test invalid Sin input func @testSinWithWrongInputType(tensor) -> tensor { ^bb0(%arg0: tensor): - // expected-error @+1 {{tfl.sin' op operand #0 must be tensor of floating-point values}} + // expected-error @+1 {{tfl.sin' op operand #0 must be tensor of 32-bit float values}} %0 = "tfl.sin"(%arg0): (tensor) -> tensor return %0#0 : tensor } @@ -157,7 +157,7 @@ func @testSinWithWrongInputType(tensor) -> tensor { // test invalid Sqrt input func @testSqrtWithWrongInputType(tensor) -> tensor { ^bb0(%arg0: tensor): - // expected-error @+1 {{tfl.sqrt' op operand #0 must be tensor of floating-point values}} + // expected-error @+1 {{tfl.sqrt' op operand #0 must be tensor of 32-bit float values}} %0 = "tfl.sqrt"(%arg0): (tensor) -> tensor return %0#0 : tensor } @@ -167,7 +167,7 @@ func @testSqrtWithWrongInputType(tensor) -> tensor { // test invalid Square input func @testSquareWithWrongInputType(tensor) -> tensor { ^bb0(%arg0: tensor): - // expected-error @+1 {{tfl.square' op operand #0 must be tensor of floating-point values}} + // expected-error @+1 {{tfl.square' op operand #0 must be tensor of 32-bit float values}} %0 = "tfl.square"(%arg0): (tensor) -> tensor return %0#0 : tensor } @@ -425,7 +425,7 @@ func @testTileF32(%arg0: tensor<4 x 1 x f32>, %arg1: tensor<4 x i32>) -> tensor< // ----- func @testEluI32(%arg0: tensor) -> tensor { - // expected-error @+1 {{operand #0 must be tensor of floating-point values}} + // expected-error @+1 {{op operand #0 must be tensor of 32-bit float values}} %0 = "tfl.elu"(%arg0): (tensor) -> tensor return %0#0 : tensor } @@ -531,11 +531,11 @@ func @testMaxUnpooling2D(%arg0: tensor<1x8x8x128xf32>, %arg1: tensor<1x8x8x128xf // ----- // CHECK-LABEL: testLogistic -func @testLogistic(tensor<1x2x3x4x5xbf16>) -> tensor<1x2x3x4x5xbf16> { -^bb0(%arg0: tensor<1x2x3x4x5xbf16>): +func @testLogistic(tensor<1x2x3x4x5xf32>) -> tensor<1x2x3x4x5xf32> { +^bb0(%arg0: tensor<1x2x3x4x5xf32>): // CHECK: "tfl.logistic"(%arg0) - %0 = "tfl.logistic"(%arg0): (tensor<1x2x3x4x5xbf16>) -> tensor<1x2x3x4x5xbf16> - return %0 : tensor<1x2x3x4x5xbf16> + %0 = "tfl.logistic"(%arg0): (tensor<1x2x3x4x5xf32>) -> tensor<1x2x3x4x5xf32> + return %0 : tensor<1x2x3x4x5xf32> } // ----- @@ -543,7 +543,7 @@ func @testLogistic(tensor<1x2x3x4x5xbf16>) -> tensor<1x2x3x4x5xbf16> { // test invalid Logistic input func @testLogisticWithWrongInputType(tensor) -> tensor { ^bb0(%arg0: tensor): - // expected-error @+1 {{tfl.logistic' op operand #0 must be tensor of floating-point or QI8 type or QUI8 type or QI16 type or QUI16 type values}} + // expected-error @+1 {{tfl.logistic' op operand #0 must be tensor of 32-bit float or QI8 type or QUI8 type or QI16 type or QUI16 type values}} %0 = "tfl.logistic"(%arg0): (tensor) -> tensor return %0#0 : tensor } diff --git a/tensorflow/compiler/mlir/lite/tests/optimize.mlir b/tensorflow/compiler/mlir/lite/tests/optimize.mlir index ae5bd6ced5e..d1ead351005 100644 --- a/tensorflow/compiler/mlir/lite/tests/optimize.mlir +++ b/tensorflow/compiler/mlir/lite/tests/optimize.mlir @@ -400,8 +400,8 @@ func @FuseFullyConnectedReshapeAddConst(%arg0: tensor<40x37xf32>, %arg1: tensor< // FOLD: return %[[fc]] } -// CHECK-LABEL: @NotReorderReshapeAddIfNotBroadcastable -func @NotReorderReshapeAddIfNotBroadcastable(%arg0: tensor<40x10x4xf32>) -> tensor<40x40xf32> { +// CHECK-LABEL: @NotReorderReshapeAddIfNotBroadcastableAfter +func @NotReorderReshapeAddIfNotBroadcastableAfter(%arg0: tensor<40x10x4xf32>) -> tensor<40x40xf32> { %cst = constant dense<2.0> : tensor<40xf32> %shape = constant dense<[40, 40]> : tensor<2xi32> %1 = "tfl.reshape"(%arg0, %shape) : (tensor<40x10x4xf32>, tensor<2xi32>) -> tensor<40x40xf32> @@ -413,6 +413,19 @@ func @NotReorderReshapeAddIfNotBroadcastable(%arg0: tensor<40x10x4xf32>) -> tens // CHECK: return %[[rs2]] } +// CHECK-LABEL: @NotReorderReshapeAddIfNotTailingDimAfter +func @NotReorderReshapeAddIfNotTailingDimAfter(%arg0: tensor<1x30x1x96xf32>) -> tensor<1x30x96xf32> { + %cst = constant dense<2.0> : tensor<1x30x96xf32> + %shape = constant dense<[1, 30, 96]> : tensor<3xi32> + %1 = "tfl.reshape"(%arg0, %shape) : (tensor<1x30x1x96xf32>, tensor<3xi32>) -> tensor<1x30x96xf32> + %2 = "tfl.add"(%1, %cst) {fused_activation_function = "NONE"} : (tensor<1x30x96xf32>, tensor<1x30x96xf32>) -> tensor<1x30x96xf32> + return %2 : tensor<1x30x96xf32> + + // CHECK: %[[rs1:.*]] = "tfl.reshape"(%arg0 + // CHECK: %[[rs2:.*]] = tfl.add %[[rs1]] + // CHECK: return %[[rs2]] +} + // CHECK-LABEL: @NotReorderReshapeAddIfNotTailingDim func @NotReorderReshapeAddIfNotTailingDim(%arg0: tensor<40x40x1xf32>) -> tensor<40x40xf32> { %cst = constant dense<2.0> : tensor<1x40xf32> diff --git a/tensorflow/compiler/mlir/lite/tests/tfl_while_outline.mlir b/tensorflow/compiler/mlir/lite/tests/tfl_while_outline.mlir index 3608d89e5e3..d3f4f5ba307 100644 --- a/tensorflow/compiler/mlir/lite/tests/tfl_while_outline.mlir +++ b/tensorflow/compiler/mlir/lite/tests/tfl_while_outline.mlir @@ -15,14 +15,14 @@ func @while() -> tensor<1xf32> %0:2 = "tfl.while"(%cst0, %cst1) ( { ^bb0(%arg2: tensor<*xi32>, %arg3: tensor<*xf32>): // CHECK: call @WhileOp_cond - // CHECK-SAME: (tensor<*xi32>, tensor<*xf32>, tensor) + // CHECK-SAME: (tensor<*xi32>, tensor<*xf32>) %cst_0 = constant dense<0> : tensor %1 = "tfl.greater"(%arg2, %cst_0) : (tensor<*xi32>, tensor) -> tensor "tfl.yield"(%1) : (tensor) -> () }, { ^bb0(%arg2: tensor<*xi32>, %arg3: tensor<*xf32>): // CHECK: call @WhileOp_body - // CHECK-SAME: (tensor<*xi32>, tensor<*xf32>, tensor) + // CHECK-SAME: (tensor<*xi32>, tensor<*xf32>) %1 = "tfl.sub"(%arg2, %cst) {fused_activation_function = "NONE"} : (tensor<*xi32>, tensor) -> tensor<*xi32> %2 = tfl.add %arg3, %arg3 {fused_activation_function = "NONE"} : tensor<*xf32> @@ -40,8 +40,7 @@ func @while() -> tensor<1xf32> // CHECK-LABEL: func @while2 // Verify that while body//cond with implicitly captured values result in changing while operands/results. -func @while2() -> tensor<1xf32> attributes {tf.entry_function = {outputs = "result"}} { - %cst = constant dense<1> : tensor +func @while2(%cst : tensor) -> tensor<1xf32> attributes {tf.entry_function = {outputs = "result"}} { %cst_0 = constant dense<5> : tensor %cst_1 = constant dense<3.000000e+00> : tensor<1xf32> // Verifies 3 operands post outlining. @@ -148,22 +147,21 @@ func @rnn(%arg0: tensor<4x4x3xf32> {tf.device = "/device:CPU:0"}) -> tensor<4x?x // CHECK: tfl.while // CHECK: tfl.yield // CHECK-SAME: (tensor) -> () -// CHECK: [[VAL_41:%.*]]:18 = +// CHECK: [[VAL_30:%.*]]:7 = // CHECK: call @tfl.while_body // CHECK: tfl.yield -// CHECK-SAME: (tensor, tensor, tensor<*xf32>, tensor<4x2xf32>, tensor<4x2xf32>, tensor<*xf32>, tensor, tensor, tensor<4x4x3xf32>, tensor<8x5xf32>, tensor<8xf32>, tensor, tensor<1xi32>, tensor, tensor<1xi32>, tensor<1xi32>, tensor, tensor<1xi32>) -> () +// CHECK-SAME: (tensor, tensor, tensor<*xf32>, tensor<4x2xf32>, tensor<4x2xf32>, tensor<*xf32>, tensor<4x4x3xf32>) -> () // CHECK-LABEL: func @tfl.while_cond( -// CHECK-SAME: [[VAL_56:%.*]]: tensor, [[VAL_57:%.*]]: tensor, [[VAL_58:%.*]]: tensor<*xf32>, [[VAL_59:%.*]]: tensor<4x2xf32>, [[VAL_60:%.*]]: tensor<4x2xf32>, [[VAL_61:%.*]]: tensor<*xf32>, [[VAL_62:%.*]]: tensor, [[VAL_63:%.*]]: tensor, [[VAL_64:%.*]]: tensor<4x4x3xf32>, [[VAL_65:%.*]]: tensor<8x5xf32>, [[VAL_66:%.*]]: tensor<8xf32>, [[VAL_67:%.*]]: tensor, [[VAL_68:%.*]]: tensor<1xi32>, [[VAL_69:%.*]]: tensor, [[VAL_70:%.*]]: tensor<1xi32>, [[VAL_71:%.*]]: tensor<1xi32>, [[VAL_72:%.*]]: tensor, [[VAL_73:%.*]]: tensor<1xi32>) -> tensor attributes {sym_visibility = "private"} { +// CHECK-SAME: [[VAL_35:%.*]]: tensor, [[VAL_36:%.*]]: tensor, [[VAL_37:%.*]]: tensor<*xf32>, [[VAL_38:%.*]]: tensor<4x2xf32>, [[VAL_39:%.*]]: tensor<4x2xf32>, [[VAL_40:%.*]]: tensor<*xf32>, [[VAL_41:%.*]]: tensor<4x4x3xf32>) -> tensor attributes {sym_visibility = "private"} { // CHECK: return // CHECK-SAME: tensor // CHECK: } // CHECK-LABEL: func @tfl.while_body( -// CHECK-SAME: [[VAL_77:%.*]]: tensor, [[VAL_78:%.*]]: tensor, [[VAL_79:%.*]]: tensor<*xf32>, [[VAL_80:%.*]]: tensor<4x2xf32>, [[VAL_81:%.*]]: tensor<4x2xf32>, [[VAL_82:%.*]]: tensor<*xf32>, [[VAL_83:%.*]]: tensor, [[VAL_84:%.*]]: tensor, [[VAL_85:%.*]]: tensor<4x4x3xf32>, [[VAL_86:%.*]]: tensor<8x5xf32>, [[VAL_87:%.*]]: tensor<8xf32>, [[VAL_88:%.*]]: tensor, [[VAL_89:%.*]]: tensor<1xi32>, [[VAL_90:%.*]]: tensor, [[VAL_91:%.*]]: tensor<1xi32>, [[VAL_92:%.*]]: tensor<1xi32>, [[VAL_93:%.*]]: tensor, [[VAL_94:%.*]]: tensor<1xi32>) -> (tensor, tensor, tensor<*xf32>, tensor<4x2xf32>, tensor<4x2xf32>, tensor<*xf32>, tensor, tensor, tensor<4x4x3xf32>, tensor<8x5xf32>, tensor<8xf32>, tensor, tensor<1xi32>, tensor, tensor<1xi32>, tensor<1xi32>, tensor, tensor<1xi32>) attributes {sym_visibility = "private"} { -// CHECK: [[VAL_123:%.*]] = "tfl.cast" +// CHECK-SAME: [[VAL_46:%.*]]: tensor, [[VAL_47:%.*]]: tensor, [[VAL_48:%.*]]: tensor<*xf32>, [[VAL_49:%.*]]: tensor<4x2xf32>, [[VAL_50:%.*]]: tensor<4x2xf32>, [[VAL_51:%.*]]: tensor<*xf32>, [[VAL_52:%.*]]: tensor<4x4x3xf32>) -> (tensor, tensor, tensor<*xf32>, tensor<4x2xf32>, tensor<4x2xf32>, tensor<*xf32>, tensor<4x4x3xf32>) attributes {sym_visibility = "private"} { +// CHECK: [[VAL_91:%.*]] = "tfl.cast" // CHECK: return -// CHECK-SAME: [[VAL_123]], [[VAL_83]], [[VAL_84]], [[VAL_85]], [[VAL_86]], [[VAL_87]], [[VAL_88]], [[VAL_89]], [[VAL_90]], [[VAL_91]], [[VAL_92]], [[VAL_93]], [[VAL_94]] : tensor, tensor, tensor<*xf32>, tensor<4x2xf32>, tensor<4x2xf32>, tensor<*xf32>, tensor, tensor, tensor<4x4x3xf32>, tensor<8x5xf32>, tensor<8xf32>, tensor, tensor<1xi32>, tensor, tensor<1xi32>, tensor<1xi32>, tensor, tensor<1xi32> +// CHECK-SAME: [[VAL_91]], [[VAL_52]] : tensor, tensor, tensor<*xf32>, tensor<4x2xf32>, tensor<4x2xf32>, tensor<*xf32>, tensor<4x4x3xf32> // CHECK: } // CHECK: } - diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc b/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc index bb7a30e64f6..69a42d884cb 100644 --- a/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc +++ b/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc @@ -79,10 +79,12 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config, pass_config.quant_specs.serialized_quant_stats)); } - if (pass_config.lower_tensor_list_ops) { - // TODO(haoliang): Add this pass by default. - pass_manager->addPass(mlir::TFL::CreateLowerStaticTensorListPass()); - } + // The conversion pipeline has to follow the following orders: + // 1) Try to convert ophint nodes if present first like ophint lstm. + // 2) Saved model related optimization like decompose resource ops + // 3) Convert composite functions like lstm/rnns, along with proper function + // inlining & dce. + // 4) Lower static tensor list pass. // The ophint extractions happen before lots of other passes: // The assumption of ophint-extraction is each ophinted region is a black-box @@ -105,29 +107,40 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config, pass_manager->addNestedPass( mlir::TFDevice::CreateDecomposeResourceOpsPass()); - // This pass does resource analysis of saved model global tensors and marks - // those deemed read-only as immutable. - pass_manager->addPass( - mlir::tf_saved_model::CreateOptimizeGlobalTensorsPass()); + // Note: + // We need to fuse composite ops before LowerStaticTensorList pass. + // The tensorflow list is not supported right now by that pass. + // Enable fusing composite ops that can be lowered to built-in TFLite ops. + if (pass_config.emit_builtin_tflite_ops) { + pass_manager->addPass(mlir::TFL::CreatePrepareCompositeFunctionsPass()); + } + // This pass marks non-exported functions as symbol visibility 'private' // those deemed read-only as immutable. pass_manager->addPass( mlir::tf_saved_model:: CreateMarkFunctionVisibilityUsingSavedModelLinkagePass()); - // Enable fusing composite ops that can be lowered to built-in TFLite ops. - if (pass_config.emit_builtin_tflite_ops) { - pass_manager->addPass(mlir::TFL::CreatePrepareCompositeFunctionsPass()); + pass_manager->addPass(mlir::createInlinerPass()); + pass_manager->addPass(mlir::createSymbolDCEPass()); + + if (pass_config.lower_tensor_list_ops) { + // TODO(haoliang): Add this pass by default. + pass_manager->addPass(mlir::TFL::CreateLowerStaticTensorListPass()); } + // This pass does resource analysis of saved model global tensors and marks + // those deemed read-only as immutable. + pass_manager->addPass( + mlir::tf_saved_model::CreateOptimizeGlobalTensorsPass()); + // Legalize while early to allow further constant folding. // TODO(jpienaar): This may not actually matter as we do canonicalization // after the legalize below, for now it needs to be below the above passes // that work on TF dialect and before inliner so that the function calls in // body and cond are inlined for optimization. if (pass_config.legalize_tf_while) { - pass_manager->addNestedPass( - mlir::TFL::CreateLegalizeTFWhilePass()); + pass_manager->addPass(mlir::TFL::CreateLegalizeTFWhilePass()); } // Add function inlining pass. Both TF and TFLite dialects are opted into diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_tf_while.cc b/tensorflow/compiler/mlir/lite/transforms/legalize_tf_while.cc index 66173c3c5b5..6d7713ad505 100644 --- a/tensorflow/compiler/mlir/lite/transforms/legalize_tf_while.cc +++ b/tensorflow/compiler/mlir/lite/transforms/legalize_tf_while.cc @@ -31,44 +31,52 @@ namespace { // Legalize TF While to TFL While with calls to the original functions from the // cond and body regions. -struct LegalizeWhile : public FunctionPass { - void runOnFunction() override { - auto func = getFunction(); - // Convert all TF WhileOps inside the function body to TFL While ops. - func.getBody().walk([](TF::WhileOp while_op) { - Operation* op = while_op.getOperation(); - // Create new TFL While op that will be used to replace TF While op. - auto new_op = OpBuilder(op).create( - op->getLoc(), op->getResultTypes(), op->getOperands(), - while_op.is_stateless()); - // Insert call to the given function into the 'region'. - auto create_region_with_call = [&while_op](FlatSymbolRefAttr symbol, - Region& region) { - OpBuilder builder(region); - auto block = builder.createBlock(®ion); - SmallVector new_operands; - auto func = while_op.getParentOfType().lookupSymbol( - symbol.getValue()); - for (Type t : func.getType().getInputs()) - new_operands.push_back(block->addArgument(t)); - auto call = - builder.create(while_op.getLoc(), symbol, - func.getType().getResults(), new_operands); - builder.create(while_op.getLoc(), call.getResults()); - }; - create_region_with_call(while_op.condAttr(), new_op.cond()); - create_region_with_call(while_op.bodyAttr(), new_op.body()); +struct LegalizeWhile : public ModulePass { + void RunOnFunction(FuncOp func); - op->replaceAllUsesWith(new_op.getResults()); - op->erase(); - }); + void runOnModule() override { + for (auto op : getModule().getOps()) RunOnFunction(op); } }; } // namespace +void RunOnWhile(TF::WhileOp while_op) { + Operation* op = while_op.getOperation(); + // Create new TFL While op that will be used to replace TF While op. + auto new_op = OpBuilder(op).create( + op->getLoc(), op->getResultTypes(), op->getOperands(), + while_op.is_stateless()); + // Insert call to the given function into the 'region'. + auto create_region_with_call = [&while_op](FlatSymbolRefAttr symbol, + Region& region) { + OpBuilder builder(region); + auto block = builder.createBlock(®ion); + SmallVector new_operands; + auto func = while_op.getParentOfType().lookupSymbol( + symbol.getValue()); + for (Type t : func.getType().getInputs()) + new_operands.push_back(block->addArgument(t)); + auto call = builder.create( + while_op.getLoc(), symbol, func.getType().getResults(), new_operands); + builder.create(while_op.getLoc(), call.getResults()); + // Mark old function as private so that it can be DCE'd if not called. + func.setVisibility(SymbolTable::Visibility::Private); + }; + create_region_with_call(while_op.condAttr(), new_op.cond()); + create_region_with_call(while_op.bodyAttr(), new_op.body()); + + op->replaceAllUsesWith(new_op.getResults()); + op->erase(); +} + +void LegalizeWhile::RunOnFunction(FuncOp func) { + // Convert all TF WhileOps inside the function body to TFL While ops. + func.getBody().walk([](TF::WhileOp while_op) { RunOnWhile(while_op); }); +} + // Creates an instance of the TensorFlow While to TFLite While pass. -std::unique_ptr> CreateLegalizeTFWhilePass() { +std::unique_ptr> CreateLegalizeTFWhilePass() { return std::make_unique(); } diff --git a/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc b/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc index 8c6a2970397..6ad9a6d2267 100644 --- a/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc +++ b/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc @@ -862,6 +862,11 @@ LogicalResult LowerStaticTensorListPass::RewriteFunction( target.addLegalOp(); target.addLegalOp(); target.addLegalOp(); + // Register fused LSTM/RNN ops as legal. + target.addLegalOp(); + target.addLegalOp(); + target.addLegalOp(); + target.addLegalOp(); OwningRewritePatternList patterns; populateWithGenerated(context, &patterns); diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td index 144227b06af..916782d95b3 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td @@ -339,11 +339,15 @@ foreach BinaryOp = [TFL_AddOp, TFL_SubOp, TFL_DivOp, TFL_MulOp] in { (ConstantOp:$rhs $a), TFL_AF_None), (TFL_ReshapeOp (BinaryOp $input, $rhs, TFL_AF_None), $shape), // The broadcasting of "BinaryOp" only happens in the lower - // dimensions, and the higher dimensions are same. + // dimensions, and the higher dimensions are same, so we know the + // result and input of the "BinaryOp" in the source pattern have + // the same shape, which is defined by `shape`. [(IsTailOfShape $rhs, $lhs), (HasOneUse $lhs), - // the two operands of the binary op is broadcastable - (AreBroadcastableTypes $rhs, $input)]>; + // The result of the new "BinaryOp" will have the same shape as + // `input`. In other words, the shape of the `Reshape` op are not + // changed after the transformation. + (IsTailOfShape $rhs, $input)]>; } foreach BinaryOp = [TFL_FloorDivOp, TFL_FloorModOp, TFL_MinimumOp, @@ -363,11 +367,15 @@ foreach BinaryOp = [TFL_FloorDivOp, TFL_FloorModOp, TFL_MinimumOp, (ConstantOp:$rhs $a)), (TFL_ReshapeOp (BinaryOp $input, $rhs), $shape), // The broadcasting of "BinaryOp" only happens in the lower - // dimensions, and the higher dimensions are same. + // dimensions, and the higher dimensions are same, so we know the + // result and input of the "BinaryOp" in the source pattern have + // the same shape, which is defined by `shape`. [(IsTailOfShape $rhs, $lhs), (HasOneUse $lhs), - // the two operands of the binary op is broadcastable - (AreBroadcastableTypes $rhs, $input)]>; + // The result of the new "BinaryOp" will have the same shape as + // `input`. In other words, the shape of the `Reshape` op are not + // changed after the transformation. + (IsTailOfShape $rhs, $input)]>; } // Returns shape of a ranked tensor. diff --git a/tensorflow/compiler/mlir/lite/transforms/passes.h b/tensorflow/compiler/mlir/lite/transforms/passes.h index 22e1b8f636f..1c92c806585 100644 --- a/tensorflow/compiler/mlir/lite/transforms/passes.h +++ b/tensorflow/compiler/mlir/lite/transforms/passes.h @@ -86,7 +86,7 @@ std::unique_ptr> CreateDefaultQuantParamsPass( std::unique_ptr> CreateDenseToSparsePass(); // Creates function pass to legalize TF While to TFL While. -std::unique_ptr> CreateLegalizeTFWhilePass(); +std::unique_ptr> CreateLegalizeTFWhilePass(); // Creates an instance of the TensorFlow Lite dialect WhileOp outline pass. std::unique_ptr> CreateWhileOutlinePass(); diff --git a/tensorflow/compiler/mlir/lite/transforms/while_loop_outline.cc b/tensorflow/compiler/mlir/lite/transforms/while_loop_outline.cc index d4c359b6178..a0675efcc6b 100644 --- a/tensorflow/compiler/mlir/lite/transforms/while_loop_outline.cc +++ b/tensorflow/compiler/mlir/lite/transforms/while_loop_outline.cc @@ -22,6 +22,7 @@ limitations under the License. #include "mlir/IR/Identifier.h" // from @llvm-project #include "mlir/IR/Location.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Matchers.h" // from @llvm-project #include "mlir/IR/StandardTypes.h" // from @llvm-project #include "mlir/IR/SymbolTable.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project @@ -89,24 +90,20 @@ void WhileOutlinePass::OutlineWhile(WhileOp while_op) { llvm::SmallVector regions{&while_op.cond(), &while_op.body()}; for (auto it : llvm::enumerate(regions)) { llvm::SetVector region_extern_values; - Value const_none = nullptr; getUsedValuesDefinedAbove(*it.value(), region_extern_values); - // Sink down none type constants into the functions. + // Sink down constants into the functions. for (auto extern_value : region_extern_values) { - if (!extern_value.getType().isa()) { + if (!matchPattern(extern_value, m_Constant())) { extern_values.insert(extern_value); continue; } - if (!const_none) { - // Add constant at start of region. - auto const_builder = - OpBuilder(&it.value()->front(), it.value()->front().begin()); - const_none = const_builder.create( - while_op.getLoc(), extern_value.getType(), - const_builder.getUnitAttr()); - } - replaceAllUsesInRegionWith(extern_value, const_none, *it.value()); + // Add constant at start of region. + auto const_builder = + OpBuilder(&it.value()->front(), it.value()->front().begin()); + auto const_value = const_builder.clone(*extern_value.getDefiningOp()); + replaceAllUsesInRegionWith(extern_value, const_value->getResult(0), + *it.value()); } } diff --git a/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc b/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc index e554686531a..af594b0125d 100644 --- a/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc +++ b/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc @@ -21,6 +21,10 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/raw_os_ostream.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h" #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h" #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" @@ -86,6 +90,17 @@ MlirOptimizationPassRegistry& MlirOptimizationPassRegistry::Global() { return *global; } +static void RegisterDialects() { + static bool init_once = []() { + mlir::registerDialect(); + mlir::registerDialect(); + mlir::registerDialect(); + mlir::registerDialect(); + return true; + }(); + (void)init_once; +} + Status MlirFunctionOptimizationPass::Run( const DeviceSet& device_set, const ConfigProto& config_proto, std::unique_ptr* graph, FunctionLibraryDefinition* flib_def, @@ -107,6 +122,7 @@ Status MlirFunctionOptimizationPass::Run( << "(registered " << registry_->passes().size() << " passes)"; GraphDebugInfo debug_info; + RegisterDialects(); mlir::MLIRContext context; GraphImportConfig import_config; import_config.graph_as_function = true; @@ -178,9 +194,12 @@ Status MlirV1CompatGraphOptimizationPass::Run( << "(registered" << registry_->passes().size() << " passes)"; GraphDebugInfo debug_info; + RegisterDialects(); mlir::MLIRContext context; GraphImportConfig import_config; - import_config.upgrade_legacy = true; + // TODO(b/150959075): Running functionalization before TPU cluster formation + // is not semantics preserving and should be disabled for now. + import_config.upgrade_legacy = false; TF_ASSIGN_OR_RETURN( auto module_ref, ConvertGraphToMlir(**options.graph, debug_info, *options.flib_def, diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index e14f9a211dc..c2120ccc4ab 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -10,6 +10,7 @@ package_group( name = "friends", includes = ["//third_party/mlir:subpackages"], packages = [ + "//learning/brain/experimental/tfrt/...", "//learning/pathways/data_parallel/tf2xla/...", "//tensorflow/compiler/...", "//tensorflow/lite/experimental/tf_runtime/...", @@ -986,7 +987,6 @@ cc_library( tf_cc_test( name = "error_util_test", srcs = ["utils/error_util_test.cc"], - tags = ["no_rocm"], deps = [ ":error_util", "//tensorflow/compiler/xla:test", @@ -1065,6 +1065,7 @@ COMPILE_MLIR_UTIL_DEPS = [ ":tensorflow_dialect_registration", ":tensorflow_passes", ":translate_utils", + "@com_google_absl//absl/types:optional", "@llvm-project//llvm:support", "@llvm-project//mlir:IR", "@llvm-project//mlir:Parser", @@ -1083,6 +1084,8 @@ COMPILE_MLIR_UTIL_DEPS = [ "//tensorflow/core:protos_all_cc", "//tensorflow/core/platform:logging", "//tensorflow/stream_executor/lib", + "//tensorflow/compiler/xla:xla_data_proto_cc", + "//tensorflow/compiler/xla/service:hlo", ] # Prefer to link 'compile_mlir_util' library that also links necessary diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td index d92964f6617..cdeb10cf03a 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td @@ -1143,6 +1143,29 @@ that are not a number (NaN) or infinity (Inf). Otherwise, passes `tensor` as-is. TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_ClipByValueOp : TF_Op<"ClipByValue", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Clips tensor values to a specified min and max."; + + let description = [{ +Given a tensor `t`, this operation returns a tensor of the same type and +shape as `t` with its values clipped to `clip_value_min` and `clip_value_max`. +Any values less than `clip_value_min` are set to `clip_value_min`. Any values +greater than `clip_value_max` are set to `clip_value_max`. + }]; + + let arguments = (ins + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$t, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$clip_value_min, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$clip_value_max + ); + + let results = (outs + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_ComplexOp : TF_Op<"Complex", [NoSideEffect, ResultsBroadcastableShape]> { let summary = "Converts two real numbers to a complex number."; @@ -2592,7 +2615,7 @@ The size of 1D Tensors matches the dimension C of the 4D Tensors. TF_DerivedOperandTypeAttr U = TF_DerivedOperandTypeAttr<3>; } -def TF_FusedBatchNormGradV3Op : TF_Op<"FusedBatchNormGradV3", [NoSideEffect]> { +def TF_FusedBatchNormGradV3Op : TF_Op<"FusedBatchNormGradV3", [NoSideEffect, TF_LayoutSensitiveInterface]> { let summary = "Gradient for batch normalization."; let description = [{ @@ -2623,9 +2646,17 @@ The size of 1D Tensors matches the dimension C of the 4D Tensors. TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; TF_DerivedOperandTypeAttr U = TF_DerivedOperandTypeAttr<3>; + + let extraClassDeclaration = [{ + // TF_LayoutSensitiveInterface: + SmallVector GetLayoutDependentArgs() { return {0, 1}; } + SmallVector GetLayoutDependentResults() { return {0}; } + StringRef GetOptimalLayout(const RuntimeDevices& devices); + LogicalResult UpdateDataFormat(StringRef data_format); + }]; } -def TF_FusedBatchNormV3Op : TF_Op<"FusedBatchNormV3", [NoSideEffect, TF_FoldOperandsTransposeInterface]> { +def TF_FusedBatchNormV3Op : TF_Op<"FusedBatchNormV3", [NoSideEffect, TF_FoldOperandsTransposeInterface, TF_LayoutSensitiveInterface]> { let summary = "Batch normalization."; let description = [{ @@ -2662,6 +2693,10 @@ The size of 1D Tensors matches the dimension C of the 4D Tensors. SmallVector GetLayoutDependentArgs() { return {0}; } SmallVector GetLayoutDependentResults() { return {0}; } LogicalResult FoldOperandsPermutation(ArrayRef permutation); + + // TF_LayoutSensitiveInterface: + StringRef GetOptimalLayout(const RuntimeDevices& devices); + LogicalResult UpdateDataFormat(StringRef data_format); }]; } diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc index 3622a636c3b..842520927e2 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc @@ -1520,6 +1520,34 @@ static LogicalResult Verify(FillOp op) { return success(); } +//===----------------------------------------------------------------------===// +// FusedBatchNormGradOp +//===----------------------------------------------------------------------===// + +// TODO(b/150954845): Add benchmarks to verify that layout preference didn't +// change in the latest GPU generations. + +LogicalResult FusedBatchNormGradV3Op::UpdateDataFormat(StringRef data_format) { + return ::mlir::TF::UpdateDataFormat(data_format, this); +} + +StringRef FusedBatchNormGradV3Op::GetOptimalLayout( + const RuntimeDevices &devices) { + // Keep current data format if no GPUs are available or if explicit placement + // does not allow to use GPU for this operation. + if (!CanUseGpuDevice(devices) || !CanUseGpuDevice(getOperation())) + return data_format(); + + // For f16 data type on devices with Tensor Cores support NHWC data format + // is up to ~2x faster. + auto x_ty = x().getType().cast(); + const bool is_f16 = x_ty.getElementType().isF16(); + if (is_f16 && CanUseTensorCores(devices)) return "NHWC"; + + // For all other data types prefer NCHW. + return "NCHW"; +} + //===----------------------------------------------------------------------===// // FusedBatchNormOp //===----------------------------------------------------------------------===// @@ -1547,9 +1575,36 @@ static LogicalResult Verify(FusedBatchNormOp op) { LogicalResult FusedBatchNormV3Op::FoldOperandsPermutation( ArrayRef permutation) { + // FusedBatchNorm in training mode is a layout sentitive operation, and should + // have already assigned an optimal data format. + if (is_training()) return failure(); + return ::mlir::TF::FoldOperandsPermutation(permutation, this); } +LogicalResult FusedBatchNormV3Op::UpdateDataFormat(StringRef data_format) { + return ::mlir::TF::UpdateDataFormat(data_format, this); +} + +StringRef FusedBatchNormV3Op::GetOptimalLayout(const RuntimeDevices &devices) { + // In inference mode FusedBatchNorm is not sensitive to data layout. + if (!is_training()) return data_format(); + + // Keep current data format if no GPUs are available or if explicit placement + // does not allow to use GPU for this operation. + if (!CanUseGpuDevice(devices) || !CanUseGpuDevice(getOperation())) + return data_format(); + + // For f16 data type on devices with Tensor Cores support NHWC data format + // is up to ~2x faster. + auto x_ty = x().getType().cast(); + const bool is_f16 = x_ty.getElementType().isF16(); + if (is_f16 && CanUseTensorCores(devices)) return "NHWC"; + + // For all other data types prefer NCHW. + return "NCHW"; +} + //===----------------------------------------------------------------------===// // GatherV2Op //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/tensorflow/tests/BUILD b/tensorflow/compiler/mlir/tensorflow/tests/BUILD index a4ebc997991..daa583bed0e 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/tests/BUILD @@ -5,6 +5,10 @@ package(licenses = ["notice"]) glob_lit_tests( data = [":test_utilities"], driver = "@llvm-project//mlir:run_lit.sh", + tags_override = { + "optimize.mlir": ["no_rocm"], + "tf_optimize.mlir": ["no_rocm"], + }, test_file_exts = ["mlir"], ) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/einsum.mlir b/tensorflow/compiler/mlir/tensorflow/tests/einsum.mlir index 3dec94a98df..c09e2b25d99 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/einsum.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/einsum.mlir @@ -7,6 +7,37 @@ func @einsum_basic(%arg0: tensor<3x4x5xf32>, %arg1: tensor<3x5x6xf32>) -> tensor // CHECK: "tf.BatchMatMulV2"(%arg0, %arg1) {adj_x = false, adj_y = false} : (tensor<3x4x5xf32>, tensor<3x5x6xf32>) -> tensor<3x4x6xf32> } +func @einsum_broadcast(%arg0: tensor<3x4x5xf32>, %arg1: tensor<5x6xf32>) -> tensor<3x4x6xf32> { + %0 = "tf.Einsum"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", equation = "ijk,km->ijm"}: (tensor<3x4x5xf32>, tensor<5x6xf32>) -> tensor<3x4x6xf32> + return %0 : tensor<3x4x6xf32> + // CHECK-LABEL: einsum_broadcast + // CHECK: "tf.BatchMatMulV2"(%arg0, %arg1) {adj_x = false, adj_y = false} : (tensor<3x4x5xf32>, tensor<5x6xf32>) -> tensor<3x4x6xf32> +} + +func @einsum_reducesum(%arg0: tensor<2x5x7xf32>, %arg1: tensor<5x2xf32>) -> tensor<5x7xf32> { + %0 = "tf.Einsum"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", equation = "lbh,bl->bh"}: (tensor<2x5x7xf32>, tensor<5x2xf32>) -> tensor<5x7xf32> + return %0 : tensor<5x7xf32> + // CHECK-LABEL: einsum_reducesum + // CHECK: %[[cst:.*]] = constant dense<[1, 2, 0]> : tensor<3xi32> + // CHECK: %[[cst_1:.*]] = constant dense<[5, 1, 2]> : tensor<3xi64> + // CHECK: %[[cst_2:.*]] = constant dense<2> : tensor<1xi32> + // CHECK: %[[v0:.*]] = "tf.Transpose"(%arg0, %[[cst]]) : (tensor<2x5x7xf32>, tensor<3xi32>) -> tensor<5x7x2xf32> + // CHECK: %[[v1:.*]] = "tf.Reshape"(%arg1, %[[cst_1]]) : (tensor<5x2xf32>, tensor<3xi64>) -> tensor<5x1x2xf32> + // CHECK: %[[v2:.*]] = "tf.Mul"(%[[v0]], %[[v1]]) : (tensor<5x7x2xf32>, tensor<5x1x2xf32>) -> tensor<5x7x2xf32> + // CHECK: "tf.Sum"(%[[v2]], %[[cst_2]]) {keep_dims = false} : (tensor<5x7x2xf32>, tensor<1xi32>) -> tensor<5x7xf32> +} +func @einsum_transpose_matmul(%arg0: tensor<2x5x7xf32>, %arg1: tensor<5x3x2xf32>) -> tensor<5x3x7xf32> { + %0 = "tf.Einsum"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", equation = "lbh,bkl->bkh"}: (tensor<2x5x7xf32>, tensor<5x3x2xf32>) -> tensor<5x3x7xf32> + return %0 : tensor<5x3x7xf32> + // CHECK-LABEL: einsum_transpose_matmul + // CHECK: %[[cst:.*]] = constant dense<[1, 2, 0]> : tensor<3xi32> + // CHECK: %[[cst_0:.*]] = constant dense<[0, 2, 1]> : tensor<3xi32> + // CHECK: %[[v0:.*]] = "tf.Transpose"(%arg0, %[[cst]]) : (tensor<2x5x7xf32>, tensor<3xi32>) -> tensor<5x7x2xf32> + // CHECK: %[[v1:.*]] = "tf.Transpose"(%arg1, %[[cst_0]]) : (tensor<5x3x2xf32>, tensor<3xi32>) -> tensor<5x2x3xf32> + // CHECK: %[[v2:.*]] = "tf.BatchMatMulV2"(%[[v0]], %[[v1]]) {adj_x = false, adj_y = false} : (tensor<5x7x2xf32>, tensor<5x2x3xf32>) -> tensor<5x7x3xf32> + // CHECK: %[[v3:.*]] = "tf.Transpose"(%[[v2]], %[[cst_0]]) : (tensor<5x7x3xf32>, tensor<3xi32>) -> tensor<5x3x7xf32> +} + func @einsum_4D(%arg0: tensor<2x5x7x3xf32>, %arg1: tensor<2x4x7x3xf32>) -> tensor<2x7x5x4xf32> { %0 = "tf.Einsum"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", equation = "bfnh,btnh->bnft"}: (tensor<2x5x7x3xf32>, tensor<2x4x7x3xf32>) -> tensor<2x7x5x4xf32> return %0 : tensor<2x7x5x4xf32> diff --git a/tensorflow/compiler/mlir/tensorflow/tests/fold-switch.mlir b/tensorflow/compiler/mlir/tensorflow/tests/fold-switch.mlir index 73ae30c7831..0b9e995b386 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/fold-switch.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/fold-switch.mlir @@ -38,6 +38,24 @@ func @test_single_branch_direct_t() -> tensor { return %0 : tensor } +// CHECK-LABEL: test_single_branch_direct_arg_f +// CHECK: Switch +// CHECK: tf.AddV2 +func @test_single_branch_direct_arg_f(%pred : tensor) -> tensor { + %cst_0 = constant dense<10> : tensor + %cst_1 = constant dense<1> : tensor + %0 = tf_executor.graph { + %7:3 = tf_executor.Switch %cst_0, %pred : tensor + %8:2 = tf_executor.island { + %12 = "tf.AddV2"(%7#1, %cst_1) : (tensor, tensor) -> tensor + tf_executor.yield %12 : tensor + } + %11:3 = tf_executor.Merge %7#0, %8#0 : tensor {N = 2 : i64} + tf_executor.fetch %11#0 : tensor + } + return %0 : tensor +} + // pred ? x + 1 : x - 1 // CHECK-LABEL: ControlFlowTest.testCond_1f // CHECK-NOT: Switch @@ -330,4 +348,4 @@ func @switch_with_send_recv() { tf_executor.fetch } return -} \ No newline at end of file +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-default-attr.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-default-attr.pbtxt index ac248041994..75002f538d6 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-default-attr.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-default-attr.pbtxt @@ -1,14 +1,15 @@ -# RUN: tf-mlir-translate -graphdef-to-mlir %s -o - | FileCheck %s +# RUN: tf-mlir-translate -graphdef-to-mlir %s -o - | FileCheck %s --dump-input-on-failure # Verify that the data_format attributes is pulled from the default value in the # registry when not present in the GraphDef # CHECK: tf.Conv2D # CHECK-SAME: data_format = "NHWC" -# Verify that we can also pull some attributes that are needed to be able to -# create a Graph in memory, like `T`. +# Verify that we don't import derived attributes as these will be added only on +# export. # CHECK: tf.MaxPool -# CHECK-SAME: T = f32 +# CHECK-NOT: T = f32 +# CHECK-SAME: : (tensor) -> tensor node { name: "input" diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-library.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-library.pbtxt index b65984227f6..d147106579d 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-library.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-library.pbtxt @@ -3,6 +3,21 @@ node { name: "unnamed" op: "foo" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "_disable_call_shape_inference" + value { + b: true + } + } experimental_debug_info { } } @@ -39,7 +54,7 @@ versions { # Verify that functions from the library are properly imported. # CHECK-LABEL: func @main() { -# CHECK: "tf.LegacyCall"() {_disable_call_shape_inference = false, f = @foo0} +# CHECK: "tf.LegacyCall"() {_disable_call_shape_inference = true, f = @foo0} # CHECK: "tf.LegacyCall"() {_disable_call_shape_inference = false, f = @bar0} # CHECK-LABEL: func @foo0() { diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-uint8-return.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-uint8-return.pbtxt index bb5e02fedf2..191ff5878ee 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-uint8-return.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-uint8-return.pbtxt @@ -105,7 +105,6 @@ versions { # CHECK: func @main # CHECK: "tf.PartitionedCall"() -# CHECK-SAME: Tout = ["tfdtype$DT_UINT8"] # CHECK-SAME: f = @[[FUNCTION:[A-Za-z0-9_]*]] # CHECK: func @[[FUNCTION]]() -> tensor # CHECK: return {{.*}} : tensor diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/output-shapes-attr.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/output-shapes-attr.pbtxt new file mode 100644 index 00000000000..2c93fde5bf2 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/output-shapes-attr.pbtxt @@ -0,0 +1,22 @@ +# RUN: tf-mlir-translate -graphdef-to-mlir %s -o - | FileCheck %s + +node { + name: "input0" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } +} +versions { + producer: 29 + min_consumer: 12 +} + +# Verify that functions from the library are properly imported. + +# CHECK-LABEL: func @main() { +# CHECK: "tf.Placeholder" +# CHECK-NOT: _output_shapes diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/string-attr.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/string-attr.pbtxt index 707b04473f3..051b88102be 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/string-attr.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/string-attr.pbtxt @@ -1,7 +1,6 @@ # RUN: tf-mlir-translate -graphdef-to-splatted-mlir %s -o - | FileCheck %s # CHECK: tf.Const -# CHECK-SAME: _output_shapes = ["tfshape$dim { size: 3 }"] # CHECK-SAME: value = opaque<"tf", "0x746674656E736F722464747970653A2044545F535452494E472074656E736F725F7368617065207B2064696D207B2073697A653A2033207D207D2074656E736F725F636F6E74656E743A20225C3030305C3030305C30303022"> : tensor<3x!tf.string> node { diff --git a/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_gpu_cc_60.mlir b/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_gpu_cc_60.mlir index 3839b000f3a..dc9b5d5b806 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_gpu_cc_60.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_gpu_cc_60.mlir @@ -62,4 +62,106 @@ func @transposeConv2DBackpropInput_f16( return %0 : tensor<1x28x28x64xf16> } +// CHECK-LABEL: func @transposeFusedBatchNormV3_f32 +func @transposeFusedBatchNormV3_f32( + %arg0: tensor<1x28x28x64xf32>, + %arg1: tensor<64xf32> +) -> tensor<1x28x28x64xf32> { + + // CHECK: "tf.FusedBatchNormV3" + // CHECK-SAME: (%[[X_TRANSPOSE:[0-9]*]], %arg1, %arg1, %arg1, %arg1) + // CHECK-SAME: data_format = "NCHW" + %y, %batch_mean, %batch_var, %reserve_1, %reserve_2, %reserve_3 + = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg1, %arg1, %arg1) + { + data_format = "NHWC", + epsilon = 1.001 : f32, + exponential_avg_factor = 1.0 : f32, + is_training = true + } + : (tensor<1x28x28x64xf32>, tensor<64xf32>, tensor<64xf32>, + tensor<64xf32>, tensor<64xf32>) + -> (tensor<1x28x28x64xf32>, tensor<64xf32>, tensor<64xf32>, + tensor<64xf32>, tensor<64xf32>, tensor<64xf32>) + + return %y : tensor<1x28x28x64xf32> +} + +// CHECK-LABEL: func @transposeFusedBatchNormV3_f16 +func @transposeFusedBatchNormV3_f16( + %arg0: tensor<1x28x28x64xf16>, + %arg1: tensor<64xf32> +) -> tensor<1x28x28x64xf16> { + + // CHECK: "tf.FusedBatchNormV3" + // CHECK-SAME: (%[[X_TRANSPOSE:[0-9]*]], %arg1, %arg1, %arg1, %arg1) + // CHECK-SAME: data_format = "NCHW" + %y, %batch_mean, %batch_var, %reserve_1, %reserve_2, %reserve_3 + = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg1, %arg1, %arg1) + { + data_format = "NHWC", + epsilon = 1.001 : f32, + exponential_avg_factor = 1.0 : f32, + is_training = true + } + : (tensor<1x28x28x64xf16>, tensor<64xf32>, tensor<64xf32>, + tensor<64xf32>, tensor<64xf32>) + -> (tensor<1x28x28x64xf16>, tensor<64xf32>, tensor<64xf32>, + tensor<64xf32>, tensor<64xf32>, tensor<64xf32>) + + return %y : tensor<1x28x28x64xf16> +} + +// CHECK-LABEL: func @transposeFusedBatchNormGradV3_f32 +func @transposeFusedBatchNormGradV3_f32( + %arg0: tensor<1x28x28x64xf32>, + %arg1: tensor<1x28x28x64xf32>, + %arg2: tensor<64xf32> +) -> tensor<1x28x28x64xf32> { + + // CHECK: "tf.FusedBatchNormGradV3" + // CHECK-SAME: (%[[X_TRANSPOSE:[0-9]*]], %[[Y_TRANSPOSE:[0-9]*]], + // CHECK-SAME: data_format = "NCHW" + %x_backprop, %scale_backprop, %offset_backprop, %reserve_1, %reserve_2 + = "tf.FusedBatchNormGradV3"(%arg0, %arg1, %arg2, %arg2, %arg2, %arg2) + { + data_format = "NHWC", + epsilon = 1.001 : f32, + exponential_avg_factor = 1.0 : f32, + is_training = true + } + : (tensor<1x28x28x64xf32>, tensor<1x28x28x64xf32>, + tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>) + -> (tensor<1x28x28x64xf32>, + tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>) + + return %x_backprop : tensor<1x28x28x64xf32> +} + +// CHECK-LABEL: func @transposeFusedBatchNormGradV3_f16 +func @transposeFusedBatchNormGradV3_f16( + %arg0: tensor<1x28x28x64xf16>, + %arg1: tensor<1x28x28x64xf16>, + %arg2: tensor<64xf32> +) -> tensor<1x28x28x64xf16> { + + // CHECK: "tf.FusedBatchNormGradV3" + // CHECK-SAME: (%[[X_TRANSPOSE:[0-9]*]], %[[Y_TRANSPOSE:[0-9]*]], + // CHECK-SAME: data_format = "NCHW" + %x_backprop, %scale_backprop, %offset_backprop, %reserve_1, %reserve_2 + = "tf.FusedBatchNormGradV3"(%arg0, %arg1, %arg2, %arg2, %arg2, %arg2) + { + data_format = "NHWC", + epsilon = 1.001 : f32, + exponential_avg_factor = 1.0 : f32, + is_training = true + } + : (tensor<1x28x28x64xf16>, tensor<1x28x28x64xf16>, + tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>) + -> (tensor<1x28x28x64xf16>, + tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>) + + return %x_backprop : tensor<1x28x28x64xf16> +} + } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_gpu_cc_70.mlir b/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_gpu_cc_70.mlir index b52ef1c4f4a..6173fa3026e 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_gpu_cc_70.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_gpu_cc_70.mlir @@ -143,4 +143,106 @@ func @transposeConv2DBackpropInput_f16( return %0 : tensor<1x64x28x28xf16> } +// CHECK-LABEL: func @transposeFusedBatchNormV3_f32 +func @transposeFusedBatchNormV3_f32( + %arg0: tensor<1x28x28x64xf32>, + %arg1: tensor<64xf32> +) -> tensor<1x28x28x64xf32> { + + // CHECK: "tf.FusedBatchNormV3" + // CHECK-SAME: (%[[X_TRANSPOSE:[0-9]*]], %arg1, %arg1, %arg1, %arg1) + // CHECK-SAME: data_format = "NCHW" + %y, %batch_mean, %batch_var, %reserve_1, %reserve_2, %reserve_3 + = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg1, %arg1, %arg1) + { + data_format = "NHWC", + epsilon = 1.001 : f32, + exponential_avg_factor = 1.0 : f32, + is_training = true + } + : (tensor<1x28x28x64xf32>, tensor<64xf32>, tensor<64xf32>, + tensor<64xf32>, tensor<64xf32>) + -> (tensor<1x28x28x64xf32>, tensor<64xf32>, tensor<64xf32>, + tensor<64xf32>, tensor<64xf32>, tensor<64xf32>) + + return %y : tensor<1x28x28x64xf32> +} + +// CHECK-LABEL: func @transposeFusedBatchNormV3_f16 +func @transposeFusedBatchNormV3_f16( + %arg0: tensor<1x64x28x28xf16>, + %arg1: tensor<64xf32> +) -> tensor<1x64x28x28xf16> { + + // CHECK: "tf.FusedBatchNormV3" + // CHECK-SAME: (%[[X_TRANSPOSE:[0-9]*]], %arg1, %arg1, %arg1, %arg1) + // CHECK-SAME: data_format = "NHWC" + %y, %batch_mean, %batch_var, %reserve_1, %reserve_2, %reserve_3 + = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg1, %arg1, %arg1) + { + data_format = "NCHW", + epsilon = 1.001 : f32, + exponential_avg_factor = 1.0 : f32, + is_training = true + } + : (tensor<1x64x28x28xf16>, tensor<64xf32>, tensor<64xf32>, + tensor<64xf32>, tensor<64xf32>) + -> (tensor<1x64x28x28xf16>, tensor<64xf32>, tensor<64xf32>, + tensor<64xf32>, tensor<64xf32>, tensor<64xf32>) + + return %y : tensor<1x64x28x28xf16> +} + +// CHECK-LABEL: func @transposeFusedBatchNormGradV3_f32 +func @transposeFusedBatchNormGradV3_f32( + %arg0: tensor<1x28x28x64xf32>, + %arg1: tensor<1x28x28x64xf32>, + %arg2: tensor<64xf32> +) -> tensor<1x28x28x64xf32> { + + // CHECK: "tf.FusedBatchNormGradV3" + // CHECK-SAME: (%[[X_TRANSPOSE:[0-9]*]], %[[Y_TRANSPOSE:[0-9]*]], + // CHECK-SAME: data_format = "NCHW" + %x_backprop, %scale_backprop, %offset_backprop, %reserve_1, %reserve_2 + = "tf.FusedBatchNormGradV3"(%arg0, %arg1, %arg2, %arg2, %arg2, %arg2) + { + data_format = "NHWC", + epsilon = 1.001 : f32, + exponential_avg_factor = 1.0 : f32, + is_training = true + } + : (tensor<1x28x28x64xf32>, tensor<1x28x28x64xf32>, + tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>) + -> (tensor<1x28x28x64xf32>, + tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>) + + return %x_backprop : tensor<1x28x28x64xf32> +} + +// CHECK-LABEL: func @transposeFusedBatchNormGradV3_f16 +func @transposeFusedBatchNormGradV3_f16( + %arg0: tensor<1x64x28x28xf16>, + %arg1: tensor<1x64x28x28xf16>, + %arg2: tensor<64xf32> +) -> tensor<1x64x28x28xf16> { + + // CHECK: "tf.FusedBatchNormGradV3" + // CHECK-SAME: (%[[X_TRANSPOSE:[0-9]*]], %[[Y_TRANSPOSE:[0-9]*]], + // CHECK-SAME: data_format = "NHWC" + %x_backprop, %scale_backprop, %offset_backprop, %reserve_1, %reserve_2 + = "tf.FusedBatchNormGradV3"(%arg0, %arg1, %arg2, %arg2, %arg2, %arg2) + { + data_format = "NCHW", + epsilon = 1.001 : f32, + exponential_avg_factor = 1.0 : f32, + is_training = true + } + : (tensor<1x64x28x28xf16>, tensor<1x64x28x28xf16>, + tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>) + -> (tensor<1x64x28x28xf16>, + tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>) + + return %x_backprop : tensor<1x64x28x28xf16> +} + } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_to_nchw.mlir b/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_to_nchw.mlir index 22be6537adb..30599b2e437 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_to_nchw.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_to_nchw.mlir @@ -146,3 +146,82 @@ func @transposeConv2DBackpropInput( return %0 : tensor<1x32x32x3xf32> } + +// CHECK-LABEL: func @transposeFusedBatchNormV3 +func @transposeFusedBatchNormV3( + %arg0: tensor<1x28x28x64xf32>, + %arg1: tensor<64xf32> +) -> tensor<1x28x28x64xf32> { + + // CHECK: %[[ARG_PERM:[0-9]*]] = "tf.Const"() + // CHECK-SAME: {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} + // CHECK: %[[ARG_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg0, %[[ARG_PERM]]) + + // CHECK: "tf.FusedBatchNormV3" + // CHECK-SAME: (%[[ARG_TRANSPOSE]], %arg1, %arg1, %arg1, %arg1) + // CHECK-SAME: data_format = "NCHW" + // CHECK-SAME: (tensor<1x64x28x28xf32>, tensor<64xf32>, + // CHECK-SAME: -> (tensor<1x64x28x28xf32>, tensor<64xf32>, + + // CHECK: %[[RES_PERM:[0-9]*]] = "tf.Const"() + // CHECK-SAME: {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} + // CHECK: %[[RES_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%y, %[[RES_PERM]]) + // CHECK: return %[[RES_TRANSPOSE]] + + %y, %batch_mean, %batch_var, %reserve_1, %reserve_2, %reserve_3 + = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg1, %arg1, %arg1) + { + data_format = "NHWC", + epsilon = 1.001 : f32, + exponential_avg_factor = 1.0 : f32, + is_training = true + } + : (tensor<1x28x28x64xf32>, tensor<64xf32>, tensor<64xf32>, + tensor<64xf32>, tensor<64xf32>) + -> (tensor<1x28x28x64xf32>, tensor<64xf32>, tensor<64xf32>, + tensor<64xf32>, tensor<64xf32>, tensor<64xf32>) + + return %y : tensor<1x28x28x64xf32> +} + +// CHECK-LABEL: func @transposeFusedBatchNormGradV3 +func @transposeFusedBatchNormGradV3( + %arg0: tensor<1x28x28x64xf32>, + %arg1: tensor<1x28x28x64xf32>, + %arg2: tensor<64xf32> +) -> tensor<1x28x28x64xf32> { + + // CHECK: %[[ARG_PERM:[0-9]*]] = "tf.Const"() + // CHECK-SAME: {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} + + // CHECK: %[[ARG0_TPOSE:[0-9]*]] = "tf.Transpose"(%arg0, %[[ARG_PERM]]) + // CHECK: %[[ARG1_TPOSE:[0-9]*]] = "tf.Transpose"(%arg1, %[[ARG_PERM]]) + + // CHECK: "tf.FusedBatchNormGradV3" + // CHECK-SAME: (%[[ARG0_TPOSE]], %[[ARG1_TPOSE]], %arg2, %arg2, %arg2, %arg2) + // CHECK-SAME: data_format = "NCHW" + // CHECK-SAME: (tensor<1x64x28x28xf32>, tensor<1x64x28x28xf32>, + // CHECK-SAME: -> (tensor<1x64x28x28xf32>, + + // CHECK: %[[RES_PERM:[0-9]*]] = "tf.Const"() + // CHECK-SAME: {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} + + // CHECK: %[[RES_TPOSE:[0-9]*]] = "tf.Transpose" + // CHECK-SAME: (%x_backprop, %[[RES_PERM]]) + // CHECK: return %[[RES_TPOSE]] + + %x_backprop, %scale_backprop, %offset_backprop, %reserve_1, %reserve_2 + = "tf.FusedBatchNormGradV3"(%arg0, %arg1, %arg2, %arg2, %arg2, %arg2) + { + data_format = "NHWC", + epsilon = 1.001 : f32, + exponential_avg_factor = 1.0 : f32, + is_training = true + } + : (tensor<1x28x28x64xf32>, tensor<1x28x28x64xf32>, + tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>) + -> (tensor<1x28x28x64xf32>, + tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>) + + return %x_backprop : tensor<1x28x28x64xf32> +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_to_nhwc.mlir b/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_to_nhwc.mlir index e27448e1d0f..e6b3bf08394 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_to_nhwc.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_to_nhwc.mlir @@ -33,3 +33,40 @@ func @transposeConv2D(%input: tensor<1x3x32x32xf32>, %filter: tensor<1x1x3x8xf32 return %0 : tensor<1x8x32x32xf32> } + +// CHECK-LABEL: func @transposeFusedBatchNormV3 +func @transposeFusedBatchNormV3( + %arg0: tensor<1x64x28x28xf32>, + %arg1: tensor<64xf32> +) -> tensor<1x64x28x28xf32> { + + // CHECK: %[[ARG_PERM:[0-9]*]] = "tf.Const"() + // CHECK-SAME: {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} + // CHECK: %[[ARG_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg0, %[[ARG_PERM]]) + + // CHECK: "tf.FusedBatchNormV3" + // CHECK-SAME: (%[[ARG_TRANSPOSE]], %arg1, %arg1, %arg1, %arg1) + // CHECK-SAME: data_format = "NHWC" + // CHECK-SAME: (tensor<1x28x28x64xf32>, tensor<64xf32>, + // CHECK-SAME: -> (tensor<1x28x28x64xf32>, tensor<64xf32>, + + // CHECK: %[[RES_PERM:[0-9]*]] = "tf.Const"() + // CHECK-SAME: {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} + // CHECK: %[[RES_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%y, %[[RES_PERM]]) + // CHECK: return %[[RES_TRANSPOSE]] + + %y, %batch_mean, %batch_var, %reserve_1, %reserve_2, %reserve_3 + = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg1, %arg1, %arg1) + { + data_format = "NCHW", + epsilon = 1.001 : f32, + exponential_avg_factor = 1.0 : f32, + is_training = true + } + : (tensor<1x64x28x28xf32>, tensor<64xf32>, tensor<64xf32>, + tensor<64xf32>, tensor<64xf32>) + -> (tensor<1x64x28x28xf32>, tensor<64xf32>, tensor<64xf32>, + tensor<64xf32>, tensor<64xf32>, tensor<64xf32>) + + return %y : tensor<1x64x28x28xf32> +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir b/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir index 4f9e12736e4..4b38465257d 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir @@ -1,552 +1,1483 @@ // RUN: tf-opt -tf-legalize-hlo %s | FileCheck %s --dump-input-on-failure -//===----------------------------------------------------------------------===// -// Binary op legalizations. -//===----------------------------------------------------------------------===// + +func @biasAdd_NHWC(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> { + %0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32> + return %0 : tensor<1x32x10x32xi32> +} + +func @biasAdd_NCHW(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> { + %0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32> + return %0 : tensor<1x32x10x32xi32> +} + +func @biasAdd_dynamic(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor, tensor) -> tensor + return %0 : tensor +} func @add(%arg0: tensor<2xi32>) -> tensor<2xi32> { -%0 = xla_hlo.add %arg0, %arg0 : tensor<2xi32> -%1 = xla_hlo.add %0, %arg0 : tensor<2xi32> -return %1 : tensor<2xi32> + %0 = xla_hlo.add %arg0, %arg0 : tensor<2xi32> + %1 = xla_hlo.add %0, %arg0 : tensor<2xi32> + return %1 : tensor<2xi32> } -// CHECK-LABEL: func @add( -// CHECK-SAME: [[VAL_0:%.*]]: tensor<2xi32>) -> tensor<2xi32> { -// CHECK: [[VAL_1:%.*]] = "tf.AddV2"([[VAL_0]], [[VAL_0]]) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> -// CHECK: [[VAL_2:%.*]] = "tf.AddV2"([[VAL_1]], [[VAL_0]]) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> -// CHECK: return [[VAL_2]] : tensor<2xi32> func @broadcast_add(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> { -%0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> -return %0 : tensor<1x2xi32> + %0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> + return %0 : tensor<1x2xi32> } -// CHECK-LABEL: func @broadcast_add( -// CHECK-SAME: [[VAL_3:%.*]]: tensor<1xi32>, [[VAL_4:%.*]]: tensor<1x2xi32>) -> tensor<1x2xi32> { -// CHECK: [[VAL_5:%.*]] = "tf.AddV2"([[VAL_3]], [[VAL_4]]) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> -// CHECK: return [[VAL_5]] : tensor<1x2xi32> func @broadcast_multi_dim_add(%arg0: tensor<4x1x1xi32>, %arg1: tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32> { -%0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<[1, 2, 3]> : tensor<3xi64>} : (tensor<4x1x1xi32>, tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32> -return %0 : tensor<4x4x4x4xi32> + %0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<[1, 2, 3]> : tensor<3xi64>} : (tensor<4x1x1xi32>, tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32> + return %0 : tensor<4x4x4x4xi32> } -// CHECK-LABEL: func @broadcast_multi_dim_add( -// CHECK-SAME: [[VAL_6:%.*]]: tensor<4x1x1xi32>, [[VAL_7:%.*]]: tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32> { -// CHECK: [[VAL_8:%.*]] = "tf.AddV2"([[VAL_6]], [[VAL_7]]) : (tensor<4x1x1xi32>, tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32> -// CHECK: return [[VAL_8]] : tensor<4x4x4x4xi32> func @div(%arg0: tensor<2xi32>) -> tensor<2xi32> { -%0 = xla_hlo.divide %arg0, %arg0 : tensor<2xi32> -return %0 : tensor<2xi32> + %0 = xla_hlo.divide %arg0, %arg0 : tensor<2xi32> + return %0 : tensor<2xi32> } -// CHECK-LABEL: func @div( -// CHECK-SAME: [[VAL_9:%.*]]: tensor<2xi32>) -> tensor<2xi32> { -// CHECK: [[VAL_10:%.*]] = "tf.RealDiv"([[VAL_9]], [[VAL_9]]) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> -// CHECK: return [[VAL_10]] : tensor<2xi32> func @broadcast_div(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> { -%0 = "xla_hlo.divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> -return %0 : tensor<1x2xi32> + %0 = "xla_hlo.divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> + return %0 : tensor<1x2xi32> } -// CHECK-LABEL: func @broadcast_div( -// CHECK-SAME: [[VAL_11:%.*]]: tensor<1xi32>, [[VAL_12:%.*]]: tensor<1x2xi32>) -> tensor<1x2xi32> { -// CHECK: [[VAL_13:%.*]] = "tf.RealDiv"([[VAL_11]], [[VAL_12]]) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> -// CHECK: return [[VAL_13]] : tensor<1x2xi32> func @shift_left(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { -%0 = xla_hlo.shift_left %arg0, %arg1 : tensor<4xi32> -return %0 : tensor<4xi32> + %0 = xla_hlo.shift_left %arg0, %arg1 : tensor<4xi32> + return %0 : tensor<4xi32> } -// CHECK-LABEL: func @shift_left( -// CHECK-SAME: [[VAL_14:%.*]]: tensor<4xi32>, [[VAL_15:%.*]]: tensor<4xi32>) -> tensor<4xi32> { -// CHECK: [[VAL_16:%.*]] = "tf.LeftShift"([[VAL_14]], [[VAL_15]]) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> -// CHECK: return [[VAL_16]] : tensor<4xi32> func @div_dynamic(%arg0: tensor, %arg1: tensor) -> tensor { -%0 = "xla_hlo.divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor) -> tensor -return %0 : tensor + %0 = "xla_hlo.divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor) -> tensor + return %0 : tensor } -// CHECK-LABEL: func @div_dynamic( -// CHECK-SAME: [[VAL_17:%.*]]: tensor, [[VAL_18:%.*]]: tensor) -> tensor { -// CHECK: [[VAL_19:%.*]] = "tf.RealDiv"([[VAL_17]], [[VAL_18]]) : (tensor, tensor) -> tensor -// CHECK: return [[VAL_19]] : tensor - -func @div_unranked(%arg0: tensor<*xi32>, %arg1: tensor) -> tensor { -%0 = "tf.Div"(%arg0, %arg1) : (tensor<*xi32>, tensor) -> tensor -return %0 : tensor -} -// CHECK-LABEL: func @div_unranked( -// CHECK-SAME: [[VAL_20:%.*]]: tensor<*xi32>, [[VAL_21:%.*]]: tensor) -> tensor { -// CHECK: [[VAL_22:%.*]] = "tf.Div"([[VAL_20]], [[VAL_21]]) : (tensor<*xi32>, tensor) -> tensor -// CHECK: return [[VAL_22]] : tensor func @maximum(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { -%0 = xla_hlo.maximum %arg0, %arg1 : tensor<4xf32> -return %0 : tensor<4xf32> + %0 = xla_hlo.maximum %arg0, %arg1 : tensor<4xf32> + return %0 : tensor<4xf32> } -// CHECK-LABEL: func @maximum( -// CHECK-SAME: [[VAL_23:%.*]]: tensor<4xf32>, [[VAL_24:%.*]]: tensor<4xf32>) -> tensor<4xf32> { -// CHECK: [[VAL_25:%.*]] = "tf.Maximum"([[VAL_23]], [[VAL_24]]) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> -// CHECK: return [[VAL_25]] : tensor<4xf32> func @minimum(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { -%0 = xla_hlo.minimum %arg0, %arg1 : tensor<4xf32> -return %0 : tensor<4xf32> + %0 = xla_hlo.minimum %arg0, %arg1 : tensor<4xf32> + return %0 : tensor<4xf32> } -// CHECK-LABEL: func @minimum( -// CHECK-SAME: [[VAL_26:%.*]]: tensor<4xf32>, [[VAL_27:%.*]]: tensor<4xf32>) -> tensor<4xf32> { -// CHECK: [[VAL_28:%.*]] = "tf.Minimum"([[VAL_26]], [[VAL_27]]) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> -// CHECK: return [[VAL_28]] : tensor<4xf32> func @mul(%arg0: tensor<2xi32>) -> tensor<2xi32> { -%0 = xla_hlo.multiply %arg0, %arg0 : tensor<2xi32> -return %0 : tensor<2xi32> + %0 = xla_hlo.multiply %arg0, %arg0 : tensor<2xi32> + return %0 : tensor<2xi32> } -// CHECK-LABEL: func @mul( -// CHECK-SAME: [[VAL_29:%.*]]: tensor<2xi32>) -> tensor<2xi32> { -// CHECK: [[VAL_30:%.*]] = "tf.Mul"([[VAL_29]], [[VAL_29]]) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> -// CHECK: return [[VAL_30]] : tensor<2xi32> func @broadcast_mul(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> { -%0 = "xla_hlo.multiply"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> -return %0 : tensor<1x2xi32> + %0 = "xla_hlo.multiply"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> + return %0 : tensor<1x2xi32> } -// CHECK-LABEL: func @broadcast_mul( -// CHECK-SAME: [[VAL_31:%.*]]: tensor<1xi32>, [[VAL_32:%.*]]: tensor<1x2xi32>) -> tensor<1x2xi32> { -// CHECK: [[VAL_33:%.*]] = "tf.Mul"([[VAL_31]], [[VAL_32]]) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> -// CHECK: return [[VAL_33]] : tensor<1x2xi32> func @real_div(%arg0: tensor<2xi32>) -> tensor<2xi32> { -%0 = xla_hlo.divide %arg0, %arg0 : tensor<2xi32> -return %0 : tensor<2xi32> + %0 = xla_hlo.divide %arg0, %arg0 : tensor<2xi32> + return %0 : tensor<2xi32> } -// CHECK-LABEL: func @real_div( -// CHECK-SAME: [[VAL_34:%.*]]: tensor<2xi32>) -> tensor<2xi32> { -// CHECK: [[VAL_35:%.*]] = "tf.RealDiv"([[VAL_34]], [[VAL_34]]) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> -// CHECK: return [[VAL_35]] : tensor<2xi32> func @broadcast_real_div(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> { -%0 = "xla_hlo.divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> -return %0 : tensor<1x2xi32> + %0 = "xla_hlo.divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> + return %0 : tensor<1x2xi32> } -// CHECK-LABEL: func @broadcast_real_div( -// CHECK-SAME: [[VAL_36:%.*]]: tensor<1xi32>, [[VAL_37:%.*]]: tensor<1x2xi32>) -> tensor<1x2xi32> { -// CHECK: [[VAL_38:%.*]] = "tf.RealDiv"([[VAL_36]], [[VAL_37]]) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> -// CHECK: return [[VAL_38]] : tensor<1x2xi32> func @sub(%arg0: tensor<2xi32>) -> tensor<2xi32> { -%0 = xla_hlo.subtract %arg0, %arg0 : tensor<2xi32> -return %0 : tensor<2xi32> + %0 = xla_hlo.subtract %arg0, %arg0 : tensor<2xi32> + return %0 : tensor<2xi32> } -// CHECK-LABEL: func @sub( -// CHECK-SAME: [[VAL_39:%.*]]: tensor<2xi32>) -> tensor<2xi32> { -// CHECK: [[VAL_40:%.*]] = "tf.Sub"([[VAL_39]], [[VAL_39]]) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> -// CHECK: return [[VAL_40]] : tensor<2xi32> func @broadcast_sub(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> { -%0 = "xla_hlo.subtract"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> -return %0 : tensor<1x2xi32> + %0 = "xla_hlo.subtract"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> + return %0 : tensor<1x2xi32> } -// CHECK-LABEL: func @broadcast_sub( -// CHECK-SAME: [[VAL_41:%.*]]: tensor<1xi32>, [[VAL_42:%.*]]: tensor<1x2xi32>) -> tensor<1x2xi32> { -// CHECK: [[VAL_43:%.*]] = "tf.Sub"([[VAL_41]], [[VAL_42]]) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> -// CHECK: return [[VAL_43]] : tensor<1x2xi32> func @shift_right(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { -%0 = xla_hlo.shift_right_arithmetic %arg0, %arg1 : tensor<4xi32> -return %0 : tensor<4xi32> + %0 = xla_hlo.shift_right_arithmetic %arg0, %arg1 : tensor<4xi32> + return %0 : tensor<4xi32> } -// CHECK-LABEL: func @shift_right( -// CHECK-SAME: [[VAL_44:%.*]]: tensor<4xi32>, [[VAL_45:%.*]]: tensor<4xi32>) -> tensor<4xi32> { -// CHECK: [[VAL_46:%.*]] = "tf.RightShift"([[VAL_44]], [[VAL_45]]) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> -// CHECK: return [[VAL_46]] : tensor<4xi32> func @broadcast_shift_right(%arg0: tensor<4xi32>, %arg1: tensor<2x4xi32>) -> tensor<2x4xi32> { -%0 = "xla_hlo.shift_right_arithmetic"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xi32>, tensor<2x4xi32>) -> tensor<2x4xi32> -return %0 : tensor<2x4xi32> + %0 = "xla_hlo.shift_right_arithmetic"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xi32>, tensor<2x4xi32>) -> tensor<2x4xi32> + return %0 : tensor<2x4xi32> } -// CHECK-LABEL: func @broadcast_shift_right( -// CHECK-SAME: [[VAL_47:%.*]]: tensor<4xi32>, [[VAL_48:%.*]]: tensor<2x4xi32>) -> tensor<2x4xi32> { -// CHECK: [[VAL_49:%.*]] = "tf.RightShift"([[VAL_47]], [[VAL_48]]) : (tensor<4xi32>, tensor<2x4xi32>) -> tensor<2x4xi32> -// CHECK: return [[VAL_49]] : tensor<2x4xi32> -func @shift_right_unsigned(%arg0: tensor<4xui8>, %arg1: tensor<4xui8>) -> tensor<4xui8> { -%0 = "tf.RightShift"(%arg0, %arg1) : (tensor<4xui8>, tensor<4xui8>) -> tensor<4xui8> -return %0 : tensor<4xui8> +func @and(%arg0: tensor<2xi1>) -> tensor<2xi1> { + %0 = xla_hlo.and %arg0, %arg0 : tensor<2xi1> + return %0 : tensor<2xi1> } -// CHECK-LABEL: func @shift_right_unsigned( -// CHECK-SAME: [[VAL_50:%.*]]: tensor<4xui8>, [[VAL_51:%.*]]: tensor<4xui8>) -> tensor<4xui8> { -// CHECK: [[VAL_52:%.*]] = "tf.RightShift"([[VAL_50]], [[VAL_51]]) : (tensor<4xui8>, tensor<4xui8>) -> tensor<4xui8> -// CHECK: return [[VAL_52]] : tensor<4xui8> -func @broadcast_shift_right_unsigned(%arg0: tensor<4xui8>, %arg1: tensor<2x4xui8>) -> tensor<2x4xui8> { -%0 = "tf.RightShift"(%arg0, %arg1) : (tensor<4xui8>, tensor<2x4xui8>) -> tensor<2x4xui8> -return %0 : tensor<2x4xui8> +func @and_broadcast(%arg0: tensor<1xi1>, %arg1: tensor<1x2xi1>) -> tensor<1x2xi1> { + %0 = "xla_hlo.and"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi1>, tensor<1x2xi1>) -> tensor<1x2xi1> + return %0 : tensor<1x2xi1> } -// CHECK-LABEL: func @broadcast_shift_right_unsigned( -// CHECK-SAME: [[VAL_53:%.*]]: tensor<4xui8>, [[VAL_54:%.*]]: tensor<2x4xui8>) -> tensor<2x4xui8> { -// CHECK: [[VAL_55:%.*]] = "tf.RightShift"([[VAL_53]], [[VAL_54]]) : (tensor<4xui8>, tensor<2x4xui8>) -> tensor<2x4xui8> -// CHECK: return [[VAL_55]] : tensor<2x4xui8> -//===----------------------------------------------------------------------===// -// Unary op legalizations. -//===----------------------------------------------------------------------===// +func @and_dynamic(%arg0: tensor, %arg1: tensor<1xi1>) -> tensor { + %0 = "xla_hlo.and"(%arg0, %arg1) : (tensor, tensor<1xi1>) -> tensor + return %0 : tensor +} + +func @or(%arg0: tensor<2xi1>) -> tensor<2xi1> { + %0 = xla_hlo.or %arg0, %arg0 : tensor<2xi1> + return %0 : tensor<2xi1> +} + +func @or_broadcast(%arg0: tensor<1xi1>, %arg1: tensor<1x2xi1>) -> tensor<1x2xi1> { + %0 = "xla_hlo.or"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi1>, tensor<1x2xi1>) -> tensor<1x2xi1> + return %0 : tensor<1x2xi1> +} + +func @or_dynamic(%arg0: tensor, %arg1: tensor<1xi1>) -> tensor { + %0 = "xla_hlo.or"(%arg0, %arg1) : (tensor, tensor<1xi1>) -> tensor + return %0 : tensor +} + +func @bitwise_or(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { + %0 = xla_hlo.or %arg0, %arg1 : tensor<4xi32> + return %0 : tensor<4xi32> +} + +func @bitwise_or_broadcast(%arg0: tensor<1xi8>, %arg1: tensor<1x4xi8>) -> tensor<1x4xi8> { + %0 = "xla_hlo.or"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi8>, tensor<1x4xi8>) -> tensor<1x4xi8> + return %0 : tensor<1x4xi8> +} + +func @bitwise_or_dynamic(%arg0: tensor, %arg1: tensor<1xi32>) -> tensor { + %0 = "xla_hlo.or"(%arg0, %arg1) : (tensor, tensor<1xi32>) -> tensor + return %0 : tensor +} + +func @bitwise_and(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { + %0 = xla_hlo.and %arg0, %arg1 : tensor<4xi32> + return %0 : tensor<4xi32> +} + +func @bitwise_and_broadcast(%arg0: tensor<1xi8>, %arg1: tensor<1x4xi8>) -> tensor<1x4xi8> { + %0 = "xla_hlo.and"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi8>, tensor<1x4xi8>) -> tensor<1x4xi8> + return %0 : tensor<1x4xi8> +} + +func @bitwise_and_dynamic(%arg0: tensor, %arg1: tensor<1xi32>) -> tensor { + %0 = "xla_hlo.and"(%arg0, %arg1) : (tensor, tensor<1xi32>) -> tensor + return %0 : tensor +} + +func @pow(%arg0: tensor<2xf32>) -> tensor<2xf32> { + %0 = xla_hlo.pow %arg0, %arg0 : tensor<2xf32> + return %0 : tensor<2xf32> +} + +func @pow_dynamic(%arg0: tensor) -> tensor { + %0 = xla_hlo.pow %arg0, %arg0 : tensor + return %0 : tensor +} + +func @floordiv_broadcast_i32(%arg0: tensor<2x3xi32>, %arg1: tensor<3xi32>) -> tensor<2x3xi32> { + %0 = xla_hlo.constant dense<0> : tensor<2x3xi32> + %1 = "xla_hlo.compare"(%arg0, %0) {comparison_direction = "LT"} : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi1> + %2 = xla_hlo.constant dense<0> : tensor<3xi32> + %3 = "xla_hlo.compare"(%arg1, %2) {comparison_direction = "LT"} : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1> + %4 = "xla_hlo.compare"(%1, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<2x3xi1>, tensor<3xi1>) -> tensor<2x3xi1> + %5 = "xla_hlo.divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32> + %6 = "xla_hlo.abs"(%arg0) : (tensor<2x3xi32>) -> tensor<2x3xi32> + %7 = "xla_hlo.abs"(%arg1) : (tensor<3xi32>) -> tensor<3xi32> + %8 = xla_hlo.constant dense<1> : tensor<3xi32> + %9 = xla_hlo.subtract %7, %8 : tensor<3xi32> + %10 = "xla_hlo.add"(%6, %9) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32> + %11 = "xla_hlo.neg"(%10) : (tensor<2x3xi32>) -> tensor<2x3xi32> + %12 = "xla_hlo.abs"(%arg1) : (tensor<3xi32>) -> tensor<3xi32> + %13 = "xla_hlo.divide"(%11, %12) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32> + %14 = "xla_hlo.select"(%4, %5, %13) : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> + return %14 : tensor<2x3xi32> +} + +func @floordiv_reverse_broadcast_i32(%arg0: tensor<3xi32>, %arg1: tensor<2x3xi32>) -> tensor<2x3xi32> { + %0 = xla_hlo.constant dense<0> : tensor<3xi32> + %1 = "xla_hlo.compare"(%arg0, %0) {comparison_direction = "LT"} : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1> + %2 = xla_hlo.constant dense<0> : tensor<2x3xi32> + %3 = "xla_hlo.compare"(%arg1, %2) {comparison_direction = "LT"} : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi1> + %4 = "xla_hlo.compare"(%1, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<3xi1>, tensor<2x3xi1>) -> tensor<2x3xi1> + %5 = "xla_hlo.divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> + %6 = "xla_hlo.abs"(%arg0) : (tensor<3xi32>) -> tensor<3xi32> + %7 = "xla_hlo.abs"(%arg1) : (tensor<2x3xi32>) -> tensor<2x3xi32> + %8 = xla_hlo.constant dense<1> : tensor<2x3xi32> + %9 = xla_hlo.subtract %7, %8 : tensor<2x3xi32> + %10 = "xla_hlo.add"(%6, %9) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> + %11 = "xla_hlo.neg"(%10) : (tensor<2x3xi32>) -> tensor<2x3xi32> + %12 = "xla_hlo.abs"(%arg1) : (tensor<2x3xi32>) -> tensor<2x3xi32> + %13 = xla_hlo.divide %11, %12 : tensor<2x3xi32> + %14 = "xla_hlo.select"(%4, %5, %13) : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> + return %14 : tensor<2x3xi32> +} + +func @floordiv_f32(%arg0: tensor<2xf32>) -> tensor<2xf32> { + %0 = xla_hlo.divide %arg0, %arg0 : tensor<2xf32> + %1 = xla_hlo.divide %arg0, %arg0 : tensor<2xf32> + %2 = "xla_hlo.floor"(%1) : (tensor<2xf32>) -> tensor<2xf32> + return %2 : tensor<2xf32> +} + +func @floordiv_f16_broadcast(%arg0: tensor<2x3xf16>, %arg1: tensor<3xf16>) -> tensor<2x3xf16> { + %0 = "xla_hlo.divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16> + %1 = "xla_hlo.divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16> + %2 = "xla_hlo.floor"(%1) : (tensor<2x3xf16>) -> tensor<2x3xf16> + return %2 : tensor<2x3xf16> +} + +func @equal(%arg0: tensor<2xi32>) -> tensor<2xi1> { + %0 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "EQ"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + return %0 : tensor<2xi1> +} + +func @equal_dynamic(%arg0: tensor, %arg1: tensor<1xi32>) -> tensor { + %0 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "EQ"} : (tensor, tensor<1xi32>) -> tensor + return %0 : tensor +} + +func @equal_broadcast(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { + %0 = "xla_hlo.compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> + return %0 : tensor<1x2xi1> +} + +func @equal_broadcast_no_incompatible_shapes_error(%arg0: tensor<2xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { + %0 = "xla_hlo.compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<2xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> + return %0 : tensor<1x2xi1> +} + +func @equal_incompatible_shape_broadcastable(%arg0: tensor, %arg1: tensor<1xi32>) -> tensor { + %0 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "EQ"} : (tensor, tensor<1xi32>) -> tensor + return %0 : tensor +} + +func @notequal(%arg0: tensor<2xi32>) -> tensor<2xi1> { + %0 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "NE"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + return %0 : tensor<2xi1> +} + +func @notequal_broadcast(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { + %0 = "xla_hlo.compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> + return %0 : tensor<1x2xi1> +} + +func @notequal_broadcast_no_incompatible_shapes_error(%arg0: tensor<2xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { + %0 = "xla_hlo.compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} : (tensor<2xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> + return %0 : tensor<1x2xi1> +} + +func @notequal_incompatible_shape_broadcastable(%arg0: tensor, %arg1: tensor<1xi32>) -> tensor { + %0 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "NE"} : (tensor, tensor<1xi32>) -> tensor + return %0 : tensor +} + +func @greater(%arg0: tensor<2xi32>) -> tensor<2xi1> { + %0 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "GT"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + return %0 : tensor<2xi1> +} + +func @broadcast_greater(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { + %0 = "xla_hlo.compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "GT"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> + return %0 : tensor<1x2xi1> +} + +func @greater_equal(%arg0: tensor<2xi32>) -> tensor<2xi1> { + %0 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "GE"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + return %0 : tensor<2xi1> +} + +func @broadcast_greater_equal(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { + %0 = "xla_hlo.compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "GE"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> + return %0 : tensor<1x2xi1> +} + +func @less(%arg0: tensor<2xi32>) -> tensor<2xi1> { + %0 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "LT"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + return %0 : tensor<2xi1> +} + +func @broadcast_less(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { + %0 = "xla_hlo.compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "LT"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> + return %0 : tensor<1x2xi1> +} + +func @less_equal(%arg0: tensor<2xi32>) -> tensor<2xi1> { + %0 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "LE"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + return %0 : tensor<2xi1> +} + +func @broadcast_less_equal(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { + %0 = "xla_hlo.compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "LE"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> + return %0 : tensor<1x2xi1> +} + +func @concat_v2(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>) -> tensor<6x3xf32> { + %2 = "xla_hlo.concatenate"(%arg0, %arg1) {dimension = 0 : i64} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<6x3xf32> + return %2 : tensor<6x3xf32> +} + +func @concat_v2_1d_axis(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>) -> tensor<3x6xf32> { + %2 = "xla_hlo.concatenate"(%arg0, %arg1) {dimension = 1 : i64} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x6xf32> + return %2 : tensor<3x6xf32> +} + +func @const() -> tensor<2xi32> { + %0 = xla_hlo.constant dense<0> : tensor<2xi32> + return %0 : tensor<2xi32> +} + +func @const_dynamic_output() -> tensor<*xi32> { + %0 = xla_hlo.constant {value = dense<0> : tensor<2xi32>} : tensor<*xi32> + return %0 : tensor<*xi32> +} + +func @relu(%arg0: tensor<1xi32>) -> tensor<1xi32> { + %0 = xla_hlo.constant dense<0> : tensor + %1 = "xla_hlo.maximum"(%0, %arg0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor, tensor<1xi32>) -> tensor<1xi32> + return %1 : tensor<1xi32> +} + +func @relu_unranked(%arg0: tensor) -> tensor { + %0 = xla_hlo.constant dense<0> : tensor + %1 = "xla_hlo.maximum"(%0, %arg0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor, tensor) -> tensor + return %1 : tensor +} + +func @relu6(%arg0: tensor<1xi32>) -> tensor<1xi32> { + %0 = xla_hlo.constant dense<0> : tensor + %1 = xla_hlo.constant dense<6> : tensor + %2 = "xla_hlo.minimum"(%arg0, %1) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<1xi32>, tensor) -> tensor<1xi32> + %3 = "xla_hlo.maximum"(%2, %0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<1xi32>, tensor) -> tensor<1xi32> + return %3 : tensor<1xi32> +} + +func @relu6_unranked(%arg0: tensor) -> tensor { + %0 = xla_hlo.constant dense<0> : tensor + %1 = xla_hlo.constant dense<6> : tensor + %2 = "xla_hlo.minimum"(%arg0, %1) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor, tensor) -> tensor + %3 = "xla_hlo.maximum"(%2, %0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor, tensor) -> tensor + return %3 : tensor +} + +func @relu_grad(%arg0: tensor<4x8xf32>, %arg1: tensor) -> tensor<4x8xf32> { + %0 = xla_hlo.constant dense<0.000000e+00> : tensor + %1 = "xla_hlo.compare"(%arg1, %0) {broadcast_dimensions = dense<[]> : tensor<0xi64>, comparison_direction = "GT"} : (tensor, tensor) -> tensor + %2 = xla_hlo.constant dense<0.000000e+00> : tensor<4x8xf32> + %3 = "xla_hlo.select"(%1, %arg0, %2) : (tensor, tensor<4x8xf32>, tensor<4x8xf32>) -> tensor<4x8xf32> + return %3 : tensor<4x8xf32> +} + +func @select(%arg0: tensor<2xi1>, %arg1: tensor<2xi32>, %arg2: tensor<2xi32>) -> tensor<2xi32> { + %0 = "xla_hlo.select"(%arg0, %arg1, %arg2) : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> + return %0 : tensor<2xi32> +} + +func @select_float(%arg0: tensor<2xi1>, %arg1: tensor<2xf32>, %arg2: tensor<2xf32>) -> tensor<2xf32> { + %0 = "xla_hlo.select"(%arg0, %arg1, %arg2) : (tensor<2xi1>, tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> + return %0 : tensor<2xf32> +} + +func @select_multidimensional(%arg0: tensor<3x2xi1>, %arg1: tensor<3x2xi32>, %arg2: tensor<3x2xi32>) -> tensor<3x2xi32> { + %0 = "xla_hlo.select"(%arg0, %arg1, %arg2) : (tensor<3x2xi1>, tensor<3x2xi32>, tensor<3x2xi32>) -> tensor<3x2xi32> + return %0 : tensor<3x2xi32> +} + +func @selectv2(%arg0: tensor<2xi1>, %arg1: tensor<2xi32>, %arg2: tensor<2xi32>) -> tensor<2xi32> { + %0 = "xla_hlo.select"(%arg0, %arg1, %arg2) : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> + return %0 : tensor<2xi32> +} + +func @selectv2_pred_scalar(%arg0: tensor, %arg1: tensor<2xi32>, %arg2: tensor<2xi32>) -> tensor<2xi32> { + %0 = "xla_hlo.select"(%arg0, %arg1, %arg2) : (tensor, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> + return %0 : tensor<2xi32> +} + +func @transpose_2d(%arg0: tensor<2x3xf32>) -> tensor<3x2xf32> { + %0 = xla_hlo.constant dense<[1, 0]> : tensor<2xi64> + %1 = xla_hlo.constant dense<[1, 0]> : tensor<2xi64> + %2 = "xla_hlo.transpose"(%arg0) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<2x3xf32>) -> tensor<3x2xf32> + return %2 : tensor<3x2xf32> +} + +func @transpose_3d_int32(%arg0: tensor<1x2x3xf32>) -> tensor<3x2x1xf32> { + %0 = xla_hlo.constant dense<[2, 1, 0]> : tensor<3xi32> + %1 = xla_hlo.constant dense<[2, 1, 0]> : tensor<3xi64> + %2 = "xla_hlo.transpose"(%arg0) {permutation = dense<[2, 1, 0]> : tensor<3xi64>} : (tensor<1x2x3xf32>) -> tensor<3x2x1xf32> + return %2 : tensor<3x2x1xf32> +} + +func @transpose_3d(%arg0: tensor<1x2x3xf32>) -> tensor<3x2x1xf32> { + %0 = xla_hlo.constant dense<[2, 1, 0]> : tensor<3xi64> + %1 = xla_hlo.constant dense<[2, 1, 0]> : tensor<3xi64> + %2 = "xla_hlo.transpose"(%arg0) {permutation = dense<[2, 1, 0]> : tensor<3xi64>} : (tensor<1x2x3xf32>) -> tensor<3x2x1xf32> + return %2 : tensor<3x2x1xf32> +} + +func @transpose_dynamic_2d(%arg0: tensor) -> tensor<4x?xf32> { + %0 = xla_hlo.constant dense<[1, 0]> : tensor<2xi64> + %1 = xla_hlo.constant dense<[1, 0]> : tensor<2xi64> + %2 = "xla_hlo.transpose"(%arg0) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor) -> tensor<4x?xf32> + return %2 : tensor<4x?xf32> +} + +func @transpose_unranked_2d(%arg0: tensor<*xf32>) -> tensor<*xf32> { + %0 = xla_hlo.constant dense<[1, 0]> : tensor<2xi64> + %1 = xla_hlo.constant dense<[1, 0]> : tensor<2xi64> + %2 = "xla_hlo.transpose"(%arg0) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<*xf32>) -> tensor<*xf32> + return %2 : tensor<*xf32> +} func @abs(%arg0: tensor<2xf32>) -> tensor<2xf32> { %0 = "xla_hlo.abs"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> return %0 : tensor<2xf32> } -// CHECK-LABEL: func @abs( -// CHECK-SAME: [[VAL_0:%.*]]: tensor<2xf32>) -> tensor<2xf32> { -// CHECK: [[VAL_1:%.*]] = "tf.Abs"([[VAL_0]]) : (tensor<2xf32>) -> tensor<2xf32> -// CHECK: return [[VAL_1]] : tensor<2xf32> func @abs_dynamic(%arg0: tensor) -> tensor { %0 = "xla_hlo.abs"(%arg0) : (tensor) -> tensor return %0 : tensor } -// CHECK-LABEL: func @abs_dynamic( -// CHECK-SAME: [[VAL_2:%.*]]: tensor) -> tensor { -// CHECK: [[VAL_3:%.*]] = "tf.Abs"([[VAL_2]]) : (tensor) -> tensor -// CHECK: return [[VAL_3]] : tensor func @abs_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { %0 = "xla_hlo.abs"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } -// CHECK-LABEL: func @abs_unranked( -// CHECK-SAME: [[VAL_4:%.*]]: tensor<*xf32>) -> tensor<*xf32> { -// CHECK: [[VAL_5:%.*]] = "tf.Abs"([[VAL_4]]) : (tensor<*xf32>) -> tensor<*xf32> -// CHECK: return [[VAL_5]] : tensor<*xf32> func @ceil(%arg0: tensor<2xf32>) -> tensor<2xf32> { %0 = "xla_hlo.ceil"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> return %0 : tensor<2xf32> } -// CHECK-LABEL: func @ceil( -// CHECK-SAME: [[VAL_6:%.*]]: tensor<2xf32>) -> tensor<2xf32> { -// CHECK: [[VAL_7:%.*]] = "tf.Ceil"([[VAL_6]]) : (tensor<2xf32>) -> tensor<2xf32> -// CHECK: return [[VAL_7]] : tensor<2xf32> func @ceil_dynamic(%arg0: tensor) -> tensor { %0 = "xla_hlo.ceil"(%arg0) : (tensor) -> tensor return %0 : tensor } -// CHECK-LABEL: func @ceil_dynamic( -// CHECK-SAME: [[VAL_8:%.*]]: tensor) -> tensor { -// CHECK: [[VAL_9:%.*]] = "tf.Ceil"([[VAL_8]]) : (tensor) -> tensor -// CHECK: return [[VAL_9]] : tensor func @ceil_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { %0 = "xla_hlo.ceil"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } -// CHECK-LABEL: func @ceil_unranked( -// CHECK-SAME: [[VAL_10:%.*]]: tensor<*xf32>) -> tensor<*xf32> { -// CHECK: [[VAL_11:%.*]] = "tf.Ceil"([[VAL_10]]) : (tensor<*xf32>) -> tensor<*xf32> -// CHECK: return [[VAL_11]] : tensor<*xf32> + +func @complex_abs(%arg0: tensor<2xcomplex>) -> tensor<2xf32> { + %0 = "xla_hlo.abs"(%arg0) : (tensor<2xcomplex>) -> tensor<2xf32> + return %0 : tensor<2xf32> +} func @cos(%arg0: tensor<2xf32>) -> tensor<2xf32> { %0 = "xla_hlo.cos"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> return %0 : tensor<2xf32> } -// CHECK-LABEL: func @cos( -// CHECK-SAME: [[VAL_12:%.*]]: tensor<2xf32>) -> tensor<2xf32> { -// CHECK: [[VAL_13:%.*]] = "tf.Cos"([[VAL_12]]) : (tensor<2xf32>) -> tensor<2xf32> -// CHECK: return [[VAL_13]] : tensor<2xf32> func @cos_dynamic(%arg0: tensor) -> tensor { %0 = "xla_hlo.cos"(%arg0) : (tensor) -> tensor return %0 : tensor } -// CHECK-LABEL: func @cos_dynamic( -// CHECK-SAME: [[VAL_14:%.*]]: tensor) -> tensor { -// CHECK: [[VAL_15:%.*]] = "tf.Cos"([[VAL_14]]) : (tensor) -> tensor -// CHECK: return [[VAL_15]] : tensor func @cos_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { %0 = "xla_hlo.cos"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } -// CHECK-LABEL: func @cos_unranked( -// CHECK-SAME: [[VAL_16:%.*]]: tensor<*xf32>) -> tensor<*xf32> { -// CHECK: [[VAL_17:%.*]] = "tf.Cos"([[VAL_16]]) : (tensor<*xf32>) -> tensor<*xf32> -// CHECK: return [[VAL_17]] : tensor<*xf32> func @exp(%arg0: tensor<2xf32>) -> tensor<2xf32> { %0 = "xla_hlo.exp"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> return %0 : tensor<2xf32> } -// CHECK-LABEL: func @exp( -// CHECK-SAME: [[VAL_18:%.*]]: tensor<2xf32>) -> tensor<2xf32> { -// CHECK: [[VAL_19:%.*]] = "tf.Exp"([[VAL_18]]) : (tensor<2xf32>) -> tensor<2xf32> -// CHECK: return [[VAL_19]] : tensor<2xf32> func @exp_dynamic(%arg0: tensor) -> tensor { %0 = "xla_hlo.exp"(%arg0) : (tensor) -> tensor return %0 : tensor } -// CHECK-LABEL: func @exp_dynamic( -// CHECK-SAME: [[VAL_20:%.*]]: tensor) -> tensor { -// CHECK: [[VAL_21:%.*]] = "tf.Exp"([[VAL_20]]) : (tensor) -> tensor -// CHECK: return [[VAL_21]] : tensor func @exp_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { %0 = "xla_hlo.exp"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } -// CHECK-LABEL: func @exp_unranked( -// CHECK-SAME: [[VAL_22:%.*]]: tensor<*xf32>) -> tensor<*xf32> { -// CHECK: [[VAL_23:%.*]] = "tf.Exp"([[VAL_22]]) : (tensor<*xf32>) -> tensor<*xf32> -// CHECK: return [[VAL_23]] : tensor<*xf32> func @floor(%arg0: tensor<2xf32>) -> tensor<2xf32> { %0 = "xla_hlo.floor"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> return %0 : tensor<2xf32> } -// CHECK-LABEL: func @floor( -// CHECK-SAME: [[VAL_24:%.*]]: tensor<2xf32>) -> tensor<2xf32> { -// CHECK: [[VAL_25:%.*]] = "tf.Floor"([[VAL_24]]) : (tensor<2xf32>) -> tensor<2xf32> -// CHECK: return [[VAL_25]] : tensor<2xf32> func @floor_dynamic(%arg0: tensor) -> tensor { %0 = "xla_hlo.floor"(%arg0) : (tensor) -> tensor return %0 : tensor } -// CHECK-LABEL: func @floor_dynamic( -// CHECK-SAME: [[VAL_26:%.*]]: tensor) -> tensor { -// CHECK: [[VAL_27:%.*]] = "tf.Floor"([[VAL_26]]) : (tensor) -> tensor -// CHECK: return [[VAL_27]] : tensor func @floor_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { %0 = "xla_hlo.floor"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } -// CHECK-LABEL: func @floor_unranked( -// CHECK-SAME: [[VAL_28:%.*]]: tensor<*xf32>) -> tensor<*xf32> { -// CHECK: [[VAL_29:%.*]] = "tf.Floor"([[VAL_28]]) : (tensor<*xf32>) -> tensor<*xf32> -// CHECK: return [[VAL_29]] : tensor<*xf32> func @is_finite(%arg0: tensor<2xf32>) -> tensor<2xi1> { %0 = "xla_hlo.is_finite"(%arg0) : (tensor<2xf32>) -> tensor<2xi1> return %0 : tensor<2xi1> } -// CHECK-LABEL: func @is_finite( -// CHECK-SAME: [[VAL_30:%.*]]: tensor<2xf32>) -> tensor<2xi1> { -// CHECK: [[VAL_31:%.*]] = "tf.IsFinite"([[VAL_30]]) : (tensor<2xf32>) -> tensor<2xi1> -// CHECK: return [[VAL_31]] : tensor<2xi1> func @is_finite_dynamic(%arg0: tensor) -> tensor { %0 = "xla_hlo.is_finite"(%arg0) : (tensor) -> tensor return %0 : tensor } -// CHECK-LABEL: func @is_finite_dynamic( -// CHECK-SAME: [[VAL_32:%.*]]: tensor) -> tensor { -// CHECK: [[VAL_33:%.*]] = "tf.IsFinite"([[VAL_32]]) : (tensor) -> tensor -// CHECK: return [[VAL_33]] : tensor func @is_finite_unranked(%arg0: tensor<*xf32>) -> tensor<*xi1> { %0 = "xla_hlo.is_finite"(%arg0) : (tensor<*xf32>) -> tensor<*xi1> return %0 : tensor<*xi1> } -// CHECK-LABEL: func @is_finite_unranked( -// CHECK-SAME: [[VAL_34:%.*]]: tensor<*xf32>) -> tensor<*xi1> { -// CHECK: [[VAL_35:%.*]] = "tf.IsFinite"([[VAL_34]]) : (tensor<*xf32>) -> tensor<*xi1> -// CHECK: return [[VAL_35]] : tensor<*xi1> func @log(%arg0: tensor<2xf32>) -> tensor<2xf32> { %0 = "xla_hlo.log"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> return %0 : tensor<2xf32> } -// CHECK-LABEL: func @log( -// CHECK-SAME: [[VAL_36:%.*]]: tensor<2xf32>) -> tensor<2xf32> { -// CHECK: [[VAL_37:%.*]] = "tf.Log"([[VAL_36]]) : (tensor<2xf32>) -> tensor<2xf32> -// CHECK: return [[VAL_37]] : tensor<2xf32> func @log_dynamic(%arg0: tensor) -> tensor { %0 = "xla_hlo.log"(%arg0) : (tensor) -> tensor return %0 : tensor } -// CHECK-LABEL: func @log_dynamic( -// CHECK-SAME: [[VAL_38:%.*]]: tensor) -> tensor { -// CHECK: [[VAL_39:%.*]] = "tf.Log"([[VAL_38]]) : (tensor) -> tensor -// CHECK: return [[VAL_39]] : tensor func @log_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { %0 = "xla_hlo.log"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } -// CHECK-LABEL: func @log_unranked( -// CHECK-SAME: [[VAL_40:%.*]]: tensor<*xf32>) -> tensor<*xf32> { -// CHECK: [[VAL_41:%.*]] = "tf.Log"([[VAL_40]]) : (tensor<*xf32>) -> tensor<*xf32> -// CHECK: return [[VAL_41]] : tensor<*xf32> func @log1p(%arg0: tensor<2xf32>) -> tensor<2xf32> { %0 = "xla_hlo.log_plus_one"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> return %0 : tensor<2xf32> } -// CHECK-LABEL: func @log1p( -// CHECK-SAME: [[VAL_42:%.*]]: tensor<2xf32>) -> tensor<2xf32> { -// CHECK: [[VAL_43:%.*]] = "tf.Log1p"([[VAL_42]]) : (tensor<2xf32>) -> tensor<2xf32> -// CHECK: return [[VAL_43]] : tensor<2xf32> func @log1p_dynamic(%arg0: tensor) -> tensor { %0 = "xla_hlo.log_plus_one"(%arg0) : (tensor) -> tensor return %0 : tensor } -// CHECK-LABEL: func @log1p_dynamic( -// CHECK-SAME: [[VAL_44:%.*]]: tensor) -> tensor { -// CHECK: [[VAL_45:%.*]] = "tf.Log1p"([[VAL_44]]) : (tensor) -> tensor -// CHECK: return [[VAL_45]] : tensor func @log1p_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { %0 = "xla_hlo.log_plus_one"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } -// CHECK-LABEL: func @log1p_unranked( -// CHECK-SAME: [[VAL_46:%.*]]: tensor<*xf32>) -> tensor<*xf32> { -// CHECK: [[VAL_47:%.*]] = "tf.Log1p"([[VAL_46]]) : (tensor<*xf32>) -> tensor<*xf32> -// CHECK: return [[VAL_47]] : tensor<*xf32> - -func @not_op_unranked(%arg0: tensor<*xi1>) -> tensor<*xi1> { - %0 = "xla_hlo.not"(%arg0) : (tensor<*xi1>) -> tensor<*xi1> - return %0 : tensor<*xi1> -} -// CHECK-LABEL: func @not_op_unranked( -// CHECK-SAME: [[VAL_48:%.*]]: tensor<*xi1>) -> tensor<*xi1> { -// CHECK: [[VAL_49:%.*]] = "tf.LogicalNot"([[VAL_48]]) : (tensor<*xi1>) -> tensor<*xi1> -// CHECK: return [[VAL_49]] : tensor<*xi1> func @neg(%arg0: tensor<2xf32>) -> tensor<2xf32> { %0 = "xla_hlo.neg"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> return %0 : tensor<2xf32> } -// CHECK-LABEL: func @neg( -// CHECK-SAME: [[VAL_50:%.*]]: tensor<2xf32>) -> tensor<2xf32> { -// CHECK: [[VAL_51:%.*]] = "tf.Neg"([[VAL_50]]) : (tensor<2xf32>) -> tensor<2xf32> -// CHECK: return [[VAL_51]] : tensor<2xf32> func @neg_dynamic(%arg0: tensor) -> tensor { %0 = "xla_hlo.neg"(%arg0) : (tensor) -> tensor return %0 : tensor } -// CHECK-LABEL: func @neg_dynamic( -// CHECK-SAME: [[VAL_52:%.*]]: tensor) -> tensor { -// CHECK: [[VAL_53:%.*]] = "tf.Neg"([[VAL_52]]) : (tensor) -> tensor -// CHECK: return [[VAL_53]] : tensor func @neg_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { %0 = "xla_hlo.neg"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } -// CHECK-LABEL: func @neg_unranked( -// CHECK-SAME: [[VAL_54:%.*]]: tensor<*xf32>) -> tensor<*xf32> { -// CHECK: [[VAL_55:%.*]] = "tf.Neg"([[VAL_54]]) : (tensor<*xf32>) -> tensor<*xf32> -// CHECK: return [[VAL_55]] : tensor<*xf32> + +func @sigmoid(%arg0: tensor<2xf32>) -> tensor<2xf32> { + %0 = xla_hlo.constant dense<5.000000e-01> : tensor + %1 = xla_hlo.constant dense<2> : tensor<1xi64> + %2 = xla_hlo.constant dense<5.000000e-01> : tensor<2xf32> + %3 = xla_hlo.multiply %arg0, %2 : tensor<2xf32> + %4 = "xla_hlo.tanh"(%3) : (tensor<2xf32>) -> tensor<2xf32> + %5 = xla_hlo.multiply %4, %2 : tensor<2xf32> + %6 = xla_hlo.add %5, %2 : tensor<2xf32> + return %6 : tensor<2xf32> +} func @sin(%arg0: tensor<2xf32>) -> tensor<2xf32> { %0 = "xla_hlo.sin"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> return %0 : tensor<2xf32> } -// CHECK-LABEL: func @sin( -// CHECK-SAME: [[VAL_56:%.*]]: tensor<2xf32>) -> tensor<2xf32> { -// CHECK: [[VAL_57:%.*]] = "tf.Sin"([[VAL_56]]) : (tensor<2xf32>) -> tensor<2xf32> -// CHECK: return [[VAL_57]] : tensor<2xf32> func @sin_dynamic(%arg0: tensor) -> tensor { %0 = "xla_hlo.sin"(%arg0) : (tensor) -> tensor return %0 : tensor } -// CHECK-LABEL: func @sin_dynamic( -// CHECK-SAME: [[VAL_58:%.*]]: tensor) -> tensor { -// CHECK: [[VAL_59:%.*]] = "tf.Sin"([[VAL_58]]) : (tensor) -> tensor -// CHECK: return [[VAL_59]] : tensor func @sin_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { %0 = "xla_hlo.sin"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } -// CHECK-LABEL: func @sin_unranked( -// CHECK-SAME: [[VAL_60:%.*]]: tensor<*xf32>) -> tensor<*xf32> { -// CHECK: [[VAL_61:%.*]] = "tf.Sin"([[VAL_60]]) : (tensor<*xf32>) -> tensor<*xf32> -// CHECK: return [[VAL_61]] : tensor<*xf32> func @rsqrt(%arg0: tensor<2xf32>) -> tensor<2xf32> { %0 = "xla_hlo.rsqrt"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> return %0 : tensor<2xf32> } -// CHECK-LABEL: func @rsqrt( -// CHECK-SAME: [[VAL_62:%.*]]: tensor<2xf32>) -> tensor<2xf32> { -// CHECK: [[VAL_63:%.*]] = "tf.Rsqrt"([[VAL_62]]) : (tensor<2xf32>) -> tensor<2xf32> -// CHECK: return [[VAL_63]] : tensor<2xf32> func @rsqrt_dynamic(%arg0: tensor) -> tensor { %0 = "xla_hlo.rsqrt"(%arg0) : (tensor) -> tensor return %0 : tensor } -// CHECK-LABEL: func @rsqrt_dynamic( -// CHECK-SAME: [[VAL_64:%.*]]: tensor) -> tensor { -// CHECK: [[VAL_65:%.*]] = "tf.Rsqrt"([[VAL_64]]) : (tensor) -> tensor -// CHECK: return [[VAL_65]] : tensor func @rsqrt_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { %0 = "xla_hlo.rsqrt"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } -// CHECK-LABEL: func @rsqrt_unranked( -// CHECK-SAME: [[VAL_66:%.*]]: tensor<*xf32>) -> tensor<*xf32> { -// CHECK: [[VAL_67:%.*]] = "tf.Rsqrt"([[VAL_66]]) : (tensor<*xf32>) -> tensor<*xf32> -// CHECK: return [[VAL_67]] : tensor<*xf32> func @sqrt(%arg0: tensor<2xf32>) -> tensor<2xf32> { %0 = "xla_hlo.sqrt"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> return %0 : tensor<2xf32> } -// CHECK-LABEL: func @sqrt( -// CHECK-SAME: [[VAL_68:%.*]]: tensor<2xf32>) -> tensor<2xf32> { -// CHECK: [[VAL_69:%.*]] = "tf.Sqrt"([[VAL_68]]) : (tensor<2xf32>) -> tensor<2xf32> -// CHECK: return [[VAL_69]] : tensor<2xf32> func @sqrt_dynamic(%arg0: tensor) -> tensor { %0 = "xla_hlo.sqrt"(%arg0) : (tensor) -> tensor return %0 : tensor } -// CHECK-LABEL: func @sqrt_dynamic( -// CHECK-SAME: [[VAL_70:%.*]]: tensor) -> tensor { -// CHECK: [[VAL_71:%.*]] = "tf.Sqrt"([[VAL_70]]) : (tensor) -> tensor -// CHECK: return [[VAL_71]] : tensor func @sqrt_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { %0 = "xla_hlo.sqrt"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } -// CHECK-LABEL: func @sqrt_unranked( -// CHECK-SAME: [[VAL_72:%.*]]: tensor<*xf32>) -> tensor<*xf32> { -// CHECK: [[VAL_73:%.*]] = "tf.Sqrt"([[VAL_72]]) : (tensor<*xf32>) -> tensor<*xf32> -// CHECK: return [[VAL_73]] : tensor<*xf32> func @tanh(%arg0: tensor<2xf32>) -> tensor<2xf32> { %0 = "xla_hlo.tanh"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> return %0 : tensor<2xf32> } -// CHECK-LABEL: func @tanh( -// CHECK-SAME: [[VAL_74:%.*]]: tensor<2xf32>) -> tensor<2xf32> { -// CHECK: [[VAL_75:%.*]] = "tf.Tanh"([[VAL_74]]) : (tensor<2xf32>) -> tensor<2xf32> -// CHECK: return [[VAL_75]] : tensor<2xf32> func @tanh_dynamic(%arg0: tensor) -> tensor { %0 = "xla_hlo.tanh"(%arg0) : (tensor) -> tensor return %0 : tensor } -// CHECK-LABEL: func @tanh_dynamic( -// CHECK-SAME: [[VAL_76:%.*]]: tensor) -> tensor { -// CHECK: [[VAL_77:%.*]] = "tf.Tanh"([[VAL_76]]) : (tensor) -> tensor -// CHECK: return [[VAL_77]] : tensor func @tanh_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { %0 = "xla_hlo.tanh"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } -// CHECK-LABEL: func @tanh_unranked( -// CHECK-SAME: [[VAL_78:%.*]]: tensor<*xf32>) -> tensor<*xf32> { -// CHECK: [[VAL_79:%.*]] = "tf.Tanh"([[VAL_78]]) : (tensor<*xf32>) -> tensor<*xf32> -// CHECK: return [[VAL_79]] : tensor<*xf32> +func @bitcast(%arg0: tensor<2xf32>) -> tensor<2xf32> { + %0 = "xla_hlo.bitcast_convert"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + return %0 : tensor<2xf32> +} + +func @bitcast_dynamic(%arg0: tensor) -> tensor { + %0 = "xla_hlo.bitcast_convert"(%arg0) : (tensor) -> tensor + return %0 : tensor +} + +func @bitcast_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { + %0 = "xla_hlo.bitcast_convert"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> +} + +func @bitcast_same_widths(%arg0: tensor<2xf32>) -> tensor<2xi32> { + %0 = "xla_hlo.bitcast_convert"(%arg0) : (tensor<2xf32>) -> tensor<2xi32> + return %0 : tensor<2xi32> +} + +func @sign(%arg0: tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> { + %0 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "NE"} : (tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xi1> + %1 = xla_hlo.constant dense<0.000000e+00> : tensor<1x2x3x4xf32> + %2 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "NE"} : (tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xi1> + %3 = xla_hlo.constant dense<0.000000e+00> : tensor<1x2x3x4xf32> + %4 = "xla_hlo.sign"(%arg0) : (tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> + %5 = "xla_hlo.select"(%2, %3, %4) : (tensor<1x2x3x4xi1>, tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> + %6 = "xla_hlo.select"(%0, %1, %5) : (tensor<1x2x3x4xi1>, tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> + return %6 : tensor<1x2x3x4xf32> +} + +func @size_rank_one_i32(%arg0: tensor) -> tensor { + %0 = xla_hlo.constant dense<1> : tensor + return %0 : tensor +} + +func @size_rank_one_i64(%arg0: tensor) -> tensor { + %0 = xla_hlo.constant dense<1> : tensor + return %0 : tensor +} + +// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py + +// CHECK-LABEL: func @biasAdd_NHWC( +// CHECK-SAME: [[VAL_0:%.*]]: tensor<1x32x10x32xi32>, [[VAL_1:%.*]]: tensor<32xi32>) -> tensor<1x32x10x32xi32> { +// CHECK: [[VAL_2:%.*]] = "tf.AddV2"([[VAL_0]], [[VAL_1]]) : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32> +// CHECK: return [[VAL_2]] : tensor<1x32x10x32xi32> +// CHECK: } + +// CHECK-LABEL: func @biasAdd_NCHW( +// CHECK-SAME: [[VAL_3:%.*]]: tensor<1x32x10x32xi32>, [[VAL_4:%.*]]: tensor<32xi32>) -> tensor<1x32x10x32xi32> { +// CHECK: [[VAL_5:%.*]] = "tf.AddV2"([[VAL_3]], [[VAL_4]]) : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32> +// CHECK: return [[VAL_5]] : tensor<1x32x10x32xi32> +// CHECK: } + +// CHECK-LABEL: func @biasAdd_dynamic( +// CHECK-SAME: [[VAL_6:%.*]]: tensor, [[VAL_7:%.*]]: tensor) -> tensor { +// CHECK: [[VAL_8:%.*]] = "tf.AddV2"([[VAL_6]], [[VAL_7]]) : (tensor, tensor) -> tensor +// CHECK: return [[VAL_8]] : tensor +// CHECK: } + +// CHECK-LABEL: func @add( +// CHECK-SAME: [[VAL_9:%.*]]: tensor<2xi32>) -> tensor<2xi32> { +// CHECK: [[VAL_10:%.*]] = "tf.AddV2"([[VAL_9]], [[VAL_9]]) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> +// CHECK: [[VAL_11:%.*]] = "tf.AddV2"([[VAL_10]], [[VAL_9]]) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> +// CHECK: return [[VAL_11]] : tensor<2xi32> +// CHECK: } + +// CHECK-LABEL: func @broadcast_add( +// CHECK-SAME: [[VAL_12:%.*]]: tensor<1xi32>, [[VAL_13:%.*]]: tensor<1x2xi32>) -> tensor<1x2xi32> { +// CHECK: [[VAL_14:%.*]] = "tf.AddV2"([[VAL_12]], [[VAL_13]]) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> +// CHECK: return [[VAL_14]] : tensor<1x2xi32> +// CHECK: } + +// CHECK-LABEL: func @broadcast_multi_dim_add( +// CHECK-SAME: [[VAL_15:%.*]]: tensor<4x1x1xi32>, [[VAL_16:%.*]]: tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32> { +// CHECK: [[VAL_17:%.*]] = "tf.AddV2"([[VAL_15]], [[VAL_16]]) : (tensor<4x1x1xi32>, tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32> +// CHECK: return [[VAL_17]] : tensor<4x4x4x4xi32> +// CHECK: } + +// CHECK-LABEL: func @div( +// CHECK-SAME: [[VAL_18:%.*]]: tensor<2xi32>) -> tensor<2xi32> { +// CHECK: [[VAL_19:%.*]] = "tf.RealDiv"([[VAL_18]], [[VAL_18]]) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> +// CHECK: return [[VAL_19]] : tensor<2xi32> +// CHECK: } + +// CHECK-LABEL: func @broadcast_div( +// CHECK-SAME: [[VAL_20:%.*]]: tensor<1xi32>, [[VAL_21:%.*]]: tensor<1x2xi32>) -> tensor<1x2xi32> { +// CHECK: [[VAL_22:%.*]] = "tf.RealDiv"([[VAL_20]], [[VAL_21]]) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> +// CHECK: return [[VAL_22]] : tensor<1x2xi32> +// CHECK: } + +// CHECK-LABEL: func @shift_left( +// CHECK-SAME: [[VAL_23:%.*]]: tensor<4xi32>, [[VAL_24:%.*]]: tensor<4xi32>) -> tensor<4xi32> { +// CHECK: [[VAL_25:%.*]] = "tf.LeftShift"([[VAL_23]], [[VAL_24]]) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> +// CHECK: return [[VAL_25]] : tensor<4xi32> +// CHECK: } + +// CHECK-LABEL: func @div_dynamic( +// CHECK-SAME: [[VAL_26:%.*]]: tensor, [[VAL_27:%.*]]: tensor) -> tensor { +// CHECK: [[VAL_28:%.*]] = "tf.RealDiv"([[VAL_26]], [[VAL_27]]) : (tensor, tensor) -> tensor +// CHECK: return [[VAL_28]] : tensor +// CHECK: } + +// CHECK-LABEL: func @maximum( +// CHECK-SAME: [[VAL_29:%.*]]: tensor<4xf32>, [[VAL_30:%.*]]: tensor<4xf32>) -> tensor<4xf32> { +// CHECK: [[VAL_31:%.*]] = "tf.Maximum"([[VAL_29]], [[VAL_30]]) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> +// CHECK: return [[VAL_31]] : tensor<4xf32> +// CHECK: } + +// CHECK-LABEL: func @minimum( +// CHECK-SAME: [[VAL_32:%.*]]: tensor<4xf32>, [[VAL_33:%.*]]: tensor<4xf32>) -> tensor<4xf32> { +// CHECK: [[VAL_34:%.*]] = "tf.Minimum"([[VAL_32]], [[VAL_33]]) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> +// CHECK: return [[VAL_34]] : tensor<4xf32> +// CHECK: } + +// CHECK-LABEL: func @mul( +// CHECK-SAME: [[VAL_35:%.*]]: tensor<2xi32>) -> tensor<2xi32> { +// CHECK: [[VAL_36:%.*]] = "tf.Mul"([[VAL_35]], [[VAL_35]]) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> +// CHECK: return [[VAL_36]] : tensor<2xi32> +// CHECK: } + +// CHECK-LABEL: func @broadcast_mul( +// CHECK-SAME: [[VAL_37:%.*]]: tensor<1xi32>, [[VAL_38:%.*]]: tensor<1x2xi32>) -> tensor<1x2xi32> { +// CHECK: [[VAL_39:%.*]] = "tf.Mul"([[VAL_37]], [[VAL_38]]) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> +// CHECK: return [[VAL_39]] : tensor<1x2xi32> +// CHECK: } + +// CHECK-LABEL: func @real_div( +// CHECK-SAME: [[VAL_40:%.*]]: tensor<2xi32>) -> tensor<2xi32> { +// CHECK: [[VAL_41:%.*]] = "tf.RealDiv"([[VAL_40]], [[VAL_40]]) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> +// CHECK: return [[VAL_41]] : tensor<2xi32> +// CHECK: } + +// CHECK-LABEL: func @broadcast_real_div( +// CHECK-SAME: [[VAL_42:%.*]]: tensor<1xi32>, [[VAL_43:%.*]]: tensor<1x2xi32>) -> tensor<1x2xi32> { +// CHECK: [[VAL_44:%.*]] = "tf.RealDiv"([[VAL_42]], [[VAL_43]]) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> +// CHECK: return [[VAL_44]] : tensor<1x2xi32> +// CHECK: } + +// CHECK-LABEL: func @sub( +// CHECK-SAME: [[VAL_45:%.*]]: tensor<2xi32>) -> tensor<2xi32> { +// CHECK: [[VAL_46:%.*]] = "tf.Sub"([[VAL_45]], [[VAL_45]]) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> +// CHECK: return [[VAL_46]] : tensor<2xi32> +// CHECK: } + +// CHECK-LABEL: func @broadcast_sub( +// CHECK-SAME: [[VAL_47:%.*]]: tensor<1xi32>, [[VAL_48:%.*]]: tensor<1x2xi32>) -> tensor<1x2xi32> { +// CHECK: [[VAL_49:%.*]] = "tf.Sub"([[VAL_47]], [[VAL_48]]) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> +// CHECK: return [[VAL_49]] : tensor<1x2xi32> +// CHECK: } + +// CHECK-LABEL: func @shift_right( +// CHECK-SAME: [[VAL_50:%.*]]: tensor<4xi32>, [[VAL_51:%.*]]: tensor<4xi32>) -> tensor<4xi32> { +// CHECK: [[VAL_52:%.*]] = "tf.RightShift"([[VAL_50]], [[VAL_51]]) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> +// CHECK: return [[VAL_52]] : tensor<4xi32> +// CHECK: } + +// CHECK-LABEL: func @broadcast_shift_right( +// CHECK-SAME: [[VAL_53:%.*]]: tensor<4xi32>, [[VAL_54:%.*]]: tensor<2x4xi32>) -> tensor<2x4xi32> { +// CHECK: [[VAL_55:%.*]] = "tf.RightShift"([[VAL_53]], [[VAL_54]]) : (tensor<4xi32>, tensor<2x4xi32>) -> tensor<2x4xi32> +// CHECK: return [[VAL_55]] : tensor<2x4xi32> +// CHECK: } + +// CHECK-LABEL: func @and( +// CHECK-SAME: [[VAL_56:%.*]]: tensor<2xi1>) -> tensor<2xi1> { +// CHECK: [[VAL_57:%.*]] = "tf.LogicalAnd"([[VAL_56]], [[VAL_56]]) : (tensor<2xi1>, tensor<2xi1>) -> tensor<2xi1> +// CHECK: return [[VAL_57]] : tensor<2xi1> +// CHECK: } + +// CHECK-LABEL: func @and_broadcast( +// CHECK-SAME: [[VAL_58:%.*]]: tensor<1xi1>, [[VAL_59:%.*]]: tensor<1x2xi1>) -> tensor<1x2xi1> { +// CHECK: [[VAL_60:%.*]] = "tf.LogicalAnd"([[VAL_58]], [[VAL_59]]) : (tensor<1xi1>, tensor<1x2xi1>) -> tensor<1x2xi1> +// CHECK: return [[VAL_60]] : tensor<1x2xi1> +// CHECK: } + +// CHECK-LABEL: func @and_dynamic( +// CHECK-SAME: [[VAL_61:%.*]]: tensor, [[VAL_62:%.*]]: tensor<1xi1>) -> tensor { +// CHECK: [[VAL_63:%.*]] = "tf.LogicalAnd"([[VAL_61]], [[VAL_62]]) : (tensor, tensor<1xi1>) -> tensor +// CHECK: return [[VAL_63]] : tensor +// CHECK: } + +// CHECK-LABEL: func @or( +// CHECK-SAME: [[VAL_64:%.*]]: tensor<2xi1>) -> tensor<2xi1> { +// CHECK: [[VAL_65:%.*]] = "tf.LogicalOr"([[VAL_64]], [[VAL_64]]) : (tensor<2xi1>, tensor<2xi1>) -> tensor<2xi1> +// CHECK: return [[VAL_65]] : tensor<2xi1> +// CHECK: } + +// CHECK-LABEL: func @or_broadcast( +// CHECK-SAME: [[VAL_66:%.*]]: tensor<1xi1>, [[VAL_67:%.*]]: tensor<1x2xi1>) -> tensor<1x2xi1> { +// CHECK: [[VAL_68:%.*]] = "tf.LogicalOr"([[VAL_66]], [[VAL_67]]) : (tensor<1xi1>, tensor<1x2xi1>) -> tensor<1x2xi1> +// CHECK: return [[VAL_68]] : tensor<1x2xi1> +// CHECK: } + +// CHECK-LABEL: func @or_dynamic( +// CHECK-SAME: [[VAL_69:%.*]]: tensor, [[VAL_70:%.*]]: tensor<1xi1>) -> tensor { +// CHECK: [[VAL_71:%.*]] = "tf.LogicalOr"([[VAL_69]], [[VAL_70]]) : (tensor, tensor<1xi1>) -> tensor +// CHECK: return [[VAL_71]] : tensor +// CHECK: } + +// CHECK-LABEL: func @bitwise_or( +// CHECK-SAME: [[VAL_72:%.*]]: tensor<4xi32>, [[VAL_73:%.*]]: tensor<4xi32>) -> tensor<4xi32> { +// CHECK: [[VAL_74:%.*]] = "tf.BitwiseOr"([[VAL_72]], [[VAL_73]]) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> +// CHECK: return [[VAL_74]] : tensor<4xi32> +// CHECK: } + +// CHECK-LABEL: func @bitwise_or_broadcast( +// CHECK-SAME: [[VAL_75:%.*]]: tensor<1xi8>, [[VAL_76:%.*]]: tensor<1x4xi8>) -> tensor<1x4xi8> { +// CHECK: [[VAL_77:%.*]] = "tf.BitwiseOr"([[VAL_75]], [[VAL_76]]) : (tensor<1xi8>, tensor<1x4xi8>) -> tensor<1x4xi8> +// CHECK: return [[VAL_77]] : tensor<1x4xi8> +// CHECK: } + +// CHECK-LABEL: func @bitwise_or_dynamic( +// CHECK-SAME: [[VAL_78:%.*]]: tensor, [[VAL_79:%.*]]: tensor<1xi32>) -> tensor { +// CHECK: [[VAL_80:%.*]] = "tf.BitwiseOr"([[VAL_78]], [[VAL_79]]) : (tensor, tensor<1xi32>) -> tensor +// CHECK: return [[VAL_80]] : tensor +// CHECK: } + +// CHECK-LABEL: func @bitwise_and( +// CHECK-SAME: [[VAL_81:%.*]]: tensor<4xi32>, [[VAL_82:%.*]]: tensor<4xi32>) -> tensor<4xi32> { +// CHECK: [[VAL_83:%.*]] = "tf.BitwiseAnd"([[VAL_81]], [[VAL_82]]) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> +// CHECK: return [[VAL_83]] : tensor<4xi32> +// CHECK: } + +// CHECK-LABEL: func @bitwise_and_broadcast( +// CHECK-SAME: [[VAL_84:%.*]]: tensor<1xi8>, [[VAL_85:%.*]]: tensor<1x4xi8>) -> tensor<1x4xi8> { +// CHECK: [[VAL_86:%.*]] = "tf.BitwiseAnd"([[VAL_84]], [[VAL_85]]) : (tensor<1xi8>, tensor<1x4xi8>) -> tensor<1x4xi8> +// CHECK: return [[VAL_86]] : tensor<1x4xi8> +// CHECK: } + +// CHECK-LABEL: func @bitwise_and_dynamic( +// CHECK-SAME: [[VAL_87:%.*]]: tensor, [[VAL_88:%.*]]: tensor<1xi32>) -> tensor { +// CHECK: [[VAL_89:%.*]] = "tf.BitwiseAnd"([[VAL_87]], [[VAL_88]]) : (tensor, tensor<1xi32>) -> tensor +// CHECK: return [[VAL_89]] : tensor +// CHECK: } + +// CHECK-LABEL: func @pow( +// CHECK-SAME: [[VAL_90:%.*]]: tensor<2xf32>) -> tensor<2xf32> { +// CHECK: [[VAL_91:%.*]] = "tf.Pow"([[VAL_90]], [[VAL_90]]) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> +// CHECK: return [[VAL_91]] : tensor<2xf32> +// CHECK: } + +// CHECK-LABEL: func @pow_dynamic( +// CHECK-SAME: [[VAL_92:%.*]]: tensor) -> tensor { +// CHECK: [[VAL_93:%.*]] = "tf.Pow"([[VAL_92]], [[VAL_92]]) : (tensor, tensor) -> tensor +// CHECK: return [[VAL_93]] : tensor +// CHECK: } + +// CHECK-LABEL: func @floordiv_broadcast_i32( +// CHECK-SAME: [[VAL_94:%.*]]: tensor<2x3xi32>, [[VAL_95:%.*]]: tensor<3xi32>) -> tensor<2x3xi32> { +// CHECK: [[VAL_96:%.*]] = "tf.Const"() {value = dense<0> : tensor<2x3xi32>} : () -> tensor<2x3xi32> +// CHECK: [[VAL_97:%.*]] = "tf.Less"([[VAL_94]], [[VAL_96]]) : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi1> +// CHECK: [[VAL_98:%.*]] = "tf.Const"() {value = dense<0> : tensor<3xi32>} : () -> tensor<3xi32> +// CHECK: [[VAL_99:%.*]] = "tf.Less"([[VAL_95]], [[VAL_98]]) : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1> +// CHECK: [[VAL_100:%.*]] = "tf.Equal"([[VAL_97]], [[VAL_99]]) {incompatible_shape_error = true} : (tensor<2x3xi1>, tensor<3xi1>) -> tensor<2x3xi1> +// CHECK: [[VAL_101:%.*]] = "tf.RealDiv"([[VAL_94]], [[VAL_95]]) : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32> +// CHECK: [[VAL_102:%.*]] = "tf.Abs"([[VAL_94]]) : (tensor<2x3xi32>) -> tensor<2x3xi32> +// CHECK: [[VAL_103:%.*]] = "tf.Abs"([[VAL_95]]) : (tensor<3xi32>) -> tensor<3xi32> +// CHECK: [[VAL_104:%.*]] = "tf.Const"() {value = dense<1> : tensor<3xi32>} : () -> tensor<3xi32> +// CHECK: [[VAL_105:%.*]] = "tf.Sub"([[VAL_103]], [[VAL_104]]) : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32> +// CHECK: [[VAL_106:%.*]] = "tf.AddV2"([[VAL_102]], [[VAL_105]]) : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32> +// CHECK: [[VAL_107:%.*]] = "tf.Neg"([[VAL_106]]) : (tensor<2x3xi32>) -> tensor<2x3xi32> +// CHECK: [[VAL_108:%.*]] = "tf.Abs"([[VAL_95]]) : (tensor<3xi32>) -> tensor<3xi32> +// CHECK: [[VAL_109:%.*]] = "tf.RealDiv"([[VAL_107]], [[VAL_108]]) : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32> +// CHECK: [[VAL_110:%.*]] = "tf.Select"([[VAL_100]], [[VAL_101]], [[VAL_109]]) : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> +// CHECK: return [[VAL_110]] : tensor<2x3xi32> +// CHECK: } + +// CHECK-LABEL: func @floordiv_reverse_broadcast_i32( +// CHECK-SAME: [[VAL_111:%.*]]: tensor<3xi32>, [[VAL_112:%.*]]: tensor<2x3xi32>) -> tensor<2x3xi32> { +// CHECK: [[VAL_113:%.*]] = "tf.Const"() {value = dense<0> : tensor<3xi32>} : () -> tensor<3xi32> +// CHECK: [[VAL_114:%.*]] = "tf.Less"([[VAL_111]], [[VAL_113]]) : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1> +// CHECK: [[VAL_115:%.*]] = "tf.Const"() {value = dense<0> : tensor<2x3xi32>} : () -> tensor<2x3xi32> +// CHECK: [[VAL_116:%.*]] = "tf.Less"([[VAL_112]], [[VAL_115]]) : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi1> +// CHECK: [[VAL_117:%.*]] = "tf.Equal"([[VAL_114]], [[VAL_116]]) {incompatible_shape_error = true} : (tensor<3xi1>, tensor<2x3xi1>) -> tensor<2x3xi1> +// CHECK: [[VAL_118:%.*]] = "tf.RealDiv"([[VAL_111]], [[VAL_112]]) : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> +// CHECK: [[VAL_119:%.*]] = "tf.Abs"([[VAL_111]]) : (tensor<3xi32>) -> tensor<3xi32> +// CHECK: [[VAL_120:%.*]] = "tf.Abs"([[VAL_112]]) : (tensor<2x3xi32>) -> tensor<2x3xi32> +// CHECK: [[VAL_121:%.*]] = "tf.Const"() {value = dense<1> : tensor<2x3xi32>} : () -> tensor<2x3xi32> +// CHECK: [[VAL_122:%.*]] = "tf.Sub"([[VAL_120]], [[VAL_121]]) : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> +// CHECK: [[VAL_123:%.*]] = "tf.AddV2"([[VAL_119]], [[VAL_122]]) : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> +// CHECK: [[VAL_124:%.*]] = "tf.Neg"([[VAL_123]]) : (tensor<2x3xi32>) -> tensor<2x3xi32> +// CHECK: [[VAL_125:%.*]] = "tf.Abs"([[VAL_112]]) : (tensor<2x3xi32>) -> tensor<2x3xi32> +// CHECK: [[VAL_126:%.*]] = "tf.RealDiv"([[VAL_124]], [[VAL_125]]) : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> +// CHECK: [[VAL_127:%.*]] = "tf.Select"([[VAL_117]], [[VAL_118]], [[VAL_126]]) : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> +// CHECK: return [[VAL_127]] : tensor<2x3xi32> +// CHECK: } + +// CHECK-LABEL: func @floordiv_f32( +// CHECK-SAME: [[VAL_128:%.*]]: tensor<2xf32>) -> tensor<2xf32> { +// CHECK: [[VAL_129:%.*]] = "tf.RealDiv"([[VAL_128]], [[VAL_128]]) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> +// CHECK: [[VAL_130:%.*]] = "tf.RealDiv"([[VAL_128]], [[VAL_128]]) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> +// CHECK: [[VAL_131:%.*]] = "tf.FloorDiv"([[VAL_128]], [[VAL_128]]) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> +// CHECK: return [[VAL_131]] : tensor<2xf32> +// CHECK: } + +// CHECK-LABEL: func @floordiv_f16_broadcast( +// CHECK-SAME: [[VAL_132:%.*]]: tensor<2x3xf16>, [[VAL_133:%.*]]: tensor<3xf16>) -> tensor<2x3xf16> { +// CHECK: [[VAL_134:%.*]] = "tf.RealDiv"([[VAL_132]], [[VAL_133]]) : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16> +// CHECK: [[VAL_135:%.*]] = "tf.RealDiv"([[VAL_132]], [[VAL_133]]) : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16> +// CHECK: [[VAL_136:%.*]] = "tf.FloorDiv"([[VAL_132]], [[VAL_133]]) : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16> +// CHECK: return [[VAL_136]] : tensor<2x3xf16> +// CHECK: } + +// CHECK-LABEL: func @equal( +// CHECK-SAME: [[VAL_137:%.*]]: tensor<2xi32>) -> tensor<2xi1> { +// CHECK: [[VAL_138:%.*]] = "tf.Equal"([[VAL_137]], [[VAL_137]]) {incompatible_shape_error = true} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> +// CHECK: return [[VAL_138]] : tensor<2xi1> +// CHECK: } + +// CHECK-LABEL: func @equal_dynamic( +// CHECK-SAME: [[VAL_139:%.*]]: tensor, [[VAL_140:%.*]]: tensor<1xi32>) -> tensor { +// CHECK: [[VAL_141:%.*]] = "tf.Equal"([[VAL_139]], [[VAL_140]]) {incompatible_shape_error = true} : (tensor, tensor<1xi32>) -> tensor +// CHECK: return [[VAL_141]] : tensor +// CHECK: } + +// CHECK-LABEL: func @equal_broadcast( +// CHECK-SAME: [[VAL_142:%.*]]: tensor<1xi32>, [[VAL_143:%.*]]: tensor<1x2xi32>) -> tensor<1x2xi1> { +// CHECK: [[VAL_144:%.*]] = "tf.Equal"([[VAL_142]], [[VAL_143]]) {incompatible_shape_error = true} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> +// CHECK: return [[VAL_144]] : tensor<1x2xi1> +// CHECK: } + +// CHECK-LABEL: func @equal_broadcast_no_incompatible_shapes_error( +// CHECK-SAME: [[VAL_145:%.*]]: tensor<2xi32>, [[VAL_146:%.*]]: tensor<1x2xi32>) -> tensor<1x2xi1> { +// CHECK: [[VAL_147:%.*]] = "tf.Equal"([[VAL_145]], [[VAL_146]]) {incompatible_shape_error = true} : (tensor<2xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> +// CHECK: return [[VAL_147]] : tensor<1x2xi1> +// CHECK: } + +// CHECK-LABEL: func @equal_incompatible_shape_broadcastable( +// CHECK-SAME: [[VAL_148:%.*]]: tensor, [[VAL_149:%.*]]: tensor<1xi32>) -> tensor { +// CHECK: [[VAL_150:%.*]] = "tf.Equal"([[VAL_148]], [[VAL_149]]) {incompatible_shape_error = true} : (tensor, tensor<1xi32>) -> tensor +// CHECK: return [[VAL_150]] : tensor +// CHECK: } + +// CHECK-LABEL: func @notequal( +// CHECK-SAME: [[VAL_151:%.*]]: tensor<2xi32>) -> tensor<2xi1> { +// CHECK: [[VAL_152:%.*]] = "tf.NotEqual"([[VAL_151]], [[VAL_151]]) {incompatible_shape_error = true} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> +// CHECK: return [[VAL_152]] : tensor<2xi1> +// CHECK: } + +// CHECK-LABEL: func @notequal_broadcast( +// CHECK-SAME: [[VAL_153:%.*]]: tensor<1xi32>, [[VAL_154:%.*]]: tensor<1x2xi32>) -> tensor<1x2xi1> { +// CHECK: [[VAL_155:%.*]] = "tf.NotEqual"([[VAL_153]], [[VAL_154]]) {incompatible_shape_error = true} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> +// CHECK: return [[VAL_155]] : tensor<1x2xi1> +// CHECK: } + +// CHECK-LABEL: func @notequal_broadcast_no_incompatible_shapes_error( +// CHECK-SAME: [[VAL_156:%.*]]: tensor<2xi32>, [[VAL_157:%.*]]: tensor<1x2xi32>) -> tensor<1x2xi1> { +// CHECK: [[VAL_158:%.*]] = "tf.NotEqual"([[VAL_156]], [[VAL_157]]) {incompatible_shape_error = true} : (tensor<2xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> +// CHECK: return [[VAL_158]] : tensor<1x2xi1> +// CHECK: } + +// CHECK-LABEL: func @notequal_incompatible_shape_broadcastable( +// CHECK-SAME: [[VAL_159:%.*]]: tensor, [[VAL_160:%.*]]: tensor<1xi32>) -> tensor { +// CHECK: [[VAL_161:%.*]] = "tf.NotEqual"([[VAL_159]], [[VAL_160]]) {incompatible_shape_error = true} : (tensor, tensor<1xi32>) -> tensor +// CHECK: return [[VAL_161]] : tensor +// CHECK: } + +// CHECK-LABEL: func @greater( +// CHECK-SAME: [[VAL_162:%.*]]: tensor<2xi32>) -> tensor<2xi1> { +// CHECK: [[VAL_163:%.*]] = "tf.Greater"([[VAL_162]], [[VAL_162]]) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> +// CHECK: return [[VAL_163]] : tensor<2xi1> +// CHECK: } + +// CHECK-LABEL: func @broadcast_greater( +// CHECK-SAME: [[VAL_164:%.*]]: tensor<1xi32>, [[VAL_165:%.*]]: tensor<1x2xi32>) -> tensor<1x2xi1> { +// CHECK: [[VAL_166:%.*]] = "tf.Greater"([[VAL_164]], [[VAL_165]]) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> +// CHECK: return [[VAL_166]] : tensor<1x2xi1> +// CHECK: } + +// CHECK-LABEL: func @greater_equal( +// CHECK-SAME: [[VAL_167:%.*]]: tensor<2xi32>) -> tensor<2xi1> { +// CHECK: [[VAL_168:%.*]] = "tf.GreaterEqual"([[VAL_167]], [[VAL_167]]) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> +// CHECK: return [[VAL_168]] : tensor<2xi1> +// CHECK: } + +// CHECK-LABEL: func @broadcast_greater_equal( +// CHECK-SAME: [[VAL_169:%.*]]: tensor<1xi32>, [[VAL_170:%.*]]: tensor<1x2xi32>) -> tensor<1x2xi1> { +// CHECK: [[VAL_171:%.*]] = "tf.GreaterEqual"([[VAL_169]], [[VAL_170]]) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> +// CHECK: return [[VAL_171]] : tensor<1x2xi1> +// CHECK: } + +// CHECK-LABEL: func @less( +// CHECK-SAME: [[VAL_172:%.*]]: tensor<2xi32>) -> tensor<2xi1> { +// CHECK: [[VAL_173:%.*]] = "tf.Less"([[VAL_172]], [[VAL_172]]) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> +// CHECK: return [[VAL_173]] : tensor<2xi1> +// CHECK: } + +// CHECK-LABEL: func @broadcast_less( +// CHECK-SAME: [[VAL_174:%.*]]: tensor<1xi32>, [[VAL_175:%.*]]: tensor<1x2xi32>) -> tensor<1x2xi1> { +// CHECK: [[VAL_176:%.*]] = "tf.Less"([[VAL_174]], [[VAL_175]]) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> +// CHECK: return [[VAL_176]] : tensor<1x2xi1> +// CHECK: } + +// CHECK-LABEL: func @less_equal( +// CHECK-SAME: [[VAL_177:%.*]]: tensor<2xi32>) -> tensor<2xi1> { +// CHECK: [[VAL_178:%.*]] = "tf.LessEqual"([[VAL_177]], [[VAL_177]]) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> +// CHECK: return [[VAL_178]] : tensor<2xi1> +// CHECK: } + +// CHECK-LABEL: func @broadcast_less_equal( +// CHECK-SAME: [[VAL_179:%.*]]: tensor<1xi32>, [[VAL_180:%.*]]: tensor<1x2xi32>) -> tensor<1x2xi1> { +// CHECK: [[VAL_181:%.*]] = "tf.LessEqual"([[VAL_179]], [[VAL_180]]) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> +// CHECK: return [[VAL_181]] : tensor<1x2xi1> +// CHECK: } + +// CHECK-LABEL: func @concat_v2( +// CHECK-SAME: [[VAL_182:%.*]]: tensor<3x3xf32>, [[VAL_183:%.*]]: tensor<3x3xf32>) -> tensor<6x3xf32> { +// CHECK: [[VAL_184:%.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor +// CHECK: [[VAL_185:%.*]] = "tf.ConcatV2"([[VAL_182]], [[VAL_183]], [[VAL_184]]) : (tensor<3x3xf32>, tensor<3x3xf32>, tensor) -> tensor<6x3xf32> +// CHECK: return [[VAL_185]] : tensor<6x3xf32> +// CHECK: } + +// CHECK-LABEL: func @concat_v2_1d_axis( +// CHECK-SAME: [[VAL_186:%.*]]: tensor<3x3xf32>, [[VAL_187:%.*]]: tensor<3x3xf32>) -> tensor<3x6xf32> { +// CHECK: [[VAL_188:%.*]] = "tf.Const"() {value = dense<1> : tensor} : () -> tensor +// CHECK: [[VAL_189:%.*]] = "tf.ConcatV2"([[VAL_186]], [[VAL_187]], [[VAL_188]]) : (tensor<3x3xf32>, tensor<3x3xf32>, tensor) -> tensor<3x6xf32> +// CHECK: return [[VAL_189]] : tensor<3x6xf32> +// CHECK: } + +// CHECK-LABEL: func @const() -> tensor<2xi32> { +// CHECK: [[VAL_190:%.*]] = "tf.Const"() {value = dense<0> : tensor<2xi32>} : () -> tensor<2xi32> +// CHECK: return [[VAL_190]] : tensor<2xi32> +// CHECK: } + +// CHECK-LABEL: func @const_dynamic_output() -> tensor<*xi32> { +// CHECK: [[VAL_191:%.*]] = "tf.Const"() {value = dense<0> : tensor<2xi32>} : () -> tensor<*xi32> +// CHECK: return [[VAL_191]] : tensor<*xi32> +// CHECK: } + +// CHECK-LABEL: func @relu( +// CHECK-SAME: [[VAL_192:%.*]]: tensor<1xi32>) -> tensor<1xi32> { +// CHECK: [[VAL_193:%.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor +// CHECK: [[VAL_194:%.*]] = "tf.Maximum"([[VAL_193]], [[VAL_192]]) : (tensor, tensor<1xi32>) -> tensor<1xi32> +// CHECK: return [[VAL_194]] : tensor<1xi32> +// CHECK: } + +// CHECK-LABEL: func @relu_unranked( +// CHECK-SAME: [[VAL_195:%.*]]: tensor) -> tensor { +// CHECK: [[VAL_196:%.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor +// CHECK: [[VAL_197:%.*]] = "tf.Maximum"([[VAL_196]], [[VAL_195]]) : (tensor, tensor) -> tensor +// CHECK: return [[VAL_197]] : tensor +// CHECK: } + +// CHECK-LABEL: func @relu6( +// CHECK-SAME: [[VAL_198:%.*]]: tensor<1xi32>) -> tensor<1xi32> { +// CHECK: [[VAL_199:%.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor +// CHECK: [[VAL_200:%.*]] = "tf.Const"() {value = dense<6> : tensor} : () -> tensor +// CHECK: [[VAL_201:%.*]] = "tf.Minimum"([[VAL_198]], [[VAL_200]]) : (tensor<1xi32>, tensor) -> tensor<1xi32> +// CHECK: [[VAL_202:%.*]] = "tf.Maximum"([[VAL_201]], [[VAL_199]]) : (tensor<1xi32>, tensor) -> tensor<1xi32> +// CHECK: return [[VAL_202]] : tensor<1xi32> +// CHECK: } + +// CHECK-LABEL: func @relu6_unranked( +// CHECK-SAME: [[VAL_203:%.*]]: tensor) -> tensor { +// CHECK: [[VAL_204:%.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor +// CHECK: [[VAL_205:%.*]] = "tf.Const"() {value = dense<6> : tensor} : () -> tensor +// CHECK: [[VAL_206:%.*]] = "tf.Minimum"([[VAL_203]], [[VAL_205]]) : (tensor, tensor) -> tensor +// CHECK: [[VAL_207:%.*]] = "tf.Maximum"([[VAL_206]], [[VAL_204]]) : (tensor, tensor) -> tensor +// CHECK: return [[VAL_207]] : tensor +// CHECK: } + +// CHECK-LABEL: func @relu_grad( +// CHECK-SAME: [[VAL_208:%.*]]: tensor<4x8xf32>, [[VAL_209:%.*]]: tensor) -> tensor<4x8xf32> { +// CHECK: [[VAL_210:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor} : () -> tensor +// CHECK: [[VAL_211:%.*]] = "tf.Greater"([[VAL_209]], [[VAL_210]]) : (tensor, tensor) -> tensor +// CHECK: [[VAL_212:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<4x8xf32>} : () -> tensor<4x8xf32> +// CHECK: [[VAL_213:%.*]] = "tf.Select"([[VAL_211]], [[VAL_208]], [[VAL_212]]) : (tensor, tensor<4x8xf32>, tensor<4x8xf32>) -> tensor<4x8xf32> +// CHECK: return [[VAL_213]] : tensor<4x8xf32> +// CHECK: } + +// CHECK-LABEL: func @select( +// CHECK-SAME: [[VAL_214:%.*]]: tensor<2xi1>, [[VAL_215:%.*]]: tensor<2xi32>, [[VAL_216:%.*]]: tensor<2xi32>) -> tensor<2xi32> { +// CHECK: [[VAL_217:%.*]] = "tf.Select"([[VAL_214]], [[VAL_215]], [[VAL_216]]) : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> +// CHECK: return [[VAL_217]] : tensor<2xi32> +// CHECK: } + +// CHECK-LABEL: func @select_float( +// CHECK-SAME: [[VAL_218:%.*]]: tensor<2xi1>, [[VAL_219:%.*]]: tensor<2xf32>, [[VAL_220:%.*]]: tensor<2xf32>) -> tensor<2xf32> { +// CHECK: [[VAL_221:%.*]] = "tf.Select"([[VAL_218]], [[VAL_219]], [[VAL_220]]) : (tensor<2xi1>, tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> +// CHECK: return [[VAL_221]] : tensor<2xf32> +// CHECK: } + +// CHECK-LABEL: func @select_multidimensional( +// CHECK-SAME: [[VAL_222:%.*]]: tensor<3x2xi1>, [[VAL_223:%.*]]: tensor<3x2xi32>, [[VAL_224:%.*]]: tensor<3x2xi32>) -> tensor<3x2xi32> { +// CHECK: [[VAL_225:%.*]] = "tf.Select"([[VAL_222]], [[VAL_223]], [[VAL_224]]) : (tensor<3x2xi1>, tensor<3x2xi32>, tensor<3x2xi32>) -> tensor<3x2xi32> +// CHECK: return [[VAL_225]] : tensor<3x2xi32> +// CHECK: } + +// CHECK-LABEL: func @selectv2( +// CHECK-SAME: [[VAL_226:%.*]]: tensor<2xi1>, [[VAL_227:%.*]]: tensor<2xi32>, [[VAL_228:%.*]]: tensor<2xi32>) -> tensor<2xi32> { +// CHECK: [[VAL_229:%.*]] = "tf.Select"([[VAL_226]], [[VAL_227]], [[VAL_228]]) : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> +// CHECK: return [[VAL_229]] : tensor<2xi32> +// CHECK: } + +// CHECK-LABEL: func @selectv2_pred_scalar( +// CHECK-SAME: [[VAL_230:%.*]]: tensor, [[VAL_231:%.*]]: tensor<2xi32>, [[VAL_232:%.*]]: tensor<2xi32>) -> tensor<2xi32> { +// CHECK: [[VAL_233:%.*]] = "tf.Select"([[VAL_230]], [[VAL_231]], [[VAL_232]]) : (tensor, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> +// CHECK: return [[VAL_233]] : tensor<2xi32> +// CHECK: } + +// CHECK-LABEL: func @transpose_2d( +// CHECK-SAME: [[VAL_234:%.*]]: tensor<2x3xf32>) -> tensor<3x2xf32> { +// CHECK: [[VAL_235:%.*]] = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> tensor<2xi64> +// CHECK: [[VAL_236:%.*]] = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> tensor<2xi64> +// CHECK: [[VAL_237:%.*]] = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> tensor<2xi64> +// CHECK: [[VAL_238:%.*]] = "tf.Transpose"([[VAL_234]], [[VAL_237]]) : (tensor<2x3xf32>, tensor<2xi64>) -> tensor<3x2xf32> +// CHECK: return [[VAL_238]] : tensor<3x2xf32> +// CHECK: } + +// CHECK-LABEL: func @transpose_3d_int32( +// CHECK-SAME: [[VAL_239:%.*]]: tensor<1x2x3xf32>) -> tensor<3x2x1xf32> { +// CHECK: [[VAL_240:%.*]] = "tf.Const"() {value = dense<[2, 1, 0]> : tensor<3xi32>} : () -> tensor<3xi32> +// CHECK: [[VAL_241:%.*]] = "tf.Const"() {value = dense<[2, 1, 0]> : tensor<3xi64>} : () -> tensor<3xi64> +// CHECK: [[VAL_242:%.*]] = "tf.Const"() {value = dense<[2, 1, 0]> : tensor<3xi64>} : () -> tensor<3xi64> +// CHECK: [[VAL_243:%.*]] = "tf.Transpose"([[VAL_239]], [[VAL_242]]) : (tensor<1x2x3xf32>, tensor<3xi64>) -> tensor<3x2x1xf32> +// CHECK: return [[VAL_243]] : tensor<3x2x1xf32> +// CHECK: } + +// CHECK-LABEL: func @transpose_3d( +// CHECK-SAME: [[VAL_244:%.*]]: tensor<1x2x3xf32>) -> tensor<3x2x1xf32> { +// CHECK: [[VAL_245:%.*]] = "tf.Const"() {value = dense<[2, 1, 0]> : tensor<3xi64>} : () -> tensor<3xi64> +// CHECK: [[VAL_246:%.*]] = "tf.Const"() {value = dense<[2, 1, 0]> : tensor<3xi64>} : () -> tensor<3xi64> +// CHECK: [[VAL_247:%.*]] = "tf.Const"() {value = dense<[2, 1, 0]> : tensor<3xi64>} : () -> tensor<3xi64> +// CHECK: [[VAL_248:%.*]] = "tf.Transpose"([[VAL_244]], [[VAL_247]]) : (tensor<1x2x3xf32>, tensor<3xi64>) -> tensor<3x2x1xf32> +// CHECK: return [[VAL_248]] : tensor<3x2x1xf32> +// CHECK: } + +// CHECK-LABEL: func @transpose_dynamic_2d( +// CHECK-SAME: [[VAL_249:%.*]]: tensor) -> tensor<4x?xf32> { +// CHECK: [[VAL_250:%.*]] = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> tensor<2xi64> +// CHECK: [[VAL_251:%.*]] = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> tensor<2xi64> +// CHECK: [[VAL_252:%.*]] = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> tensor<2xi64> +// CHECK: [[VAL_253:%.*]] = "tf.Transpose"([[VAL_249]], [[VAL_252]]) : (tensor, tensor<2xi64>) -> tensor<4x?xf32> +// CHECK: return [[VAL_253]] : tensor<4x?xf32> +// CHECK: } + +// CHECK-LABEL: func @transpose_unranked_2d( +// CHECK-SAME: [[VAL_254:%.*]]: tensor<*xf32>) -> tensor<*xf32> { +// CHECK: [[VAL_255:%.*]] = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> tensor<2xi64> +// CHECK: [[VAL_256:%.*]] = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> tensor<2xi64> +// CHECK: [[VAL_257:%.*]] = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> tensor<2xi64> +// CHECK: [[VAL_258:%.*]] = "tf.Transpose"([[VAL_254]], [[VAL_257]]) : (tensor<*xf32>, tensor<2xi64>) -> tensor<*xf32> +// CHECK: return [[VAL_258]] : tensor<*xf32> +// CHECK: } + +// CHECK-LABEL: func @abs( +// CHECK-SAME: [[VAL_259:%.*]]: tensor<2xf32>) -> tensor<2xf32> { +// CHECK: [[VAL_260:%.*]] = "tf.Abs"([[VAL_259]]) : (tensor<2xf32>) -> tensor<2xf32> +// CHECK: return [[VAL_260]] : tensor<2xf32> +// CHECK: } + +// CHECK-LABEL: func @abs_dynamic( +// CHECK-SAME: [[VAL_261:%.*]]: tensor) -> tensor { +// CHECK: [[VAL_262:%.*]] = "tf.Abs"([[VAL_261]]) : (tensor) -> tensor +// CHECK: return [[VAL_262]] : tensor +// CHECK: } + +// CHECK-LABEL: func @abs_unranked( +// CHECK-SAME: [[VAL_263:%.*]]: tensor<*xf32>) -> tensor<*xf32> { +// CHECK: [[VAL_264:%.*]] = "tf.Abs"([[VAL_263]]) : (tensor<*xf32>) -> tensor<*xf32> +// CHECK: return [[VAL_264]] : tensor<*xf32> +// CHECK: } + +// CHECK-LABEL: func @ceil( +// CHECK-SAME: [[VAL_265:%.*]]: tensor<2xf32>) -> tensor<2xf32> { +// CHECK: [[VAL_266:%.*]] = "tf.Ceil"([[VAL_265]]) : (tensor<2xf32>) -> tensor<2xf32> +// CHECK: return [[VAL_266]] : tensor<2xf32> +// CHECK: } + +// CHECK-LABEL: func @ceil_dynamic( +// CHECK-SAME: [[VAL_267:%.*]]: tensor) -> tensor { +// CHECK: [[VAL_268:%.*]] = "tf.Ceil"([[VAL_267]]) : (tensor) -> tensor +// CHECK: return [[VAL_268]] : tensor +// CHECK: } + +// CHECK-LABEL: func @ceil_unranked( +// CHECK-SAME: [[VAL_269:%.*]]: tensor<*xf32>) -> tensor<*xf32> { +// CHECK: [[VAL_270:%.*]] = "tf.Ceil"([[VAL_269]]) : (tensor<*xf32>) -> tensor<*xf32> +// CHECK: return [[VAL_270]] : tensor<*xf32> +// CHECK: } + +// CHECK-LABEL: func @complex_abs( +// CHECK-SAME: [[VAL_271:%.*]]: tensor<2xcomplex>) -> tensor<2xf32> { +// CHECK: [[VAL_272:%.*]] = "tf.ComplexAbs"([[VAL_271]]) : (tensor<2xcomplex>) -> tensor<2xf32> +// CHECK: return [[VAL_272]] : tensor<2xf32> +// CHECK: } + +// CHECK-LABEL: func @cos( +// CHECK-SAME: [[VAL_273:%.*]]: tensor<2xf32>) -> tensor<2xf32> { +// CHECK: [[VAL_274:%.*]] = "tf.Cos"([[VAL_273]]) : (tensor<2xf32>) -> tensor<2xf32> +// CHECK: return [[VAL_274]] : tensor<2xf32> +// CHECK: } + +// CHECK-LABEL: func @cos_dynamic( +// CHECK-SAME: [[VAL_275:%.*]]: tensor) -> tensor { +// CHECK: [[VAL_276:%.*]] = "tf.Cos"([[VAL_275]]) : (tensor) -> tensor +// CHECK: return [[VAL_276]] : tensor +// CHECK: } + +// CHECK-LABEL: func @cos_unranked( +// CHECK-SAME: [[VAL_277:%.*]]: tensor<*xf32>) -> tensor<*xf32> { +// CHECK: [[VAL_278:%.*]] = "tf.Cos"([[VAL_277]]) : (tensor<*xf32>) -> tensor<*xf32> +// CHECK: return [[VAL_278]] : tensor<*xf32> +// CHECK: } + +// CHECK-LABEL: func @exp( +// CHECK-SAME: [[VAL_279:%.*]]: tensor<2xf32>) -> tensor<2xf32> { +// CHECK: [[VAL_280:%.*]] = "tf.Exp"([[VAL_279]]) : (tensor<2xf32>) -> tensor<2xf32> +// CHECK: return [[VAL_280]] : tensor<2xf32> +// CHECK: } + +// CHECK-LABEL: func @exp_dynamic( +// CHECK-SAME: [[VAL_281:%.*]]: tensor) -> tensor { +// CHECK: [[VAL_282:%.*]] = "tf.Exp"([[VAL_281]]) : (tensor) -> tensor +// CHECK: return [[VAL_282]] : tensor +// CHECK: } + +// CHECK-LABEL: func @exp_unranked( +// CHECK-SAME: [[VAL_283:%.*]]: tensor<*xf32>) -> tensor<*xf32> { +// CHECK: [[VAL_284:%.*]] = "tf.Exp"([[VAL_283]]) : (tensor<*xf32>) -> tensor<*xf32> +// CHECK: return [[VAL_284]] : tensor<*xf32> +// CHECK: } + +// CHECK-LABEL: func @floor( +// CHECK-SAME: [[VAL_285:%.*]]: tensor<2xf32>) -> tensor<2xf32> { +// CHECK: [[VAL_286:%.*]] = "tf.Floor"([[VAL_285]]) : (tensor<2xf32>) -> tensor<2xf32> +// CHECK: return [[VAL_286]] : tensor<2xf32> +// CHECK: } + +// CHECK-LABEL: func @floor_dynamic( +// CHECK-SAME: [[VAL_287:%.*]]: tensor) -> tensor { +// CHECK: [[VAL_288:%.*]] = "tf.Floor"([[VAL_287]]) : (tensor) -> tensor +// CHECK: return [[VAL_288]] : tensor +// CHECK: } + +// CHECK-LABEL: func @floor_unranked( +// CHECK-SAME: [[VAL_289:%.*]]: tensor<*xf32>) -> tensor<*xf32> { +// CHECK: [[VAL_290:%.*]] = "tf.Floor"([[VAL_289]]) : (tensor<*xf32>) -> tensor<*xf32> +// CHECK: return [[VAL_290]] : tensor<*xf32> +// CHECK: } + +// CHECK-LABEL: func @is_finite( +// CHECK-SAME: [[VAL_291:%.*]]: tensor<2xf32>) -> tensor<2xi1> { +// CHECK: [[VAL_292:%.*]] = "tf.IsFinite"([[VAL_291]]) : (tensor<2xf32>) -> tensor<2xi1> +// CHECK: return [[VAL_292]] : tensor<2xi1> +// CHECK: } + +// CHECK-LABEL: func @is_finite_dynamic( +// CHECK-SAME: [[VAL_293:%.*]]: tensor) -> tensor { +// CHECK: [[VAL_294:%.*]] = "tf.IsFinite"([[VAL_293]]) : (tensor) -> tensor +// CHECK: return [[VAL_294]] : tensor +// CHECK: } + +// CHECK-LABEL: func @is_finite_unranked( +// CHECK-SAME: [[VAL_295:%.*]]: tensor<*xf32>) -> tensor<*xi1> { +// CHECK: [[VAL_296:%.*]] = "tf.IsFinite"([[VAL_295]]) : (tensor<*xf32>) -> tensor<*xi1> +// CHECK: return [[VAL_296]] : tensor<*xi1> +// CHECK: } + +// CHECK-LABEL: func @log( +// CHECK-SAME: [[VAL_297:%.*]]: tensor<2xf32>) -> tensor<2xf32> { +// CHECK: [[VAL_298:%.*]] = "tf.Log"([[VAL_297]]) : (tensor<2xf32>) -> tensor<2xf32> +// CHECK: return [[VAL_298]] : tensor<2xf32> +// CHECK: } + +// CHECK-LABEL: func @log_dynamic( +// CHECK-SAME: [[VAL_299:%.*]]: tensor) -> tensor { +// CHECK: [[VAL_300:%.*]] = "tf.Log"([[VAL_299]]) : (tensor) -> tensor +// CHECK: return [[VAL_300]] : tensor +// CHECK: } + +// CHECK-LABEL: func @log_unranked( +// CHECK-SAME: [[VAL_301:%.*]]: tensor<*xf32>) -> tensor<*xf32> { +// CHECK: [[VAL_302:%.*]] = "tf.Log"([[VAL_301]]) : (tensor<*xf32>) -> tensor<*xf32> +// CHECK: return [[VAL_302]] : tensor<*xf32> +// CHECK: } + +// CHECK-LABEL: func @log1p( +// CHECK-SAME: [[VAL_303:%.*]]: tensor<2xf32>) -> tensor<2xf32> { +// CHECK: [[VAL_304:%.*]] = "tf.Log1p"([[VAL_303]]) : (tensor<2xf32>) -> tensor<2xf32> +// CHECK: return [[VAL_304]] : tensor<2xf32> +// CHECK: } + +// CHECK-LABEL: func @log1p_dynamic( +// CHECK-SAME: [[VAL_305:%.*]]: tensor) -> tensor { +// CHECK: [[VAL_306:%.*]] = "tf.Log1p"([[VAL_305]]) : (tensor) -> tensor +// CHECK: return [[VAL_306]] : tensor +// CHECK: } + +// CHECK-LABEL: func @log1p_unranked( +// CHECK-SAME: [[VAL_307:%.*]]: tensor<*xf32>) -> tensor<*xf32> { +// CHECK: [[VAL_308:%.*]] = "tf.Log1p"([[VAL_307]]) : (tensor<*xf32>) -> tensor<*xf32> +// CHECK: return [[VAL_308]] : tensor<*xf32> +// CHECK: } + +// CHECK-LABEL: func @neg( +// CHECK-SAME: [[VAL_309:%.*]]: tensor<2xf32>) -> tensor<2xf32> { +// CHECK: [[VAL_310:%.*]] = "tf.Neg"([[VAL_309]]) : (tensor<2xf32>) -> tensor<2xf32> +// CHECK: return [[VAL_310]] : tensor<2xf32> +// CHECK: } + +// CHECK-LABEL: func @neg_dynamic( +// CHECK-SAME: [[VAL_311:%.*]]: tensor) -> tensor { +// CHECK: [[VAL_312:%.*]] = "tf.Neg"([[VAL_311]]) : (tensor) -> tensor +// CHECK: return [[VAL_312]] : tensor +// CHECK: } + +// CHECK-LABEL: func @neg_unranked( +// CHECK-SAME: [[VAL_313:%.*]]: tensor<*xf32>) -> tensor<*xf32> { +// CHECK: [[VAL_314:%.*]] = "tf.Neg"([[VAL_313]]) : (tensor<*xf32>) -> tensor<*xf32> +// CHECK: return [[VAL_314]] : tensor<*xf32> +// CHECK: } + +// CHECK-LABEL: func @sigmoid( +// CHECK-SAME: [[VAL_315:%.*]]: tensor<2xf32>) -> tensor<2xf32> { +// CHECK: [[VAL_316:%.*]] = "tf.Const"() {value = dense<5.000000e-01> : tensor} : () -> tensor +// CHECK: [[VAL_317:%.*]] = "tf.Const"() {value = dense<2> : tensor<1xi64>} : () -> tensor<1xi64> +// CHECK: [[VAL_318:%.*]] = "tf.Const"() {value = dense<5.000000e-01> : tensor<2xf32>} : () -> tensor<2xf32> +// CHECK: [[VAL_319:%.*]] = "tf.Mul"([[VAL_315]], [[VAL_318]]) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> +// CHECK: [[VAL_320:%.*]] = "tf.Tanh"([[VAL_319]]) : (tensor<2xf32>) -> tensor<2xf32> +// CHECK: [[VAL_321:%.*]] = "tf.Mul"([[VAL_320]], [[VAL_318]]) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> +// CHECK: [[VAL_322:%.*]] = "tf.AddV2"([[VAL_321]], [[VAL_318]]) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> +// CHECK: return [[VAL_322]] : tensor<2xf32> +// CHECK: } + +// CHECK-LABEL: func @sin( +// CHECK-SAME: [[VAL_323:%.*]]: tensor<2xf32>) -> tensor<2xf32> { +// CHECK: [[VAL_324:%.*]] = "tf.Sin"([[VAL_323]]) : (tensor<2xf32>) -> tensor<2xf32> +// CHECK: return [[VAL_324]] : tensor<2xf32> +// CHECK: } + +// CHECK-LABEL: func @sin_dynamic( +// CHECK-SAME: [[VAL_325:%.*]]: tensor) -> tensor { +// CHECK: [[VAL_326:%.*]] = "tf.Sin"([[VAL_325]]) : (tensor) -> tensor +// CHECK: return [[VAL_326]] : tensor +// CHECK: } + +// CHECK-LABEL: func @sin_unranked( +// CHECK-SAME: [[VAL_327:%.*]]: tensor<*xf32>) -> tensor<*xf32> { +// CHECK: [[VAL_328:%.*]] = "tf.Sin"([[VAL_327]]) : (tensor<*xf32>) -> tensor<*xf32> +// CHECK: return [[VAL_328]] : tensor<*xf32> +// CHECK: } + +// CHECK-LABEL: func @rsqrt( +// CHECK-SAME: [[VAL_329:%.*]]: tensor<2xf32>) -> tensor<2xf32> { +// CHECK: [[VAL_330:%.*]] = "tf.Rsqrt"([[VAL_329]]) : (tensor<2xf32>) -> tensor<2xf32> +// CHECK: return [[VAL_330]] : tensor<2xf32> +// CHECK: } + +// CHECK-LABEL: func @rsqrt_dynamic( +// CHECK-SAME: [[VAL_331:%.*]]: tensor) -> tensor { +// CHECK: [[VAL_332:%.*]] = "tf.Rsqrt"([[VAL_331]]) : (tensor) -> tensor +// CHECK: return [[VAL_332]] : tensor +// CHECK: } + +// CHECK-LABEL: func @rsqrt_unranked( +// CHECK-SAME: [[VAL_333:%.*]]: tensor<*xf32>) -> tensor<*xf32> { +// CHECK: [[VAL_334:%.*]] = "tf.Rsqrt"([[VAL_333]]) : (tensor<*xf32>) -> tensor<*xf32> +// CHECK: return [[VAL_334]] : tensor<*xf32> +// CHECK: } + +// CHECK-LABEL: func @sqrt( +// CHECK-SAME: [[VAL_335:%.*]]: tensor<2xf32>) -> tensor<2xf32> { +// CHECK: [[VAL_336:%.*]] = "tf.Sqrt"([[VAL_335]]) : (tensor<2xf32>) -> tensor<2xf32> +// CHECK: return [[VAL_336]] : tensor<2xf32> +// CHECK: } + +// CHECK-LABEL: func @sqrt_dynamic( +// CHECK-SAME: [[VAL_337:%.*]]: tensor) -> tensor { +// CHECK: [[VAL_338:%.*]] = "tf.Sqrt"([[VAL_337]]) : (tensor) -> tensor +// CHECK: return [[VAL_338]] : tensor +// CHECK: } + +// CHECK-LABEL: func @sqrt_unranked( +// CHECK-SAME: [[VAL_339:%.*]]: tensor<*xf32>) -> tensor<*xf32> { +// CHECK: [[VAL_340:%.*]] = "tf.Sqrt"([[VAL_339]]) : (tensor<*xf32>) -> tensor<*xf32> +// CHECK: return [[VAL_340]] : tensor<*xf32> +// CHECK: } + +// CHECK-LABEL: func @tanh( +// CHECK-SAME: [[VAL_341:%.*]]: tensor<2xf32>) -> tensor<2xf32> { +// CHECK: [[VAL_342:%.*]] = "tf.Tanh"([[VAL_341]]) : (tensor<2xf32>) -> tensor<2xf32> +// CHECK: return [[VAL_342]] : tensor<2xf32> +// CHECK: } + +// CHECK-LABEL: func @tanh_dynamic( +// CHECK-SAME: [[VAL_343:%.*]]: tensor) -> tensor { +// CHECK: [[VAL_344:%.*]] = "tf.Tanh"([[VAL_343]]) : (tensor) -> tensor +// CHECK: return [[VAL_344]] : tensor +// CHECK: } + +// CHECK-LABEL: func @tanh_unranked( +// CHECK-SAME: [[VAL_345:%.*]]: tensor<*xf32>) -> tensor<*xf32> { +// CHECK: [[VAL_346:%.*]] = "tf.Tanh"([[VAL_345]]) : (tensor<*xf32>) -> tensor<*xf32> +// CHECK: return [[VAL_346]] : tensor<*xf32> +// CHECK: } + +// CHECK-LABEL: func @bitcast( +// CHECK-SAME: [[VAL_347:%.*]]: tensor<2xf32>) -> tensor<2xf32> { +// CHECK: [[VAL_348:%.*]] = "tf.Bitcast"([[VAL_347]]) : (tensor<2xf32>) -> tensor<2xf32> +// CHECK: return [[VAL_348]] : tensor<2xf32> +// CHECK: } + +// CHECK-LABEL: func @bitcast_dynamic( +// CHECK-SAME: [[VAL_349:%.*]]: tensor) -> tensor { +// CHECK: [[VAL_350:%.*]] = "tf.Bitcast"([[VAL_349]]) : (tensor) -> tensor +// CHECK: return [[VAL_350]] : tensor +// CHECK: } + +// CHECK-LABEL: func @bitcast_unranked( +// CHECK-SAME: [[VAL_351:%.*]]: tensor<*xf32>) -> tensor<*xf32> { +// CHECK: [[VAL_352:%.*]] = "tf.Bitcast"([[VAL_351]]) : (tensor<*xf32>) -> tensor<*xf32> +// CHECK: return [[VAL_352]] : tensor<*xf32> +// CHECK: } + +// CHECK-LABEL: func @bitcast_same_widths( +// CHECK-SAME: [[VAL_353:%.*]]: tensor<2xf32>) -> tensor<2xi32> { +// CHECK: [[VAL_354:%.*]] = "tf.Bitcast"([[VAL_353]]) : (tensor<2xf32>) -> tensor<2xi32> +// CHECK: return [[VAL_354]] : tensor<2xi32> +// CHECK: } + +// CHECK-LABEL: func @sign( +// CHECK-SAME: [[VAL_355:%.*]]: tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> { +// CHECK: [[VAL_356:%.*]] = "tf.NotEqual"([[VAL_355]], [[VAL_355]]) {incompatible_shape_error = true} : (tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xi1> +// CHECK: [[VAL_357:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1x2x3x4xf32>} : () -> tensor<1x2x3x4xf32> +// CHECK: [[VAL_358:%.*]] = "tf.NotEqual"([[VAL_355]], [[VAL_355]]) {incompatible_shape_error = true} : (tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xi1> +// CHECK: [[VAL_359:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1x2x3x4xf32>} : () -> tensor<1x2x3x4xf32> +// CHECK: [[VAL_360:%.*]] = "tf.Sign"([[VAL_355]]) : (tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> +// CHECK: [[VAL_361:%.*]] = "tf.Select"([[VAL_358]], [[VAL_359]], [[VAL_360]]) : (tensor<1x2x3x4xi1>, tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> +// CHECK: [[VAL_362:%.*]] = "tf.Select"([[VAL_356]], [[VAL_357]], [[VAL_361]]) : (tensor<1x2x3x4xi1>, tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> +// CHECK: return [[VAL_362]] : tensor<1x2x3x4xf32> +// CHECK: } + +// CHECK-LABEL: func @size_rank_one_i32( +// CHECK-SAME: [[VAL_363:%.*]]: tensor) -> tensor { +// CHECK: [[VAL_364:%.*]] = "tf.Const"() {value = dense<1> : tensor} : () -> tensor +// CHECK: return [[VAL_364]] : tensor +// CHECK: } + +// CHECK-LABEL: func @size_rank_one_i64( +// CHECK-SAME: [[VAL_365:%.*]]: tensor) -> tensor { +// CHECK: [[VAL_366:%.*]] = "tf.Const"() {value = dense<1> : tensor} : () -> tensor +// CHECK: return [[VAL_366]] : tensor +// CHECK: } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/func_list_attr.mlir b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/func_list_attr.mlir index 556d586f6c3..c6543f3121e 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/func_list_attr.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/func_list_attr.mlir @@ -6,7 +6,7 @@ func @main() { // CHECK-NEXT: name: "predicate" // CHECK-NEXT: op: "Const" // CHECK-NEXT: attr { -// CHECK-NEXT: key: "dtype" +// CHECK: key: "dtype" // CHECK-NEXT: value { // CHECK-NEXT: type: DT_INT32 // CHECK-NEXT: } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/function-order.mlir b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/function-order.mlir index cec9818885c..c0c8284370a 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/function-order.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/function-order.mlir @@ -22,7 +22,7 @@ func @main() { // CHECK-NEXT: name: "tf.Const" // CHECK-NEXT: op: "Const" // CHECK-NEXT: attr { -// CHECK-NEXT: key: "dtype" +// CHECK: key: "dtype" // CHECK-NEXT: value { // CHECK-NEXT: type: DT_INT32 // CHECK-NEXT: } @@ -43,8 +43,8 @@ func @main() { // CHECK-NEXT: name: "tf.Empty" // CHECK-NEXT: op: "Empty" // CHECK-NEXT: input: "tf.Const:output:0" -// CHECK-NEXT: attr { -// CHECK-NEXT: key: "dtype" +// CHECK-NEXT: attr { +// CHECK: key: "dtype" // CHECK-NEXT: value { // CHECK-NEXT: type: DT_FLOAT // CHECK-NEXT: } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/output-shapes-attr.mlir b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/output-shapes-attr.mlir new file mode 100644 index 00000000000..fb3ee49bbc5 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/output-shapes-attr.mlir @@ -0,0 +1,64 @@ +// RUN: tf-mlir-translate -mlir-to-graphdef %s -o - | FileCheck %s + +func @main(%arg0: tensor<10xi32>) -> tensor<10xi32> +attributes {tf.entry_function = {inputs = "input0", outputs = "Placeholder"}} { + %graph = tf_executor.graph { + %0:2 = tf_executor.island wraps "tf.Placeholder.input"(%arg0) {device = "", dtype = "tfdtype$DT_INT32", shape = "tfshape$dim { size: 10 }"} : (tensor<10xi32>) -> tensor<10xi32> + tf_executor.fetch %0 : tensor<10xi32> + } + return %graph : tensor<10xi32> +} + +// CHECK: node { +// CHECK-NEXT: name: "Placeholder" +// CHECK-NEXT: op: "Placeholder" +// CHECK-NEXT: attr { +// CHECK-NEXT: key: "_output_shapes" +// CHECK-NEXT: value { +// CHECK-NEXT: list { +// CHECK-NEXT: shape { +// CHECK-NEXT: dim { +// CHECK-NEXT: size: 10 +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: attr { +// CHECK-NEXT: key: "dtype" +// CHECK-NEXT: value { +// CHECK-NEXT: type: DT_INT32 +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: attr { +// CHECK-NEXT: key: "shape" +// CHECK-NEXT: value { +// CHECK-NEXT: shape { +// CHECK-NEXT: dim { +// CHECK-NEXT: size: 10 +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: experimental_debug_info { +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: node { +// CHECK-NEXT: name: "main" +// CHECK-NEXT: op: "_Retval" +// CHECK-NEXT: input: "Placeholder" +// CHECK-NEXT: attr { +// CHECK-NEXT: key: "T" +// CHECK-NEXT: value { +// CHECK-NEXT: type: DT_INT32 +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: attr { +// CHECK-NEXT: key: "index" +// CHECK-NEXT: value { +// CHECK-NEXT: i: 0 +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: library { +// CHECK-NEXT: } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/parse_example.mlir b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/parse_example.mlir index 5f805636531..72dd164ea3c 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/parse_example.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/parse_example.mlir @@ -34,7 +34,7 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: attr { - // CHECK-NEXT: key: "dense_shapes" + // CHECK: key: "dense_shapes" // CHECK-NEXT: value { // CHECK-NEXT: list { // CHECK-NEXT: shape { diff --git a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/shape_list_attr.mlir b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/shape_list_attr.mlir index c56204c1cd4..d0a8a0e47de 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/shape_list_attr.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/shape_list_attr.mlir @@ -2,7 +2,7 @@ // CHECK: attr { -// CHECK-NEXT: key: "dtypes" +// CHECK: key: "dtypes" // CHECK-NEXT: value { // CHECK-NEXT: list { // CHECK-NEXT: type: DT_INT32 diff --git a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/simple.mlir b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/simple.mlir index 8f3d0b5c9ba..415c37e8fff 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/simple.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/simple.mlir @@ -5,7 +5,7 @@ func @main() { // CHECK-NEXT: name: "Empty/shape" // CHECK-NEXT: op: "Const" // CHECK: attr { - // CHECK-NEXT: key: "dtype" + // CHECK: key: "dtype" // CHECK-NEXT: value { // CHECK-NEXT: type: DT_INT32 // CHECK-NEXT: } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/simple_tf_dialect_op.mlir b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/simple_tf_dialect_op.mlir index 0ba7c90b244..725c1b9484e 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/simple_tf_dialect_op.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/simple_tf_dialect_op.mlir @@ -4,8 +4,8 @@ func @main() { ^bb0: // CHECK: name: "node_name" // CHECK-NEXT: op: "Const" - // CHECK: attr { - // CHECK-NEXT: key: "dtype" + // CHECK-NEXT: attr { + // CHECK: key: "dtype" // CHECK-NEXT: value { // CHECK-NEXT: type: DT_INT32 // CHECK-NEXT: } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/tf-gradient-attr.mlir b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/tf-gradient-attr.mlir index 329d5e77348..463c1fd63ec 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/tf-gradient-attr.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/tf-gradient-attr.mlir @@ -6,7 +6,7 @@ func @main() { // CHECK-NEXT: name: "Const" // CHECK-NEXT: op: "Const" // CHECK-NEXT: attr { - // CHECK-NEXT: key: "dtype" + // CHECK: key: "dtype" // CHECK-NEXT: value { // CHECK-NEXT: type: DT_FLOAT // CHECK-NEXT: } @@ -31,7 +31,7 @@ func @main() { // CHECK-NEXT: name: "foo" // CHECK-NEXT: op: "foo" // CHECK-NEXT: input: "Const" - // CHECK-NEXT: experimental_debug_info { + // CHECK: experimental_debug_info { // CHECK-NEXT: } // CHECK-NEXT: } %1:2 = tf_executor.island wraps "tf.foo"(%0#0) {device = ""} : (tensor) -> tensor<*xf32> loc("foo") diff --git a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/tf-legacy-call.mlir b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/tf-legacy-call.mlir index 3fa1f8001e4..5f92d789066 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/tf-legacy-call.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/tf-legacy-call.mlir @@ -18,6 +18,11 @@ func @foo0(%arg0: tensor<*xi32>) -> tensor<*xi32> { // CHECK: node { // CHECK: name: "tf.LegacyCall" // CHECK-NEXT: op: "foo0" +// CHECK: attr { +// CHECK-NEXT: key: "_output_shapes" +// CHECK-NEXT: value { +// CHECK-NEXT: list { +// CHECK-NEXT: shape { // CHECK: library { // CHECK-NEXT: function { diff --git a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/tf_add.mlir b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/tf_add.mlir index ed0b53407bc..29f7f860f1c 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/tf_add.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/tf_add.mlir @@ -14,8 +14,8 @@ attributes {tf.entry_function = {inputs = "input0,input1", outputs = "Add"}} { // CHECK: node { // CHECK-NEXT: name: "input0" // CHECK-NEXT: op: "Placeholder" -// CHECK-NEXT: attr { -// CHECK-NEXT: key: "dtype" +// CHECK-NEXT: attr { +// CHECK: key: "dtype" // CHECK-NEXT: value { // CHECK-NEXT: type: DT_INT32 // CHECK-NEXT: } @@ -36,8 +36,8 @@ attributes {tf.entry_function = {inputs = "input0,input1", outputs = "Add"}} { // CHECK-NEXT: node { // CHECK-NEXT: name: "input1" // CHECK-NEXT: op: "Placeholder" -// CHECK-NEXT: attr { -// CHECK-NEXT: key: "dtype" +// CHECK-NEXT: attr { +// CHECK: key: "dtype" // CHECK-NEXT: value { // CHECK-NEXT: type: DT_INT32 // CHECK-NEXT: } @@ -66,7 +66,7 @@ attributes {tf.entry_function = {inputs = "input0,input1", outputs = "Add"}} { // CHECK-NEXT: type: DT_INT32 // CHECK-NEXT: } // CHECK-NEXT: } -// CHECK-NEXT: experimental_debug_info { +// CHECK: experimental_debug_info { // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: node { diff --git a/tensorflow/compiler/mlir/tensorflow/tests/replicate_invariant_op_hoisting.mlir b/tensorflow/compiler/mlir/tensorflow/tests/replicate_invariant_op_hoisting.mlir index 4e3564fb6a0..3c7d0e6c644 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/replicate_invariant_op_hoisting.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/replicate_invariant_op_hoisting.mlir @@ -124,3 +124,36 @@ func @nested_ops(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) { // CHECK-NEXT: tf_device.return %[[OP_C]] // CHECK-NEXT: device = "c" // CHECK-NEXT: tf_device.return %[[SHAPE]], %[[LAUNCH_A]], %[[LAUNCH_B]], %[[LAUNCH_C]] + + +// CHECK-LABEL: func @do_not_hoist_ops_with_virtual_device +// CHECK-SAME: [[VAL_0:%.*]]: tensor<*xf32>, [[VAL_1:%.*]]: tensor<*xf32>) +func @do_not_hoist_ops_with_virtual_device(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) { + %0:8 = tf_device.replicate([%arg0, %arg1] as %ri: tensor<*xf32>) {devices = {TPU_REPLICATED_CORE_0 = ["/device:TPU:0", "/device:TPU:1"]}, n = 2: i32} { + %1 = "tf.Shape"(%ri) {device = "", T = "tfdtype$DT_FLOAT", out_type = "tfdtype$DT_INT32"} : (tensor<*xf32>) -> tensor + %2 = "tf.opA"(%1) {device = "TPU_REPLICATED_CORE_0"} : (tensor) -> tensor<*xi32> + %3 = "tf_device.launch"() ( { + %b = "tf.opB"(%1) : (tensor) -> tensor<*xi32> + tf_device.return %b : tensor<*xi32> + }) {device = "TPU_REPLICATED_CORE_0"} : () -> tensor<*xi32> + %4 = "tf_device.launch"() ( { + %c = "tf.opC"(%1) {device = "TPU_REPLICATED_CORE_0"} : (tensor) -> tensor<*xi32> + tf_device.return %c : tensor<*xi32> + }) {device = "c"} : () -> tensor<*xi32> + tf_device.return %1, %2, %3, %4 : tensor, tensor<*xi32>, tensor<*xi32>, tensor<*xi32> + } + return +} + +// CHECK: [[SHAPE:%.*]] = "tf.Shape"([[VAL_0]]) +// CHECK: tf_device.replicate({{\[}}[[VAL_0]], [[VAL_1]]] as [[VAL_4:%.*]]: tensor<*xf32>) {devices = {TPU_REPLICATED_CORE_0 = ["/device:TPU:0", "/device:TPU:1"]}, n = 2 : i32} { +// CHECK: [[OP_A:%.*]] = "tf.opA"([[SHAPE]]) {device = "TPU_REPLICATED_CORE_0"} : (tensor) -> tensor<*xi32> +// CHECK: [[LAUNCH_B:%.*]] = "tf_device.launch"() ( { +// CHECK: [[OP_B:%.*]] = "tf.opB"([[SHAPE]]) : (tensor) -> tensor<*xi32> +// CHECK: tf_device.return [[OP_B]] : tensor<*xi32> +// CHECK: }) {device = "TPU_REPLICATED_CORE_0"} : () -> tensor<*xi32> +// CHECK: [[LAUNCH_C:%.*]] = "tf_device.launch"() ( { +// CHECK: [[OP_C:%.*]] = "tf.opC"([[SHAPE]]) {device = "TPU_REPLICATED_CORE_0"} : (tensor) -> tensor<*xi32> +// CHECK: tf_device.return [[OP_C]] : tensor<*xi32> +// CHECK: }) {device = "c"} : () -> tensor<*xi32> +// CHECK: tf_device.return [[SHAPE]], [[OP_A]], [[LAUNCH_B]], [[LAUNCH_C]] diff --git a/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir b/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir index 706524e39a1..99b8823f2bb 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir @@ -183,7 +183,7 @@ func @multiple_blocks_one_return(%arg0: tensor) -> tensor<*xf32> { // CHECK-LABEL: func @reused_if_then_branch // CHECK-SAME: (%arg0: tensor<*xf32>) -> tensor<*xf32> - // expected-error @+1 {{expected control flow function reused_if_then_branch to have exactly 1 use}} + // expected-warning @+1 {{expected control flow function reused_if_then_branch to have exactly 1 use}} func @reused_if_then_branch(%arg0: tensor<*xf32>) -> tensor<*xf32> { // CHECK: return // CHECK-SAME: tensor<*xf32> @@ -192,7 +192,7 @@ func @multiple_blocks_one_return(%arg0: tensor) -> tensor<*xf32> { // CHECK-LABEL: func @reused_if_else_branch // CHECK-SAME: (%arg0: tensor<*xf32>) -> tensor<*xf32> - // expected-error @+1 {{expected control flow function reused_if_else_branch to have exactly 1 use}} + // expected-warning @+1 {{expected control flow function reused_if_else_branch to have exactly 1 use}} func @reused_if_else_branch(%arg0: tensor<*xf32>) -> tensor<*xf32> { // CHECK: "tf.Identity"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> %0 = "tf.Identity"(%arg0) : (tensor<*xf32>) -> (tensor<*xf32>) @@ -278,4 +278,23 @@ func @multiple_blocks_one_return(%arg0: tensor) -> tensor<*xf32> { func @variant_body_func(%arg0: tensor>>) -> tensor>> { return %arg0 : tensor>> } + + // Test propagation from called functions to the call site. + // CHECK-LABEL: func @stateful_partitioned_call( + // CHECK-SAME: -> tensor<20xi32> + func @stateful_partitioned_call(%arg0: tensor<20xi32>) -> tensor<*xi32> { + // CHECK: tf.PartitionedCall + // CHECK-SAME: (tensor<20xi32>) -> tensor<20xi32> + %0 = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @a_called_func} : (tensor<20xi32>) -> (tensor<*xi32>) + // CHECK: tf.StatefulPartitionedCall + // CHECK-SAME: (tensor<20xi32>) -> tensor<20xi32> + %1 = "tf.StatefulPartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @stateful_partitioned_call_func} : (tensor<20xi32>) -> (tensor<*xi32>) + return %0 : tensor<*xi32> + } + func @a_called_func(%arg0: tensor) -> (tensor) { + return %arg0 : tensor + } + func @stateful_partitioned_call_func(%arg0: tensor) -> (tensor) { + return %arg0 : tensor + } } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tensor_array_ops_decomposition.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tensor_array_ops_decomposition.mlir index 35cfb19a80b..1a13338b0ba 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tensor_array_ops_decomposition.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tensor_array_ops_decomposition.mlir @@ -187,3 +187,199 @@ func @main() { %write3 = "tf.TensorArrayWriteV3"(%grad3#0, %index, %value, %grad3#1) : (tensor, tensor, tensor<3xf32>, tensor) -> tensor return } + +// ----- + +// Tests while loop with access to the tensor array defined outside and its +// gradient defined inside. The gradient creation should be moved outside. + +// CHECK-LABEL: func @main +func @main() -> () { + // CHECK: %[[SIZE:.*]] = "tf.Const"() {value = dense<5> : tensor} : () -> tensor + %size = "tf.Const"() {value = dense<5> : tensor} : () -> tensor + %index = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + // CHECK: %[[VAR:.*]] = "tf.MlirLocalVarOp"() : () -> tensor>> + // CHECK: %[[GVAR:.*]] = "tf.MlirLocalVarOp"() : () -> tensor>> + %ta:2 = "tf.TensorArrayV3"(%size) {dtype = f32, element_shape = "tfshape$dim { size: 3 }", dynamic_size = false, clear_after_read = true, identical_element_shapes = true, tensor_array_name = "ta"} : (tensor) -> (tensor, tensor) + // CHECK: "tf.While"(%[[VAR]], %[[SIZE]], %[[GVAR]]) + %1:2 = "tf.While"(%ta#0, %size) { + body = @while_body, cond = @while_cond, device = "", is_stateless = false} + : (tensor, tensor) -> (tensor, tensor) + // CHECK: %[[READ:.*]] = "tf.ReadVariableOp"(%[[VAR]]) : (tensor>>) -> tensor<5x3xf32> + // CHECK: "tf.Slice"(%[[READ]], + %read = "tf.TensorArrayReadV3"(%1#0, %index, %ta#1) : (tensor, tensor, tensor) -> tensor<3xf32> + return +} +// CHECK: func @while_body(%[[BARG0:.*]]: tensor>>, %[[BARG1:.*]]: tensor, %[[BARG2:.*]]: tensor>>) +func @while_body(%arg0: tensor, %arg1: tensor) -> (tensor, tensor) { + // CHECK: %[[CONST1:.*]] = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %const1 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + // CHECK: %[[SUB:.*]] = "tf.Sub"(%[[BARG1]], %[[CONST1]]) + %sub = "tf.Sub"(%arg1, %const1) : (tensor, tensor) -> tensor + %elem = "tf._SomeOp"() : () -> tensor<3xf32> + %flow = "tf.Const"() {value = dense<1.0> : tensor} : () -> tensor + // CHECK: %[[READ1:.*]] = "tf.ReadVariableOp"(%[[BARG0]]) : (tensor>>) -> tensor<5x3xf32> + // CHECK: %[[UPDATE1:.*]] = "tf.XlaDynamicUpdateSlice"(%[[READ1]], + // CHECK: "tf.AssignVariableOp"(%[[BARG0]], %[[UPDATE1]]) + %write = "tf.TensorArrayWriteV3"(%arg0, %sub, %elem, %flow) : (tensor, tensor, tensor<3xf32>, tensor) -> tensor + %grad:2 = "tf.TensorArrayGradV3"(%arg0, %write) {source = "a"} : (tensor, tensor) -> (tensor, tensor) + // CHECK: %[[READ2:.*]] = "tf.ReadVariableOp"(%[[BARG2]]) : (tensor>>) -> tensor<5x3xf32> + // CHECK: %[[UPDATE2:.*]] = "tf.XlaDynamicUpdateSlice"(%[[READ2]], + // CHECK: "tf.AssignVariableOp"(%[[BARG2]], %[[UPDATE2]]) + %gwrite = "tf.TensorArrayWriteV3"(%grad#0, %sub, %elem, %grad#1) : (tensor, tensor, tensor<3xf32>, tensor) -> tensor + // CHECK: return %[[BARG0]], %[[SUB]], %[[BARG2]] + return %arg0, %sub : tensor, tensor +} +// CHECK: func @while_cond(%[[CARG0:.*]]: tensor>>, %[[CARG1:.*]]: tensor, %[[CARG2:.*]]: tensor>>) +func @while_cond(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK-NEXT: return %[[CARG1]] + return %arg1 : tensor +} + +// ----- + +// Tests If op with access to the tensor array defined outside and its gradient +// defined inside. The gradient creation should be moved outside. + +// CHECK-LABEL: func @main +func @main() -> () { + // CHECK: %[[SIZE:.*]] = "tf.Const"() {value = dense<5> : tensor} : () -> tensor + %size = "tf.Const"() {value = dense<5> : tensor} : () -> tensor + %index = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + // CHECK: %[[VAR:.*]] = "tf.MlirLocalVarOp"() : () -> tensor>> + %ta:2 = "tf.TensorArrayV3"(%size) {dtype = f32, element_shape = "tfshape$dim { size: 3 }", dynamic_size = false, clear_after_read = true, identical_element_shapes = true, tensor_array_name = "ta"} : (tensor) -> (tensor, tensor) + // CHECK: %[[COND:.*]] = "tf._SomeOp"() : () -> tensor + %cond = "tf._SomeOp"() : () -> tensor + // CHECK: %[[GVAR1:.*]] = "tf.MlirLocalVarOp"() : () -> tensor>> + // CHECK: %[[GVAR2:.*]] = "tf.MlirLocalVarOp"() : () -> tensor>> + // CHECK: "tf.If"(%[[COND]], %[[VAR]], %[[GVAR1]], %[[GVAR2]]) + %1 = "tf.If"(%cond, %ta#0) { + then_branch = @then_branch, else_branch = @else_branch, device = "", is_stateless = false} + : (tensor, tensor) -> tensor + // CHECK: %[[READ:.*]] = "tf.ReadVariableOp"(%[[VAR]]) : (tensor>>) -> tensor<5x3xf32> + // CHECK: "tf.Slice"(%[[READ]], + %read = "tf.TensorArrayReadV3"(%1, %index, %ta#1) : (tensor, tensor, tensor) -> tensor<3xf32> + return +} +// CHECK: func @then_branch(%[[TARG0:.*]]: tensor>>, %[[TARG1:.*]]: tensor>>, %[[TARG2:.*]]: tensor>>) +func @then_branch(%arg0: tensor) -> tensor { + %const1 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %elem = "tf._SomeOp"() : () -> tensor<3xf32> + %flow = "tf.Const"() {value = dense<1.0> : tensor} : () -> tensor + // CHECK: %[[READ1:.*]] = "tf.ReadVariableOp"(%[[TARG1]]) : (tensor>>) -> tensor<5x3xf32> + // CHECK: %[[UPDATE1:.*]] = "tf.XlaDynamicUpdateSlice"(%[[READ1]], + // CHECK: "tf.AssignVariableOp"(%[[TARG1]], %[[UPDATE1]]) + %grad:2 = "tf.TensorArrayGradV3"(%arg0, %flow) {source = "a"} : (tensor, tensor) -> (tensor, tensor) + %gwrite = "tf.TensorArrayWriteV3"(%grad#0, %const1, %elem, %grad#1) : (tensor, tensor, tensor<3xf32>, tensor) -> tensor + // CHECK: return %[[TARG0]] + return %arg0 : tensor +} +// CHECK: func @else_branch(%[[EARG0:.*]]: tensor>>, %[[EARG1:.*]]: tensor>>, %[[EARG2:.*]]: tensor>>) +func @else_branch(%arg0: tensor) -> tensor { + // CHECK: %[[CONST1:.*]] = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %const1 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %elem = "tf._SomeOp"() : () -> tensor<3xf32> + %flow = "tf.Const"() {value = dense<1.0> : tensor} : () -> tensor + // CHECK: %[[READ2:.*]] = "tf.ReadVariableOp"(%[[EARG2]]) : (tensor>>) -> tensor<5x3xf32> + // CHECK: %[[UPDATE2:.*]] = "tf.XlaDynamicUpdateSlice"(%[[READ2]], + // CHECK: "tf.AssignVariableOp"(%[[EARG2]], %[[UPDATE2]]) + %grad:2 = "tf.TensorArrayGradV3"(%arg0, %flow) {source = "b"} : (tensor, tensor) -> (tensor, tensor) + %gwrite = "tf.TensorArrayWriteV3"(%grad#0, %const1, %elem, %grad#1) : (tensor, tensor, tensor<3xf32>, tensor) -> tensor + // CHECK: return %[[EARG0]] + return %arg0 : tensor +} + +// ----- + +// Tests (Stateful)PartitionedCall op with access to the tensor array defined +// outside and its gradient defined inside. The gradient creation should be +// moved outside. + +// CHECK-LABEL: func @main +func @main() -> () { + // CHECK: %[[SIZE:.*]] = "tf.Const"() {value = dense<5> : tensor} : () -> tensor + %size = "tf.Const"() {value = dense<5> : tensor} : () -> tensor + %index = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + // CHECK: %[[VAR:.*]] = "tf.MlirLocalVarOp"() : () -> tensor>> + %ta:2 = "tf.TensorArrayV3"(%size) {dtype = f32, element_shape = "tfshape$dim { size: 3 }", dynamic_size = false, clear_after_read = true, identical_element_shapes = true, tensor_array_name = "ta"} : (tensor) -> (tensor, tensor) + // CHECK: %[[COND:.*]] = "tf._SomeOp"() : () -> tensor + %cond = "tf._SomeOp"() : () -> tensor + // CHECK: %[[GVAR1:.*]] = "tf.MlirLocalVarOp"() : () -> tensor>> + %grad:2 = "tf.TensorArrayGradV3"(%ta#0, %ta#1) {source = "a"} : (tensor, tensor) -> (tensor, tensor) + // CHECK: %[[GVAR2:.*]] = "tf.MlirLocalVarOp"() : () -> tensor>> + // CHECK: "tf.StatefulPartitionedCall"(%[[VAR]], %[[GVAR1]], %[[GVAR2]]) + // CHECK-SAME: f = @callee_tensorarray_decomposed + %call = "tf.StatefulPartitionedCall"(%ta#0) {f = @callee, config = "", config_proto = "", executor_type = ""} + : (tensor) -> tensor + // CHECK: "tf.PartitionedCall"(%[[VAR]], %[[GVAR1]], %[[GVAR2]]) + // CHECK-SAME: f = @callee_tensorarray_decomposed + %call2 = "tf.PartitionedCall"(%call) {f = @callee, config = "", config_proto = "", executor_type = ""} + : (tensor) -> tensor + // CHECK: %[[READ:.*]] = "tf.ReadVariableOp"(%[[VAR]]) : (tensor>>) -> tensor<5x3xf32> + // CHECK: "tf.Slice"(%[[READ]], + %read = "tf.TensorArrayReadV3"(%call2, %index, %ta#1) : (tensor, tensor, tensor) -> tensor<3xf32> + return +} +// CHECK-LABEL: func @callee +// CHECK-SAME: (%[[OCARG0:.*]]: tensor) -> tensor +func @callee(%arg0: tensor) -> tensor { + %const1 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %elem = "tf._SomeOp"() : () -> tensor<3xf32> + %flow = "tf.Const"() {value = dense<1.0> : tensor} : () -> tensor + %grad:2 = "tf.TensorArrayGradV3"(%arg0, %flow) {source = "a"} : (tensor, tensor) -> (tensor, tensor) + %gwrite = "tf.TensorArrayWriteV3"(%grad#0, %const1, %elem, %grad#1) : (tensor, tensor, tensor<3xf32>, tensor) -> tensor + %grad2:2 = "tf.TensorArrayGradV3"(%arg0, %flow) {source = "b"} : (tensor, tensor) -> (tensor, tensor) + %gwrite2 = "tf.TensorArrayWriteV3"(%grad2#0, %const1, %elem, %grad2#1) : (tensor, tensor, tensor<3xf32>, tensor) -> tensor + return %arg0 : tensor +} +// CHECK: func @callee_tensorarray_decomposed(%[[CARG0:.*]]: tensor>>, %[[CARG1:.*]]: tensor>>, %[[CARG2:.*]]: tensor>>) +// CHECK: %[[READ1:.*]] = "tf.ReadVariableOp"(%[[CARG1]]) : (tensor>>) -> tensor<5x3xf32> +// CHECK: %[[UPDATE1:.*]] = "tf.XlaDynamicUpdateSlice"(%[[READ1]], +// CHECK: "tf.AssignVariableOp"(%[[CARG1]], %[[UPDATE1]]) +// CHECK: %[[READ2:.*]] = "tf.ReadVariableOp"(%[[CARG2]]) : (tensor>>) -> tensor<5x3xf32> +// CHECK: %[[UPDATE2:.*]] = "tf.XlaDynamicUpdateSlice"(%[[READ2]], +// CHECK: "tf.AssignVariableOp"(%[[CARG2]], %[[UPDATE2]]) +// CHECK: return %[[CARG0]] + +// ----- + +// Test the pass reports failure on unknown size. + +func @main(%arg0: tensor) -> () { + // expected-error @+1 {{unknown max element count}} + %ta:2 = "tf.TensorArrayV3"(%arg0) {dtype = f32, element_shape = "tfshape$dim { size: 3 }", dynamic_size = false, clear_after_read = true, identical_element_shapes = true, tensor_array_name = "ta"} : (tensor) -> (tensor, tensor) + return +} + +// ----- + +// Test the pass reports failure on unknown shape. + +func @main(%arg0: tensor) -> () { + %size = "tf.Const"() {value = dense<5> : tensor} : () -> tensor + // expected-error @+1 {{unknown element shape}} + %ta:2 = "tf.TensorArrayV3"(%size) {dtype = f32, element_shape = "tfshape$unknown_rank: true", dynamic_size = false, clear_after_read = true, identical_element_shapes = true, tensor_array_name = "ta"} : (tensor) -> (tensor, tensor) + return +} + +// ----- + +// Tests that the pass reports error on ambiguous tensor array. + +func @main(%arg0: tensor) -> () { + %size = "tf.Const"() {value = dense<10> : tensor} : () -> tensor + %ta0:2 = "tf.TensorArrayV3"(%size) {dtype = f32, element_shape = "tfshape$dim { size: 3 }", dynamic_size = false, clear_after_read = true, identical_element_shapes = true, tensor_array_name = "ta"} : (tensor) -> (tensor, tensor) + %ta1:2 = "tf.TensorArrayV3"(%size) {dtype = f32, element_shape = "tfshape$dim { size: 3 }", dynamic_size = false, clear_after_read = true, identical_element_shapes = true, tensor_array_name = "ta"} : (tensor) -> (tensor, tensor) + %if_op = "tf.If"(%arg0, %ta0#0, %ta1#0) {then_branch = @if_then, else_branch = @if_else, is_stateless = false} + : (tensor, tensor, tensor) -> tensor + %index = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + // expected-error @+1 {{unknown tensor array}} + %read = "tf.TensorArrayReadV3"(%if_op, %index, %ta0#1) : (tensor, tensor, tensor) -> tensor<3xf32> + return +} +func @if_then(%arg0: tensor, %arg1: tensor) -> tensor { + return %arg0 : tensor +} +func @if_else(%arg0: tensor, %arg1: tensor) -> tensor { + return %arg1 : tensor +} diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/einsum.cc b/tensorflow/compiler/mlir/tensorflow/transforms/einsum.cc index 115b8938975..483da1c70f7 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/einsum.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/einsum.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/transforms/einsum.h" +#include #include #include @@ -49,6 +50,9 @@ enum EinsumEquation { FourDMatrixDotProd, ThreeDReshapeTail, FourDBatchMatMul, + BroadcastMatMul, + ReduceSum, + TransposeMatMul, UnsupportedEquation }; @@ -121,6 +125,18 @@ EinsumEquation parseEquation(const std::vector& eqn) { if (is_equal(eqn, {A, B, C, COMMA, C, D, E, ARROW, A, B, D, E})) { return EinsumEquation::ThreeDReshapeTail; } + // BFH,HO->BFO + if (is_equal(eqn, {A, B, C, COMMA, C, D, ARROW, A, B, D})) { + return EinsumEquation::BroadcastMatMul; + } + // LBH,BL->BH + if (is_equal(eqn, {A, B, C, COMMA, B, A, ARROW, B, C})) { + return EinsumEquation::ReduceSum; + } + // LBH,BKL->BKH + if (is_equal(eqn, {A, B, C, COMMA, B, D, A, ARROW, B, D, C})) { + return EinsumEquation::TransposeMatMul; + } return EinsumEquation::UnsupportedEquation; } @@ -151,6 +167,28 @@ TF::TransposeOp createTransposeOp(Value value, Location loc, perm_op); } +TF::SumOp createSumOp(Value value, Location loc, + llvm::ArrayRef redux_axes, + PatternRewriter* rewriter) { + auto value_type = value.getType().cast(); + auto shape = value_type.getShape(); + auto redux_type = RankedTensorType::get( + {static_cast(redux_axes.size())}, rewriter->getIntegerType(32)); + auto redux_attr = DenseElementsAttr::get(redux_type, redux_axes); + auto redux_op = rewriter->create(loc, redux_type, redux_attr); + std::vector sum_shape(shape.size() - redux_axes.size()); + int count = 0; + for (int i = 0; i < shape.size(); ++i) { + if (std::find(redux_axes.begin(), redux_axes.end(), i) == + redux_axes.end()) { + sum_shape[count] = shape[i]; + count++; + } + } + auto sum_type = RankedTensorType::get(sum_shape, value_type.getElementType()); + return rewriter->create(loc, sum_type, value, redux_op); +} + TF::ReshapeOp createReshapeOp(Value value, ArrayRef shape, Type element_type, Location loc, PatternRewriter* rewriter) { @@ -173,7 +211,6 @@ LogicalResult ConvertTFEinsumOp::matchAndRewrite( Value lhs = op.getOperand(0); Value rhs = op.getOperand(1); Location loc = op.getLoc(); - if (!lhs.getType().isa()) { // LHS must be a ranked tensor type return failure(); @@ -193,10 +230,10 @@ LogicalResult ConvertTFEinsumOp::matchAndRewrite( return failure(); } - // Currently support use cases of LHS, RHS dims = 3 or 4 + // Currently support use cases of LHS dims \in {3,4} RHS dims \in {2, 3, 4} const int dims_lhs = lhs_shape.size(); const int dims_rhs = rhs_shape.size(); - if (dims_rhs < 3 || dims_rhs > 4 || dims_lhs < 3 || dims_lhs > 4) { + if (dims_lhs < 3 || dims_lhs > 4 || dims_rhs < 2 || dims_lhs > 4) { return failure(); } @@ -209,6 +246,46 @@ LogicalResult ConvertTFEinsumOp::matchAndRewrite( rewriter.replaceOp(op, bmm_op.getResult()); return success(); } + if (einsum_eqn == EinsumEquation::BroadcastMatMul) { + // Case "BFH,HO->BFO" + auto bmm_op = rewriter.create( + loc, ArrayRef{output_type}, lhs, rhs, rewriter.getBoolAttr(false), + rewriter.getBoolAttr(false)); + rewriter.replaceOp(op, bmm_op.getResult()); + return success(); + } + if (einsum_eqn == EinsumEquation::ReduceSum) { + // Case "LBH,BL->BH" + // Transpose LHS + lhs = createTransposeOp(lhs, loc, {1, 2, 0}, &rewriter); + // Reshape RHS + auto rhs_element_type = rhs_type.getElementType(); + const int rhs_dim0 = rhs_shape[0]; + const int rhs_dim1 = rhs_shape[1]; + auto reshaped_rhs = createReshapeOp(rhs, {rhs_dim0, 1, rhs_dim1}, + rhs_element_type, loc, &rewriter); + auto mul_op = rewriter.create(loc, lhs, reshaped_rhs); + + auto sum_op = createSumOp(mul_op, loc, {2}, &rewriter); + rewriter.replaceOp(op, {sum_op.getResult()}); + return success(); + } + if (einsum_eqn == EinsumEquation::TransposeMatMul) { + // Case "LBH,BKL->BKH" + // Transpose LHS + lhs = createTransposeOp(lhs, loc, {1, 2, 0}, &rewriter); + // Transpose RHS + rhs = createTransposeOp(rhs, loc, {0, 2, 1}, &rewriter); + std::vector bmm_shape = {lhs_shape[1], lhs_shape[2], rhs_shape[1]}; + auto bmm_type = RankedTensorType::get(bmm_shape, rhs_type.getElementType()); + auto bmm_op = rewriter.create( + loc, ArrayRef{bmm_type}, lhs, rhs, rewriter.getBoolAttr(false), + rewriter.getBoolAttr(false)); + + auto trans_bmm = createTransposeOp(bmm_op, loc, {0, 2, 1}, &rewriter); + rewriter.replaceOp(op, {trans_bmm.getResult()}); + return success(); + } if (einsum_eqn == EinsumEquation::ThreeDReshapeTail) { // Case "BFD,DNH->BFNH" auto lhs_type = lhs.getType().cast(); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/fold_switch.cc b/tensorflow/compiler/mlir/tensorflow/transforms/fold_switch.cc index 7d0e7e20e5d..30444b88677 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/fold_switch.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/fold_switch.cc @@ -67,7 +67,7 @@ class SwitchFoldPass : public mlir::FunctionPass { // Returns the defining op for a value looking through islands. static Operation* GetDefiningOp(Value val) { Operation* op = val.getDefiningOp(); - auto island_op = dyn_cast(op); + auto island_op = dyn_cast_or_null(op); if (!island_op) return op; auto yield_op = island_op.GetYield(); auto index = val.cast().getResultNumber(); @@ -84,7 +84,8 @@ static Operation* GetDefiningOp(Value val) { static Value LookThroughIdentityOp(Value pred_val) { if (!pred_val) return pred_val; auto op = GetDefiningOp(pred_val); - if (auto id_op = dyn_cast(op)) pred_val = id_op.input(); + if (auto id_op = dyn_cast_or_null(op)) + pred_val = id_op.input(); return pred_val; } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc b/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc index c3a0b1e303a..0ec30f44ce7 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc @@ -40,6 +40,20 @@ class LegalizeHloToTf : public FunctionPass { void runOnFunction() override; }; +// Returns whether the two values are guaranteed to be broadcastable to the +// same shape, this broadcasts size 1 tensors up to any rank. +// TODO(jpienaar): Move this to more general location. +static bool AreBroadcastCompatible(Value x, Value y) { + auto x_ranked = x.getType().dyn_cast(); + auto y_ranked = y.getType().dyn_cast(); + if (!x_ranked || !y_ranked) { + return true; + } + SmallVector resultShape; + return OpTrait::util::getBroadcastedShape(x_ranked.getShape(), + y_ranked.getShape(), resultShape); +} + #include "tensorflow/compiler/mlir/tensorflow/transforms/generated_legalize_hlo.inc" /// Performs the lowering to XLA dialect. diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo_patterns.td b/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo_patterns.td index bc4dd24f498..8a71005bf70 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo_patterns.td +++ b/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo_patterns.td @@ -20,14 +20,16 @@ include "mlir/Dialect/StandardOps/IR/Ops.td" include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td" include "tensorflow/compiler/mlir/xla/ir/hlo_ops.td" -def SignedIntTensor : TensorOf<[I1, I8, I16, I32, I64]>; +def : Pat<(HLO_ConstOp $value), (TF_ConstOp $value)>; //===----------------------------------------------------------------------===// // Binary op patterns. //===----------------------------------------------------------------------===// -class DirectBinaryPat - : Pat<(FromOp $l, $r, $_), (ToOp $l, $r)>; +// Check that two values can be broadcasted together +// TODO(jpienaar): Move somewhere more general +def AreBroadcastCompatible : Constraint, + "types must be broadcastable">; foreach fromToBinPair = [[HLO_AddOp, TF_AddV2Op], [HLO_DivOp, TF_DivOp], @@ -37,24 +39,41 @@ foreach fromToBinPair = [[HLO_AddOp, TF_AddV2Op], [HLO_MulOp, TF_MulOp], [HLO_PowOp, TF_PowOp], [HLO_DivOp, TF_RealDivOp], - [HLO_SubOp, TF_SubOp]] in - def : DirectBinaryPat; + [HLO_SubOp, TF_SubOp], + [HLO_Atan2Op, TF_Atan2Op], + [HLO_RemOp, TF_ModOp]] in + def : Pat<(fromToBinPair[0] $l, $r, $_), (fromToBinPair[1] $l, $r), + [(AreBroadcastCompatible $l, $r)]>; -def LowerRightShiftSigned : - Pat<(HLO_ShiftRightArithmeticOp $l, $r, $_), (TF_RightShiftOp $l, $r), - [(SignedIntTensor $r)]>; +foreach pair = [[HLO_AndOp, TF_BitwiseAndOp], + [HLO_OrOp, TF_BitwiseOrOp], + [HLO_XorOp, TF_BitwiseXorOp]] in + def : Pat<(pair[0] TF_IntTensor:$l, TF_IntTensor:$r, $_), (pair[1] $l, $r), + [(AreBroadcastCompatible $l, $r)]>; -def : Pat<(HLO_FloorOp (HLO_DivOp $l, $r, $_)), (TF_FloorDivOp $l, $r)>; +foreach pair = [[HLO_AndOp, TF_LogicalAndOp], + [HLO_OrOp, TF_LogicalOrOp]] in + def : Pat<(pair[0] I1Tensor:$l, I1Tensor:$r, $_), (pair[1] $l, $r), + [(AreBroadcastCompatible $l, $r)]>; + +def : Pat<(HLO_ShiftRightArithmeticOp $l, $r, $_), (TF_RightShiftOp $l, $r), + [(AreBroadcastCompatible $l, $r)]>; +def : Pat<(HLO_ShiftRightLogicalOp $l, $r, $_), (TF_RightShiftOp $l, $r), + [(AreBroadcastCompatible $l, $r)]>; + +def : Pat<(HLO_FloorOp (HLO_DivOp $l, $r, $_)), (TF_FloorDivOp $l, $r), + [(AreBroadcastCompatible $l, $r)]>; //===----------------------------------------------------------------------===// // Unary op patterns. //===----------------------------------------------------------------------===// -foreach Mapping = [ - [HLO_AbsOp, TF_AbsOp], +foreach Mapping = [[HLO_AbsOp, TF_AbsOp], + [HLO_BitcastConvertOp, TF_BitcastOp], [HLO_CeilOp, TF_CeilOp], [HLO_CosOp, TF_CosOp], [HLO_ExpOp, TF_ExpOp], + [HLO_Expm1Op, TF_Expm1Op], [HLO_FloorOp, TF_FloorOp], [HLO_ImagOp, TF_ImagOp], [HLO_IsFiniteOp, TF_IsFiniteOp], @@ -65,8 +84,46 @@ foreach Mapping = [ [HLO_RealOp, TF_RealOp], [HLO_RsqrtOp, TF_RsqrtOp], [HLO_SinOp, TF_SinOp], + [HLO_SignOp, TF_SignOp], [HLO_SqrtOp, TF_SqrtOp], - [HLO_TanhOp, TF_TanhOp], - ] in { - def : Pat<(Mapping[0] $input), (Mapping[1] $input)>; -} + [HLO_TanhOp, TF_TanhOp]] in + def : Pat<(Mapping[0] TF_IntOrFpTensor:$input), (Mapping[1] $input)>; + +def : Pat<(HLO_AbsOp TF_ComplexTensor:$arg), (TF_ComplexAbsOp $arg)>; + +def : Pat<(HLO_BroadcastOp $arg, $shape), + (TF_BroadcastToOp $arg, (TF_ConstOp $shape))>; +def : Pat<(HLO_TransposeOp $arg, $permutation), + (TF_TransposeOp $arg, (TF_ConstOp $permutation))>; +def : Pat<(HLO_ReverseOp $op, $dims), (TF_ReverseV2Op $op, (TF_ConstOp $dims))>; + +//===----------------------------------------------------------------------===// +// Ternary op patterns. +//===----------------------------------------------------------------------===// + +def : Pat<(HLO_ClampOp $min, $arg, $max), + (TF_MaximumOp (TF_MinimumOp $arg, $max), $min)>; +def : Pat<(HLO_SelectOp $cond, $t, $e), (TF_SelectOp $cond, $t, $e)>; + +//===----------------------------------------------------------------------===// +// Variadic op patterns. +//===----------------------------------------------------------------------===// + +def : Pat<(HLO_ConcatenateOp $inputs, $dim), + (TF_ConcatV2Op $inputs, (TF_ConstOp $dim))>; + +//===----------------------------------------------------------------------===// +// Compare op patterns. +//===----------------------------------------------------------------------===// + +foreach p = [[TF_EqualOp, HLO_COMPARISON_DIRECTION_EQ], + [TF_NotEqualOp, HLO_COMPARISON_DIRECTION_NE]] in + def : Pat<(HLO_CompareOp $l, $r, $_, p[1]), (p[0] $l, $r, ConstBoolAttrTrue), + [(AreBroadcastCompatible $l, $r)]>; + +foreach pair = [[TF_GreaterEqualOp, HLO_COMPARISON_DIRECTION_GE], + [TF_GreaterOp, HLO_COMPARISON_DIRECTION_GT], + [TF_LessEqualOp, HLO_COMPARISON_DIRECTION_LE], + [TF_LessOp, HLO_COMPARISON_DIRECTION_LT]] in + def : Pat<(HLO_CompareOp $l, $r, $_, pair[1]), (pair[0] $l, $r), + [(AreBroadcastCompatible $l, $r)]>; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/replicate_invariant_op_hoisting.cc b/tensorflow/compiler/mlir/tensorflow/transforms/replicate_invariant_op_hoisting.cc index 4d836cd056a..5c21e1bffcc 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/replicate_invariant_op_hoisting.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/replicate_invariant_op_hoisting.cc @@ -33,6 +33,9 @@ namespace mlir { namespace TFDevice { namespace { + +constexpr char kDeviceAttr[] = "device"; + struct ReplicateInvariantOpHoistingPass : public FunctionPass { void runOnFunction() override; @@ -109,6 +112,22 @@ void MakeShapeOpInvariant(tf_device::ReplicateOp replicate_op, int num_replicas, } } +// Check if op uses a device from a list of virtual devices. +bool UsesVirtualDevice(const Optional& virtual_devices, + Operation* operation) { + if (!virtual_devices.hasValue()) return false; + + auto result = operation->walk([&](Operation* op) { + StringAttr op_device = op->getAttrOfType(kDeviceAttr); + if (!op_device) return WalkResult::advance(); + + if (virtual_devices.getValue().get(op_device.getValue())) + return WalkResult::interrupt(); + return WalkResult::advance(); + }); + return result.wasInterrupted(); +} + // Checks if op and inner op operands are all replicate invariant. bool IsOpReplicateInvariant(Region* replicate_region, Operation* op) { auto ancestor_of_replicate = [&](Region* region) { @@ -140,9 +159,13 @@ void HoistReplicateInvariantOps(tf_device::ReplicateOp replicate_op) { }); Region* replicate_region = &replicate_op.body(); + Optional virtual_device_list = replicate_op.devices(); for (Operation& inner_op : llvm::make_early_inc_range(replicate_op.GetBody())) { if (llvm::isa(inner_op)) continue; + // Skip hoisting if the inner op device attribute is a virtual device + // defined by tf_device.replicate. + if (UsesVirtualDevice(virtual_device_list, &inner_op)) continue; if (IsOpReplicateInvariant(replicate_region, &inner_op)) inner_op.moveBefore(replicate_op); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/replicate_to_island.cc b/tensorflow/compiler/mlir/tensorflow/transforms/replicate_to_island.cc index 5cec4c0ed66..0b41225e503 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/replicate_to_island.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/replicate_to_island.cc @@ -163,9 +163,6 @@ LogicalResult CreateIslandsFromReplicate(const Dialect* tf_dialect, tf_device::ReplicateOp replicate_op) { OpBuilder builder(island_op); const int num_replicas = replicate_op.n().getLimitedValue(); - if (!replicate_op.GetBody().getOps().empty()) - return replicate_op.emitError() - << "TPU computation with multiple logical cores is not supported."; // Create islands per replica. llvm::SmallVector replicas = diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc index 84c527d18ed..e01055916ce 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc @@ -236,6 +236,40 @@ bool PassThroughOperandTypes(OperandRange operands, ResultRange results) { return changed; } +// Infers the shape from a (Stateful)PartionedCall operation by looking up the +// called function and propagating the return type. +bool InferShapeForCall(Operation* op) { + auto call_op = cast(op); + CallInterfaceCallable callable = call_op.getCallableForCallee(); + SymbolRefAttr sym = callable.dyn_cast(); + if (!sym) return false; + FuncOp func = + dyn_cast(SymbolTable::lookupNearestSymbolFrom(op, sym)); + if (!func) return false; + + bool changed = false; + // Map each of the results of the call to the returned type of the + // function. + for (auto result : llvm::zip(op->getResults(), func.getType().getResults())) { + if (std::get<0>(result).getType() == std::get<1>(result)) continue; + // Skip already statically shaped results. + auto shaped_type = std::get<0>(result).getType().dyn_cast(); + if (!shaped_type || shaped_type.hasStaticShape()) continue; + + auto new_type = std::get<1>(result).dyn_cast(); + if (!new_type) continue; + + // Inserts a cast back to the original type if any user is not in the + // TF dialect. + AddCastBackForUnsupportedNonTFUses(op, std::get<0>(result), + op->getDialect(), shaped_type); + // Finally we inferred the shape and replace the type for this result. + std::get<0>(result).setType(new_type); + changed = true; + } + return changed; +} + } // namespace bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect, @@ -264,6 +298,11 @@ bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect, return false; } + // Handle call operations by looking up callee and infering return shape as + // needed. + if (isa(op) || isa(op)) + return InferShapeForCall(op); + // tf.Cast are only inferred if they have at least one user in the tf dialect. // This is necessary to avoid reprocessing the tf.Cast that are inserted at // the end of this function. @@ -438,7 +477,7 @@ LogicalResult RefineShapeForControlFlowFunc(FuncOp func, auto func_uses = SymbolTable::getSymbolUses(func, &module.getBodyRegion()); int num_uses = std::distance(func_uses->begin(), func_uses->end()); if (num_uses != 1) { - func.emitError(llvm::formatv( + func.emitWarning(llvm::formatv( "expected control flow function {0} to have exactly 1 use, found {1}.", func.getName(), num_uses)); return failure(); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tensor_array_ops_decomposition.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tensor_array_ops_decomposition.cc index 59dab25c15c..b7efc5aa64b 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tensor_array_ops_decomposition.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tensor_array_ops_decomposition.cc @@ -19,7 +19,8 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringExtras.h" -#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/StringMap.h" +#include "llvm/ADT/StringSet.h" #include "llvm/Support/Casting.h" #include "llvm/Support/FormatVariadic.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project @@ -55,6 +56,8 @@ namespace { namespace cutil = TF::collection_ops_util; +using std::string; + // A pass that converts tensor array operations to tensor operations and // read/assign ops on local variables. A later resource lifting pass can further // remove the local variables. @@ -85,7 +88,7 @@ LogicalResult GetSplitElementTypeAndCount(TF::TensorArraySplitV3Op split, return split.emitOpError("unknown or invalid split tensor shape"); } int64_t length = buffer_type.getDimSize(0) / *count; - for (auto len : lengths_const.value().getValues()) { + for (const auto& len : lengths_const.value().getValues()) { if (length == len.getSExtValue()) continue; return split.emitOpError("different split lengths are not supported"); } @@ -145,7 +148,7 @@ struct TensorArrayStats { // this is a gradient. bool accumulate_on_write; // Maps from a gradient source string to the local variable to the gradient. - llvm::SmallDenseMap grads; + llvm::StringMap grads; }; LogicalResult HandleTensorArrayV3Op( @@ -224,10 +227,7 @@ LogicalResult HandleTensorArrayWriteV3Op( cutil::GetElement(index_reshape, buffer, builder, write.getLoc(), /*keep_slice_shape=*/true); // Add a size-1 leading dimension to elem. - for (auto dim : buffer.getType().cast().getShape()) - LOG(ERROR) << " buffer : " << dim; auto slice_type = original_elem.getType().cast(); - for (auto dim : slice_type.getShape()) LOG(ERROR) << " resahpe : " << dim; elem = builder.create( write.getLoc(), ArrayRef{slice_type}, ArrayRef{elem, cutil::GetR1Const(slice_type.getShape(), builder, @@ -339,6 +339,26 @@ LogicalResult HandleTensorArraySizeV3Op( return success(); } +LogicalResult CreateAndInitializeGradVariable(Type local_var_type, + Operation* op, Value* var) { + OpBuilder builder(op); + *var = builder.create( + op->getLoc(), ArrayRef{local_var_type}, ArrayRef{}, + ArrayRef{}); + Value buffer; + auto buffer_type = getElementTypeOrSelf(local_var_type) + .cast() + .getSubtypes()[0] + .cast(); + if (failed(cutil::CreateInitBufferValue( + buffer_type.getShape().drop_front(), buffer_type.getDimSize(0), op, + buffer_type.getElementType(), builder, &buffer))) { + return failure(); + } + cutil::WriteLocalVariable(*var, buffer, builder, op->getLoc()); + return success(); +} + LogicalResult HandleTensorArrayGradV3Op( TF::TensorArrayGradV3Op grad, llvm::SmallDenseMap* stats) { @@ -347,26 +367,17 @@ LogicalResult HandleTensorArrayGradV3Op( Value grad_var; auto sit = stats->find(local_var); if (sit == stats->end()) return grad.emitOpError("unknown tensor array"); - auto emplace_res = sit->getSecond().grads.try_emplace(grad.source(), Value()); + auto emplace_res = + sit->getSecond().grads.try_emplace(grad.source().str(), Value()); if (!emplace_res.second) { // If the source has been assigned a grad, use it. - grad_var = emplace_res.first->getSecond(); + grad_var = emplace_res.first->second; } else { - grad_var = builder.create( - grad.getLoc(), ArrayRef{local_var.getType()}, ArrayRef{}, - ArrayRef{}); - Value buffer; - auto buffer_type = getElementTypeOrSelf(local_var.getType()) - .cast() - .getSubtypes()[0] - .cast(); - if (failed(cutil::CreateInitBufferValue( - buffer_type.getShape().drop_front(), buffer_type.getDimSize(0), - grad, buffer_type.getElementType(), builder, &buffer))) { + if (failed(CreateAndInitializeGradVariable(local_var.getType(), grad, + &grad_var))) { return failure(); } - cutil::WriteLocalVariable(grad_var, buffer, builder, grad.getLoc()); - emplace_res.first->getSecond() = grad_var; + emplace_res.first->second = grad_var; // Write to a grad accumulates with previous writes. (*stats)[grad_var].accumulate_on_write = true; } @@ -409,36 +420,454 @@ LogicalResult HandleTensorArrayScatterV3Op( return success(); } -LogicalResult DecomposeTensorArrayOps(Block* block, ModuleOp module) { - llvm::SmallDenseMap stats; +// Updates func's type according to its current arguments and return values. +void UpdateFuncType(FuncOp func) { + llvm::SmallVector arg_types; + for (auto arg : func.getArguments()) arg_types.push_back(arg.getType()); + func.setType(FunctionType::get( + arg_types, + llvm::to_vector<8>(func.front().getTerminator()->getOperandTypes()), + func.getContext())); +} + +// Finds the accessed gradient sources for each tensor array argument. +llvm::SmallDenseMap> AccessedGradients( + ArrayRef funcs, ModuleOp module) { + llvm::SmallDenseMap> result; + llvm::SmallDenseMap> result_sets; + auto insert = [&](Value v, const string& source) { + auto arg = v.cast(); + if (!arg) return; + auto insert_res = result_sets[arg.getArgNumber()].insert(source); + if (!insert_res.second) return; + result[arg.getArgNumber()].push_back(source); + }; + for (FuncOp func : funcs) { + for (auto& op : func.front().getOperations()) { + if (llvm::isa(&op) || llvm::isa(&op)) { + op.replaceAllUsesWith(op.getOperands()); + continue; + } + if (auto grad = llvm::dyn_cast(&op)) { + insert(grad.handle(), grad.source().str()); + } else if (auto while_op = llvm::dyn_cast(&op)) { + auto body = module.lookupSymbol(while_op.body()); + auto cond = module.lookupSymbol(while_op.cond()); + for (const auto& entry : AccessedGradients({body, cond}, module)) { + for (const string& source : entry.getSecond()) { + insert(while_op.getOperand(entry.getFirst()), source); + } + } + } else if (auto if_op = llvm::dyn_cast(&op)) { + auto then_branch = module.lookupSymbol(if_op.then_branch()); + auto else_branch = module.lookupSymbol(if_op.else_branch()); + for (const auto& entry : + AccessedGradients({then_branch, else_branch}, module)) { + for (const string& source : entry.getSecond()) { + insert(if_op.getOperand(entry.getFirst() + 1), source); + } + } + } else if (auto pc = llvm::dyn_cast(&op)) { + if (!pc.f().isa()) continue; + auto callee = module.lookupSymbol(pc.f().getRootReference()); + for (const auto& entry : AccessedGradients({callee}, module)) { + for (const string& source : entry.getSecond()) { + insert(pc.getOperand(entry.getFirst()), source); + } + } + } else if (auto spc = + llvm::dyn_cast(&op)) { + auto callee = module.lookupSymbol(spc.f()); + for (const auto& entry : AccessedGradients({callee}, module)) { + for (const string& source : entry.getSecond()) { + insert(spc.getOperand(entry.getFirst()), source); + } + } + } + } + } + return result; +} + +// Contains cached information for decomposed callee functions for (stateful) +// partitioned call ops. +struct PartitionedCallTensorArrayOpsInfo { + bool signature_change; + FuncOp decomposed_callee; + llvm::SmallVector>, 4> + arg_grads; + llvm::SmallVector, 4> ret_forward_input; +}; + +// Updates a called function's input signature by adjusting resource types, and +// adding required gradient arguments. +void ChangeFunctionInputSignature( + FuncOp func, + const llvm::SmallDenseMap>& grads, + llvm::function_ref ta_arg_buffer_type, + llvm::function_ref ta_accumulate_on_write, + llvm::SmallDenseMap* stats) { + int64_t original_args = func.getNumArguments(); + for (int64_t argnum = 0; argnum < original_args; ++argnum) { + auto arg = func.getArgument(argnum); + Type t = ta_arg_buffer_type(argnum); + if (!t) continue; + arg.setType(t); + auto grad_it = grads.find(argnum); + if (grad_it == grads.end()) continue; + llvm::StringMap grads_map; + for (const string& source : grad_it->getSecond()) { + auto g = func.front().addArgument(t); + (*stats)[g].accumulate_on_write = true; + grads_map[source] = g; + } + auto& stat = (*stats)[arg]; + stat.accumulate_on_write = ta_accumulate_on_write(argnum); + stat.grads = std::move(grads_map); + } + UpdateFuncType(func); +} + +LogicalResult DecomposeTensorArrayOps( + Block*, ModuleOp, llvm::SmallDenseMap*, + llvm::SmallDenseMap*); + +LogicalResult HandleWhileOp( + TF::WhileOp while_op, ModuleOp module, + llvm::SmallDenseMap* stats, + llvm::SmallDenseMap* + decomposed_partitioned_call_callees) { + auto body = module.lookupSymbol(while_op.body()); + auto cond = module.lookupSymbol(while_op.cond()); + auto grads = AccessedGradients({body, cond}, module); + auto ta_arg_buffer_type = [&](int64_t index) -> Type { + auto it = stats->find(while_op.getOperand(index)); + if (it == stats->end()) return nullptr; + return it->getFirst().getType(); + }; + auto ta_accumulate_on_write = [&](int64_t index) { + auto it = stats->find(while_op.getOperand(index)); + if (it == stats->end()) return false; + return it->getSecond().accumulate_on_write; + }; + llvm::SmallDenseMap body_stats; + ChangeFunctionInputSignature(body, grads, ta_arg_buffer_type, + ta_accumulate_on_write, &body_stats); + llvm::SmallDenseMap cond_stats; + ChangeFunctionInputSignature(cond, grads, ta_arg_buffer_type, + ta_accumulate_on_write, &cond_stats); + if (failed(DecomposeTensorArrayOps(&body.front(), module, &body_stats, + decomposed_partitioned_call_callees)) || + failed(DecomposeTensorArrayOps(&cond.front(), module, &cond_stats, + decomposed_partitioned_call_callees))) { + return failure(); + } + if (body_stats.empty() && cond_stats.empty()) return success(); + auto old_body_ret = body.front().getTerminator(); + auto new_retvals = llvm::to_vector<8>(old_body_ret->getOperands()); + for (int64_t i = 0; i < while_op.getNumResults(); ++i) { + if (!ta_arg_buffer_type(i)) continue; + auto retval = old_body_ret->getOperand(i); + auto arg = retval.dyn_cast(); + if (!arg) { + return while_op.emitOpError( + "output tensor array does not alias input in a while loop"); + } + for (const string& source : grads[i]) { + new_retvals.push_back(body_stats[arg].grads[source]); + } + } + OpBuilder(old_body_ret).create(old_body_ret->getLoc(), new_retvals); + old_body_ret->erase(); + UpdateFuncType(body); + // Recreate the while op. + auto operands = llvm::to_vector<8>(while_op.getOperands()); + for (int64_t i = 0; i < while_op.getNumOperands(); ++i) { + auto grad_it = grads.find(i); + auto& stat = (*stats)[operands[i]]; + if (grad_it == grads.end()) continue; + for (const string& source : grad_it->getSecond()) { + auto it = stat.grads.find(source); + if (it != stat.grads.end()) { + operands.push_back(it->second); + } else { + Value grad_var; + if (failed(CreateAndInitializeGradVariable(operands[i].getType(), + while_op, &grad_var))) { + return failure(); + } + stat.grads[source] = grad_var; + operands.push_back(grad_var); + } + } + } + OpBuilder builder(while_op); + auto new_while = + builder.create(while_op.getLoc(), body.getType().getInputs(), + operands, while_op.getAttrs()); + // Clear the output shapes as it is not needed for XLA lowering. + new_while.setAttr("output_shapes", builder.getArrayAttr({})); + for (int64_t i = 0; i < while_op.getNumOperands(); ++i) { + if (ta_arg_buffer_type(i)) { + while_op.getResult(i).replaceAllUsesWith(while_op.getOperand(i)); + } else { + while_op.getResult(i).replaceAllUsesWith(new_while.getResult(i)); + } + } + while_op.erase(); + return success(); +} + +LogicalResult HandleIfOp( + TF::IfOp if_op, ModuleOp module, + llvm::SmallDenseMap* stats, + llvm::SmallDenseMap* + decomposed_partitioned_call_callees) { + auto then_branch = module.lookupSymbol(if_op.then_branch()); + auto else_branch = module.lookupSymbol(if_op.else_branch()); + auto grads = AccessedGradients({then_branch, else_branch}, module); + auto ta_arg_buffer_type = [&](int64_t index) -> Type { + auto it = stats->find(if_op.getOperand(index + 1)); + if (it == stats->end()) return nullptr; + return it->getFirst().getType(); + }; + auto ta_accumulate_on_write = [&](int64_t index) { + auto it = stats->find(if_op.getOperand(index + 1)); + if (it == stats->end()) return false; + return it->getSecond().accumulate_on_write; + }; + llvm::SmallDenseMap then_stats; + ChangeFunctionInputSignature(then_branch, grads, ta_arg_buffer_type, + ta_accumulate_on_write, &then_stats); + llvm::SmallDenseMap else_stats; + ChangeFunctionInputSignature(else_branch, grads, ta_arg_buffer_type, + ta_accumulate_on_write, &else_stats); + if (failed(DecomposeTensorArrayOps(&then_branch.front(), module, &then_stats, + decomposed_partitioned_call_callees)) || + failed(DecomposeTensorArrayOps(&else_branch.front(), module, &else_stats, + decomposed_partitioned_call_callees))) { + return failure(); + } + if (then_stats.empty() && else_stats.empty()) return success(); + // Recreate the if op. + auto operands = llvm::to_vector<8>(if_op.getOperands()); + for (int64_t i = 0; i < if_op.getNumOperands() - 1; ++i) { + auto grad_it = grads.find(i); + auto& stat = (*stats)[operands[i + 1]]; + if (grad_it == grads.end()) continue; + for (const string& source : grad_it->getSecond()) { + auto it = stat.grads.find(source); + if (it != stat.grads.end()) { + operands.push_back(it->second); + } else { + Value grad_var; + if (failed(CreateAndInitializeGradVariable(operands[i + 1].getType(), + if_op, &grad_var))) { + return failure(); + } + stat.grads[source] = grad_var; + operands.push_back(grad_var); + } + } + } + OpBuilder builder(if_op); + auto new_if = builder.create(if_op.getLoc(), + then_branch.getType().getResults(), + operands, if_op.getAttrs()); + // Clear the output shapes as it is not needed for XLA lowering. + new_if.setAttr("output_shapes", builder.getArrayAttr({})); + auto ret_forwards_input = [](FuncOp f, int64_t ret_ind) -> int64_t { + auto retval = f.front().getTerminator()->getOperand(ret_ind); + auto arg = retval.dyn_cast(); + if (!arg) return -1; + return arg.getArgNumber(); + }; + for (int64_t i = 0; i < if_op.getNumResults(); ++i) { + if (!getElementTypeOrSelf(if_op.getResult(i).getType()) + .isa()) { + if_op.getResult(i).replaceAllUsesWith(new_if.getResult(i)); + continue; + } + int64_t then_forward_input = ret_forwards_input(then_branch, i); + int64_t else_foward_input = ret_forwards_input(else_branch, i); + if (then_forward_input != else_foward_input || then_forward_input < 0) { + return if_op.emitOpError( + "branches do not forward the same input resource"); + } + if_op.getResult(i).replaceAllUsesWith( + if_op.getOperand(then_forward_input + 1)); + } + if_op.erase(); + return success(); +} + +template +LogicalResult HandlePartitionedCallOp( + CallOp call, FuncOp callee, ModuleOp module, + llvm::SmallDenseMap* stats, + llvm::SmallDenseMap* + decomposed_partitioned_call_callees) { + auto emplace_res = decomposed_partitioned_call_callees->try_emplace( + callee, PartitionedCallTensorArrayOpsInfo()); + auto& info = emplace_res.first->getSecond(); + // Recreates the call op with info. + auto recreate_caller = [&]() -> LogicalResult { + auto new_operands = llvm::to_vector<8>(call.getOperands()); + for (const auto& entry : info.arg_grads) { + auto it = stats->find(call.getOperand(entry.first)); + if (it == stats->end()) return call.emitOpError("unknown tensor array"); + for (const string& source : entry.second) { + auto grad_it = it->getSecond().grads.find(source); + if (grad_it != it->getSecond().grads.end()) { + new_operands.push_back(grad_it->second); + } else { + Value grad_var; + if (failed(CreateAndInitializeGradVariable(it->getFirst().getType(), + call, &grad_var))) { + return failure(); + } + it->getSecond().grads[source] = grad_var; + new_operands.push_back(grad_var); + } + } + } + OpBuilder builder(call); + auto new_call = builder.create( + call.getLoc(), info.decomposed_callee.getType().getResults(), + new_operands, call.getAttrs()); + new_call.setAttr( + "f", builder.getSymbolRefAttr( + const_cast(info.decomposed_callee).getName())); + for (const auto& entry : info.ret_forward_input) { + call.getResult(entry.first) + .replaceAllUsesWith(call.getOperand(entry.second)); + } + call.replaceAllUsesWith(new_call); + call.erase(); + return success(); + }; + if (!emplace_res.second) { + // This callee was handled before. + if (!info.signature_change) return success(); + return recreate_caller(); + } + // Rewrite the callee on a cloned function. + info.signature_change = false; + auto ta_arg_buffer_type = [&](int64_t index) -> Type { + auto it = stats->find(call.getOperand(index)); + if (it == stats->end()) return nullptr; + info.signature_change = true; + return it->getFirst().getType(); + }; + auto ta_accumulate_on_write = [&](int64_t index) { + auto it = stats->find(call.getOperand(index)); + if (it == stats->end()) return false; + return it->getSecond().accumulate_on_write; + }; + auto callee_clone = callee.clone(); + auto grads = AccessedGradients({callee_clone}, module); + for (int64_t i = 0; i < callee_clone.getNumArguments(); ++i) { + auto it = grads.find(i); + if (it == grads.end()) continue; + info.arg_grads.emplace_back(i, it->getSecond()); + } + llvm::SmallDenseMap callee_stats; + ChangeFunctionInputSignature(callee_clone, grads, ta_arg_buffer_type, + ta_accumulate_on_write, &callee_stats); + if (failed(DecomposeTensorArrayOps(&callee_clone.front(), module, + &callee_stats, + decomposed_partitioned_call_callees))) { + return failure(); + } + for (int64_t i = 0; i < call.getNumResults(); ++i) { + auto ret = callee_clone.front().getTerminator()->getOperand(i); + if (!getElementTypeOrSelf(ret.getType()).isa()) continue; + auto arg = ret.dyn_cast(); + if (!arg) continue; + info.ret_forward_input.emplace_back(i, arg.getArgNumber()); + } + + if (!info.signature_change) { + // Signature is not modified. We do not need to keep two copies. + info.signature_change = false; + auto name = callee.getName(); + callee.erase(); + callee_clone.setName(name); + SymbolTable(module).insert(callee_clone); + } else { + info.decomposed_callee = callee_clone; + // Add the clone with a new name. + auto name = + llvm::formatv("{0}_{1}", callee.getName(), "tensorarray_decomposed") + .str(); + callee_clone.setName(name); + SymbolTable(module).insert(callee_clone); + } + if (info.signature_change) return recreate_caller(); + return success(); +} + +LogicalResult DecomposeTensorArrayOps( + Block* block, ModuleOp module, + llvm::SmallDenseMap* stats, + llvm::SmallDenseMap* + decomposed_partitioned_call_callees) { for (auto& op : llvm::make_early_inc_range(block->getOperations())) { if (llvm::isa(&op) || llvm::isa(&op)) { op.replaceAllUsesWith(op.getOperands()); op.erase(); } else if (auto ta = llvm::dyn_cast(&op)) { - if (failed(HandleTensorArrayV3Op(ta, module, &stats))) { + if (failed(HandleTensorArrayV3Op(ta, module, stats))) { return failure(); } } else if (auto read = llvm::dyn_cast(&op)) { - if (failed(HandleTensorArrayReadV3Op(read, stats))) return failure(); + if (failed(HandleTensorArrayReadV3Op(read, *stats))) return failure(); } else if (auto write = llvm::dyn_cast(&op)) { - if (failed(HandleTensorArrayWriteV3Op(write, stats))) return failure(); + if (failed(HandleTensorArrayWriteV3Op(write, *stats))) return failure(); } else if (auto concat = llvm::dyn_cast(&op)) { - if (failed(HandleTensorArrayConcatV3Op(concat, stats))) return failure(); + if (failed(HandleTensorArrayConcatV3Op(concat, *stats))) return failure(); } else if (auto split = llvm::dyn_cast(&op)) { - if (failed(HandleTensorArraySplitV3Op(split, stats))) return failure(); + if (failed(HandleTensorArraySplitV3Op(split, *stats))) return failure(); } else if (auto size = llvm::dyn_cast(&op)) { - if (failed(HandleTensorArraySizeV3Op(size, stats))) return failure(); + if (failed(HandleTensorArraySizeV3Op(size, *stats))) return failure(); } else if (auto grad = llvm::dyn_cast(&op)) { - if (failed(HandleTensorArrayGradV3Op(grad, &stats))) return failure(); + if (failed(HandleTensorArrayGradV3Op(grad, stats))) return failure(); } else if (auto gather = llvm::dyn_cast(&op)) { - if (failed(HandleTensorArrayGatherV3Op(gather, stats))) return failure(); + if (failed(HandleTensorArrayGatherV3Op(gather, *stats))) return failure(); } else if (auto scatter = llvm::dyn_cast(&op)) { - if (failed(HandleTensorArrayScatterV3Op(scatter, stats))) { + if (failed(HandleTensorArrayScatterV3Op(scatter, *stats))) { return failure(); } } else if (auto close = llvm::dyn_cast(&op)) { close.erase(); + } else if (auto while_op = llvm::dyn_cast(&op)) { + if (failed(HandleWhileOp(while_op, module, stats, + decomposed_partitioned_call_callees))) { + return failure(); + } + } else if (auto if_op = llvm::dyn_cast(&op)) { + if (failed(HandleIfOp(if_op, module, stats, + decomposed_partitioned_call_callees))) { + return failure(); + } + } else if (auto pcall = llvm::dyn_cast(&op)) { + if (!pcall.f().isa()) { + return pcall.emitOpError( + "TensorArray decomposition does not support call with nested " + "references."); + } + if (failed(HandlePartitionedCallOp( + pcall, module.lookupSymbol(pcall.f().getRootReference()), + module, stats, decomposed_partitioned_call_callees))) { + return failure(); + } + } else if (auto spcall = + llvm::dyn_cast(&op)) { + if (failed(HandlePartitionedCallOp( + spcall, module.lookupSymbol(spcall.f()), module, stats, + decomposed_partitioned_call_callees))) { + return failure(); + } } } return success(); @@ -448,7 +877,11 @@ void TensorArrayOpsDecompositionPass::runOnModule() { auto module = getModule(); auto main = module.lookupSymbol("main"); if (!main) return; - if (failed(DecomposeTensorArrayOps(&main.front(), module))) { + llvm::SmallDenseMap stats; + llvm::SmallDenseMap + decomposed_partitioned_call_callees; + if (failed(DecomposeTensorArrayOps(&main.front(), module, &stats, + &decomposed_partitioned_call_callees))) { signalPassFailure(); } } diff --git a/tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.cc b/tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.cc index 114a03cc45d..a47fe2232c3 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.cc @@ -194,6 +194,25 @@ StatusOr> ConvertTFDialectOpToNodeDef( "When populating derived attrs for ", inst->getName().getStringRef().str()); } + + // If the instruction is in the TF dialect, the code above already filtered + // results with control types. Here we only add the shapes for the leading + // values with ShapedType, assuming values with non-ShapedType are put at the + // end of the result. + if (!ignore_unregistered_attrs && inst->getNumResults() > 0) { + auto values = inst->getResults(); + auto begin = values.begin(); + auto end = values.begin(); + while (end != values.end() && (*end).getType().isa()) + end++; + if (begin != end) { + mlir::TF::ResultShapeRange output_shapes = { + mlir::TF::ResultShapeIterator(begin), + mlir::TF::ResultShapeIterator(end)}; + TF_RETURN_IF_ERROR(SetShapeAttribute("_output_shapes", output_shapes, + node_def->mutable_attr())); + } + } return node_def; } diff --git a/tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h b/tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h index e6657ebc8dd..a19ad1f2940 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h +++ b/tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h @@ -26,8 +26,11 @@ namespace tensorflow { // Converts an MLIR operation to TensorFlow NodeDef with given node name. This // name should be unique to the graph it is being inserted to. If the // `ignore_unregistered_attrs` argument is set to true, the attributes which are -// not in the op registry will be ignored. Set it to true if the returned -// NodeDef will be executed by the linked TF Eager runtime. +// not in the op registry will be ignored. If the `ignore_unregistered_attrs` +// argument is not set to true, _output_shapes attribute is added to nodes with +// ShapedType for the leading values with ShapedType in the results of the +// nodes. Set it to true if the returned NodeDef will be executed by the linked +// TF Eager runtime. stream_executor::port::StatusOr> ConvertTFDialectOpToNodeDef(mlir::Operation* inst, llvm::StringRef name, bool ignore_unregistered_attrs); diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc index 26c2f89d3ad..155995a4f65 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc @@ -53,6 +53,7 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/Module.h" // from @llvm-project #include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/OperationSupport.h" // from @llvm-project #include "mlir/IR/StandardTypes.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project #include "tensorflow/compiler/jit/shape_inference_helpers.h" @@ -110,8 +111,17 @@ using stream_executor::port::StatusOr; namespace { -const char* disable_call_shape_inference_attribute_name = - "_disable_call_shape_inference"; +bool IsDisableCallShapeInferenceAttribute(const AttrValue& attr_value, + llvm::StringRef attr_name) { + return attr_name.compare("_disable_call_shape_inference") == 0 && + attr_value.value_case() == AttrValue::kB; +} + +bool IsOutputShapesAttribute(const AttrValue& attr_value, + llvm::StringRef attr_name) { + return attr_name.compare("_output_shapes") == 0 && + attr_value.value_case() == AttrValue::kList; +} // This class is used to generate new MLIR function name strings that are both // unique in the TF function library `flib_` and unique among the name strings @@ -1451,9 +1461,7 @@ mlir::Operation* ImporterBase::createOperation( for (const auto& name_and_value : node.attrs()) { const auto& attr_name = name_and_value.first; const AttrValue& attr_value = name_and_value.second; - if (strcmp(attr_name.c_str(), - disable_call_shape_inference_attribute_name) == 0 && - attr_value.value_case() == AttrValue::kB) { + if (IsDisableCallShapeInferenceAttribute(attr_value, attr_name)) { disable_call_shape_inference = attr_value.b(); } } @@ -1596,15 +1604,25 @@ Status ImporterBase::ConvertNode(const Node& node) { using FuncPairType = std::pair; std::vector funcs; result.attributes.reserve(node.attrs().size() + 2); + auto abstract_op = result.name.getAbstractOperation(); + auto derived_op = + abstract_op + ? abstract_op->getInterface() + : nullptr; for (const auto& name_and_value : node.attrs()) { const auto& attr_name = name_and_value.first; + // Skip adding derived attributes to the generated op. + if (derived_op && derived_op->isDerivedAttribute(attr_name)) continue; const AttrValue& attr_value = name_and_value.second; - // LegacyCall can only represent _diable_call_shape_inference attribute. - // If a call has other attributes, can't convert it to LegacyCall. + + // Remove _output_shapes attribute that will be added by the exporter. + if (IsOutputShapesAttribute(attr_value, attr_name)) continue; + + // We represent the _diable_call_shape_inference attribute and remove + // the _output_shapes attribute for LegacyCall. If a call has other + // attributes, we can't convert it to LegacyCall. if (convert_to_legacy_call && - (strcmp(attr_name.c_str(), - disable_call_shape_inference_attribute_name) || - attr_value.value_case() != AttrValue::kB)) { + !IsDisableCallShapeInferenceAttribute(attr_value, attr_name)) { convert_to_legacy_call = false; } if (attr_value.value_case() == AttrValue::kFunc) { diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc index 3fd711b9ef8..3e250ec287b 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h" +#include "absl/types/optional.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/StringRef.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project @@ -43,6 +44,8 @@ limitations under the License. #include "tensorflow/compiler/mlir/xla/transforms/passes.h" #include "tensorflow/compiler/mlir/xla/type_to_shape.h" #include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/xla/service/hlo_sharding.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/logging.h" namespace tensorflow { @@ -74,7 +77,7 @@ Status ParseMlirModule(llvm::StringRef mlir_module_string, Status GetXlaInputShapes( mlir::ModuleOp module, llvm::ArrayRef arg_shapes, bool use_tuple_args, - const xla::CustomShapeRepresentationFn shape_representation_fn, + const XlaCompiler::ShapeRepresentationFn shape_representation_fn, std::vector* xla_input_shapes) { xla_input_shapes->clear(); @@ -93,7 +96,24 @@ Status GetXlaInputShapes( DataType dtype; TF_RETURN_IF_ERROR(ConvertToDataType(func_type.getInput(i), &dtype)); TF_ASSIGN_OR_RETURN(xla_shape, - shape_representation_fn(arg_shapes[i], dtype)); + shape_representation_fn(arg_shapes[i], dtype, + /*use_fast_memory=*/false)); + + // Rewrite layout with sharding, if sharding is set. + auto sharding = + main_func.getArgAttrOfType(i, "xla_hlo.sharding"); + if (!sharding) continue; + + absl::optional arg_sharding; + xla::OpSharding op_sharding; + if (!op_sharding.ParseFromString(sharding.getValue().str())) + return errors::InvalidArgument("failed to parse argument sharding ", i, + " '", sharding.getValue().str(), "'"); + + TF_ASSIGN_OR_RETURN(arg_sharding, xla::HloSharding::FromProto(op_sharding)); + TF_RETURN_IF_ERROR( + RewriteLayoutWithShardedShape(arg_sharding, /*use_fast_memory=*/false, + shape_representation_fn, &xla_shape)); } if (use_tuple_args) { xla_input_shapes->push_back( @@ -108,9 +128,14 @@ Status GetXlaInputShapes( // output based on static shapes in MLIR module Status GetOutputInfo( mlir::ModuleOp module, - const xla::CustomShapeRepresentationFn shape_representation_fn, + const XlaCompiler::ShapeRepresentationFn shape_representation_fn, xla::Shape* xla_output_shape, std::vector* outputs) { + auto shape_representation_fn_no_fast_memory = + [shape_representation_fn](const TensorShape& shape, DataType dtype) { + return shape_representation_fn(shape, dtype, /*use_fast_memory=*/false); + }; + mlir::FuncOp main_func = module.lookupSymbol("main"); mlir::FunctionType func_type = main_func.getType(); @@ -121,8 +146,9 @@ Status GetOutputInfo( shapes.reserve(func_type.getNumResults()); for (mlir::Type type : func_type.getResults()) { - TF_ASSIGN_OR_RETURN(xla::Shape shape, - TypeToShape(type, shape_representation_fn)); + TF_ASSIGN_OR_RETURN( + xla::Shape shape, + xla::TypeToShape(type, shape_representation_fn_no_fast_memory)); auto tensor_type = type.dyn_cast(); shapes.push_back(shape); @@ -225,16 +251,17 @@ static void RegisterDialects() { (void)init_once; } -} // namespace -// namespace +} // namespace -Status ConvertMLIRToXlaComputation(mlir::ModuleOp module_op, - xla::XlaComputation* xla_computation, - bool use_tuple_args, bool return_tuple) { +Status ConvertMLIRToXlaComputation( + mlir::ModuleOp module_op, xla::XlaComputation* xla_computation, + bool use_tuple_args, bool return_tuple, + const XlaCompiler::ShapeRepresentationFn shape_representation_fn) { mlir::PassManager tf2xla(module_op.getContext()); tf2xla.addNestedPass(mlir::createCanonicalizerPass()); tf2xla.addPass(mlir::TF::CreateTensorListOpsDecompositionPass()); tf2xla.addPass(mlir::TF::CreateStackOpsDecompositionPass()); + tf2xla.addPass(mlir::TF::CreateTensorArrayOpsDecompositionPass()); tf2xla.addPass(mlir::TFDevice::CreateDecomposeResourceOpsPass()); tf2xla.addPass(mlir::TF::CreatePromoteResourcesToArgsPass()); // LegalizeTFControlFlow encapsulates arguments for control flow operations @@ -273,7 +300,8 @@ Status ConvertMLIRToXlaComputation(mlir::ModuleOp module_op, xla::HloProto hlo_proto; TF_RETURN_IF_ERROR(mlir::ConvertMlirHloToHlo(module_op, &hlo_proto, - use_tuple_args, return_tuple)); + use_tuple_args, return_tuple, + shape_representation_fn)); *xla_computation = xla::XlaComputation(hlo_proto.hlo_module()); return Status::OK(); } @@ -281,7 +309,7 @@ Status ConvertMLIRToXlaComputation(mlir::ModuleOp module_op, static Status CompileMlirToXlaHlo( mlir::ModuleOp module_op, llvm::ArrayRef arg_shapes, bool use_tuple_args, - const XlaCompiler::ShapeRepresentationFn shape_representation_fn, + XlaCompiler::ShapeRepresentationFn shape_representation_fn, XlaCompiler::CompilationResult* compilation_result) { if (VLOG_IS_ON(1)) tensorflow::DumpMlirOpToFile("mlir_compile_before", module_op); @@ -292,35 +320,28 @@ static Status CompileMlirToXlaHlo( if (VLOG_IS_ON(1)) tensorflow::DumpMlirOpToFile("mlir_compile_shape_refiner", module_op); + if (!shape_representation_fn) + shape_representation_fn = IdentityShapeRepresentationFn(); + // Convert MLIR module to XLA HLO proto contained in XlaComputation. compilation_result->computation = std::make_shared(); TF_RETURN_IF_ERROR(ConvertMLIRToXlaComputation( module_op, compilation_result->computation.get(), use_tuple_args, - /*return_tuple=*/true)); + /*return_tuple=*/true, shape_representation_fn)); // Construct mapping from XlaComputation's arg to input edges of execute // node. GetInputMappingForMlir(arg_shapes.size(), &compilation_result->input_mapping); - auto shape_representation_fn_no_fast_memory = - [shape_representation_fn](const TensorShape& shape, - DataType dtype) -> StatusOr { - if (shape_representation_fn) - return shape_representation_fn(shape, dtype, /*use_fast_memory=*/false); - xla::Shape xla_shape; - TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype, shape, &xla_shape)); - return xla_shape; - }; - // Compute all input shapes. TF_RETURN_IF_ERROR(GetXlaInputShapes(module_op, arg_shapes, use_tuple_args, - shape_representation_fn_no_fast_memory, + shape_representation_fn, &compilation_result->xla_input_shapes)); // Compute all output descriptions. - TF_RETURN_IF_ERROR(GetOutputInfo( - module_op, shape_representation_fn_no_fast_memory, - &compilation_result->xla_output_shape, &compilation_result->outputs)); + TF_RETURN_IF_ERROR(GetOutputInfo(module_op, shape_representation_fn, + &compilation_result->xla_output_shape, + &compilation_result->outputs)); // Compute what resource variables need to be updated after XlaComputation's // execution. diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h index 0dd4b8c5efe..2ce0a31eb78 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h @@ -43,9 +43,13 @@ namespace tensorflow { // entry computation. // return_tuple: when this is true, always create a tuple result for the // entry computation. -Status ConvertMLIRToXlaComputation(mlir::ModuleOp module_op, - xla::XlaComputation* xla_computation, - bool use_tuple_args, bool return_tuple); +// shape_representation_fn: when this is set, this shape representation function +// will be used to determine argument and result shapes. Otherwise the +// original shape will be used as is. +Status ConvertMLIRToXlaComputation( + mlir::ModuleOp module_op, xla::XlaComputation* xla_computation, + bool use_tuple_args, bool return_tuple, + const XlaCompiler::ShapeRepresentationFn shape_representation_fn = nullptr); // Compiles a serialized MLIR module into XLA HLO, generates all accompanying // metadata and stores them in CompilationResult. diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc index f65fcc1016d..d406934c520 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc @@ -40,7 +40,8 @@ xla::StatusOr TestShapeRepresentation(const TensorShape& shape, } TEST(CompileSerializedMlirToXlaHloTest, InvalidSerializedMlirModule) { - string invalid_mlir_module = "totally @invalid MLIR module {here} <-"; + constexpr char invalid_mlir_module[] = + "totally @invalid MLIR module {here} <-"; std::vector arg_shapes; XlaCompiler::CompilationResult compilation_result; @@ -49,7 +50,7 @@ TEST(CompileSerializedMlirToXlaHloTest, InvalidSerializedMlirModule) { /*use_tuple_args=*/true, TestShapeRepresentation, &compilation_result); EXPECT_EQ(s.code(), tensorflow::errors::Code::INVALID_ARGUMENT); EXPECT_EQ(s.ToString(), - "Invalid argument: could not parse MLIR module: error: " + "Invalid argument: could not parse MLIR module-:1:1: error: " "custom op 'totally' is unknown\n"); } @@ -76,7 +77,7 @@ TEST(CompileSerializedMlirToXlaHloTest, TupleArgs) { auto status_or_hlo_module = xla::HloModule::CreateFromProto( compilation_result.computation->proto(), module_config); TF_ASSERT_OK(status_or_hlo_module.status()); - string expected_hlo_module_string = R"(HloModule main.6 + constexpr char expected_hlo_module_string[] = R"(HloModule main.6 ENTRY %main.6 (arg_tuple.1: (f32[], f32[])) -> (f32[]) { %arg_tuple.1 = (f32[], f32[]) parameter(0) @@ -134,7 +135,7 @@ TEST(CompileSerializedMlirToXlaHloTest, IndividualArgs) { auto status_or_hlo_module = xla::HloModule::CreateFromProto( compilation_result.computation->proto(), module_config); TF_ASSERT_OK(status_or_hlo_module.status()); - string expected_hlo_module_string = R"(HloModule main.5 + constexpr char expected_hlo_module_string[] = R"(HloModule main.5 ENTRY %main.5 (Arg_0.1: f32[], Arg_1.2: f32[]) -> (f32[]) { %Arg_0.1 = f32[] parameter(0) @@ -181,7 +182,7 @@ ENTRY %main.5 (Arg_0.1: f32[], Arg_1.2: f32[]) -> (f32[]) { TEST(CompileSerializedMlirToXlaHloTest, CompileTimeConstantFoldedSuccess) { // "tf.Shape" can only be folded away after shape inference. tf.Reshape can // only be lowered when tf.Shape is folded into a constant. - string mlir_module = R"( + constexpr char mlir_module[] = R"( module attributes {tf.versions = {producer = 179 : i32}} { func @main(%arg0: tensor<10x19xf32>, %arg1: tensor<19x10xf32> {tf_device.is_same_data_across_replicas = true}) -> tensor<10x19xf32> { %0 = "tf.Shape"(%arg0) : (tensor<10x19xf32>) -> tensor<2xi64> @@ -205,7 +206,7 @@ TEST(CompileSerializedMlirToXlaHloTest, CompileTimeConstantFoldedSuccess) { auto status_or_hlo_module = xla::HloModule::CreateFromProto( compilation_result.computation->proto(), module_config); TF_ASSERT_OK(status_or_hlo_module.status()); - string expected_hlo_module_string = R"(HloModule main.6 + constexpr char expected_hlo_module_string[] = R"(HloModule main.6 ENTRY %main.6 (arg_tuple.1: (f32[10,19], f32[19,10])) -> (f32[10,19]) { %arg_tuple.1 = (f32[10,19]{1,0}, f32[19,10]{1,0}) parameter(0), parameter_replication={false,true} @@ -221,7 +222,7 @@ ENTRY %main.6 (arg_tuple.1: (f32[10,19], f32[19,10])) -> (f32[10,19]) { } TEST(CompileSerializedMlirToXlaHloTest, ShapeInference) { - string mlir_module = R"( + constexpr char mlir_module[] = R"( module attributes {tf.versions = {producer = 179 : i32}} { func @main(%arg0: tensor<*xf32>, %arg1: tensor) -> tensor { %0 = "tf.MatMul"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", transpose_a = false, transpose_b = false} : (tensor<*xf32>, tensor) -> tensor @@ -245,13 +246,14 @@ TEST(CompileSerializedMlirToXlaHloTest, ShapeInference) { compilation_result.computation->proto(), module_config); TF_ASSERT_OK(status_or_hlo_module.status()); - string expected_signature = + constexpr char expected_signature[] = R"((arg_tuple.1: (f32[10,17], f32[17,19])) -> (f32[10,19]))"; EXPECT_THAT(status_or_hlo_module.ValueOrDie()->ToString(), ::testing::HasSubstr(expected_signature)); } -constexpr llvm::StringRef kBroadcastGradientArgsModule = R"( +TEST(CompileSerializedMlirToXlaHloTest, ConstantFoldHook) { + constexpr char mlir_module[] = R"( module attributes {tf.versions = {producer = 179 : i32}} { func @main() -> (tensor<0xi32>, tensor<0xi32>) { %0 = "tf.Const"() {value = dense<[]> : tensor<0xi32>} : () -> tensor<0xi32> @@ -261,12 +263,11 @@ module attributes {tf.versions = {producer = 179 : i32}} { } )"; -TEST(CompileSerializedMlirToXlaHloTest, ConstantFoldHook) { std::vector arg_shapes(2, TensorShape()); XlaCompiler::CompilationResult compilation_result; Status s = CompileSerializedMlirToXlaHlo( - kBroadcastGradientArgsModule, arg_shapes, + mlir_module, arg_shapes, /*use_tuple_args=*/true, TestShapeRepresentation, &compilation_result); TF_ASSERT_OK(s); @@ -275,7 +276,7 @@ TEST(CompileSerializedMlirToXlaHloTest, ConstantFoldHook) { auto status_or_hlo_module = xla::HloModule::CreateFromProto( compilation_result.computation->proto(), module_config); TF_ASSERT_OK(status_or_hlo_module.status()); - string expected_hlo_module_string = R"(HloModule main.4 + constexpr char expected_hlo_module_string[] = R"(HloModule main.4 ENTRY %main.4 (arg_tuple.1: ()) -> (s32[0], s32[0]) { %arg_tuple.1 = () parameter(0) @@ -288,6 +289,128 @@ ENTRY %main.4 (arg_tuple.1: ()) -> (s32[0], s32[0]) { status_or_hlo_module.ValueOrDie()->ToString()); } +// The following xla::OpSharding protos are used: +// Serialized string: +// "\08\03\1A\02\01\02\22\02\00\01" +// Proto debug string: +// type: OTHER +// tile_assignment_dimensions: 1 +// tile_assignment_dimensions: 2 +// tile_assignment_devices: 0 +// tile_assignment_devices: 1 +// +// Serialized string: +// "\08\01\1A\01\01\22\01\00" +// Proto debug string: +// type: MAXIMAL +// tile_assignment_dimensions: 1 +// tile_assignment_devices: 0 +// +// Serialized string: +// "" +// Proto debug string (empty but would equivalent to): +// type: REPLICATED +TEST(CompileSerializedMlirToXlaHloTest, ArgumentSharding) { + constexpr char mlir_module[] = R"( +module attributes {tf.versions = {producer = 179 : i32}} { + func @main(%arg0: tensor<128x10xf32> {xla_hlo.sharding = "\08\03\1A\02\01\02\22\02\00\01"}, %arg1: tensor<10x1024xf32> {xla_hlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg2: tensor<128x1024xf32> {xla_hlo.sharding = ""}) { + return + } +} +)"; + + std::vector arg_shapes{TensorShape({128, 10}), + TensorShape({10, 1024}), + TensorShape({128, 1024})}; + XlaCompiler::CompilationResult compilation_result; + + Status s = CompileSerializedMlirToXlaHlo( + mlir_module, arg_shapes, + /*use_tuple_args=*/true, TestShapeRepresentation, &compilation_result); + TF_ASSERT_OK(s); + + const xla::HloModuleConfig module_config( + compilation_result.computation->GetProgramShape().ValueOrDie()); + auto status_or_hlo_module = xla::HloModule::CreateFromProto( + compilation_result.computation->proto(), module_config); + TF_ASSERT_OK(status_or_hlo_module.status()); + constexpr char expected_hlo_module_string[] = R"(HloModule main.6 + +ENTRY %main.6 (arg_tuple.1: (f32[128,10], f32[10,1024], f32[128,1024])) -> () { + %arg_tuple.1 = (f32[128,10]{1,0}, f32[10,1024]{1,0}, f32[128,1024]{1,0}) parameter(0), sharding={{devices=[1,2]0,1}, {maximal device=0}, {replicated}} + %get-tuple-element.2 = f32[128,10]{1,0} get-tuple-element((f32[128,10]{1,0}, f32[10,1024]{1,0}, f32[128,1024]{1,0}) %arg_tuple.1), index=0 + %get-tuple-element.3 = f32[10,1024]{1,0} get-tuple-element((f32[128,10]{1,0}, f32[10,1024]{1,0}, f32[128,1024]{1,0}) %arg_tuple.1), index=1 + %get-tuple-element.4 = f32[128,1024]{1,0} get-tuple-element((f32[128,10]{1,0}, f32[10,1024]{1,0}, f32[128,1024]{1,0}) %arg_tuple.1), index=2 + ROOT %tuple.5 = () tuple() +} + +)"; + EXPECT_EQ(expected_hlo_module_string, + status_or_hlo_module.ValueOrDie()->ToString()); +} + +TEST(CompileSerializedMlirToXlaHloTest, BadArgumentSharding) { + constexpr char mlir_module[] = R"( +module attributes {tf.versions = {producer = 179 : i32}} { + func @main(%arg0: tensor<128x10xf32> {xla_hlo.sharding = "bad_sharding"}) { + return + } +} +)"; + + std::vector arg_shapes{TensorShape({128, 10})}; + XlaCompiler::CompilationResult compilation_result; + + Status s = CompileSerializedMlirToXlaHlo( + mlir_module, arg_shapes, + /*use_tuple_args=*/true, TestShapeRepresentation, &compilation_result); + ASSERT_FALSE(s.ok()); + EXPECT_EQ(s.error_message(), + "failed to parse argument sharding 0 'bad_sharding'"); +} + +TEST(CompileSerializedMlirToXlaHloTest, ResultSharding) { + constexpr char mlir_module[] = R"( +module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 351 : i32}} { + func @main(%arg0: tensor<128x10xf32>, %arg1: tensor<10x1024xf32>, %arg2: tensor<128x1024xf32>) -> (tensor<128x10xf32> {xla_hlo.sharding = "\08\03\1A\02\01\02\22\02\00\01"}, tensor<10x1024xf32> {xla_hlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor<128x1024xf32> {xla_hlo.sharding = ""}) { + return %arg0, %arg1, %arg2 : tensor<128x10xf32>, tensor<10x1024xf32>, tensor<128x1024xf32> + } +} +)"; + + std::vector arg_shapes{TensorShape({128, 10}), + TensorShape({10, 1024}), + TensorShape({128, 1024})}; + XlaCompiler::CompilationResult compilation_result; + + Status s = CompileSerializedMlirToXlaHlo( + mlir_module, arg_shapes, + /*use_tuple_args=*/true, TestShapeRepresentation, &compilation_result); + TF_ASSERT_OK(s); + + const xla::HloModuleConfig module_config( + compilation_result.computation->GetProgramShape().ValueOrDie()); + auto status_or_hlo_module = xla::HloModule::CreateFromProto( + compilation_result.computation->proto(), module_config); + TF_ASSERT_OK(status_or_hlo_module.status()); + constexpr char expected_hlo_module_string[] = R"(HloModule main.9 + +ENTRY %main.9 (arg_tuple.1: (f32[128,10], f32[10,1024], f32[128,1024])) -> (f32[128,10], f32[10,1024], f32[128,1024]) { + %arg_tuple.1 = (f32[128,10]{1,0}, f32[10,1024]{1,0}, f32[128,1024]{1,0}) parameter(0) + %get-tuple-element.2 = f32[128,10]{1,0} get-tuple-element((f32[128,10]{1,0}, f32[10,1024]{1,0}, f32[128,1024]{1,0}) %arg_tuple.1), index=0 + %reshape.5 = f32[128,10]{1,0} reshape(f32[128,10]{1,0} %get-tuple-element.2) + %get-tuple-element.3 = f32[10,1024]{1,0} get-tuple-element((f32[128,10]{1,0}, f32[10,1024]{1,0}, f32[128,1024]{1,0}) %arg_tuple.1), index=1 + %reshape.6 = f32[10,1024]{1,0} reshape(f32[10,1024]{1,0} %get-tuple-element.3) + %get-tuple-element.4 = f32[128,1024]{1,0} get-tuple-element((f32[128,10]{1,0}, f32[10,1024]{1,0}, f32[128,1024]{1,0}) %arg_tuple.1), index=2 + %reshape.7 = f32[128,1024]{1,0} reshape(f32[128,1024]{1,0} %get-tuple-element.4) + ROOT %tuple.8 = (f32[128,10]{1,0}, f32[10,1024]{1,0}, f32[128,1024]{1,0}) tuple(f32[128,10]{1,0} %reshape.5, f32[10,1024]{1,0} %reshape.6, f32[128,1024]{1,0} %reshape.7), sharding={{devices=[1,2]0,1}, {maximal device=0}, {replicated}} +} + +)"; + EXPECT_EQ(expected_hlo_module_string, + status_or_hlo_module.ValueOrDie()->ToString()); +} + // Verify that conversion from Graph to MLIR and empty shape representation // function is successful. TEST(CompileGraphToXlaHlo, Basic) { @@ -311,7 +434,7 @@ TEST(CompileGraphToXlaHlo, Basic) { result.computation->proto(), module_config); ASSERT_TRUE(status_or_hlo_module.ok()); - string expected_hlo_module_string = R"(HloModule main.3 + constexpr char expected_hlo_module_string[] = R"(HloModule main.3 ENTRY %main.3 (Arg_0.1: f32[]) -> (f32[]) { %Arg_0.1 = f32[] parameter(0) diff --git a/tensorflow/compiler/mlir/tf_mlir_opt_main.cc b/tensorflow/compiler/mlir/tf_mlir_opt_main.cc index fbb775e061f..6bdd4838e97 100644 --- a/tensorflow/compiler/mlir/tf_mlir_opt_main.cc +++ b/tensorflow/compiler/mlir/tf_mlir_opt_main.cc @@ -55,6 +55,12 @@ static llvm::cl::opt verify_passes( llvm::cl::desc("Run the verifier after each transformation pass"), llvm::cl::init(true)); +// NOLINTNEXTLINE +static llvm::cl::opt allowUnregisteredDialects( + "allow-unregistered-dialect", + llvm::cl::desc("Allow operation with no registered dialects"), + llvm::cl::init(true)); + int main(int argc, char **argv) { tensorflow::InitMlir y(&argc, &argv); @@ -77,7 +83,7 @@ int main(int argc, char **argv) { if (failed(mlir::MlirOptMain(output->os(), std::move(file), pass_pipeline, split_input_file, verify_diagnostics, - verify_passes))) + verify_passes, allowUnregisteredDialects))) return 1; output->keep(); return 0; diff --git a/tensorflow/compiler/mlir/xla/BUILD b/tensorflow/compiler/mlir/xla/BUILD index 6597eeaa967..453576ba9ee 100644 --- a/tensorflow/compiler/mlir/xla/BUILD +++ b/tensorflow/compiler/mlir/xla/BUILD @@ -618,7 +618,10 @@ cc_library( ":hlo", ":type_to_shape", ":xla_dialect_registration", + "//tensorflow/compiler/mlir/tensorflow:convert_type", "//tensorflow/compiler/mlir/tensorflow:error_util", + "//tensorflow/compiler/tf2xla:common", + "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/xla:comparison_util", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", @@ -629,6 +632,8 @@ cc_library( "//tensorflow/compiler/xla/client/lib:quantize", "//tensorflow/compiler/xla/client/lib:slicing", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/core:framework", + "//tensorflow/core:protos_all_cc", "//tensorflow/stream_executor/lib", "@llvm-project//llvm:support", "@llvm-project//mlir:Analysis", @@ -727,6 +732,7 @@ genrule( outs = ["operator_writers.inc"], cmd = ("$(location :operator_writer_gen) " + "-I external/llvm-project/mlir/include " + + "-I external/org_tensorflow " + "$(location //tensorflow/compiler/mlir/xla:ir/hlo_ops.td) " + " -o $@"), tools = [":operator_writer_gen"], diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_client_ops.td b/tensorflow/compiler/mlir/xla/ir/hlo_client_ops.td index 6a60a42861a..48b765f2299 100644 --- a/tensorflow/compiler/mlir/xla/ir/hlo_client_ops.td +++ b/tensorflow/compiler/mlir/xla/ir/hlo_client_ops.td @@ -51,7 +51,7 @@ class HLOClient_Op traits> : // broadcasting (via the broadcast_dimensions attribute) and implicit degenerate // shape broadcasting. // -// These have 1:1 correspondance with same-named ops in the xla_hlo dialect; +// These have 1:1 correspondence with same-named ops in the xla_hlo dialect; // however, those operations do not support broadcasting. // // See: diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc b/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc index abaad272acd..86e865a1657 100644 --- a/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc +++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc @@ -177,6 +177,31 @@ void ConstOp::build(Builder* builder, OperationState& result, Attribute value) { result.addAttribute("value", value); } +//===----------------------------------------------------------------------===// +// DotGeneralOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(DotGeneralOp op) { + auto dot_dimension_numbers = op.dot_dimension_numbers(); + int64_t lhs_batching_dimensions_size = llvm::size( + dot_dimension_numbers.lhs_batching_dimensions().getValues()); + int64_t rhs_batching_dimensions_size = llvm::size( + dot_dimension_numbers.rhs_batching_dimensions().getValues()); + if (lhs_batching_dimensions_size != rhs_batching_dimensions_size) { + return op.emitError() + << "lhs and rhs should have the same number of batching dimensions"; + } + int64_t lhs_contracting_dimensions_size = llvm::size( + dot_dimension_numbers.lhs_contracting_dimensions().getValues()); + int64_t rhs_contracting_dimensions_size = llvm::size( + dot_dimension_numbers.rhs_contracting_dimensions().getValues()); + if (lhs_contracting_dimensions_size != rhs_contracting_dimensions_size) { + return op.emitError() << "lhs and rhs should have the same number of " + "contracting dimensions"; + } + return success(); +} + //===----------------------------------------------------------------------===// // IotaOp //===----------------------------------------------------------------------===// @@ -598,6 +623,28 @@ static LogicalResult Verify(DynamicBroadcastInDimOp op) { return success(); } +// If a DynamicBroadCastInDimOp is not actually dynamic, use an ordinary +// BroadcastInDimOp. +class DynamicBroadcastInDimOpNotActuallyDynamic + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(DynamicBroadcastInDimOp op, + PatternRewriter& rewriter) const override { + auto type = op.getType().dyn_cast(); + if (!type || !type.hasStaticShape()) { + return rewriter.notifyMatchFailure(op, "requires static shape"); + } + rewriter.replaceOpWithNewOp( + op, op.getType(), op.operand(), op.broadcast_dimensions()); + return success(); + } +}; + +void DynamicBroadcastInDimOp::getCanonicalizationPatterns( + OwningRewritePatternList& results, MLIRContext* context) { + results.insert(context); +} + //===----------------------------------------------------------------------===// // ClampOp //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops.td b/tensorflow/compiler/mlir/xla/ir/hlo_ops.td index bc05a1c100c..00b43198c55 100644 --- a/tensorflow/compiler/mlir/xla/ir/hlo_ops.td +++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops.td @@ -802,6 +802,7 @@ def HLO_DynamicBroadcastInDimOp : HLO_Op<"dynamic_broadcast_in_dim", let results = (outs HLO_Tensor); + let hasCanonicalizer = 1; // Cannot be exported to legacy formats. let hasCustomHLOConverter = 1; } @@ -948,6 +949,7 @@ def HLO_DotGeneralOp: HLO_Op<"dot_general", [NoSideEffect]>, BASE_HLO_DotGeneral ); let results = (outs HLO_Tensor); + let verifier = [{ return Verify(*this); }]; } // Define Base Einsum op within the HLO dialect as these are client ops and diff --git a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc index 670f34b4318..8922cc131c6 100644 --- a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc +++ b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc @@ -21,6 +21,7 @@ limitations under the License. #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/MemoryBuffer.h" @@ -37,8 +38,12 @@ limitations under the License. #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/StandardTypes.h" // from @llvm-project #include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/IR/UseDefLists.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" #include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" #include "tensorflow/compiler/mlir/xla/type_to_shape.h" +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/xla/client/lib/matrix.h" #include "tensorflow/compiler/xla/client/lib/quantize.h" #include "tensorflow/compiler/xla/client/lib/slicing.h" @@ -49,6 +54,8 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/stream_executor/lib/statusor.h" using ::stream_executor::port::StatusOr; @@ -64,6 +71,7 @@ using ::tensorflow::uint8; constexpr char kPaddingMapAttr[] = "xla_hlo.padding_map"; constexpr char kShapeIndicesAttr[] = "shape_indices"; constexpr char kPaddingArgIndicesAttr[] = "padding_arg_indices"; +constexpr char kShardingAttr[] = "xla_hlo.sharding"; constexpr char kRepicationAttr[] = "tf_device.is_same_data_across_replicas"; // Passes through everything except for unique_ptr, on which it calls get(). @@ -377,7 +385,7 @@ static xla::ScatterDimensionNumbers Convert_scatter_dimension_numbers( // returns absl::nullopt. static absl::optional CreateOpShardingFromAttribute( mlir::Operation* op) { - auto sharding = op->getAttrOfType("xla_hlo.sharding"); + auto sharding = op->getAttrOfType(kShardingAttr); if (!sharding) { return absl::nullopt; } @@ -389,6 +397,43 @@ static absl::optional CreateOpShardingFromAttribute( return sharding_proto; } +// Checks if all shardings are set. +static bool AllOptionalShardingsAreSet( + llvm::ArrayRef> shardings) { + return llvm::all_of(shardings, + [](const absl::optional& sharding) { + return sharding.has_value(); + }); +} + +// Extracts sharding from attribute string. +static absl::optional CreateOpShardingFromStringRef( + llvm::StringRef sharding) { + xla::OpSharding sharding_proto; + if (!sharding_proto.ParseFromString(sharding.str())) return absl::nullopt; + return sharding_proto; +} + +// Extracts argument and result shardings from function. +static void ExtractShardingsFromFunction( + mlir::FuncOp function, + llvm::SmallVectorImpl>* arg_shardings, + llvm::SmallVectorImpl>* ret_shardings) { + arg_shardings->resize(function.getNumArguments(), + absl::optional()); + for (int i = 0; i < function.getNumArguments(); ++i) + if (auto sharding = + function.getArgAttrOfType(i, kShardingAttr)) + (*arg_shardings)[i] = CreateOpShardingFromStringRef(sharding.getValue()); + + ret_shardings->resize(function.getNumResults(), + absl::optional()); + for (int i = 0; i < function.getNumResults(); ++i) + if (auto sharding = + function.getResultAttrOfType(i, kShardingAttr)) + (*ret_shardings)[i] = CreateOpShardingFromStringRef(sharding.getValue()); +} + namespace mlir { namespace { class ConvertToHloModule { @@ -402,12 +447,17 @@ class ConvertToHloModule { // are converted to a tuple even when there is only a single return value. // Multiple return values are always converted to a tuple and returned as a // single value. - explicit ConvertToHloModule(mlir::ModuleOp module, bool use_tuple_args, - bool return_tuple) + explicit ConvertToHloModule( + mlir::ModuleOp module, bool use_tuple_args, bool return_tuple, + tensorflow::XlaCompiler::ShapeRepresentationFn shape_representation_fn) : module_(module), module_builder_("main"), use_tuple_args_(use_tuple_args), - return_tuple_(return_tuple) {} + return_tuple_(return_tuple), + shape_representation_fn_(shape_representation_fn) { + if (!shape_representation_fn_) + shape_representation_fn_ = tensorflow::IdentityShapeRepresentationFn(); + } // Perform the lowering to XLA. This function returns failure if an error was // encountered. @@ -432,6 +482,8 @@ class ConvertToHloModule { LogicalResult LowerBasicBlockAsFunction( Block* block, xla::XlaBuilder* builder, bool is_entry_function, const std::vector& entry_args_same_across_replicas, + llvm::ArrayRef> arg_shardings, + llvm::ArrayRef> ret_shardings, xla::XlaComputation* result); ::xla::HloModuleProto ConsumeMainProto() { @@ -445,10 +497,22 @@ class ConvertToHloModule { ConvertToHloModule::ValueLoweringMap* value_lowering); private: - LogicalResult Lower(mlir::Operation* inst, bool is_entry_function, - xla::XlaBuilder* builder, - ConvertToHloModule::ValueLoweringMap* value_lowering, - xla::XlaComputation* result); + LogicalResult Lower( + mlir::Operation* inst, bool is_entry_function, + llvm::ArrayRef> ret_shardings, + xla::XlaBuilder* builder, + ConvertToHloModule::ValueLoweringMap* value_lowering, + xla::XlaComputation* result); + + LogicalResult SetEntryTupleShapesAndLeafReplication( + Block* block, const std::vector& entry_args_same_across_replicas, + llvm::SmallVectorImpl* arg_shapes, + std::vector* leaf_replication); + + LogicalResult SetEntryTupleShardings( + Block* block, xla::XlaBuilder* builder, + llvm::ArrayRef> arg_shardings, + llvm::SmallVectorImpl* arg_shapes); // The module being lowered. mlir::ModuleOp module_; @@ -465,6 +529,10 @@ class ConvertToHloModule { // Whether to always return a tuple. bool return_tuple_; + // Shape representation function to determine entry function argument and + // result shapes. + tensorflow::XlaCompiler::ShapeRepresentationFn shape_representation_fn_; + // Unique suffix to give to the name of the next lowered region. size_t region_id_ = 0; }; @@ -876,7 +944,9 @@ StatusOr CreateLiteralFromAttr(Type type, ElementsAttr attr) { } LogicalResult ConvertToHloModule::Lower( - mlir::Operation* inst, bool is_entry_function, xla::XlaBuilder* builder, + mlir::Operation* inst, bool is_entry_function, + llvm::ArrayRef> ret_shardings, + xla::XlaBuilder* builder, ConvertToHloModule::ValueLoweringMap* value_lowering, xla::XlaComputation* result) { if (succeeded(ExportXlaOperator(inst, {value_lowering, this, builder}))) { @@ -906,11 +976,37 @@ LogicalResult ConvertToHloModule::Lower( xla::XlaOp return_value; unsigned num_return_values = inst->getNumOperands(); if ((return_tuple_ && is_entry_function) || num_return_values > 1) { + const bool has_ret_shardings = + !ret_shardings.empty() && AllOptionalShardingsAreSet(ret_shardings); + std::vector returns(num_return_values); - for (unsigned i = 0, e = inst->getNumOperands(); i != e; ++i) { - returns[i] = value_map[inst->getOperand(i)]; + for (OpOperand& ret : inst->getOpOperands()) { + unsigned index = ret.getOperandNumber(); + returns[index] = value_map[ret.get()]; + if (!is_entry_function || !has_ret_shardings) continue; + + xla::Shape return_shape = xla::TypeToShape(ret.get().getType()); + StatusOr reshape = + tensorflow::ReshapeWithCorrectRepresentationAndSharding( + builder, returns[index], return_shape, shape_representation_fn_, + ret_shardings[index], /*fast_mem=*/false); + if (!reshape.ok()) + return inst->emitError() << reshape.status().error_message(); + + returns[index] = reshape.ValueOrDie(); } + + if (has_ret_shardings) { + xla::OpSharding sharding; + sharding.set_type(xla::OpSharding::TUPLE); + for (auto& ret_sharding : ret_shardings) + *sharding.add_tuple_shardings() = ret_sharding.value(); + + builder->SetSharding(sharding); + } + return_value = xla::Tuple(builder, returns); + builder->ClearSharding(); } else if (num_return_values == 1) { return_value = value_map[inst->getOperand(0)]; } @@ -976,6 +1072,8 @@ LogicalResult ConvertToHloModule::RunOnFunction(mlir::FuncOp f) { xla::XlaComputation computation; std::vector entry_args_same_across_replicas; + llvm::SmallVector, 4> arg_shardings; + llvm::SmallVector, 4> ret_shardings; if (entry_function) { bool any_arg_replicated = false; entry_args_same_across_replicas.reserve(f.getNumArguments()); @@ -1000,21 +1098,90 @@ LogicalResult ConvertToHloModule::RunOnFunction(mlir::FuncOp f) { // means no replication. This avoids the need for unrelated tests to handle // this field. if (!any_arg_replicated) entry_args_same_across_replicas.clear(); + + ExtractShardingsFromFunction(f, &arg_shardings, &ret_shardings); } - if (failed(LowerBasicBlockAsFunction(&f.front(), &builder, entry_function, - entry_args_same_across_replicas, - &computation))) { + if (failed(LowerBasicBlockAsFunction( + &f.front(), &builder, entry_function, entry_args_same_across_replicas, + arg_shardings, ret_shardings, &computation))) { return failure(); } lowered_computation_[f] = std::move(computation); return success(); } +LogicalResult ConvertToHloModule::SetEntryTupleShapesAndLeafReplication( + Block* block, const std::vector& entry_args_same_across_replicas, + llvm::SmallVectorImpl* arg_shapes, + std::vector* leaf_replication) { + arg_shapes->reserve(block->getNumArguments()); + leaf_replication->reserve(block->getNumArguments()); + for (BlockArgument& arg : block->getArguments()) { + arg_shapes->push_back(xla::TypeToShape(arg.getType())); + xla::Shape& arg_shape = arg_shapes->back(); + tensorflow::TensorShape arg_tensor_shape; + auto status = + tensorflow::XLAShapeToTensorShape(arg_shape, &arg_tensor_shape); + if (!status.ok()) + return block->getParentOp()->emitError() << status.error_message(); + + tensorflow::DataType dtype; + status = tensorflow::ConvertToDataType(arg.getType(), &dtype); + if (!status.ok()) + return block->getParentOp()->emitError() << status.error_message(); + + auto arg_shape_status = shape_representation_fn_(arg_tensor_shape, dtype, + /*use_fast_memory=*/false); + if (!arg_shape_status.ok()) + return block->getParentOp()->emitError() + << arg_shape_status.status().error_message(); + + arg_shape = std::move(arg_shape_status.ValueOrDie()); + + if (entry_args_same_across_replicas.empty()) continue; + for (int i = 0, e = xla::ShapeUtil::GetLeafCount(arg_shape); i < e; ++i) + leaf_replication->push_back( + entry_args_same_across_replicas[arg.getArgNumber()]); + } + + return success(); +} + +LogicalResult ConvertToHloModule::SetEntryTupleShardings( + Block* block, xla::XlaBuilder* builder, + llvm::ArrayRef> arg_shardings, + llvm::SmallVectorImpl* arg_shapes) { + if (!arg_shardings.empty() && AllOptionalShardingsAreSet(arg_shardings)) { + xla::OpSharding sharding; + sharding.set_type(xla::OpSharding::TUPLE); + for (auto arg_sharding : llvm::enumerate(arg_shardings)) { + auto hlo_sharding = + xla::HloSharding::FromProto(arg_sharding.value().value()); + if (!hlo_sharding.ok()) + return block->getParentOp()->emitError() + << hlo_sharding.status().error_message(); + + auto status = tensorflow::RewriteLayoutWithShardedShape( + hlo_sharding.ValueOrDie(), /*use_fast_memory=*/false, + shape_representation_fn_, &(*arg_shapes)[arg_sharding.index()]); + if (!status.ok()) + return block->getParentOp()->emitError() << status.error_message(); + + *sharding.add_tuple_shardings() = arg_sharding.value().value(); + } + + builder->SetSharding(sharding); + } + + return success(); +} + LogicalResult ConvertToHloModule::LowerBasicBlockAsFunction( Block* block, xla::XlaBuilder* builder, bool is_entry_function, const std::vector& entry_args_same_across_replicas, + llvm::ArrayRef> arg_shardings, + llvm::ArrayRef> ret_shardings, xla::XlaComputation* result) { - auto& bb = *block; // Mapping from the Value to lowered XlaOp. The code below lowers in // program order and will fail if an operand is unseen. This can be improved. ValueLoweringMap lowering; @@ -1022,29 +1189,28 @@ LogicalResult ConvertToHloModule::LowerBasicBlockAsFunction( // If using tuples as input, then there is only one input parameter that is a // tuple. if (is_entry_function && use_tuple_args_) { - std::vector arg_shapes; - arg_shapes.reserve(bb.getNumArguments()); + llvm::SmallVector arg_shapes; std::vector leaf_replication; - for (auto& arg : bb.getArguments()) { - arg_shapes.push_back(xla::TypeToShape(arg.getType())); - if (!entry_args_same_across_replicas.empty()) { - for (int i = 0; i < xla::ShapeUtil::GetLeafCount(arg_shapes.back()); - ++i) { - leaf_replication.push_back( - entry_args_same_across_replicas[arg.getArgNumber()]); - } - } - } + if (failed(SetEntryTupleShapesAndLeafReplication( + block, entry_args_same_across_replicas, &arg_shapes, + &leaf_replication))) + return failure(); + + if (failed( + SetEntryTupleShardings(block, builder, arg_shardings, &arg_shapes))) + return failure(); + xla::Shape input_shape = xla::ShapeUtil::MakeTupleShape(arg_shapes); auto tuple = xla::Parameter(builder, 0, input_shape, "arg_tuple", leaf_replication); - for (auto& it : llvm::enumerate(bb.getArguments())) { - lowering[it.value()] = xla::GetTupleElement(tuple, it.index()); - } + + builder->ClearSharding(); + + for (BlockArgument& arg : block->getArguments()) + lowering[arg] = xla::GetTupleElement(tuple, arg.getArgNumber()); } else { - for (auto& it : llvm::enumerate(bb.getArguments())) { - auto arg = it.value(); - auto num = it.index(); + for (BlockArgument& arg : block->getArguments()) { + auto num = arg.getArgNumber(); xla::Shape shape = xla::TypeToShape(arg.getType()); if (entry_args_same_across_replicas.empty()) { lowering[arg] = @@ -1058,8 +1224,9 @@ LogicalResult ConvertToHloModule::LowerBasicBlockAsFunction( } } - for (auto& inst : bb) - if (failed(Lower(&inst, is_entry_function, builder, &lowering, result))) + for (auto& inst : *block) + if (failed(Lower(&inst, is_entry_function, ret_shardings, builder, + &lowering, result))) return failure(); return success(); @@ -1069,8 +1236,10 @@ LogicalResult ConvertToHloModule::LowerRegionAsComputation( mlir::Region* region, xla::XlaComputation* func) { std::unique_ptr builder = module_builder_.CreateSubBuilder(absl::StrCat("region_", region_id_++)); - return LowerBasicBlockAsFunction(®ion->front(), builder.get(), - /*is_entry_function=*/false, {}, func); + return LowerBasicBlockAsFunction( + ®ion->front(), builder.get(), + /*is_entry_function=*/false, /*entry_args_same_across_replicas=*/{}, + /*arg_shardings=*/{}, /*ret_shardings=*/{}, func); } std::string PaddingMapBadArrayAttrMsg(llvm::StringRef attr_name, int index) { @@ -1241,9 +1410,12 @@ LogicalResult AddDynamicParameterBindings(mlir::ModuleOp module, } // namespace Status ConvertMlirHloToHlo(mlir::ModuleOp module, xla::HloProto* hlo_proto, - bool use_tuple_args, bool return_tuple) { + bool use_tuple_args, bool return_tuple, + const tensorflow::XlaCompiler::ShapeRepresentationFn + shape_representation_fn) { mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext()); - ConvertToHloModule converter(module, use_tuple_args, return_tuple); + ConvertToHloModule converter(module, use_tuple_args, return_tuple, + shape_representation_fn); if (failed(converter.Run())) return diag_handler.ConsumeStatus(); auto hlo_module = converter.ConsumeMainProto(); hlo_proto->mutable_hlo_module()->Swap(&hlo_module); diff --git a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h index 983d61a8af2..1a341b00d0c 100644 --- a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h +++ b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h @@ -18,6 +18,7 @@ limitations under the License. #include "mlir/IR/Module.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/service/hlo_module.h" @@ -31,7 +32,9 @@ namespace mlir { // Multiple return values are always converted to a tuple and returned as a // single value. Status ConvertMlirHloToHlo(mlir::ModuleOp module, xla::HloProto* hlo_proto, - bool use_tuple_args, bool return_tuple); + bool use_tuple_args, bool return_tuple, + const tensorflow::XlaCompiler::ShapeRepresentationFn + shape_representation_fn = nullptr); // Creates XlaOp equivalent of a given MLIR operation using the operand info // from `value_lowering` map. diff --git a/tensorflow/compiler/mlir/xla/tests/canonicalize.mlir b/tensorflow/compiler/mlir/xla/tests/canonicalize.mlir index 18a29968600..1b7d879ca03 100644 --- a/tensorflow/compiler/mlir/xla/tests/canonicalize.mlir +++ b/tensorflow/compiler/mlir/xla/tests/canonicalize.mlir @@ -31,6 +31,14 @@ func @dynamic_slice_constant_start_dynamic_shape(%arg0: tensor, %arg1: return %1 : tensor<1x4xi32> } +// CHECK-LABEL: func @dynamic_broadcast_in_dim_op_not_actually_dynamic +func @dynamic_broadcast_in_dim_op_not_actually_dynamic(%arg0: tensor<4xf32>, %arg1: tensor<2xi64>) -> tensor<5x4xf32> { + // CHECK: %[[RESULT:.+]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<5x4xf32> + %0 = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %arg1) { broadcast_dimensions = dense<1> : tensor<1xi64> } : (tensor<4xf32>, tensor<2xi64>) -> tensor<5x4xf32> + // CHECK: return %[[RESULT]] : tensor<5x4xf32> + return %0 : tensor<5x4xf32> +} + // CHECK-LABEL: @complex_expand_fold func @complex_expand_fold(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>) { %0 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> (tensor<4xcomplex>) diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir index d8a1a156b0c..a462a7f4a1f 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir @@ -795,7 +795,8 @@ func @floormod_unranked(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*x func @broadcast_to(%arg0: tensor<16xf32>) -> tensor<16x16x16x16xf32> { %cst = "tf.Const"() { value = dense<16> : tensor<4xi32> } : () -> tensor<4xi32> - // CHECK: "xla_hlo.broadcast_in_dim" + // CHECK: [[CST:%.+]] = xla_hlo.constant + // CHECK: "xla_hlo.dynamic_broadcast_in_dim"(%arg0, [[CST]]) // CHECK-SAME: {broadcast_dimensions = dense<3> : tensor<1xi64>} %0 = "tf.BroadcastTo"(%arg0, %cst) : (tensor<16xf32>, tensor<4xi32>) -> tensor<16x16x16x16xf32> return %0 : tensor<16x16x16x16xf32> @@ -2837,7 +2838,9 @@ func @range(%arg0: tensor, %arg1: tensor) -> tensor<5xf32> { // CHECK-SAME: [[START:%.*]]: tensor, [[STOP:%.*]]: tensor func @linspace_static(%arg0: tensor, %arg1: tensor) -> tensor<4xf32> { // CHECK-DAG: [[NUM:%.*]] = xla_hlo.constant dense<4> - // CHECK-DAG: [[STEP_DENOMINATOR:%.*]] = "xla_hlo.convert"([[NUM]]) + // CHECK-DAG: [[NUM_F32:%.*]] = "xla_hlo.convert"([[NUM]]) + // CHECK-DAG: [[ONE:%.*]] = xla_hlo.constant dense<1.000000e+00> + // CHECK-DAG: [[STEP_DENOMINATOR:%.*]] = xla_hlo.subtract [[NUM_F32]], [[ONE]] // CHECK-DAG: [[STEP_NUMERATOR:%.*]] = xla_hlo.subtract [[STOP]], [[START]] // CHECK-DAG: [[STEP:%.*]] = xla_hlo.divide [[STEP_NUMERATOR]], [[STEP_DENOMINATOR]] // CHECK-DAG: [[IOTA:%.*]] = "xla_hlo.iota"() {iota_dimension = 0 : i64} @@ -2856,6 +2859,15 @@ func @linspace_dynamic(%arg0: tensor, %arg1: tensor, %arg2: tensor } +// CHECK-LABEL: func @linspace_invalid_num +func @linspace_invalid_num(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: xla_hlo.constant {value = dense<[]> : tensor<0xi32>} : tensor + // CHECK: "tf.LinSpace" + %0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = i32, value = dense<[]> : tensor<0xi32>} : () -> tensor + %1 = "tf.LinSpace"(%arg0, %arg1, %0) : (tensor, tensor, tensor) -> tensor + return %1 : tensor +} + //===----------------------------------------------------------------------===// // Conv op legalizations. //===----------------------------------------------------------------------===// @@ -3705,8 +3717,8 @@ func @cumsum_dynamic(%arg0: tensor, %arg1: tensor) -> tensor // CHECK-LABEL: func @batchmatmulv2_broadcast_singleton_dimension func @batchmatmulv2_broadcast_singleton_dimension(%arg0: tensor<1x4x2xf32>, %arg1: tensor<3x2x4xf32>) -> tensor<3x4x4xf32> { - // CHECK: [[BLHS:%.+]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<1x4x2xf32>) -> tensor<3x4x2xf32> - // CHECK: [[BRHS:%.+]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<3x2x4xf32>) -> tensor<3x2x4xf32> + // CHECK: [[BLHS:%.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, {{.*}}) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<1x4x2xf32>, {{.*}}) -> tensor<3x4x2xf32> + // CHECK: [[BRHS:%.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, {{.*}}) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<3x2x4xf32>, {{.*}}) -> tensor<3x2x4xf32> // CHECK: [[BDST:%.+]] = "xla_hlo.dot_general"([[BLHS]], [[BRHS]]) {dot_dimension_numbers = { // CHECK-SAME: lhs_batching_dimensions = dense<0> : tensor<1xi64>, // CHECK-SAME: lhs_contracting_dimensions = dense<2> : tensor<1xi64>, @@ -3720,8 +3732,8 @@ func @batchmatmulv2_broadcast_singleton_dimension(%arg0: tensor<1x4x2xf32>, %arg // CHECK-LABEL: func @batchmatmulv2_lhs_batch func @batchmatmulv2_lhs_batch(%arg0: tensor<3x4x2xf32>, %arg1: tensor<2x4xf32>) -> tensor<3x4x4xf32> { - // CHECK: [[BLHS:%.+]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<3x4x2xf32>) -> tensor<3x4x2xf32> - // CHECK: [[BRHS:%.+]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<2x4xf32>) -> tensor<3x2x4xf32> + // CHECK: [[BLHS:%.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, {{.*}}) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<3x4x2xf32>, {{.*}}) -> tensor<3x4x2xf32> + // CHECK: [[BRHS:%.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, {{.*}}) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<2x4xf32>, {{.*}}) -> tensor<3x2x4xf32> // CHECK: [[BDST:%.+]] = "xla_hlo.dot_general"([[BLHS]], [[BRHS]]) {dot_dimension_numbers = { // CHECK-SAME: lhs_batching_dimensions = dense<0> : tensor<1xi64>, // CHECK-SAME: lhs_contracting_dimensions = dense<2> : tensor<1xi64>, @@ -3735,8 +3747,8 @@ func @batchmatmulv2_lhs_batch(%arg0: tensor<3x4x2xf32>, %arg1: tensor<2x4xf32>) // CHECK-LABEL: func @batchmatmulv2_rhs_batch func @batchmatmulv2_rhs_batch(%arg0: tensor<4x2xf32>, %arg1: tensor<3x2x4xf32>) -> tensor<3x4x4xf32> { - // CHECK: [[BLHS:%.+]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<4x2xf32>) -> tensor<3x4x2xf32> - // CHECK: [[BRHS:%.+]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<3x2x4xf32>) -> tensor<3x2x4xf32> + // CHECK: [[BLHS:%.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, {{.*}}) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<4x2xf32>, {{.*}}) -> tensor<3x4x2xf32> + // CHECK: [[BRHS:%.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, {{.*}}) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<3x2x4xf32>, {{.*}}) -> tensor<3x2x4xf32> // CHECK: [[BDST:%.+]] = "xla_hlo.dot_general"([[BLHS]], [[BRHS]]) {dot_dimension_numbers = { // CHECK-SAME: lhs_batching_dimensions = dense<0> : tensor<1xi64>, // CHECK-SAME: lhs_contracting_dimensions = dense<2> : tensor<1xi64>, @@ -3757,8 +3769,8 @@ func @batchmatmulv2_dynamic(%arg0: tensor, %arg1: tensor) // CHECK-LABEL: func @batchmatmulv2_adj_real func @batchmatmulv2_adj_real(%arg0: tensor<5x2xf32>, %arg1: tensor<2x4xf32>) -> tensor<5x4xf32> { - // CHECK: [[BLHS:%.+]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<5x2xf32>) -> tensor<5x2xf32> - // CHECK: [[BRHS:%.+]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<2x4xf32>) -> tensor<2x4xf32> + // CHECK: [[BLHS:%.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, {{.*}}) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<5x2xf32>, {{.*}}) -> tensor<5x2xf32> + // CHECK: [[BRHS:%.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, {{.*}}) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<2x4xf32>, {{.*}}) -> tensor<2x4xf32> // CHECK: [[BDST:%.+]] = "xla_hlo.dot_general"([[BLHS]], [[BRHS]]) {dot_dimension_numbers = { // CHECK-SAME: lhs_batching_dimensions = dense<[]> : tensor<0xi64>, // CHECK-SAME: lhs_contracting_dimensions = dense<0> : tensor<1xi64>, @@ -3780,8 +3792,8 @@ func @batchmatmulv2_adj_complex(%arg0: tensor<5x2xcomplex>, %arg1: tensor<2 // CHECK: [[RHSIM:%.+]] = "xla_hlo.imag"(%arg1) : (tensor<2x4xcomplex>) -> tensor<2x4xf32> // CHECK: [[RHSIMNEG:%.+]] = "xla_hlo.neg"([[RHSIM]]) : (tensor<2x4xf32>) -> tensor<2x4xf32> // CHECK: [[RHSCONJ:%.+]] = "xla_hlo.complex"([[RHSRE]], [[RHSIMNEG]]) : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xcomplex> - // CHECK: [[BLHS:%.+]] = "xla_hlo.broadcast_in_dim"([[LHSCONJ]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<5x2xcomplex>) -> tensor<5x2xcomplex> - // CHECK: [[BRHS:%.+]] = "xla_hlo.broadcast_in_dim"([[RHSCONJ]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<2x4xcomplex>) -> tensor<2x4xcomplex> + // CHECK: [[BLHS:%.+]] = "xla_hlo.dynamic_broadcast_in_dim"([[LHSCONJ]], {{.*}}) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<5x2xcomplex>, {{.*}}) -> tensor<5x2xcomplex> + // CHECK: [[BRHS:%.+]] = "xla_hlo.dynamic_broadcast_in_dim"([[RHSCONJ]], {{.*}}) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<2x4xcomplex>, {{.*}}) -> tensor<2x4xcomplex> // CHECK: [[BDST:%.+]] = "xla_hlo.dot_general"([[BLHS]], [[BRHS]]) {dot_dimension_numbers = { // CHECK-SAME: lhs_batching_dimensions = dense<[]> : tensor<0xi64>, // CHECK-SAME: lhs_contracting_dimensions = dense<0> : tensor<1xi64>, diff --git a/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-parallel-loops.mlir b/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-parallel-loops.mlir index 3317d24d820..1e375e142f7 100644 --- a/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-parallel-loops.mlir +++ b/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-parallel-loops.mlir @@ -125,3 +125,77 @@ func @dynamic_reduce(%arg: memref, // CHECK: } // CHECK: store [[REDUCTION_RESULT]], [[RESULT_BUF]]{{\[}}[[I]], [[K]]] // CHECK: loop.yield + +// ----- + +func @reduce_window(%arg: memref<112x112xf32>, + %init: memref, + %result: memref<56x56xf32>) { + "xla_lhlo.reduce_window"(%arg, %init, %result) ( { + ^bb0(%lhs: memref, %rhs: memref, %res: memref): + "xla_lhlo.maximum"(%lhs, %rhs, %res) + : (memref, memref, memref) -> () + "xla_lhlo.terminator"() : () -> () + }) { + padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>, + window_dimensions = dense<[3, 3]> : tensor<2xi64>, + window_strides = dense<[2, 2]> : tensor<2xi64> + } : (memref<112x112xf32>, memref, memref<56x56xf32>) -> () + return +} +// CHECK-LABEL: func @reduce_window( +// CHECK-SAME: [[OPERAND_BUF:%.*]]: memref<112x112xf32>, +// CHECK-SAME: [[INIT_BUF:%.*]]: memref, +// CHECK-SAME: [[RESULT_BUF:%.*]]: memref<56x56xf32>) { +// CHECK-DAG: [[IN_BOUNDS:%.*]] = constant 1 : i1 +// CHECK-DAG: [[C0:%.*]] = constant 0 : index +// CHECK-DAG: [[C1:%.*]] = constant 1 : index +// CHECK-DAG: [[C2:%.*]] = constant 2 : index +// CHECK-DAG: [[C3:%.*]] = constant 3 : index +// CHECK-DAG: [[C56:%.*]] = constant 56 : index +// CHECK-DAG: [[C112:%.*]] = constant 112 : index +// CHECK: [[INIT:%.*]] = load [[INIT_BUF]][] : memref +// CHECK: loop.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]]) +// CHECK-SAME: to ([[C56]], [[C56]]) step ([[C1]], [[C1]]) { +// CHECK: [[REDUCTION_RESULT:%.*]] = loop.parallel +// CHECK-SAME: ([[IW:%.*]], [[JW:%.*]]) = ([[C0]], [[C0]]) +// CHECK-SAME: to ([[C3]], [[C3]]) step ([[C1]], [[C1]]) +// CHECK-SAME: init ([[INIT]]) -> f32 { + +// CHECK: [[START_I:%.*]] = muli [[I]], [[C2]] : index +// CHECK: [[OFFSET_I:%.*]] = subi [[IW]], [[C0]] : index +// CHECK: [[INDEX_I:%.*]] = addi [[START_I]], [[OFFSET_I]] : index +// CHECK: [[INDEX_I_FITS:%.*]] = cmpi "ult", [[INDEX_I]], [[C112]] +// CHECK: [[IN_BOUNDS_0:%.*]] = and [[INDEX_I_FITS]], [[IN_BOUNDS]] + +// CHECK: [[START_J:%.*]] = muli [[J]], [[C2]] : index +// CHECK: [[OFFSET_J:%.*]] = subi [[JW]], [[C0]] : index +// CHECK: [[INDEX_J:%.*]] = addi [[START_J]], [[OFFSET_J]] : index +// CHECK: [[INDEX_J_FITS:%.*]] = cmpi "ult", [[INDEX_J]], [[C112]] +// CHECK: [[IN_BOUNDS_1:%.*]] = and [[IN_BOUNDS_0]], [[INDEX_J_FITS]] + +// CHECK: [[ELEM_TO_REDUCE:%.*]] = loop.if [[IN_BOUNDS_1]] -> (f32) { +// CHECK: [[OPERAND_ELEM:%.*]] = +// CHECK-SAME: load [[OPERAND_BUF]]{{\[}}[[INDEX_I]], [[INDEX_J]]] +// CHECK: loop.yield [[OPERAND_ELEM]] : f32 +// CHECK: } else { +// CHECK: loop.yield [[INIT]] : f32 +// CHECK: } + +// CHECK: loop.reduce([[ELEM_TO_REDUCE]]) : f32 { +// CHECK: ^bb0([[ELEM:%.*]]: f32, [[ACC:%.*]]: f32): +// CHECK: [[ELEM_BUF:%.*]] = alloc() : memref +// CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref +// CHECK: [[ACC_BUF:%.*]] = alloc() : memref +// CHECK: store [[ACC]], [[ACC_BUF]][] : memref +// CHECK: "xla_lhlo.maximum"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_BUF]]) +// CHECK: [[ACC_RESULT:%.*]] = load [[ACC_BUF]][] : memref +// CHECK: loop.reduce.return [[ACC_RESULT]] : f32 +// CHECK: } +// CHECK: loop.yield +// CHECK: } +// CHECK: store [[REDUCTION_RESULT]], [[RESULT_BUF]]{{\[}}[[I]], [[J]]] +// CHECK: loop.yield +// CHECK: } +// CHECK: return +// CHECK: } diff --git a/tensorflow/compiler/mlir/xla/tests/ops.mlir b/tensorflow/compiler/mlir/xla/tests/ops.mlir index 037eded9ba6..a1cddab54c9 100644 --- a/tensorflow/compiler/mlir/xla/tests/ops.mlir +++ b/tensorflow/compiler/mlir/xla/tests/ops.mlir @@ -932,3 +932,29 @@ func @reshape_invalid_shapes(%operand: tensor<2x4xf32>) -> tensor<3x3xf32> { %0 = "xla_hlo.reshape"(%operand) : (tensor<2x4xf32>) -> tensor<3x3xf32> return %0 : tensor<3x3xf32> } + +// ----- + +func @dot_general(%arg0: tensor, %arg1: tensor) { + // expected-error @+1 {{lhs and rhs should have the same number of batching dimensions}} + %0 = "xla_hlo.dot_general"(%arg0, %arg1) { dot_dimension_numbers = { + lhs_batching_dimensions = dense<0> : tensor<1xi64>, + lhs_contracting_dimensions = dense<2> : tensor<1xi64>, + rhs_batching_dimensions = dense<[]> : tensor<0xi64>, + rhs_contracting_dimensions = dense<1> : tensor<1xi64> + }} : (tensor, tensor) -> tensor + return +} + +// ----- + +func @dot_general(%arg0: tensor, %arg1: tensor) { + // expected-error @+1 {{lhs and rhs should have the same number of contracting dimensions}} + %0 = "xla_hlo.dot_general"(%arg0, %arg1) { dot_dimension_numbers = { + lhs_batching_dimensions = dense<0> : tensor<1xi64>, + lhs_contracting_dimensions = dense<[]> : tensor<0xi64>, + rhs_batching_dimensions = dense<0> : tensor<1xi64>, + rhs_contracting_dimensions = dense<1> : tensor<1xi64> + }} : (tensor, tensor) -> tensor + return +} diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc index 65704ca8dec..f8ec40aa42d 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc @@ -932,6 +932,28 @@ class ConvertBF16FloorDivOp : public OpRewritePattern { } }; +class ConvertBroadcastToOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TF::BroadcastToOp op, + PatternRewriter &rewriter) const override { + auto input_type = op.input().getType().dyn_cast(); + auto output_type = op.output().getType().dyn_cast(); + if (!input_type || !output_type) { + return rewriter.notifyMatchFailure(op, "requires ranked shape"); + } + auto rank_diff = output_type.getRank() - input_type.getRank(); + // The tf.BroadcastTo op performs "right-aligned" numpy-style broadcasting. + auto broadcast_dimensions = llvm::to_vector<4>( + llvm::seq(rank_diff, output_type.getRank())); + rewriter.replaceOpWithNewOp( + op, output_type, op.input(), op.shape(), + rewriter.getI64TensorAttr(broadcast_dimensions)); + return success(); + } +}; + // Converts TensorFlow EinsumOp to either HLO EinsumOp or UnaryEinsumOp // depending on arity of the op. class ConvertEinsumOp : public OpRewritePattern { @@ -2110,13 +2132,23 @@ class ConvertLinSpaceOp : public OpRewritePattern { return failure(); } + DenseIntElementsAttr num_attr; + if (!matchPattern(op.num(), m_Constant(&num_attr))) { + return rewriter.notifyMatchFailure(op, "Num must be a constant scalar"); + } + + if (num_attr.begin() == num_attr.end()) { + return rewriter.notifyMatchFailure(op, "Num must not be empty"); + } + int64_t num = (*num_attr.begin()).getSExtValue(); + // Calculate the scaling that needs to be applied to the iota. auto step_numerator = rewriter.create( op.getLoc(), op.start().getType(), op.stop(), op.start(), xla::getBroadcastDimensionsAttr(&rewriter, op.stop(), op.start())); Value step_denominator = rewriter.create( op.getLoc(), op.num(), result_type.getElementType()); - if (op.num() > 1) { + if (num > 1) { Value one = GetScalarConstOfType(result_type.getElementType(), op.getLoc(), 1, &rewriter); step_denominator = rewriter.create( @@ -3734,15 +3766,16 @@ LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion) { TF::PopulateLoweringTFPatterns(context, &patterns); patterns.insert< ConvertAllOp, ConvertAnyOp, ConvertArgMaxOp, ConvertBatchMatMulV2Op, - ConvertBF16FloorDivOp, ConvertConv2D, ConvertConv2DBackpropFilterOp, - ConvertConv2DBackpropInputOp, ConvertCumsumOp, ConvertEinsumOp, - ConvertFusedBatchNormGradOp, ConvertFusedBatchNormGradV2Op, - ConvertFusedBatchNormGradV3Op, ConvertFusedBatchNormV3Op, - ConvertInfeedDequeueTupleOp, ConvertLinSpaceOp, ConvertMaxOp, - ConvertMinOp, ConvertAvgPoolOp, ConvertMaxPoolOp, ConvertMaxPoolGradOp, - ConvertMeanOp, ConvertOneHotOp, ConvertOutfeedEnqueueTupleOp, - ConvertProdOp, ConvertRangeOp, ConvertSelectV2Op, ConvertSigmoidOp, - ConvertSizeOp, ConvertSoftmaxOp, + ConvertBroadcastToOp, ConvertBF16FloorDivOp, ConvertConv2D, + ConvertConv2DBackpropFilterOp, ConvertConv2DBackpropInputOp, + ConvertCumsumOp, ConvertEinsumOp, ConvertFusedBatchNormGradOp, + ConvertFusedBatchNormGradV2Op, ConvertFusedBatchNormGradV3Op, + ConvertFusedBatchNormV3Op, ConvertInfeedDequeueTupleOp, ConvertLinSpaceOp, + ConvertMaxOp, ConvertMinOp, ConvertAvgPoolOp, ConvertMaxPoolOp, + ConvertMaxPoolGradOp, ConvertMeanOp, ConvertOneHotOp, + ConvertOutfeedEnqueueTupleOp, ConvertProdOp, ConvertRangeOp, + ConvertSelectV2Op, ConvertSigmoidOp, ConvertSizeOp, + ConvertSoftmaxOp, ConvertSoftmaxOp, ConvertSplitOp, ConvertSplitVOp, ConvertStridedSliceOp, ConvertStridedSliceGradOp, ConvertSumOp, ConvertTensorScatterUpdateOp, ConvertTileOp, ConvertTopKV2Op, diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td index b9599201601..2f825a882f7 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td @@ -186,16 +186,6 @@ def : Pat<(TF_FloorModOp AnyStaticShapeTensor:$l, AnyStaticShapeTensor:$r), (HLO_AddOp $r, $rem, (BinBroadcastDimensions $r, $rem)), $rem)>; -//===----------------------------------------------------------------------===// -// BroadcastTo op patterns. -//===----------------------------------------------------------------------===// - -// input and result needs to ranked for computation of the broadcast dimensions. -def : Pat<(TF_BroadcastToOp:$result AnyRankedTensor:$input, $shape), - (HLO_BroadcastInDimOp $input, - (BinBroadcastDimensionsNonEmpty $input, $result)), - [(AnyRankedTensor $result)]>; - //===----------------------------------------------------------------------===// // Logical & bitwise binary op patterns. //===----------------------------------------------------------------------===// @@ -382,7 +372,7 @@ class createIotaOp: NativeCodeCall< def createConvertOp: NativeCodeCall< "CreateConvertOp(&($_builder), $0.getOwner()->getLoc(), $1, $2)">; -// Performs a substitution of MatrixBandPartOp for XLA HLO ops. Psuedocode is +// Performs a substitution of MatrixBandPartOp for XLA HLO ops. Pseudocode is // shown below, given a tensor `input` with k dimensions [I, J, K, ..., M, N] // and two integers, `num_lower` and `num_upper`: // @@ -454,14 +444,14 @@ def : Pat<(TF_ConstOp:$res ElementsAttr:$value), (HLO_ConstOp $value), // TODO(hinsu): Make these patterns to TF to TF lowering. Relu6 lowering will // require HLO canonicalization of min and max on a tensor to ClampOp. -// TODO(hinsu): Lower unsinged and quantized types after supporting +// TODO(hinsu): Lower unsigned and quantized types after supporting // them in GetScalarOfType. def : Pat<(TF_ReluOp AnyRankedTensor:$input), (HLO_MaxOp (HLO_ConstOp:$zero (GetScalarOfType<0> $input)), $input, (BinBroadcastDimensions $zero, $input)), [(TF_SintOrFpTensor $input)]>; -// TODO(hinsu): Lower unsinged and quantized types after supporting +// TODO(hinsu): Lower unsigned and quantized types after supporting // them in GetScalarOfType. def : Pat<(TF_Relu6Op AnyRankedTensor:$input), (HLO_ClampOp (HLO_ConstOp (GetScalarOfType<0> $input)), $input, diff --git a/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_parallel_loops.cc b/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_parallel_loops.cc index f2ae7227a23..1250db08ee5 100644 --- a/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_parallel_loops.cc +++ b/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_parallel_loops.cc @@ -29,6 +29,50 @@ namespace mlir { namespace xla_lhlo { namespace { +// Converts a block with LHLO ops and with signature: +// ^bb(%lhs: memref, %rhs: memref, %res: memref): +// into a reduction operator of loop.reduce by doing buffer allocation for +// scalar arguments and the result of `loop.reduce` to make it compatible with +// LHLO ops. +void ConvertToReductionOperator(Location loc, loop::ReduceOp reduce_op, + Block* lhlo_block, + ConversionPatternRewriter* rewriter) { + Block& loop_reduce_op_body = reduce_op.reductionOperator().front(); + rewriter->setInsertionPointToStart(&loop_reduce_op_body); + + // Allocate buffers to hold arguments of reduction operator block to stay + // compatible with the LHLO dialect ops in the reduction body. + Value elem_arg = lhlo_block->getArgument(0); + Value elem_buf = + rewriter->create(loc, elem_arg.getType().cast()); + rewriter->create(loc, loop_reduce_op_body.getArgument(0), elem_buf); + Value acc_arg = lhlo_block->getArgument(1); + Value acc_buf = + rewriter->create(loc, acc_arg.getType().cast()); + rewriter->create(loc, loop_reduce_op_body.getArgument(1), acc_buf); + + // Clone the ops from `xla_lhlo.reduce` into reduction operator block. + BlockAndValueMapping mapping; + mapping.map(lhlo_block->getArguments(), + ValueRange{elem_buf, acc_buf, acc_buf}); + for (auto& nested : lhlo_block->without_terminator()) { + auto clone = rewriter->clone(nested, mapping); + mapping.map(nested.getResults(), clone->getResults()); + } + Value acc_result = rewriter->create(loc, acc_buf); + rewriter->create(loc, acc_result); +} + +// Returns result of ConstantOp if `dim` is static, otherwise uses DimOp to +// extract dimension at runtime. +Value GetStaticOrDynamicDim(mlir::Location loc, Value shaped_value, + size_t dim_index, int64_t dim, + ConversionPatternRewriter* rewriter) { + return dim == ShapedType::kDynamicSize + ? rewriter->create(loc, shaped_value, dim_index).getResult() + : rewriter->create(loc, dim); +} + // Converts `xla_lhlo.ReduceOp` into two loop::ParallelOp and a loop::ReduceOp. // The outper `ParallelOp` refers to the parallel loops if there are // any. The inner `ParalleOp` refers to the reduction loops and `ReduceOp` @@ -42,7 +86,7 @@ namespace { // } ) {dimensions = dense<[1]> : tensor<1xi64>} // : (memref<100x10x5xf32>, memref, memref<100x5xf32>) -> () // -// is converted into: +// is roughly converted into: // // %init = load %init_buf[] : memref // loop.parallel (%i, %k) = (%c0, %c0) to (%c100, %c5) step (%c1, %c1) { @@ -67,15 +111,15 @@ class ReduceOpConverter : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite( - xla_lhlo::ReduceOp xla_reduce_op, ArrayRef args, + xla_lhlo::ReduceOp xla_reduce_op, ArrayRef /*args*/, ConversionPatternRewriter& rewriter) const final { // TODO(b/137624192) Implement variadic reduce. if (xla_reduce_op.out().size() != 1) return failure(); loop::ReduceOp reduce_op = - CreateParallelLoopsWithReduceOp(xla_reduce_op, args, &rewriter); - ConvertReductionOperator(xla_reduce_op, - &reduce_op.reductionOperator().front(), &rewriter); + CreateReduceOpInNestedParallelLoops(xla_reduce_op, &rewriter); + ConvertToReductionOperator(xla_reduce_op.getLoc(), reduce_op, + &xla_reduce_op.body().front(), &rewriter); rewriter.replaceOp(xla_reduce_op, llvm::None); return success(); } @@ -100,8 +144,8 @@ class ReduceOpConverter : public OpConversionPattern { // } : f32 // loop.yield // } - loop::ReduceOp CreateParallelLoopsWithReduceOp( - xla_lhlo::ReduceOp xla_reduce_op, ArrayRef args, + loop::ReduceOp CreateReduceOpInNestedParallelLoops( + xla_lhlo::ReduceOp xla_reduce_op, ConversionPatternRewriter* rewriter) const { auto loc = xla_reduce_op.getLoc(); DenseSet reducing_dims; @@ -114,20 +158,13 @@ class ReduceOpConverter : public OpConversionPattern { SmallVector parallel_lower, parallel_upper, parallel_step; SmallVector reduce_lower, reduce_upper, reduce_step; auto operand_shape = operand.getType().cast().getShape(); - Type index_type = rewriter->getIndexType(); for (auto dim : llvm::enumerate(operand_shape)) { const bool is_reducing_dim = reducing_dims.count(dim.index()); - Value ub = - dim.value() == ShapedType::kDynamicSize - ? rewriter->create(loc, operand, dim.index()).getResult() - : rewriter->create( - loc, index_type, - rewriter->getIntegerAttr(index_type, dim.value())); - Value lb = rewriter->create( - loc, index_type, rewriter->getIntegerAttr(index_type, 0)); - Value step = rewriter->create( - loc, index_type, rewriter->getIntegerAttr(index_type, 1)); + Value ub = GetStaticOrDynamicDim(loc, operand, dim.index(), dim.value(), + rewriter); + Value lb = rewriter->create(loc, 0); + Value step = rewriter->create(loc, 1); (is_reducing_dim ? reduce_lower : parallel_lower).push_back(lb); (is_reducing_dim ? reduce_upper : parallel_upper).push_back(ub); (is_reducing_dim ? reduce_step : parallel_step).push_back(step); @@ -153,8 +190,7 @@ class ReduceOpConverter : public OpConversionPattern { out_indices.push_back(iv); } } else { - out_indices.push_back(rewriter->create( - loc, index_type, rewriter->getIntegerAttr(index_type, 0))); + out_indices.push_back(rewriter->create(loc, 0)); } rewriter->create(loc, reduction_result, out, out_indices); @@ -175,39 +211,209 @@ class ReduceOpConverter : public OpConversionPattern { loc, *xla_reduce_op.operands().begin(), indices); return rewriter->create(loc, elem); } +}; - // Converts `xla_lhlo.reduce` reduction operator into `loop.reduce` op by - // doing buffer allocation for scalar arguments and the result of - // `loop.reduce` to make it compatible with LHLO ops. - void ConvertReductionOperator(xla_lhlo::ReduceOp xla_reduce_op, - Block* loop_reduce_op_body, - ConversionPatternRewriter* rewriter) const { - rewriter->setInsertionPointToStart(loop_reduce_op_body); +// Pseudocode: +// for each index O in output +// accumulator = neutral_value +// in_bounds = true +// for each index W in window +// for each dimension i from 0 to rank - 1 +// index = O[i] * stride[i] + W[i] - pad_low[i] +// in_bounds = inbounds && (index `ult` shape[i]) +// I[i] = index +// if (in_bounds) +// value = input[I] +// else +// value = neutral_value +// accumulator = reduction_operator(output[O], value) +// output[O] = accumulator +// +// Converts `xla_lhlo.ReduceWindowOp` into two loop::ParallelOp and a +// loop::ReduceOp. +// The outper `ParallelOp` refers to the parallel loops that traverese output +// buffer. The inner `ParalleOp` refers to the reduction loops that traverse +// reduction windows and `ReduceOp` contains the reduction operator. +// +// Example: +// +// func @reduce_window(%arg: memref<112x112xf32>, +// %init: memref, +// %result: memref<56x56xf32>) { +// "xla_lhlo.reduce_window"(%arg, %init, %result) ( { +// ^bb0(%lhs: memref, %rhs: memref, %res: memref): +// "xla_lhlo.maximum"(%lhs, %rhs, %res) +// : (memref, memref, memref) -> () +// "xla_lhlo.terminator"() : () -> () +// }) { +// padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>, +// window_dimensions = dense<[3, 3]> : tensor<2xi64>, +// window_strides = dense<[2, 2]> : tensor<2xi64> +// } : (memref<112x112xf32>, memref, memref<56x56xf32>) -> () +// return +// } +// +// is roughly converted into: +// +// %neutral_elem = load %init_buf[] : memref +// loop.parallel (%i, %j) = (%c0, %c0) to (%c56, %c56) step (%c1, %c1) { +// %result = loop.parallel (%iw, %jw) = (%c0, %c0) +// to (%c3, %c3) step (%c1, %c1) neutral_elem (%0) -> f32 { +// %in_bounds = +// %elem = load %operand[%computed_i, %computed_j] +// %elem_or_neutral = select %in_bounds, %elem, %neutral_elem : f32 +// loop.reduce(%elem_to_reduce) : f32 { +// ^bb0(%arg7: f32, %arg8: f32): +// +// } +// loop.yield +// } +// store %result, %output_buffer[%i, %j] : memref<56x56xf32> +// loop.yield +// } +// return +// } +class ReduceWindowOpConverter + : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; - // Allocate buffers to hold arguments of reduction operator block to stay - // compatible with the LHLO dialect ops in the reduction body. - auto loc = xla_reduce_op.getLoc(); - Value elem_arg = xla_reduce_op.body().front().getArgument(0); - Value elem_buf = - rewriter->create(loc, elem_arg.getType().cast()); - rewriter->create(loc, loop_reduce_op_body->getArgument(0), - elem_buf); - Value acc_arg = xla_reduce_op.body().front().getArgument(1); - Value acc_buf = - rewriter->create(loc, acc_arg.getType().cast()); - rewriter->create(loc, loop_reduce_op_body->getArgument(1), - acc_buf); + LogicalResult matchAndRewrite( + xla_lhlo::ReduceWindowOp xla_reduce_window_op, ArrayRef /*args*/, + ConversionPatternRewriter& rewriter) const final { + loop::ParallelOp output_loop, window_loop; + std::tie(output_loop, window_loop) = + CreateParallelLoopsToTraverseOutputAndWindow(xla_reduce_window_op, + &rewriter); - // Clone the ops from `xla_lhlo.reduce` into reduction operator block. - BlockAndValueMapping mapping; - mapping.map(xla_reduce_op.body().front().getArguments(), - ValueRange{elem_buf, acc_buf, acc_buf}); - for (auto& nested : xla_reduce_op.body().front().without_terminator()) { - auto clone = rewriter->clone(nested, mapping); - mapping.map(nested.getResults(), clone->getResults()); + loop::ReduceOp reduce_op = CreateReduceOpInNestedParallelLoops( + xla_reduce_window_op, output_loop, window_loop, &rewriter); + + ConvertToReductionOperator(xla_reduce_window_op.getLoc(), reduce_op, + &xla_reduce_window_op.body().front(), &rewriter); + rewriter.replaceOp(xla_reduce_window_op, llvm::None); + return success(); + } + + private: + std::pair + CreateParallelLoopsToTraverseOutputAndWindow( + xla_lhlo::ReduceWindowOp xla_reduce_window_op, + ConversionPatternRewriter* rewriter) const { + auto loc = xla_reduce_window_op.getLoc(); + Value init_value = + rewriter->create(loc, xla_reduce_window_op.init_value()); + + Value zero = rewriter->create(loc, 0); + Value one = rewriter->create(loc, 1); + + // Create an outer parallel loop that spans the output of ReduceWindowOp. + Value xla_output = xla_reduce_window_op.out(); + auto output_shape = xla_output.getType().cast().getShape(); + SmallVector parallel_lower, parallel_upper, parallel_step; + for (auto dim : llvm::enumerate(output_shape)) { + parallel_upper.push_back(GetStaticOrDynamicDim( + loc, xla_output, dim.index(), dim.value(), rewriter)); + parallel_lower.push_back(zero); + parallel_step.push_back(one); } - Value acc_result = rewriter->create(loc, acc_buf); - rewriter->create(loc, acc_result); + auto output_loop = rewriter->create( + loc, parallel_lower, parallel_upper, parallel_step); + + // Create a nested loop that traverses the window. + rewriter->setInsertionPointToStart(output_loop.getBody()); + SmallVector window_lower, window_upper, window_step; + for (const auto& window_dim : xla_reduce_window_op.window_dimensions()) { + window_step.push_back(one); + window_lower.push_back(zero); + window_upper.push_back( + rewriter->create(loc, window_dim.getSExtValue())); + } + auto window_loop = rewriter->create( + loc, window_lower, window_upper, window_step, init_value); + + Value reduction_result = *window_loop.getResults().begin(); + auto output_ivs = output_loop.getInductionVars(); + rewriter->create( + loc, reduction_result, xla_output, + llvm::makeArrayRef(output_ivs.begin(), output_ivs.end())); + return std::make_pair(output_loop, window_loop); + } + + loop::ReduceOp CreateReduceOpInNestedParallelLoops( + xla_lhlo::ReduceWindowOp xla_reduce_window_op, + loop::ParallelOp output_loop, loop::ParallelOp window_loop, + ConversionPatternRewriter* rewriter) const { + rewriter->setInsertionPointToStart(window_loop.getBody()); + auto loc = xla_reduce_window_op.getLoc(); + + if (!xla_reduce_window_op.window_strides().hasValue()) { + xla_reduce_window_op.emitOpError("No window strides specified."); + } + if (!xla_reduce_window_op.padding().hasValue()) { + xla_reduce_window_op.emitOpError("No padding specified."); + } + if (xla_reduce_window_op.base_dilations().hasValue() || + xla_reduce_window_op.window_dilations().hasValue()) { + xla_reduce_window_op.emitRemark( + "Lowering to parallel loops does not support `base_dilations` or " + "`window_dilations` attributes yet. The attributes will be ignored."); + } + + Value xla_operand = xla_reduce_window_op.operand(); + auto xla_operand_type = xla_operand.getType().cast(); + auto xla_operand_shape = xla_operand_type.getShape(); + + auto output_ivs = llvm::to_vector<2>(output_loop.getInductionVars()); + auto window_ivs = llvm::to_vector<2>(window_loop.getInductionVars()); + auto window_strides = xla_reduce_window_op.window_strides().getValue(); + auto padding = xla_reduce_window_op.padding().getValue(); + + SmallVector operand_indices; + // `in_bounds` is false when the element in the reduce window is in the + // padding area, true otherwise. + Value in_bounds = rewriter->create( + loc, rewriter->getI1Type(), + rewriter->getIntegerAttr(rewriter->getI1Type(), 1)); + for (unsigned i = 0, e = output_loop.getNumLoops(); i < e; ++i) { + auto stride = window_strides.getValue(i); + auto pad_low = padding.getValue({i, 0}); + + Value stride_val = + rewriter->create(loc, stride.getSExtValue()); + Value pad_low_val = + rewriter->create(loc, pad_low.getSExtValue()); + + Value center = rewriter->create(loc, output_ivs[i], stride_val); + Value offset = rewriter->create(loc, window_ivs[i], pad_low_val); + Value index = rewriter->create(loc, center, offset); + operand_indices.push_back(index); + Value upper_bound = GetStaticOrDynamicDim(loc, xla_operand, i, + xla_operand_shape[i], rewriter); + // We must check whether 0 <= index_i < shape_i, as otherwise we are in + // the pad and then we have to use the neutral element for reduction. + // Equivalently, it can be computed as the unsigned comparison index_i < + // shape_i, since a negative value wraps to a large positive value. + in_bounds = rewriter->create( + loc, in_bounds, + rewriter->create(loc, CmpIPredicate::ult, index, + upper_bound)); + } + + auto elem_or_init = + rewriter->create(loc, xla_operand_type.getElementType(), + in_bounds, /*withElseRegion=*/true); + + OpBuilder then_builder = elem_or_init.getThenBodyBuilder(); + Value elem = then_builder.create( + loc, xla_reduce_window_op.operand(), operand_indices); + then_builder.create(loc, elem); + + OpBuilder else_builder = elem_or_init.getElseBodyBuilder(); + else_builder.create(loc, *window_loop.initVals().begin()); + + return rewriter->create(loc, + *elem_or_init.results().begin()); } }; @@ -217,12 +423,14 @@ struct LhloLegalizeToParallelLoops auto func = getFunction(); OwningRewritePatternList patterns; - patterns.insert(func.getContext()); + patterns.insert( + func.getContext()); ConversionTarget target(getContext()); target.addLegalDialect(); target.addIllegalOp(); + target.addIllegalOp(); if (failed(applyPartialConversion(func, target, patterns, nullptr))) { signalPassFailure(); diff --git a/tensorflow/compiler/tests/image_ops_test.py b/tensorflow/compiler/tests/image_ops_test.py index e4ec28978ba..b89472b8085 100644 --- a/tensorflow/compiler/tests/image_ops_test.py +++ b/tensorflow/compiler/tests/image_ops_test.py @@ -30,6 +30,7 @@ from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_image_ops from tensorflow.python.ops import image_ops @@ -978,6 +979,7 @@ class NonMaxSuppressionTest(xla_test.XLATestCase): class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase): + @test_util.with_forward_compatibility_horizons(None, [2020, 4, 21]) def testBatchedNMSFrom6(self): boxes_data = [[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4], [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]], @@ -1009,12 +1011,13 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase): scores: scores_np } indices_output, num_valid_output = sess.run([indices, num_valid], inputs) - invalid_index = len(boxes_data[0]) - 1 + invalid_index = 0 self.assertAllEqual([[0, 1, 2, 4, 5, invalid_index], [0, 1, 3, 5, invalid_index, invalid_index]], indices_output) self.assertAllEqual([5, 4], num_valid_output) + @test_util.with_forward_compatibility_horizons(None, [2020, 4, 21]) def testBatchedNMSFrom6Max3(self): boxes_data = [[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4], [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]], @@ -1048,6 +1051,7 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase): self.assertAllEqual([[0, 1, 2], [0, 1, 3]], indices_output) self.assertAllEqual([3, 3], num_valid_output) + @test_util.with_forward_compatibility_horizons(None, [2020, 4, 21]) def testBatchedNMSSingleFrom6Max3(self): boxes_data = [[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4], [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]] @@ -1078,6 +1082,7 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase): self.assertAllEqual([0, 1, 2], indices_output) self.assertAllEqual(3, num_valid_output) + @test_util.with_forward_compatibility_horizons(None, [2020, 4, 21]) def testBatchedNMSSingleFrom6NoPad(self): boxes_data = [[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4], [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]] @@ -1107,6 +1112,7 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase): self.assertAllEqual([0, 1, 2, 4, 5], indices_output) self.assertAllEqual(5, num_valid_output) + @test_util.with_forward_compatibility_horizons(None, [2020, 4, 21]) def testBatchedNMSBatchDimsFrom6Max3(self): boxes_data = [[[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4], [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]], @@ -1140,6 +1146,7 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase): self.assertAllEqual([[[0, 1, 2], [0, 1, 3]]], indices_output) self.assertAllEqual([[3, 3]], num_valid_output) + @test_util.with_forward_compatibility_horizons(None, [2020, 4, 21]) def testBatchedNMSScoreThresholdFrom6Max3(self): boxes_data = [[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4], [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]], @@ -1171,10 +1178,11 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase): scores: scores_np } indices_output, num_valid_output = sess.run([indices, num_valid], inputs) - invalid_index = len(boxes_data[0]) - 1 + invalid_index = 0 self.assertAllEqual([3, 2], num_valid_output) self.assertAllEqual([[0, 1, 2], [0, 1, invalid_index]], indices_output) + @test_util.with_forward_compatibility_horizons(None, [2020, 4, 21]) def testBatchedNMSUnsortedInputFrom6(self): boxes_data = [[[0, 2, 1, 2], [3, 3, 4, 4], [0, 0, 1, 1], [0, 0.4, 1, 1.4], [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8]], @@ -1205,10 +1213,13 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase): scores: scores_np } indices_output, num_valid_output = sess.run([indices, num_valid], inputs) - self.assertAllEqual([[2, 1, 3, 5, 0, 0], [1, 5, 0, 3, 3, 3]], + invalid_index = 0 + self.assertAllEqual([[2, 1, 3, 5, 0, invalid_index], + [1, 5, 0, 3, invalid_index, invalid_index]], indices_output) self.assertAllEqual([5, 4], num_valid_output) + @test_util.with_forward_compatibility_horizons(None, [2020, 4, 21]) def testBatchedNMSNoncanonicalizedInputFrom6(self): boxes_data = [[[1, 0, 0, 1], [4, 3, 3, 4], [1, 0.4, 0, 1.4], [1, 0.6, 0, 1.6], [1, 0.8, 0, 1.8], [1, 2, 0, 2]], @@ -1240,11 +1251,84 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase): scores: scores_np } indices_output, num_valid_output = sess.run([indices, num_valid], inputs) - invalid_index = len(boxes_data[0]) - 1 + invalid_index = 0 self.assertAllEqual([[0, 1, 2, 4, 5, invalid_index], [0, 1, 3, 5, invalid_index, invalid_index]], indices_output) self.assertAllEqual([5, 4], num_valid_output) + @test_util.with_forward_compatibility_horizons(None, [2020, 4, 21]) + def testBatchedNMSScoreThresholdCanInputsFrom6Max3(self): + boxes_data = [[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4], + [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]], + [[0, 2, 1, 2], [0, 0.8, 1, 1.8], [0, 0.6, 1, 1.6], + [0, 0.4, 1, 1.4], [0, 0.2, 1, 1.2], [0, 0, 1, 1]]] + scores_data = [[0.9, 0.7, 0.6, 0.4, 0.3, 0.2], + [0.8, 0.7, 0.6, 0.4, 0.3, 0.1]] + max_output_size = 3 + iou_threshold = 0.5 + boxes_np = np.array(boxes_data, dtype=np.float32) + scores_np = np.array(scores_data, dtype=np.float32) + + with self.session() as sess: + boxes = array_ops.placeholder(boxes_np.dtype, shape=boxes_np.shape) + scores = array_ops.placeholder(scores_np.dtype, shape=scores_np.shape) + with self.test_scope(): + (indices, num_valid) = image_ops.non_max_suppression_padded( + boxes=boxes, + scores=scores, + max_output_size=max_output_size, + iou_threshold=iou_threshold, + score_threshold=0.5, + pad_to_max_output_size=True, + sorted_input=True, + canonicalized_coordinates=False) + + inputs = { + boxes: boxes_np, + scores: scores_np + } + indices_output, num_valid_output = sess.run([indices, num_valid], inputs) + invalid_index = 0 + self.assertAllEqual([3, 2], num_valid_output) + self.assertAllEqual([[0, 1, 2], [0, 1, invalid_index]], indices_output) + + @test_util.with_forward_compatibility_horizons(None, [2020, 4, 21]) + def testBatchedNMSFrom6DynamicInput(self): + boxes_data = [[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4], + [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]], + [[0, 2, 1, 2], [0, 0.8, 1, 1.8], [0, 0.6, 1, 1.6], + [0, 0.4, 1, 1.4], [0, 0.2, 1, 1.2], [0, 0, 1, 1]]] + scores_data = [[0.9, 0.7, 0.6, 0.5, 0.4, 0.3], + [0.8, 0.7, 0.6, 0.5, 0.4, 0.3]] + max_output_size = 6 + iou_threshold = 0.5 + boxes_np = np.array(boxes_data, dtype=np.float32) + scores_np = np.array(scores_data, dtype=np.float32) + + with self.session() as sess: + boxes = array_ops.placeholder(boxes_np.dtype) + scores = array_ops.placeholder(scores_np.dtype) + + with self.test_scope(): + (indices, num_valid) = image_ops.non_max_suppression_padded( + boxes=boxes, + scores=scores, + max_output_size=max_output_size, + iou_threshold=iou_threshold, + pad_to_max_output_size=True, + sorted_input=True, + canonicalized_coordinates=True) + + inputs = { + boxes: boxes_np, + scores: scores_np + } + indices_output, num_valid_output = sess.run([indices, num_valid], inputs) + invalid_index = 0 + self.assertAllEqual([[0, 1, 2, 4, 5, invalid_index], + [0, 1, 3, 5, invalid_index, invalid_index]], + indices_output) + self.assertAllEqual([5, 4], num_valid_output) if __name__ == "__main__": test.main() diff --git a/tensorflow/compiler/tests/stateless_random_ops_test.py b/tensorflow/compiler/tests/stateless_random_ops_test.py index 14b062e5cba..f9d792806b0 100644 --- a/tensorflow/compiler/tests/stateless_random_ops_test.py +++ b/tensorflow/compiler/tests/stateless_random_ops_test.py @@ -26,6 +26,7 @@ from tensorflow.python.kernel_tests.random import util as \ random_test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import stateless_random_ops as stateless +from tensorflow.python.ops import variables from tensorflow.python.platform import test @@ -132,5 +133,38 @@ class StatelessRandomOpsTest(xla_test.XLATestCase): variance_rtol=6e-3 if dtype == dtypes.bfloat16 else 1e-3) +class StatelessRandomOpsBenchmark(test.Benchmark): + """Microbenchmarks for the stateless random ops.""" + + def _benchmarkUniform(self, name, dtype, use_xla_jit): + + def BuilderFn(): + shape = (10, 1000, 1000) + seed_var = variables.Variable((312, 456), + dtype=dtypes.int32, + name='input') + random_t = stateless.stateless_random_uniform( + shape, seed=seed_var, dtype=dtype) + return '%s.shape%s' % (name, shape), [random_t] + + xla_test.Benchmark(self, BuilderFn, use_xla_jit=use_xla_jit, device='cpu') + + def benchmarkUniformF32(self): + self._benchmarkUniform( + 'uniform_f32', dtype=dtypes.float32, use_xla_jit=False) + + def benchmarkUniformF64(self): + self._benchmarkUniform( + 'uniform_f64', dtype=dtypes.float64, use_xla_jit=False) + + def benchmarkUniformF32XLA(self): + self._benchmarkUniform( + 'uniform_f32', dtype=dtypes.float32, use_xla_jit=True) + + def benchmarkUniformF64XLA(self): + self._benchmarkUniform( + 'uniform_f64', dtype=dtypes.float64, use_xla_jit=True) + + if __name__ == '__main__': test.main() diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index 10a67e835b1..1c5867a1312 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -612,10 +612,12 @@ tf_cc_test( ":xla_compiler", "//tensorflow/cc:cc_ops", "//tensorflow/cc:function_ops", + "//tensorflow/cc:functional_ops", "//tensorflow/cc:ops", "//tensorflow/compiler/jit:xla_cluster_util", "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", "//tensorflow/core:ops", "//tensorflow/core:test", "//tensorflow/core:test_main", diff --git a/tensorflow/compiler/tf2xla/const_analysis.cc b/tensorflow/compiler/tf2xla/const_analysis.cc index 48513a43fb3..c90261303f5 100644 --- a/tensorflow/compiler/tf2xla/const_analysis.cc +++ b/tensorflow/compiler/tf2xla/const_analysis.cc @@ -145,6 +145,21 @@ Status GetCompileTimeConstInputs(const NodeDef& node, const OpKernel* op_kernel, TF_RETURN_IF_ERROR( GetFunctionBodies(flib_runtime, node, "branches", &branch_bodies)); return CondConstInputIndices(branch_bodies, const_input_idxs, flib_runtime); + } else if (node.op() == "PartitionedCall" || + node.op() == "StatefulPartitionedCall") { + const FunctionBody* fbody; + TF_RETURN_IF_ERROR(GetFunctionBody(flib_runtime, node, "f", &fbody)); + int num_inputs = fbody->fdef.signature().input_arg_size(); + std::vector compile_time_const_arg_indices(num_inputs); + TF_RETURN_IF_ERROR(BackwardsConstAnalysis( + *(fbody->graph), &compile_time_const_arg_indices, + /*compile_time_const_nodes=*/nullptr, flib_runtime)); + for (int i = 0; i < num_inputs; i++) { + if (compile_time_const_arg_indices[i]) { + const_input_idxs->push_back(i); + } + } + return Status::OK(); } else if (op_def != nullptr) { return XlaOpRegistry::CompileTimeConstantInputs(node, *op_def, const_input_idxs); @@ -166,11 +181,21 @@ Status GetCompileTimeConstInputs(const Node* node, // Backwards dataflow analysis that finds arguments to a graph that must be // compile-time constants. -Status BackwardsConstAnalysis(const Graph& g, - std::vector* compile_time_const_arg_indices, - std::vector* compile_time_const_nodes, - FunctionLibraryRuntime* flib_runtime, - std::function edge_filter) { +Status BackwardsConstAnalysis( + const Graph& g, std::vector* compile_time_const_arg_indices, + std::vector* compile_time_const_nodes, + FunctionLibraryRuntime* flib_runtime, + std::function edge_filter_input) { + if (!compile_time_const_nodes && g.GetConstArgIndicesCache().has_value() && + !edge_filter_input) { + VLOG(5) << "Using cached argument indices on graph " << &g; + *compile_time_const_arg_indices = g.GetConstArgIndicesCache().value(); + return Status::OK(); + } + auto edge_filter = [&](const Edge& e) { + return edge_filter_input ? edge_filter_input(e) : true; + }; + std::vector compile_time_const_nodes_impl; if (compile_time_const_nodes) { CHECK_EQ(compile_time_const_nodes->size(), g.num_node_ids()); @@ -252,6 +277,10 @@ Status BackwardsConstAnalysis(const Graph& g, // acyclic graph. DFS(g, /*enter=*/{}, /*leave=*/visit, NodeComparatorName{}, [](const Edge& edge) { return !edge.src()->IsNextIteration(); }); + if (compile_time_const_arg_indices && !edge_filter_input) { + VLOG(5) << "Setting the cache on the graph: " << &g; + g.GetConstArgIndicesCache() = *compile_time_const_arg_indices; + } return status; } diff --git a/tensorflow/compiler/tf2xla/const_analysis.h b/tensorflow/compiler/tf2xla/const_analysis.h index 587347ff8a5..ba5fa45fd9a 100644 --- a/tensorflow/compiler/tf2xla/const_analysis.h +++ b/tensorflow/compiler/tf2xla/const_analysis.h @@ -33,14 +33,13 @@ namespace tensorflow { // The ids of the nodes in `graph` that must be constant are returned in // `compile_time_const_nodes`, if `compile_time_const_nodes` is not null. // -// Only propagate const-ness along edges for which `edge_filter` returns true. +// If `edge_filter` is non-null, only propagate const-ness along edges for which +// `edge_filter` returns true. Status BackwardsConstAnalysis( const Graph& g, std::vector* compile_time_const_arg_indices, std::vector* compile_time_const_nodes, FunctionLibraryRuntime* flib_runtime, - std::function edge_filter = [](const Edge& e) { - return true; - }); + std::function edge_filter_input = nullptr); // Given an op kernel and function library runtime, return all the indices of // inputs that need to be compile time constant. diff --git a/tensorflow/compiler/tf2xla/const_analysis_test.cc b/tensorflow/compiler/tf2xla/const_analysis_test.cc index ed5f004550f..936b74f7b33 100644 --- a/tensorflow/compiler/tf2xla/const_analysis_test.cc +++ b/tensorflow/compiler/tf2xla/const_analysis_test.cc @@ -19,11 +19,14 @@ limitations under the License. #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/ops/function_ops.h" +#include "tensorflow/cc/ops/functional_ops.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/compiler/jit/xla_cluster_util.h" +#include "tensorflow/core/common_runtime/process_function_library_runtime.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" +#include "tensorflow/core/public/version.h" namespace tensorflow { namespace { @@ -89,6 +92,59 @@ TEST(ConstAnalysisTest, TopologicalOrder) { } } +void TestFunctionCall(bool is_stateful_partitioned_call) { + FunctionDef callee = FunctionDefHelper::Define( + "Callee", {"t:float", "shape:int32"}, {"result:float"}, {}, + {{{"result"}, "Reshape", {"t", "shape"}, {{"T", DT_FLOAT}}}}); + + FunctionDefLibrary flib; + *flib.add_function() = callee; + FunctionLibraryDefinition flib_def(OpRegistry::Global(), flib); + + Scope root = Scope::NewRootScope().ExitOnError(); + + auto arg0 = ops::_Arg(root.WithOpName("tensor"), DT_FLOAT, 0); + auto arg1 = ops::_Arg(root.WithOpName("shape"), DT_INT32, 1); + + NameAttrList call_attrs; + call_attrs.set_name("Callee"); + if (is_stateful_partitioned_call) { + ops::StatefulPartitionedCall b(root.WithOpName("Call"), + {Output(arg0), Output(arg1)}, {DT_FLOAT}, + call_attrs); + } else { + ops::PartitionedCall b(root.WithOpName("Call"), + {Output(arg0), Output(arg1)}, {DT_FLOAT}, + call_attrs); + } + + Graph graph(&flib_def); + TF_ASSERT_OK(root.ToGraph(&graph)); + + OptimizerOptions opts; + std::unique_ptr pflr( + new ProcessFunctionLibraryRuntime(nullptr, Env::Default(), + /*config=*/nullptr, + TF_GRAPH_DEF_VERSION, &flib_def, opts)); + FunctionLibraryRuntime* lib_runtime = + pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice); + + std::vector const_args(2, false); + TF_ASSERT_OK(BackwardsConstAnalysis(graph, &const_args, + /*compile_time_const_nodes=*/nullptr, + lib_runtime)); + + EXPECT_EQ(const_args, std::vector({false, true})); +} + +TEST(ConstAnalysisTest, PartitionedCall) { + TestFunctionCall(/*is_stateful_partitioned_call=*/false); +} + +TEST(ConstAnalysisTest, StatefulPartitionedCall) { + TestFunctionCall(/*is_stateful_partitioned_call=*/true); +} + TEST(ConstAnalysisTest, DontFollowControlDependencies) { Scope root = Scope::NewRootScope(); diff --git a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc index 888ccff7856..e9cd5d2744e 100644 --- a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc @@ -107,6 +107,11 @@ class ConvBackpropInputOp : public XlaOpKernel { OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &input_tensor_shape)); xla::Shape input_shape = TensorShapeToXLAShape(ctx->input_xla_type(1), input_tensor_shape); + OP_REQUIRES(ctx, input_shape.rank() == attrs_.num_spatial_dims + 2, + errors::InvalidArgument( + "The rank of the specified input shape must be " + "num_spatial_dims + 2. Expected ", + attrs_.num_spatial_dims + 2, " got ", input_shape.rank())); xla::StatusOr in_backprop = MakeXlaBackpropInputConvOp(ctx->op_kernel().type_string(), input_shape, diff --git a/tensorflow/compiler/tf2xla/kernels/slice_op.cc b/tensorflow/compiler/tf2xla/kernels/slice_op.cc index 1be651da470..17d0b87edda 100644 --- a/tensorflow/compiler/tf2xla/kernels/slice_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/slice_op.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/ops_util.h" #include "tensorflow/core/framework/register_types.h" @@ -58,18 +59,21 @@ class SliceOp : public XlaOpKernel { std::vector begin; std::vector size; OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(2, &size)); + std::vector wrapped_size(size.size()); if (ctx->ConstantInputAsIntVector(1, &begin).ok()) { // `begin` is a compile-time constant. for (int i = 0; i < input_dims; ++i) { if (size[i] == -1) { // A size[i] of -1 means "all elements from begin[i] to dim_size(i)". - size[i] = input_shape.dim_size(i) - begin[i]; + wrapped_size[i] = input_shape.dim_size(i) - begin[i]; + } else { + wrapped_size[i] = size[i]; } } for (int i = 0; i < input_dims; ++i) { int64 b = begin[i]; - int64 s = size[i]; + int64 s = wrapped_size[i]; if (input_shape.dim_size(i) == 0) { OP_REQUIRES(ctx, b == 0 && s == 0, errors::InvalidArgument( @@ -91,10 +95,28 @@ class SliceOp : public XlaOpKernel { std::vector limits; limits.reserve(begin.size()); for (int i = 0; i < begin.size(); ++i) { - limits.push_back(begin[i] + size[i]); + limits.push_back(begin[i] + wrapped_size[i]); } std::vector strides(begin.size(), 1); - ctx->SetOutput(0, xla::Slice(ctx->Input(0), begin, limits, strides)); + auto slice = xla::Slice(ctx->Input(0), begin, limits, strides); + // Check for slice on dynamic dimensions. + ctx->set_dynamic_dimension_is_minus_one(true); + std::vector dynamic_size; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(2, &dynamic_size)); + + for (int64 i = 0; i < size.size(); ++i) { + if (dynamic_size[i] == -1) { + if (size[i] != -1) { + // If there is a dynamic dimension, properly set dimension size of + // the slice. + auto dynamic_size = + xla::Reshape(xla::Slice(ctx->Input(2), {i}, {i + 1}, {1}), {}); + + slice = xla::SetDimensionSize(slice, dynamic_size, i); + } + } + } + ctx->SetOutput(0, slice); } else { // `begin` is not a compile-time constant. for (int i = 0; i < input_dims; ++i) { diff --git a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc index 05f1ee1797a..9093175af75 100644 --- a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/mem.h" namespace tensorflow { @@ -115,6 +116,72 @@ class StridedSliceOp : public XlaOpKernel { slice = xla::Rev(slice, dimensions_to_reverse); } slice = xla::Slice(slice, slice_begin, slice_end, slice_strides); + auto operand_shape_or = ctx->builder()->GetShape(ctx->Input(0)); + OP_REQUIRES_OK(ctx, operand_shape_or.status()); + xla::Shape xla_shape = operand_shape_or.ValueOrDie(); + if (xla_shape.is_static()) { + // Static output shape, return a static slice. + slice = xla::Reshape(slice, final_shape.dim_sizes()); + ctx->SetOutput(0, slice); + return; + } + auto input_dim_sizes = input_shape.dim_sizes(); + + for (int64 i = 0; i < xla_shape.rank(); ++i) { + if (xla_shape.is_dynamic_dimension(i)) { + input_dim_sizes[i] = -1; + } + } + PartialTensorShape input_partial_shape(input_dim_sizes); + partial_final_shape.Clear(); + end.clear(); + strides.clear(); + begin.clear(); + // Run shape inferenference again with partial shape. + OP_REQUIRES_OK(ctx, ValidateStridedSliceOp( + &begin_tensor, &end_tensor, strides_tensor, + input_partial_shape, begin_mask_, end_mask_, + ellipsis_mask_, new_axis_mask_, shrink_axis_mask_, + &dummy_processing_shape, &partial_final_shape, + &dummy, &dummy, &dummy, &begin, &end, &strides)); + if (partial_final_shape.AsTensorShape(&final_shape)) { + // Static output shape, return a static slice. + slice = xla::Reshape(slice, final_shape.dim_sizes()); + ctx->SetOutput(0, slice); + return; + } + + // We consider slicing a dynamic tensor t with negative indices as a + // dynamic sized slice. E.g., t[: -n], the result length is shape(t) - n + for (int64 i = 0; i < partial_final_shape.dims(); ++i) { + bool dynamic_dim = partial_final_shape.dim_size(i) - 1; + bool backward_slice = end[i] < 0; + if (dynamic_dim && backward_slice) { + OP_REQUIRES( + ctx, strides[i] == 1, + errors::InvalidArgument("XLA has not implemented dynamic " + "sized slice with non-trival stride yet. " + "Please file a bug against XLA")); + + OP_REQUIRES(ctx, begin[i] >= 0, + errors::InvalidArgument( + "XLA has not implemented dynamic " + "sized slice with negative begin index %lld. " + "Please file a bug against XLA", + begin[i])); + // If there is a dynamic dimension, properly set dimension size of + // the result. + auto operand_size = xla::GetDimensionSize(ctx->Input(0), i); + + operand_size = xla::Add( + operand_size, xla::ConstantR0(ctx->builder(), end[i])); + slice = xla::SetDimensionSize( + slice, + xla::Sub(operand_size, + xla::ConstantR0(ctx->builder(), begin[i])), + i); + } + } } else { // When output shape is fully defined, it must be a size one slice: // diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index 9b17ebe0260..85f2d5c1fc6 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -138,46 +138,6 @@ Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr graph, return Status::OK(); } -// There is a shape_representation_fn or sharding for an output, this function -// uses a reshape to fix the layout. -xla::StatusOr ReshapeWithCorrectRepresentationAndSharding( - xla::XlaBuilder* builder, xla::XlaOp original, xla::Shape original_shape, - XlaCompiler::ShapeRepresentationFn shape_representation_fn, - absl::optional sharding, bool fast_mem) { - if (original_shape.IsTuple()) { - std::vector elements; - for (int64 i = 0; i < original_shape.tuple_shapes_size(); ++i) { - auto subsharding = sharding ? sharding->tuple_shardings(i) : sharding; - TF_ASSIGN_OR_RETURN(auto element, - ReshapeWithCorrectRepresentationAndSharding( - builder, xla::GetTupleElement(original, i), - original_shape.tuple_shapes(i), - shape_representation_fn, subsharding, fast_mem)); - elements.push_back(element); - } - return xla::Tuple(builder, elements); - } - if (!original_shape.IsArray()) return original; - TensorShape shape; - TF_RETURN_IF_ERROR(XLAShapeToTensorShape(original_shape, &shape)); - TF_ASSIGN_OR_RETURN(DataType dtype, EncodePrimitiveTypeAsDataType( - original_shape.element_type())); - TF_ASSIGN_OR_RETURN(auto to_shape, - shape_representation_fn(shape, dtype, fast_mem)); - if (sharding) { - TF_ASSIGN_OR_RETURN(auto hlo_sharding, - xla::HloSharding::FromProto(*sharding)); - TF_RETURN_IF_ERROR(RewriteLayoutWithShardedShape( - hlo_sharding, fast_mem, shape_representation_fn, &to_shape)); - } - if (xla::ShapeUtil::Compatible(original_shape, to_shape)) { - for (int64 i = 0; i < original_shape.rank(); ++i) { - to_shape.set_dynamic_dimension(i, original_shape.is_dynamic_dimension(i)); - } - } - return xla::Reshape(to_shape, original); -} - // Builds the XLA computation. // - `args` is the list of input arguments // - `retvals` is the list of retvals produced by _Retval operators, in index @@ -562,13 +522,7 @@ XlaCompiler::XlaCompiler(XlaCompiler::Options options) // The default shape representation function is the identity. if (!options_.shape_representation_fn) { - options_.shape_representation_fn = - [](const TensorShape& shape, DataType dtype, - bool use_fast_memory) -> xla::StatusOr { - xla::Shape xla_shape; - TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype, shape, &xla_shape)); - return xla_shape; - }; + options_.shape_representation_fn = IdentityShapeRepresentationFn(); } } @@ -1502,6 +1456,15 @@ xla::StatusOr XlaCompiler::GetNodeToken(const string& node_name) { return iter->second; } +XlaCompiler::ShapeRepresentationFn IdentityShapeRepresentationFn() { + return [](const TensorShape& shape, DataType dtype, + bool use_fast_memory) -> xla::StatusOr { + xla::Shape xla_shape; + TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype, shape, &xla_shape)); + return xla_shape; + }; +} + // Rewrites the layout of xla_shape if there is tiled sharding. Status RewriteLayoutWithShardedShape( const absl::optional& sharding, bool use_fast_memory, @@ -1542,4 +1505,44 @@ Status RewriteLayoutWithShardedShape( return Status::OK(); } +// There is a shape_representation_fn or sharding for an output, this function +// uses a reshape to fix the layout. +xla::StatusOr ReshapeWithCorrectRepresentationAndSharding( + xla::XlaBuilder* builder, xla::XlaOp original, xla::Shape original_shape, + XlaCompiler::ShapeRepresentationFn shape_representation_fn, + absl::optional sharding, bool fast_mem) { + if (original_shape.IsTuple()) { + std::vector elements; + for (int64 i = 0; i < original_shape.tuple_shapes_size(); ++i) { + auto subsharding = sharding ? sharding->tuple_shardings(i) : sharding; + TF_ASSIGN_OR_RETURN(auto element, + ReshapeWithCorrectRepresentationAndSharding( + builder, xla::GetTupleElement(original, i), + original_shape.tuple_shapes(i), + shape_representation_fn, subsharding, fast_mem)); + elements.push_back(element); + } + return xla::Tuple(builder, elements); + } + if (!original_shape.IsArray()) return original; + TensorShape shape; + TF_RETURN_IF_ERROR(XLAShapeToTensorShape(original_shape, &shape)); + TF_ASSIGN_OR_RETURN(DataType dtype, EncodePrimitiveTypeAsDataType( + original_shape.element_type())); + TF_ASSIGN_OR_RETURN(auto to_shape, + shape_representation_fn(shape, dtype, fast_mem)); + if (sharding) { + TF_ASSIGN_OR_RETURN(auto hlo_sharding, + xla::HloSharding::FromProto(*sharding)); + TF_RETURN_IF_ERROR(RewriteLayoutWithShardedShape( + hlo_sharding, fast_mem, shape_representation_fn, &to_shape)); + } + if (xla::ShapeUtil::Compatible(original_shape, to_shape)) { + for (int64 i = 0; i < original_shape.rank(); ++i) { + to_shape.set_dynamic_dimension(i, original_shape.is_dynamic_dimension(i)); + } + } + return xla::Reshape(to_shape, original); +} + } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h index d67b1f26696..b95d250636a 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.h +++ b/tensorflow/compiler/tf2xla/xla_compiler.h @@ -518,12 +518,22 @@ class XlaCompiler { TF_DISALLOW_COPY_AND_ASSIGN(XlaCompiler); }; +// Creates an identity shape representation function. +XlaCompiler::ShapeRepresentationFn IdentityShapeRepresentationFn(); + // Rewrites the layout of xla_shape if there is tiled sharding. Status RewriteLayoutWithShardedShape( const absl::optional& sharding, bool use_fast_memory, XlaCompiler::ShapeRepresentationFn shape_representation_fn, xla::Shape* xla_shape); +// Adds reshapes to fix the layout of an output, if a shape_representation_fn or +// sharding is present. +xla::StatusOr ReshapeWithCorrectRepresentationAndSharding( + xla::XlaBuilder* builder, xla::XlaOp original, xla::Shape original_shape, + XlaCompiler::ShapeRepresentationFn shape_representation_fn, + absl::optional sharding, bool fast_mem); + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILER_H_ diff --git a/tensorflow/compiler/xla/client/lib/prng.cc b/tensorflow/compiler/xla/client/lib/prng.cc index 17fb4c3c369..044a742eddd 100644 --- a/tensorflow/compiler/xla/client/lib/prng.cc +++ b/tensorflow/compiler/xla/client/lib/prng.cc @@ -434,17 +434,21 @@ XlaOp ConvertRandomBitsToUniformFloatingPoint(XlaOp bits, XlaOp minval, (value_type == F64 && bit_type == U64)); // Form random mantissa bits for float/double, with a leading 1 bit. - int float_bits = primitive_util::BitWidth(value_type); + int num_float_bits = primitive_util::BitWidth(value_type); // Subtract one as SignificandWidth includes the leading 1 bit. - int mantissa_bits = primitive_util::SignificandWidth(value_type) - 1; + int num_mantissa_bits = primitive_util::SignificandWidth(value_type) - 1; - bits = ShiftRightLogical(bits, ScalarLike(bits, float_bits - mantissa_bits)) | - BitcastConvertType(ScalarLike(minval, 1.0), bit_type); - XlaOp values = BitcastConvertType(bits, value_type); + // Ignore the exponent bits and convert the mantissa bits to the floating + // point type. + bits = ShiftRightLogical( + bits, ScalarLike(bits, num_float_bits - num_mantissa_bits)); - // We have a floating point number in the range [1.0, 2.0). - // Subtract 1.0f to shift to the range [0.0, 1.0) - values = values - ScalarLike(values, 1.0); + // We have an integer-valued floating point number in the range + // [0, 2**{num_mantissa_bits}). + XlaOp values = ConvertElementType(bits, value_type); + + // Divide by 2**{-num_mantissa_bits} to get a number in the range [0.0, 1.0). + values = values * ScalarLike(values, std::ldexp(1., -num_mantissa_bits)); // Multiply and add to shift to the range [minval, maxval). return values * (maxval - minval) + minval; diff --git a/tensorflow/compiler/xla/python/local_client.cc b/tensorflow/compiler/xla/python/local_client.cc index c721f3bea8b..128661ae8bd 100644 --- a/tensorflow/compiler/xla/python/local_client.cc +++ b/tensorflow/compiler/xla/python/local_client.cc @@ -636,8 +636,8 @@ StatusOr> PyLocalBuffer::CopyToDevice( } Status PyLocalBuffer::CopyToRemoteDevice( - absl::string_view serialized_descriptor, Device* dst_device) { - return client_->CopyToRemoteDevice(this, serialized_descriptor, dst_device); + absl::string_view serialized_descriptor) { + return client_->CopyToRemoteDevice(this, serialized_descriptor); } Status PyLocalBuffer::BlockHostUntilReady() { @@ -667,10 +667,11 @@ static Device* LookupDevice(const PyLocalClient& client, int device_id) { PyLocalExecutable::PyLocalExecutable( std::vector> executables, - DeviceAssignment device_assignment, PyLocalClient* client) + bool tuple_arguments, DeviceAssignment device_assignment, + PyLocalClient* client) : client_(client), - device_assignment_( - std::make_shared(device_assignment)) { + device_assignment_(std::make_shared(device_assignment)), + tuple_arguments_(tuple_arguments) { executables_.reserve(executables.size()); for (auto& executable : executables) { executables_.emplace_back(std::move(executable)); @@ -727,7 +728,7 @@ PyLocalExecutable::ExecuteHelper( std::unique_ptr tuple_buffer; std::vector tupled_arguments; - if (options.tuple_arguments) { + if (tuple_arguments_) { TF_ASSIGN_OR_RETURN(tuple_buffer, PyLocalBuffer::MakeTuple( argument_handles, client_, device)); tupled_arguments = {tuple_buffer.get()}; @@ -1037,7 +1038,8 @@ PyLocalExecutable::Compile(const XlaComputation& computation, build_options)); return absl::make_unique( - std::move(local_executables), build_options.device_assignment(), client); + std::move(local_executables), options.tuple_arguments, + build_options.device_assignment(), client); } } // namespace xla diff --git a/tensorflow/compiler/xla/python/local_client.h b/tensorflow/compiler/xla/python/local_client.h index 401064af77c..c9b50fbbbef 100644 --- a/tensorflow/compiler/xla/python/local_client.h +++ b/tensorflow/compiler/xla/python/local_client.h @@ -159,9 +159,8 @@ class PyLocalClient : public std::enable_shared_from_this { notifier(Unimplemented("Cross host receives not implemented.")); } - virtual Status CopyToRemoteDevice(PyLocalBuffer* buffer, - absl::string_view serialized_descriptor, - Device* device) const { + virtual Status CopyToRemoteDevice( + PyLocalBuffer* buffer, absl::string_view serialized_descriptor) const { return Unimplemented("Cross host sends not implemented."); } @@ -275,17 +274,16 @@ class PyLocalBuffer { // Copies the buffer to device `dst_device`. StatusOr> CopyToDevice(Device* dst_device); - // Copies the buffer to remote device `dst_device`. This call must be preceded - // by a call to MakeCrossHostReceiveBuffers on the remote host's - // dst_device. MakeCrossHostReceiveBuffers takes an array of shapes to - // construct the destination buffers, and a callback supplies an array - // containing both the destination buffers, and a serialized descriptor for - // each buffer. For each destination buffer there should be a matching call to - // src->CopyToRemoteDevice on a remote host for a src buffer of the - // corresponding shape. serialized_descriptor is the string returned by the - // callback along with the corresponding destination buffer. - Status CopyToRemoteDevice(absl::string_view serialized_descriptor, - Device* dst_device); + // Copies the buffer to the remote device encoded in serialized_descriptor. + // This call must be preceded by a call to MakeCrossHostReceiveBuffers on the + // remote host's destination device. MakeCrossHostReceiveBuffers takes an + // array of shapes to construct the destination buffers, and a callback + // supplies an array containing both the destination buffers, and a serialized + // descriptor for each buffer. For each destination buffer there should be a + // matching call to src->CopyToRemoteDevice on a remote host for a src buffer + // of the corresponding shape. serialized_descriptor is the string returned by + // the callback along with the corresponding destination buffer. + Status CopyToRemoteDevice(absl::string_view serialized_descriptor); // Blocks the host until the buffer's value has been computed and is ready for // immediate use on the device. Useful in particular for timing benchmarks. @@ -316,15 +314,15 @@ struct CompileOptions { // The layouts of the arguments that the computation should expect. absl::optional> argument_layouts; + // If true, the arguments to the computation will be wrapped in a tuple and + // passed as a single parameter. + bool tuple_arguments = false; + // XLA's compilation time options. ExecutableBuildOptions executable_build_options; }; struct ExecuteOptions { - // If true, the arguments to the computation will be wrapped in a tuple and - // passed as a single parameter. - bool tuple_arguments = false; - // If true, the computation must return a tuple, which will be destructured // into its elements. bool untuple_result = false; @@ -340,7 +338,8 @@ class PyLocalExecutable { CompileOptions options); PyLocalExecutable(std::vector> executables, - DeviceAssignment device_assignment, PyLocalClient* client); + bool tuple_arguments, DeviceAssignment device_assignment, + PyLocalClient* client); PyLocalClient* client() const { return client_; } @@ -404,6 +403,10 @@ class PyLocalExecutable { std::vector> executables_; std::shared_ptr device_assignment_; + // True if the executables were compiled expecting arguments in a single + // tuple. + const bool tuple_arguments_; + // The replica and partition indices of device_assignment_ to be run by this // client. On single-host platforms without partitioning, this is all replicas // (i.e. local_logical_device_ids_[i] = (i, 0)), but this may not be the case diff --git a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.cc b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.cc index 56ac640cb6c..f062afc48a4 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.cc +++ b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.cc @@ -503,9 +503,10 @@ static std::shared_ptr LookupDevice(const PyTpuClient& client, PyTpuExecutable::PyTpuExecutable( std::unique_ptr compiled_program, DeviceAssignment device_assignment, std::shared_ptr client, - xla::Shape result_shape) + xla::Shape result_shape, bool tuple_arguments) : client_(std::move(client)), device_assignment_(std::move(device_assignment)), + tuple_arguments_(tuple_arguments), result_shape_(std::move(result_shape)) { VLOG(1) << "DeviceAssignment. " << device_assignment_.ToString(); const int num_replicas = device_assignment_.replica_count(); @@ -612,7 +613,7 @@ Status WaitForExecuteEvent(tpu_driver::Event* event) { } StatusOr>> PyTpuExecutable::Execute( - absl::Span argument_handles, bool tuple_arguments) { + absl::Span argument_handles) { if (num_replicas() != 1) { return InvalidArgument( "Attempted to execute computation with %d replicas using Execute().", @@ -627,7 +628,7 @@ StatusOr>> PyTpuExecutable::Execute( std::vector all_core_arguments; std::unique_ptr tupled_arguments; - if (tuple_arguments) { + if (tuple_arguments_) { TF_ASSIGN_OR_RETURN(tupled_arguments, PyTpuBuffer::MakeTuple(argument_handles, client_, local_devices_.front())); @@ -658,8 +659,7 @@ StatusOr>> PyTpuExecutable::Execute( StatusOr>>> PyTpuExecutable::ExecuteOnLocalDevices( - absl::Span> argument_handles, - bool tuple_arguments) { + absl::Span> argument_handles) { tensorflow::profiler::TraceMe traceme( "PyTpuExecutable::ExecuteOnLocalDevices"); @@ -679,7 +679,7 @@ PyTpuExecutable::ExecuteOnLocalDevices( std::vector> tupled_arguments; std::vector> tupled_argument_pointers; - if (tuple_arguments) { + if (tuple_arguments_) { tupled_arguments.resize(argument_handles.size()); tupled_argument_pointers.resize(argument_handles.size()); for (int i = 0; i < num_local_devices; ++i) { @@ -750,7 +750,7 @@ PyTpuExecutable::ExecuteOnLocalDevices( absl::optional> argument_layouts, const ExecutableBuildOptions* build_options, std::shared_ptr client, - absl::optional device_assignment) { + absl::optional device_assignment, bool tuple_arguments) { tensorflow::profiler::TraceMe traceme("PyTpuExecutable::Compile"); VLOG(1) << "Compile: " @@ -814,7 +814,7 @@ PyTpuExecutable::ExecuteOnLocalDevices( return absl::make_unique( std::move(compiled_program), std::move(*device_assignment), - std::move(client), std::move(result_layout)); + std::move(client), std::move(result_layout), tuple_arguments); } } // namespace xla diff --git a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.h b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.h index 2b1ac4a3044..f30ce4fda17 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.h +++ b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.h @@ -268,12 +268,12 @@ class PyTpuExecutable { absl::optional> argument_layouts, const ExecutableBuildOptions* build_options, std::shared_ptr client, - absl::optional device_assignment); + absl::optional device_assignment, bool tuple_arguments); PyTpuExecutable( std::unique_ptr compiled_program, DeviceAssignment device_assignment, std::shared_ptr client, - xla::Shape result_shape); + xla::Shape result_shape, bool tuple_arguments); virtual ~PyTpuExecutable() { for (auto it = executables_.begin(); it != executables_.end(); ++it) { client_->driver()->UnloadProgram(std::move(it->second), {}); @@ -309,7 +309,7 @@ class PyTpuExecutable { // inside for computation to finish. Coordinate with JAX code change to see if // we can make both Execute and ExecutePerReplica non-blocking. StatusOr>> Execute( - absl::Span argument_handles, bool tuple_arguments); + absl::Span argument_handles); // Execute on local devices. Takes a sequence of argument lists (one argument // list per local device) and returns a tuple of results (one result per local @@ -317,8 +317,7 @@ class PyTpuExecutable { // count. StatusOr>>> ExecuteOnLocalDevices( - absl::Span> argument_handles, - bool tuple_arguments); + absl::Span> argument_handles); void Delete() { executables_.clear(); } @@ -336,6 +335,7 @@ class PyTpuExecutable { std::shared_ptr const client_; std::map> executables_; const DeviceAssignment device_assignment_; + const bool tuple_arguments_; // The replica and partition indices of device_assignment_ to be run by this // client. On single-host platforms without partitioning, this is all replicas diff --git a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.py b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.py index 2c4be78c9c5..89338934904 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.py +++ b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.py @@ -100,7 +100,8 @@ class TpuBackend(xla_client.Backend): return _tpu_client.TpuExecutable.Compile(c_computation, compile_options.argument_layouts, options, self.client, - compile_options.device_assignment) + compile_options.device_assignment, + compile_options.tuple_arguments) def get_default_device_assignment(self, num_replicas, num_partitions=None): if num_partitions is not None: diff --git a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client_extension.cc b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client_extension.cc index 752ea4c4907..88d17cb8e2a 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client_extension.cc +++ b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client_extension.cc @@ -167,7 +167,8 @@ PYBIND11_MODULE(tpu_client_extension, m) { const ExecutableBuildOptions* build_options, std::shared_ptr client, absl::optional>> - device_assignment) + device_assignment, + bool tuple_arguments) -> StatusOr> { py::gil_scoped_release gil_release; absl::optional xla_device_assignment; @@ -178,7 +179,7 @@ PYBIND11_MODULE(tpu_client_extension, m) { } return PyTpuExecutable::Compile( computation, argument_layouts, build_options, client, - std::move(xla_device_assignment)); + std::move(xla_device_assignment), tuple_arguments); }) .def("local_logical_device_ids", &PyTpuExecutable::local_logical_device_ids) @@ -187,11 +188,9 @@ PYBIND11_MODULE(tpu_client_extension, m) { &PyTpuExecutable::SizeOfGeneratedCodeInBytes) .def("Delete", &PyTpuExecutable::Delete) .def("Execute", &PyTpuExecutable::Execute, - py::call_guard(), py::arg("arguments"), - py::arg("tuple_arguments")) + py::call_guard(), py::arg("arguments")) .def("ExecuteOnLocalDevices", &PyTpuExecutable::ExecuteOnLocalDevices, - py::call_guard(), py::arg("arguments"), - py::arg("tuple_arguments")); + py::call_guard(), py::arg("arguments")); py::class_>(m, "TpuDevice") .def_property_readonly("coords", &TpuDevice::coords) diff --git a/tensorflow/compiler/xla/python/xla.cc b/tensorflow/compiler/xla/python/xla.cc index 60952c393ab..2affd4b30fa 100644 --- a/tensorflow/compiler/xla/python/xla.cc +++ b/tensorflow/compiler/xla/python/xla.cc @@ -1040,7 +1040,8 @@ PYBIND11_MODULE(xla_extension, m) { absl::optional> argument_layouts, const ExecutableBuildOptions* build_options, std::shared_ptr client, - absl::optional device_assignment) + absl::optional device_assignment, + bool tuple_arguments) -> StatusOr> { py::gil_scoped_release gil_release; CompileOptions options; @@ -1048,6 +1049,7 @@ PYBIND11_MODULE(xla_extension, m) { if (build_options) { options.executable_build_options = *build_options; } + options.tuple_arguments = tuple_arguments; if (device_assignment) { options.executable_build_options.set_device_assignment( *device_assignment); @@ -1065,7 +1067,8 @@ PYBIND11_MODULE(xla_extension, m) { const ExecutableBuildOptions* build_options, std::shared_ptr client, absl::optional>> - device_assignment) + device_assignment, + bool tuple_arguments) -> StatusOr> { py::gil_scoped_release gil_release; CompileOptions options; @@ -1073,6 +1076,7 @@ PYBIND11_MODULE(xla_extension, m) { if (build_options) { options.executable_build_options = *build_options; } + options.tuple_arguments = tuple_arguments; if (device_assignment) { TF_ASSIGN_OR_RETURN( DeviceAssignment xla_assignment, @@ -1105,11 +1109,10 @@ PYBIND11_MODULE(xla_extension, m) { .def( "Execute", [](const PyLocalExecutable& executable, - absl::Span args, bool tuple_arguments) + absl::Span args) -> StatusOr>> { py::gil_scoped_release gil_release; ExecuteOptions options; - options.tuple_arguments = tuple_arguments; options.untuple_result = true; TF_ASSIGN_OR_RETURN( std::vector> output_buffers, @@ -1122,17 +1125,15 @@ PYBIND11_MODULE(xla_extension, m) { } return outputs; }, - py::arg("arguments"), py::arg("tuple_arguments")) + py::arg("arguments")) .def( "ExecuteOnLocalDevices", [](const PyLocalExecutable& executable, - absl::Span> args, - bool tuple_arguments) + absl::Span> args) -> StatusOr< std::vector>>> { py::gil_scoped_release gil_release; ExecuteOptions options; - options.tuple_arguments = tuple_arguments; options.untuple_result = true; TF_ASSIGN_OR_RETURN( std::vector>> @@ -1150,7 +1151,7 @@ PYBIND11_MODULE(xla_extension, m) { } return outputs; }, - py::arg("arguments"), py::arg("tuple_arguments")) + py::arg("arguments")) .def( "get_hlo_modules", [](const PyLocalExecutable& executable) diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index a7e8903b113..a52fa7545f1 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -147,7 +147,8 @@ class LocalBackend(Backend): return _xla.LocalExecutable.Compile(c_computation, compile_options.argument_layouts, options, self.client, - compile_options.device_assignment) + compile_options.device_assignment, + compile_options.tuple_arguments) def get_default_device_assignment(self, num_replicas, num_partitions=None): if num_partitions is not None: @@ -504,6 +505,7 @@ class CompileOptions(object): self.argument_layouts = None self.result_layout = None self.device_assignment = None + self.tuple_arguments = False class Computation(object): @@ -613,7 +615,7 @@ def execute_with_python_values(executable, arguments=(), backend=None): arg, device=executable.local_devices()[0], backend=backend) arguments = [put(arg) for arg in arguments] - outputs = executable.Execute(arguments, tuple_arguments=False) + outputs = executable.Execute(arguments) return [x.to_py() for x in outputs] @@ -642,8 +644,9 @@ def execute_with_python_values_replicated(executable, arguments, backend=None): for replica_args in arguments: arg_buffers.append(flat_arg_buffers[:len(replica_args)]) flat_arg_buffers = flat_arg_buffers[len(replica_args):] - return [[x.to_py() for x in xs] for xs in executable.ExecuteOnLocalDevices( - arg_buffers, tuple_arguments=False)] + return [[x.to_py() + for x in xs] + for xs in executable.ExecuteOnLocalDevices(arg_buffers)] class PaddingType(enum.Enum): @@ -868,7 +871,7 @@ class ComputationBuilder(object): shape, name=None, parameter_num=None, - replicated=False): + replicated=None): """Enqueues a Parameter op onto the computation, given a shape. Args: @@ -880,6 +883,7 @@ class ComputationBuilder(object): parameters, use it for *all* parameters to avoid clashes. replicated: whether to mark the parameter's leaves as replicated. May be a bool, in which case it applies to all leaves, or an iterable of bools. + The default is None, which means no replication annotation. Returns: An XlaOp. @@ -888,7 +892,9 @@ class ComputationBuilder(object): name = '' if parameter_num is None: parameter_num = next(self._parameter_numbering) - if isinstance(replicated, bool): + if replicated is None: + replicated = [] + elif isinstance(replicated, bool): replicated = [replicated] * shape.leaf_count() return ops.Parameter(self._builder, parameter_num, diff --git a/tensorflow/compiler/xla/python/xla_client_test.py b/tensorflow/compiler/xla/python/xla_client_test.py index 848e8c881d2..36d5da2841b 100644 --- a/tensorflow/compiler/xla/python/xla_client_test.py +++ b/tensorflow/compiler/xla/python/xla_client_test.py @@ -494,7 +494,7 @@ class BufferTest(ComputationTest): arg_buffer = xla_client.Buffer.from_pyval(arg) arg_buffer.delete() with self.assertRaises(RuntimeError): - compiled_c.Execute([arg_buffer], tuple_arguments=False) + compiled_c.Execute([arg_buffer]) def testShape(self): pyval = np.array([[1., 2.]], np.float32) @@ -1903,8 +1903,7 @@ class EmbeddedComputationsTest(ComputationTest): compiled_c = c.Build().Compile() for want in to_round_trip: - execution = threading.Thread( - target=lambda: compiled_c.Execute([], tuple_arguments=False)) + execution = threading.Thread(target=lambda: compiled_c.Execute([])) execution.start() xla_client.transfer_to_infeed(want) got = xla_client.transfer_from_outfeed( diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 6d470149ca8..5faf58f0c22 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -1117,6 +1117,8 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto_cc", + "//tensorflow/compiler/xla/service:executable", + "//tensorflow/compiler/xla/service:maybe_owning_device_memory", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "@com_google_absl//absl/memory", @@ -2228,6 +2230,7 @@ cc_library( ":while_loop_analysis", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", @@ -2404,6 +2407,7 @@ cc_library( "//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:window_util", "//tensorflow/core/platform:macros", "@com_google_absl//absl/container:flat_hash_map", diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc index e8fabc1d8f7..3e9daa96150 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment.cc @@ -1008,7 +1008,22 @@ bool BufferAssigner::MaybeAssignBuffer(BufferAllocation* allocation, Status BufferAssigner::MergeInplaceOpBuffers(BufferAssignment* assignment) { // Try allocate same buffer for dynamic update slice's operand and output. - // + + // If memory_space_assignment is run and there is information about a color in + // preset assignments, don't merge those buffers. We expect + // memory_space_assignment to have merged these buffers. If + // memory_space_assignment didn't merge these buffers and have assigned + // different offsets to the operand and the output buffer, merging the buffers + // can cause memory corruption if memory_space_assignment assigned a different + // buffer at the same offset. + absl::flat_hash_set excluded_colors; + if (preset_assignments_) { + for (const auto& color_and_info : + preset_assignments_->assignment_informations()) { + excluded_colors.insert(color_and_info.first); + } + } + // TODO(yunxing): Moving this logic to alias analysis and add must-alias rule // to operations that can be done in place. for (HloComputation* computation : assignment->module().computations()) { @@ -1039,6 +1054,13 @@ Status BufferAssigner::MergeInplaceOpBuffers(BufferAssignment* assignment) { assignment->alias_analysis().GetUniqueBufferAt( instruction->operand(0), {}); + // The instruction or operand color is excluded because it was assigned by + // memory_space_assignment. + if (excluded_colors.contains(instruction_buffer.color().value()) || + excluded_colors.contains(operand_buffer.color().value())) { + continue; + } + // Already have the same buffer. No need to merge those. if (instruction_buffer.id() == operand_buffer.id()) { continue; diff --git a/tensorflow/compiler/xla/service/collective_ops_utils.cc b/tensorflow/compiler/xla/service/collective_ops_utils.cc index a4eba334f31..126c3e33832 100644 --- a/tensorflow/compiler/xla/service/collective_ops_utils.cc +++ b/tensorflow/compiler/xla/service/collective_ops_utils.cc @@ -30,13 +30,18 @@ absl::optional MatchReductionComputation( .WithShape(m::Shape().IsEffectiveScalar())); }; + // Match the operation to a reduction kind. We can represent and/or of pred as + // min/max. This works because pred is stored as an 8-bit int of value 0 or 1. + PrimitiveType type = computation->root_instruction()->shape().element_type(); if (match_opcode(HloOpcode::kAdd)) { return ReductionKind::SUM; } else if (match_opcode(HloOpcode::kMultiply)) { return ReductionKind::PRODUCT; - } else if (match_opcode(HloOpcode::kMinimum)) { + } else if (match_opcode(HloOpcode::kMinimum) || + (type == PRED && match_opcode(HloOpcode::kAnd))) { return ReductionKind::MIN; - } else if (match_opcode(HloOpcode::kMaximum)) { + } else if (match_opcode(HloOpcode::kMaximum) || + (type == PRED && match_opcode(HloOpcode::kOr))) { return ReductionKind::MAX; } else { return absl::nullopt; diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc index 60e184411e9..d2e279f5e21 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc @@ -278,6 +278,7 @@ class CpuAllReduceRendezvous : public xla::Rendezvous { case xla::S8: DoAllReduce(participant); break; + case xla::PRED: case xla::U8: DoAllReduce(participant); break; diff --git a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h index 361d4b9c842..d933380442f 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h @@ -44,6 +44,12 @@ class CpuTransferManager : public GenericTransferManager { const Shape& literal_shape, MutableBorrowingLiteral literal) override; + bool CanShapedBufferBeAccessedNow( + se::StreamExecutor* executor, + const ShapedBuffer& device_buffer) const override { + return true; + } + private: Status TransferBufferToInfeed(se::StreamExecutor* executor, int64 size, const void* source); diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index 3f5eb3bdf0c..d2b2b4534a5 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -1416,6 +1416,7 @@ Status IrEmitter::HandleAllReduceMultipleReplica(HloInstruction* crs) { bool is_datatype_supported = [&] { // TODO(cheshire): Fix duplication wrt. cpu_runtime switch (datatype) { + case PRED: case S8: case U8: case S32: diff --git a/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc b/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc index 34d144ea1e9..94815e2fdbc 100644 --- a/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc +++ b/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/while_util.h" +#include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/window_util.h" namespace xla { @@ -250,15 +251,25 @@ Status DynamicDimensionInferenceVisitor::HandlePad(HloInstruction* hlo) { } const PaddingConfig_PaddingConfigDimension& padding_config = hlo->padding_config().dimensions(dimension); - if (padding_config.interior_padding() == 0 && - padding_config.edge_padding_low() == 0 && - padding_config.edge_padding_high() == 0) { - parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size, constraint); + if (padding_config.interior_padding() == 0) { + HloInstruction* dynamic_size_adjusted = dynamic_size; + HloInstruction* adjustment = hlo->parent()->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0( + padding_config.edge_padding_low() + + padding_config.edge_padding_high()))); + dynamic_size_adjusted = + hlo->parent()->AddInstruction(HloInstruction::CreateBinary( + dynamic_size_adjusted->shape(), HloOpcode::kAdd, + dynamic_size_adjusted, adjustment)); + parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size_adjusted, + constraint); return Status::OK(); } else { return Unimplemented( - "Dynamic dimension propagation on padding dimension is not " - "supported."); + "Dynamic dimension propagation on interio padding dimension is " + "not " + "supported: %s", + hlo->ToString()); } }); } @@ -400,11 +411,19 @@ Status DynamicDimensionInferenceVisitor::HandleDot(HloInstruction* hlo) { Status DynamicDimensionInferenceVisitor::HandleTranspose(HloInstruction* hlo) { return ForEachOperandDynamicDimension( - hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension, - int64 operand_index, HloInstruction* dynamic_size, - DimensionConstraint constraint) { - parent_->SetDynamicSize(hlo, {}, hlo->dimensions()[dimension], - dynamic_size, constraint); + hlo, + [&](HloInstruction* operand, ShapeIndex index, int64 dimension, + int64 operand_index, HloInstruction* dynamic_size, + DimensionConstraint constraint) -> Status { + int64 permuted_dim = -1; + for (int64 i = 0; i < hlo->dimensions().size(); ++i) { + if (hlo->dimensions()[i] == dimension) { + TF_RET_CHECK(permuted_dim == -1); + permuted_dim = i; + } + } + parent_->SetDynamicSize(hlo, {}, permuted_dim, dynamic_size, + constraint); return Status::OK(); }); } @@ -979,14 +998,8 @@ Status DynamicDimensionInferenceVisitor::HandleSlice(HloInstruction* hlo) { hlo->slice_strides(dimension) != 1 || hlo->slice_limits(dimension) != operand->shape().dimensions(dimension)) { - // Slicing a single element out eliminates the dynamic dimension. - if (hlo->shape().dimensions(dimension) == 1) { - return Status::OK(); - } - return Unimplemented( - "Dynamic dimension propagation on Slice where it doesn't slice " - "out an entire dimension is not supported %s", - hlo->ToString()); + // Slicing a partial element out eliminates the dynamic dimension. + return Status::OK(); } parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size, constraint); diff --git a/tensorflow/compiler/xla/service/dynamic_dimension_inference_test.cc b/tensorflow/compiler/xla/service/dynamic_dimension_inference_test.cc index d2913f9d2a1..dbe57985fd4 100644 --- a/tensorflow/compiler/xla/service/dynamic_dimension_inference_test.cc +++ b/tensorflow/compiler/xla/service/dynamic_dimension_inference_test.cc @@ -386,6 +386,53 @@ TEST_F(DynamicDimensionInferenceTest, DotTestBatch) { EXPECT_EQ(inference_->GetDynamicSize(dot, {}, 3), nullptr); } +TEST_F(DynamicDimensionInferenceTest, DotTestMultiContracting) { + auto builder = HloComputation::Builder(TestName()); + auto lhs_shape = ShapeUtil::MakeShape(F32, {2, 2, 8, 64}); + auto rhs_shape = ShapeUtil::MakeShape(F32, {2, 2, 512}); + auto output_shape = ShapeUtil::MakeShape(F32, {8, 64, 512}); + + auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, lhs_shape, "A")); + auto* b_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, rhs_shape, "B")); + builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/2, scalar_shape_, "size_param")); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(0); + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + dot_dnums.add_rhs_contracting_dimensions(1); + auto dot = builder.AddInstruction( + HloInstruction::CreateDot(output_shape, a_param, b_param, dot_dnums, + HloTestBase::DefaultPrecisionConfig(2))); + + module_->AddEntryComputation(builder.Build()); + + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{2, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 0})); + + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{2, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 1})); + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{2, {}}, + DynamicParameterBinding::DynamicDimension{1, {}, 0})); + + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{2, {}}, + DynamicParameterBinding::DynamicDimension{1, {}, 1})); + + SCOPED_TRACE(module_->ToString()); + TF_ASSERT_OK(RunInference()); + // Nothing is dynamic in the output. + EXPECT_EQ(inference_->GetDynamicSize(dot, {}, 0), nullptr); + EXPECT_EQ(inference_->GetDynamicSize(dot, {}, 1), nullptr); + EXPECT_EQ(inference_->GetDynamicSize(dot, {}, 2), nullptr); +} + TEST_F(DynamicDimensionInferenceTest, ConvolutionTest) { auto builder = HloComputation::Builder(TestName()); constexpr int xdim = 3; @@ -474,6 +521,45 @@ TEST_F(DynamicDimensionInferenceTest, TransposeTest) { EXPECT_EQ(inference_->GetDynamicSize(transpose, {}, 2), size_param_1); } +TEST_F(DynamicDimensionInferenceTest, NonDescendingTransposeTest) { + // Test the ability to trace unmodified dimensions + auto builder = HloComputation::Builder(TestName()); + auto input_shape = ShapeUtil::MakeShape(F32, {1, 2, 3}); + auto output_shape = ShapeUtil::MakeShape(F32, {3, 1, 2}); + + auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, input_shape, "A")); + auto* size_param_1 = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, scalar_shape_, "size_param")); + auto* size_param_2 = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/2, scalar_shape_, "size_param")); + auto* size_param_3 = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/3, scalar_shape_, "size_param")); + + auto* transpose = builder.AddInstruction( + HloInstruction::CreateTranspose(output_shape, a_param, {2, 0, 1})); + + module_->AddEntryComputation(builder.Build()); + + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{1, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 0})); + + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{2, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 1})); + + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{3, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 2})); + + SCOPED_TRACE(module_->ToString()); + TF_ASSERT_OK(RunInference()); + EXPECT_EQ(inference_->GetDynamicSize(transpose, {}, 0), size_param_3); + EXPECT_EQ(inference_->GetDynamicSize(transpose, {}, 1), size_param_1); + EXPECT_EQ(inference_->GetDynamicSize(transpose, {}, 2), size_param_2); +} + TEST_F(DynamicDimensionInferenceTest, ReshapeTest) { // Test the ability to trace unmodified reshape dimensions. auto builder = HloComputation::Builder(TestName()); diff --git a/tensorflow/compiler/xla/service/dynamic_padder_test.cc b/tensorflow/compiler/xla/service/dynamic_padder_test.cc index c37f9d0c3db..e669bc4dbe2 100644 --- a/tensorflow/compiler/xla/service/dynamic_padder_test.cc +++ b/tensorflow/compiler/xla/service/dynamic_padder_test.cc @@ -865,6 +865,45 @@ ENTRY main { EXPECT_EQ(result, expected); } +XLA_TEST_F(ExecutionTest, DynamicPad) { + const string hlo_text = R"( +HloModule TEST + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + rhs = s32[] parameter(1) + ROOT add = s32[] add(lhs, rhs) +} + +ENTRY main { + param = s32[4] parameter(0) + size = s32[] constant(3) + padding = s32[] constant(2) + param_dynamic = s32[4] set-dimension-size(param, size), + dimensions={0} + // pad head and tail to 2 + pad = s32[6] pad(param_dynamic, padding), padding=1_1 + + init = s32[] constant(0) + ROOT reduce = s32[] reduce(pad, init), + dimensions={0}, + to_apply=update_s32 +} +)"; + + Literal operand = LiteralUtil::CreateR1({1, 4, 3, 5}); + auto module = GetHloModule(hlo_text); + + // After padding head and tail with "2", the effective data will be [2, 1, 4, + // 3, 2] + + Literal result = PadAndExecute(std::move(module), {&operand}, + /*slice_dynamic_output=*/false); + Literal expected = LiteralUtil::CreateR0(12); + + EXPECT_EQ(result, expected); +} + XLA_TEST_F(ExecutionTest, DynamicTupleSort) { const string hlo_text = R"( HloModule TEST diff --git a/tensorflow/compiler/xla/service/executable.h b/tensorflow/compiler/xla/service/executable.h index 4859759eba5..8a9a96ce363 100644 --- a/tensorflow/compiler/xla/service/executable.h +++ b/tensorflow/compiler/xla/service/executable.h @@ -45,6 +45,18 @@ namespace xla { // TODO(b/150633678): Both the ExecutionInput and ExecutionOutput need to be // revisited, with the execute APIs taking data structure which can better model // shareable buffers. +// +// ExecutionInput buffers are in one of three states: +// +// 1) Owned by the caller and immutable. +// 2) Donated by the caller but returned on error. +// 3) Donated by the caller and freed on error. +// +// Case (1) buffers are stored as MaybeOwningDeviceMemory(DeviceMemoryBase). +// Case (2) buffers are stored as MaybeOwningDeviceMemory(OwningDeviceMemory), +// with their indices present in unowned_indices_. +// Case (3) buffers are stored as MaybeOwningDeviceMemory(OwningDeviceMemory), +// with their indices absent from unowned_indices_. class ExecutionInput { public: ExecutionInput() = default; @@ -80,6 +92,10 @@ class ExecutionInput { unowned_indices_.push_back(index); } + void SetUnownedIndex(const ShapeIndex& index) { + unowned_indices_.push_back(index); + } + const ShapeTree& Buffers() const { return buffers_; } ShapeTree* MutableBuffers() { return &buffers_; } @@ -94,6 +110,8 @@ class ExecutionInput { private: ShapeTree buffers_; + // (Unordered) set of indices of buffers that should be returned to the + // caller if an error occurs when enqueuing the computation. std::vector unowned_indices_; }; diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 4a903548c22..0877ac2cfc7 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -11,11 +11,13 @@ load( ) load( "//tensorflow:tensorflow.bzl", + "if_cuda_or_rocm", "tf_cc_test", "tf_copts", "tf_cuda_library", ) load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") +load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm") load( "//tensorflow/core/platform/default:cuda_build_defs.bzl", "if_cuda_is_configured", @@ -420,9 +422,24 @@ filegroup( ), ) +# use alias since nested select statements not possible +cc_library( + name = "empty", +) + +alias( + name = "virtual_nccl", + actual = if_cuda("@local_config_nccl//:nccl", ":empty"), +) + +alias( + name = "virtual_rccl", + actual = if_rocm("@local_config_rocm//rocm:rccl", ":empty"), +) + tf_cuda_library( name = "nccl_all_reduce_thunk", - srcs = if_cuda( + srcs = if_cuda_or_rocm( [":nccl_all_reduce_thunk_src"], ["dummy_all_reduce_thunk.cc"], ), @@ -443,10 +460,15 @@ tf_cuda_library( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + ] + if_cuda([ "//tensorflow/stream_executor/cuda:cuda_activation", "//tensorflow/stream_executor/cuda:cuda_gpu_executor", - ] + if_nccl([ - "@local_config_nccl//:nccl", + ]) + if_rocm([ + "//tensorflow/stream_executor/rocm:rocm_activation", + "//tensorflow/stream_executor/rocm:rocm_gpu_executor", + ]) + if_nccl([ + ":virtual_nccl", + ":virtual_rccl", ]), ) diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index 3a8c3321e24..87054d8322a 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -518,7 +518,9 @@ StatusOr> GpuCompiler::RunBackend( << "Invalid LLVM IR before optimizations:\n" << err_stream.str() << "\nThis probably indicates a bug in the HLO -> LLVM IR lowering. " - "Rerun with --xla_dump_to to get the IR. "; + "Rerun with --xla_dump_to to get the IR and looks for files with " + "name containing: *" + << FilenameFor(*module, "", "") << "*"; } GpuVersion gpu_version = GetGpuVersion(stream_exec); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc b/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc index 2f8ce62dd84..1316e8ad1aa 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h" +#include #include #include #include @@ -66,6 +67,25 @@ bool IfFusedReadsElementsMultipleTimes(const HloInstruction& instr) { return false; } +std::vector ExtractRelativeOrderOfNontrivialDims(const Shape& shape) { + std::vector relative_order; + for (int64 dim : LayoutUtil::MinorToMajor(shape)) { + if (shape.dimensions(dim) > 1) { + relative_order.push_back(dim); + } + } + // Now normalize the dimensions to values between 0 and true rank - 1. + std::vector sorted_dims = relative_order; + std::sort(sorted_dims.begin(), sorted_dims.end()); + for (int64& dim : relative_order) { + int64 sorted_index = std::distance( + sorted_dims.begin(), + std::lower_bound(sorted_dims.begin(), sorted_dims.end(), dim)); + dim = sorted_index; + } + return relative_order; +} + } // namespace bool LayoutsAreReduceInputFusionFriendly(const HloInstruction& producer, @@ -73,17 +93,20 @@ bool LayoutsAreReduceInputFusionFriendly(const HloInstruction& producer, std::vector params; AppendParams(producer, ¶ms); AppendParams(reduce, ¶ms); - int64 max_rank = -1; - const Layout* max_rank_layout; + int64 max_true_rank = -1; + std::vector max_rank_order; for (HloInstruction* param : params) { - if (param->shape().IsArray() && param->shape().rank() > max_rank) { - max_rank = param->shape().rank(); - max_rank_layout = ¶m->shape().layout(); + if (param->shape().IsArray() && + ShapeUtil::TrueRank(param->shape()) > max_true_rank) { + max_true_rank = ShapeUtil::TrueRank(param->shape()); + max_rank_order = ExtractRelativeOrderOfNontrivialDims(param->shape()); } } return absl::c_all_of(params, [&](HloInstruction* param) { - return (!param->shape().IsArray()) || (param->shape().rank() < max_rank) || - (LayoutUtil::Equal(param->shape().layout(), *max_rank_layout)); + return !param->shape().IsArray() || + ShapeUtil::TrueRank(param->shape()) < max_true_rank || + ExtractRelativeOrderOfNontrivialDims(param->shape()) == + max_rank_order; }); } diff --git a/tensorflow/compiler/xla/service/gpu/gpu_fusible_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_fusible_test.cc index ae31b10deb3..854aab86b8e 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_fusible_test.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_fusible_test.cc @@ -91,6 +91,44 @@ TEST_F(GpuFusibleTest, LayoutsAreReduceInputFusionFriendly(*loop_fusion, *reduce_fusion)); } +TEST_F(GpuFusibleTest, + LayoutsAreReduceInputFusionFriendly_MixedLayoutProducerWithTrivialDim) { + auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"( + mixed_input_layouts_computation { + p0.1 = f16[128,1,32,32]{1,3,2,0} parameter(0) + p1.1 = f16[128,1,32,32]{3,2,1,0} parameter(1) + copy = f16[128,1,32,32]{1,3,2,0} copy(p1.1) + c0 = f16[] constant(0) + broadcast = f16[128,1,32,32]{1,3,2,0} broadcast(c0), dimensions={} + greater-than = pred[128,1,32,32]{1,3,2,0} compare(copy, broadcast), direction=GT + ROOT root = f16[128,1,32,32]{1,3,2,0} select(greater-than, p0.1, broadcast) + } + fused_reduce { + p0.2 = f16[128,1,32,32]{1,3,2,0} parameter(0) + convert = f32[128,1,32,32]{1,3,2,0} convert(p0.2) + c0.2 = f32[] constant(0) + ROOT reduce = f32[1]{0} reduce(convert, c0.2), dimensions={0,2,3}, to_apply=scalar_add + } + ENTRY entry { + p0 = f16[128,1,32,32]{1,3,2,0} parameter(0) + p1 = f16[128,1,32,32]{3,2,1,0} parameter(1) + loop_fusion = f16[128,1,32,32]{1,3,2,0} fusion(p0, p1), kind=kLoop, calls=mixed_input_layouts_computation + reduce_fusion = f32[1]{0} fusion(loop_fusion), kind=kInput, calls=fused_reduce + ROOT root = (f32[1]{0}, f16[128,1,32,32]{1,3,2,0}) tuple(reduce_fusion, loop_fusion) + })")) + .ValueOrDie(); + SCOPED_TRACE(module->ToString()); + const HloInstruction* reduce_fusion = + module->entry_computation()->root_instruction()->operand(0); + ASSERT_EQ(reduce_fusion->fused_expression_root()->opcode(), + HloOpcode::kReduce); + const HloInstruction* loop_fusion = + module->entry_computation()->root_instruction()->operand(1); + ASSERT_EQ(loop_fusion->fused_expression_root()->opcode(), HloOpcode::kSelect); + EXPECT_TRUE( + LayoutsAreReduceInputFusionFriendly(*loop_fusion, *reduce_fusion)); +} + TEST_F(GpuFusibleTest, LayoutsAreReduceInputFusionFriendly_CopyProducer) { auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"( fused_reduce { @@ -152,17 +190,18 @@ TEST_F(GpuFusibleTest, } TEST_F(GpuFusibleTest, - LayoutsAreReduceInputFusionFriendly_ConsiderMaximumRanksParamsOnly) { + LayoutsAreReduceInputFusionFriendly_ConsiderMaximumTrueRanksParamsOnly) { auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"( broadcasting_computation { p0.1 = f32[128,1024,32,32]{1,3,2,0} parameter(0) - p1.1 = f32[128]{0} parameter(1) - broadcast = f32[128,1024,32,32]{1,3,2,0} broadcast(p1.1), dimensions={0} + p1.1 = f32[1,128,1,1]{3,2,1,0} parameter(1) + reshape = f32[128]{0} reshape(p1.1) + broadcast = f32[128,1024,32,32]{1,3,2,0} broadcast(reshape), dimensions={0} ROOT add = f32[128,1024,32,32]{1,3,2,0} add(p0.1, broadcast) } ENTRY entry { p0 = f32[128,1024,32,32]{1,3,2,0} parameter(0) - p1 = f32[128]{0} parameter(1) + p1 = f32[1,128,1,1]{3,2,1,0} parameter(1) loop_fusion = f32[128,1024,32,32]{1,3,2,0} fusion(p0, p1), kind=kLoop, calls=broadcasting_computation c0.2 = f32[] constant(0) ROOT reduce = f32[1024]{0} reduce(loop_fusion, c0.2), dimensions={0,2,3}, to_apply=scalar_add diff --git a/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc b/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc index 8d568e7f5d4..344dee56ee4 100644 --- a/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc @@ -29,7 +29,11 @@ limitations under the License. #include "absl/strings/str_join.h" #include "absl/types/optional.h" #include "absl/types/span.h" +#if GOOGLE_CUDA #include "third_party/nccl/nccl.h" +#elif TENSORFLOW_USE_ROCM +#include "rocm/include/rccl/rccl.h" +#endif #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/refcounting_hash_map.h" #include "tensorflow/compiler/xla/service/collective_ops_utils.h" @@ -39,7 +43,17 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/blocking_counter.h" #include "tensorflow/core/platform/mutex.h" -#include "tensorflow/stream_executor/cuda/cuda_activation.h" +#include "tensorflow/stream_executor/gpu/gpu_activation.h" + +#if TENSORFLOW_USE_ROCM +// Local hipify of cuda symbols +#define cudaError_t hipError_t +#define cudaStream_t hipStream_t +#define cudaGetErrorString hipGetErrorString +#define cudaGetDevice hipGetDevice +#define cudaSetDevice hipSetDevice +#define cudaSuccess hipSuccess +#endif namespace xla { namespace gpu { @@ -159,6 +173,7 @@ absl::optional DatatypeToNccl(PrimitiveType element_type) { switch (element_type) { case S8: return ncclInt8; + case PRED: case U8: return ncclUint8; case S32: @@ -443,7 +458,7 @@ RendezvousNcclAllReduce::SubmitParticipantImpl( ncclRedOp_t computation = ReductionKindToNccl(participant.reduction_kind); se::StreamExecutor* executor = participant.stream->parent(); - se::cuda::ScopedActivateExecutorContext scoped_context(executor); + se::gpu::ScopedActivateExecutorContext scoped_context(executor); cudaStream_t* cu_stream = reinterpret_cast( participant.stream->implementation()->GpuStreamMemberHack()); VLOG(3) << "Using stream pointer: " << cu_stream diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_index_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_index_test.cc index 871692a7b26..4203622f53d 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_index_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_index_test.cc @@ -88,6 +88,38 @@ TEST_F(GpuIndexTest, CompatibleUseLinearIndexWithReshape) { /*match_optimized_ir=*/true); } +TEST_F(GpuIndexTest, + ReuseMultidimIndexWithTrivialReshapeAndNonContiguousBroadcast) { + HloModuleConfig config; + config.set_debug_options(HloTestBase::GetDebugOptionsForTest()); + auto module = ParseAndReturnVerifiedModule(R"( + HloModule test_module + + ENTRY CompatibleUseLinearIndexWithReshape { + x = f32[1,7,2,5,3]{4,3,2,1,0} parameter(0) + y = f32[2,1,3]{2,1,0} parameter(1) + reshape = f32[1,2,3]{2,1,0} reshape(y) + broadcast = f32[1,7,2,5,3]{4,3,2,1,0} broadcast(reshape), dimensions={0,2,4} + ROOT gte = pred[1,7,2,5,3]{4,3,2,1,0} compare(x, broadcast), direction=GE + })", + config) + .ValueOrDie(); + CompileAndVerifyIr(std::move(module), + R"( +; CHECK: %[[tmp4:.*]] = udiv i32 %[[linear_index:.*]], 1 +; CHECK: %[[dim4:.*]] = urem i32 %[[tmp4]], 3 +; CHECK: %[[tmp3:.*]] = udiv i32 %[[linear_index]], 3 +; CHECK: %[[dim3:.*]] = urem i32 %[[tmp3]], 5 +; CHECK: %[[tmp2:.*]] = udiv i32 %[[linear_index]], 15 +; CHECK: %[[dim2:.*]] = urem i32 %[[tmp2]], 2 +; CHECK: %[[tmp1:.*]] = udiv i32 %[[linear_index]], 30 +; CHECK: %[[dim1:.*]] = urem i32 %[[tmp1]], 7 +; CHECK: %[[dim0:.*]] = udiv i32 %[[linear_index]], 210 +; CHECK: %{{.*}} = getelementptr inbounds [2 x [1 x [3 x float]]], [2 x [1 x [3 x float]]]* %{{.*}}, i32 0, i32 %[[dim2]], i32 0, i32 %[[dim4]] + )", + /*match_optimized_ir=*/false); +} + TEST_F(GpuIndexTest, CompatibleUseLinearIndexWithReshapeAndBroadcast) { HloModuleConfig config; config.set_debug_options(HloTestBase::GetDebugOptionsForTest()); diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto index 9ad07df8e9a..c4911df150f 100644 --- a/tensorflow/compiler/xla/service/hlo.proto +++ b/tensorflow/compiler/xla/service/hlo.proto @@ -358,6 +358,11 @@ message DynamicParameterBindingProto { repeated Binding entries = 1; } +message CrossProgramPrefetch { + int64 parameter = 1; + repeated int64 index = 2; +} + // Serialization of HloModule. message HloModuleProto { string name = 1; @@ -381,6 +386,8 @@ message HloModuleProto { HloInputOutputAliasProto input_output_alias = 8; DynamicParameterBindingProto dynamic_parameter_binding = 9; + + repeated CrossProgramPrefetch cross_program_prefetches = 10; } // Serialization of LogicalBuffer. diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 9698735b509..484ed3eaa6c 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -3525,17 +3525,6 @@ bool HloPtrComparator::operator()(const HloInstruction* const& lhs, return lhs->unique_id() < rhs->unique_id(); } -bool HloInstruction::CouldBeBitcast() const { - switch (opcode_) { - case HloOpcode::kTranspose: - return true; - case HloOpcode::kReshape: - return std::get<0>(ReshapeMerelyInsertsOrDeletes1SizedDimensions()); - default: - return false; - } -} - Status HloInstruction::GetBackendConfigInternal( tensorflow::protobuf::Message* proto) const { proto->Clear(); @@ -3648,6 +3637,10 @@ const std::vector& HloInstruction::slice_starts() const { return Cast(this)->slice_starts(); } +std::vector* HloInstruction::mutable_slice_starts() { + return Cast(this)->mutable_slice_starts(); +} + int64 HloInstruction::slice_limits(int64 dimension) const { return Cast(this)->slice_limits(dimension); } @@ -3656,6 +3649,10 @@ const std::vector& HloInstruction::slice_limits() const { return Cast(this)->slice_limits(); } +std::vector* HloInstruction::mutable_slice_limits() { + return Cast(this)->mutable_slice_limits(); +} + int64 HloInstruction::slice_strides(int64 dimension) const { return Cast(this)->slice_strides(dimension); } @@ -3664,6 +3661,10 @@ const std::vector& HloInstruction::slice_strides() const { return Cast(this)->slice_strides(); } +std::vector* HloInstruction::mutable_slice_strides() { + return Cast(this)->mutable_slice_strides(); +} + const Literal& HloInstruction::literal() const { return Cast(this)->literal(); } diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 22220ccc2d5..fdeea10c496 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -1562,10 +1562,6 @@ class HloInstruction { // Returns the module for this instruction. HloModule* GetModule() const; - // Returns whether we could assign input and output layouts to this - // instruction to make it a bitcast. - bool CouldBeBitcast() const; - // Get/Set the number of partitions per outer dimension (in order, starting // with outer-most dimension first). Currently used by the parallel cpu // backend to partition HLOs into parallel tasks. @@ -1620,14 +1616,17 @@ class HloInstruction { // Delegates to HloSliceInstruction::slice_start. int64 slice_starts(int64 dimension) const; const std::vector& slice_starts() const; + std::vector* mutable_slice_starts(); // Delegates to HloSliceInstruction::slice_limits. int64 slice_limits(int64 dimension) const; const std::vector& slice_limits() const; + std::vector* mutable_slice_limits(); // Delegates to HloSliceInstruction::slice_strides. int64 slice_strides(int64 dimension) const; const std::vector& slice_strides() const; + std::vector* mutable_slice_strides(); // Returns the literal associated with this instruction. const Literal& literal() const; diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h index 8770e9be342..0cf8f7e6eb0 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.h +++ b/tensorflow/compiler/xla/service/hlo_instructions.h @@ -700,17 +700,20 @@ class HloSliceInstruction : public HloInstruction { // Returns the start index in the given dimension for a slice node. int64 slice_starts(int64 dimension) const { return slice_starts_[dimension]; } const std::vector& slice_starts() const { return slice_starts_; } + std::vector* mutable_slice_starts() { return &slice_starts_; } // Returns the (exclusive) limit index in the given dimension for a slice // node. int64 slice_limits(int64 dimension) const { return slice_limits_[dimension]; } const std::vector& slice_limits() const { return slice_limits_; } + std::vector* mutable_slice_limits() { return &slice_limits_; } // Returns the stride in the given dimension for a slice node. int64 slice_strides(int64 dimension) const { return slice_strides_[dimension]; } const std::vector& slice_strides() const { return slice_strides_; } + std::vector* mutable_slice_strides() { return &slice_strides_; } private: std::vector ExtraAttributesToStringImpl( @@ -738,6 +741,8 @@ class HloConstantInstruction : public HloInstruction { explicit HloConstantInstruction(const Shape& shape); // Returns the literal associated with this instruction. const Literal& literal() const { return *literal_; } + // Returns the (mutable) literal associated with this instruction. + Literal* mutable_literal() { return &literal_.value(); } // Returns whether there is literal associated with this instruction. bool HasLiteral() const { return literal_.has_value(); } // Returns a serialized representation of this instruction. diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc index c8a68db25d4..de65ed99303 100644 --- a/tensorflow/compiler/xla/service/hlo_module.cc +++ b/tensorflow/compiler/xla/service/hlo_module.cc @@ -263,6 +263,15 @@ HloModuleProto HloModule::ToProto() const { *proto.mutable_input_output_alias() = input_output_alias_config().ToProto(); *proto.mutable_dynamic_parameter_binding() = dynamic_parameter_binding().ToProto(); + for (const auto& parameter_indices : CrossProgramPrefetches()) { + const auto& parameter = parameter_indices.first; + const auto& indices = parameter_indices.second; + auto* prefetch = proto.mutable_cross_program_prefetches()->Add(); + prefetch->set_parameter(parameter); + for (auto index : indices) { + prefetch->add_index(index); + } + } return proto; } @@ -389,6 +398,12 @@ StatusOr> HloModule::CreateFromProto( TF_RETURN_IF_ERROR(module->set_schedule(std::move(schedule))); } + for (auto prefetch : proto.cross_program_prefetches()) { + module->AddCrossProgramPrefetch( + prefetch.parameter(), + ShapeIndex(prefetch.index().begin(), prefetch.index().end())); + } + return std::move(module); } @@ -669,6 +684,11 @@ std::unique_ptr HloModule::Clone(const HloModuleConfig& config, } TF_CHECK_OK(module->set_schedule(std::move(clone_schedule))); } + for (const auto& parameter_indices : CrossProgramPrefetches()) { + const auto& parameter = parameter_indices.first; + const auto& indices = parameter_indices.second; + module->AddCrossProgramPrefetch(parameter, indices); + } return module; } diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h index 38395f173e1..5f97d0c66b6 100644 --- a/tensorflow/compiler/xla/service/hlo_module.h +++ b/tensorflow/compiler/xla/service/hlo_module.h @@ -345,6 +345,17 @@ class HloModule { spmd_output_sharding_ = sharding; } + // Add a program argument to be prefetched across programs. + void AddCrossProgramPrefetch(int64 parameter, const ShapeIndex& index) { + cross_program_prefetches_.emplace_back(parameter, index); + } + + // Get the list of program arguments to be prefetch across programs. + const absl::Span> CrossProgramPrefetches() + const { + return cross_program_prefetches_; + } + private: HloComputation* AddComputationInternal( std::unique_ptr computation, bool is_entry, @@ -392,6 +403,9 @@ class HloModule { // The HLO sharding of the entry computation's output (root) for // SPMD-partitioned programs. absl::optional spmd_output_sharding_; + + // Arguments to be prefetched across programs. + std::vector> cross_program_prefetches_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_sharding.cc b/tensorflow/compiler/xla/service/hlo_sharding.cc index 538a99b46ea..9701c343288 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding.cc @@ -443,6 +443,25 @@ Shape HloSharding::TileShape(const Shape& shape) const { return result_shape; } +Shape HloSharding::TileShape(const Shape& shape, int64 device) const { + if (IsTileMaximal()) { + return shape; + } + + std::vector index = TileIndexForDevice(device); + Shape result_shape = shape; + for (int64 i = 0; i < index.size(); ++i) { + const int64 shape_dim = shape.dimensions(i); + int64 offset = std::min( + index[i] * CeilOfRatio(shape_dim, tile_assignment_.dim(i)), shape_dim); + int64 limit = std::min( + (index[i] + 1) * CeilOfRatio(shape_dim, tile_assignment_.dim(i)), + shape_dim); + result_shape.set_dimensions(i, limit - offset); + } + return result_shape; +} + HloSharding HloSharding::GetSubSharding(const Shape& shape, const ShapeIndex& index) const { CHECK(IsTuple()); diff --git a/tensorflow/compiler/xla/service/hlo_sharding.h b/tensorflow/compiler/xla/service/hlo_sharding.h index 56479add95f..20fa7232e65 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.h +++ b/tensorflow/compiler/xla/service/hlo_sharding.h @@ -216,6 +216,10 @@ class HloSharding { // REQUIRES: !IsTuple() Shape TileShape(const Shape& shape) const; + // Gets the tile shape on the device. + // REQUIRES: !IsTuple() + Shape TileShape(const Shape& shape, int64 device) const; + private: HloSharding() : replicated_(true), diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index 74e77510d2a..53938a489f1 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -692,14 +692,6 @@ bool InstructionFusion::ShouldFuse(HloInstruction* consumer, return false; } - if (producer->CouldBeBitcast() && - // We can't fuse parameters anyhow, so we leave the user unfused to become - // a bitcast. If the operand is not a parameter, we would break a - // potential fusion to make it a bitcast, which is not so clear a win. - producer->operand(0)->opcode() == HloOpcode::kParameter) { - return false; - } - return true; } diff --git a/tensorflow/compiler/xla/service/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/instruction_fusion_test.cc index 0f400f2b2ed..f4309ea09ae 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion_test.cc @@ -110,54 +110,6 @@ TEST_F(InstructionFusionTest, FuseInstructionsIntoMultiOutput) { << module->ToString(); } -TEST_F(InstructionFusionTest, PotentialBitcastReshapeOfParameterUnfused) { - HloComputation::Builder builder(TestName()); - auto param0 = builder.AddInstruction( - HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(S32, {}), "0")); - auto reshape1 = builder.AddInstruction( - HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {1, 1}), param0)); - - auto module = CreateNewVerifiedModule(); - auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_EQ(reshape1, computation->root_instruction()); - EXPECT_FALSE( - InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true) - .Run(module.get()) - .ValueOrDie()); -} - -TEST_F(InstructionFusionTest, PotentialBitcastSimpleReshapeOfParameterUnfused) { - HloComputation::Builder builder(TestName()); - auto param0 = builder.AddInstruction( - HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(S32, {}), "0")); - auto reshape1 = builder.AddInstruction( - HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {1, 1}), param0)); - - auto module = CreateNewVerifiedModule(); - auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_EQ(reshape1, computation->root_instruction()); - EXPECT_FALSE( - InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true) - .Run(module.get()) - .ValueOrDie()); -} - -TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfParameterUnfused) { - HloComputation::Builder builder(TestName()); - auto param0 = builder.AddInstruction( - HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(S32, {}), "0")); - auto transpose1 = builder.AddInstruction(HloInstruction::CreateTranspose( - ShapeUtil::MakeShape(S32, {}), param0, {})); - - auto module = CreateNewVerifiedModule(); - auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_EQ(transpose1, computation->root_instruction()); - EXPECT_FALSE( - InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true) - .Run(module.get()) - .ValueOrDie()); -} - TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusible) { HloComputation::Builder builder(TestName()); auto shape = ShapeUtil::MakeShape(F32, {16, 16}); diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc index 396fcf9e92e..da0dbf94ddd 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc @@ -15,6 +15,9 @@ limitations under the License. #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" +#include +#include + #include "llvm/IR/Constants.h" #include "llvm/IR/Instructions.h" #include "tensorflow/compiler/xla/layout_util.h" @@ -137,40 +140,76 @@ IrArray::Index IrArray::Index::SourceIndexOfReshape( const Shape& output_shape, const Shape& input_shape, llvm::IRBuilder<>* builder) const { CHECK_EQ(multidim_.size(), output_shape.rank()); - const auto common_factors = - CommonFactors(AsInt64Slice(input_shape.dimensions()), - AsInt64Slice(output_shape.dimensions())); std::vector source_multidim_index( input_shape.rank(), llvm::UndefValue::get(index_type_)); - // We compute the source indices in each common factor from only the target - // indices in the same common factor. - for (ssize_t k = common_factors.size() - 2; k >= 0; --k) { - absl::Span dimensions = - AsInt64Slice(output_shape.dimensions()) - .subspan(common_factors[k].second, - common_factors[k + 1].second - common_factors[k].second); - llvm::Value* logical_linear_index = - Index(absl::Span(multidim_).subspan( - common_factors[k].second, - common_factors[k + 1].second - common_factors[k].second), - dimensions, index_type_) - .Linearize(dimensions, builder); - // Delinearizes logical_linear_index for the source array in row-major - // collapsed order. The first rank-1 indices are the remainder of the - // linear index by each dimension size. - for (int64 i = common_factors[k + 1].first - 1; - i >= common_factors[k].first; --i) { - llvm::Value* divisor = - GetConstantWithIndexType(input_shape.dimensions(i)); - if (input_shape.dimensions(i) == 1) { - source_multidim_index[i] = GetConstantWithIndexType(0); - } else if (i == common_factors[k].first) { - source_multidim_index[i] = logical_linear_index; + auto trivial_reshape = + ShapeUtil::InsertedOrDeleted1SizedDimensions(input_shape, output_shape); + if (std::get<0>(trivial_reshape)) { + // The 1-sized dimensions which only appear in 'input_shape'. + auto deleted_dims_indices = std::get<1>(trivial_reshape); + // The 1-sized dimensions which only appear in 'output_shape'. + auto inserted_dims_indices = std::get<2>(trivial_reshape); + + // This is a two-way merge of 'deleted_dims_indices' with indexing into + // 'source_multidim_index', and a two-way merge of 'inserted_dims_indices' + // with indexing into 'multidim_'. When we find a dimension in + // 'source_multidim_index' which does not belong to 'deleted_dims_indices', + // we retrieve the corresponding value from 'multidim_' (skipping any + // indices that appear in 'inserted_dims_indices'). + for (int64 i = 0, j = 0, k = 0, l = 0; i < source_multidim_index.size(); + ++i) { + if (j == deleted_dims_indices.size() || deleted_dims_indices[j] > i) { + // This is a dimension that was preserved. Take the matching value from + // multidim_. + while (l < inserted_dims_indices.size() && + inserted_dims_indices[l] == k) { + // Skip 1-sized dimensions. + ++k; + ++l; + } + source_multidim_index[i] = multidim_[k]; + ++k; } else { - source_multidim_index[i] = - builder->CreateURem(logical_linear_index, divisor); + // This is a 1-sized dimension that only appears in the operand. + source_multidim_index[i] = GetConstantWithIndexType(0); + ++j; + } + } + } else { + const auto common_factors = + CommonFactors(AsInt64Slice(input_shape.dimensions()), + AsInt64Slice(output_shape.dimensions())); + // We compute the source indices in each common factor from only the target + // indices in the same common factor. + for (ssize_t k = common_factors.size() - 2; k >= 0; --k) { + absl::Span dimensions = + AsInt64Slice(output_shape.dimensions()) + .subspan(common_factors[k].second, + common_factors[k + 1].second - common_factors[k].second); + llvm::Value* logical_linear_index = + Index(absl::Span(multidim_).subspan( + common_factors[k].second, + common_factors[k + 1].second - common_factors[k].second), + dimensions, index_type_) + .Linearize(dimensions, builder); + // Delinearizes logical_linear_index for the source array in row-major + // collapsed order. The first rank-1 indices are the remainder of the + // linear index by each dimension size. + for (int64 i = common_factors[k + 1].first - 1; + i >= common_factors[k].first; --i) { + llvm::Value* divisor = + GetConstantWithIndexType(input_shape.dimensions(i)); + if (input_shape.dimensions(i) == 1) { + source_multidim_index[i] = GetConstantWithIndexType(0); + } else if (i == common_factors[k].first) { + source_multidim_index[i] = logical_linear_index; + } else { + source_multidim_index[i] = + builder->CreateURem(logical_linear_index, divisor); + } + logical_linear_index = + builder->CreateUDiv(logical_linear_index, divisor); } - logical_linear_index = builder->CreateUDiv(logical_linear_index, divisor); } } diff --git a/tensorflow/compiler/xla/service/memory_space_assignment.cc b/tensorflow/compiler/xla/service/memory_space_assignment.cc index 4dc1c5782b6..f812165be04 100644 --- a/tensorflow/compiler/xla/service/memory_space_assignment.cc +++ b/tensorflow/compiler/xla/service/memory_space_assignment.cc @@ -355,6 +355,17 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() { continue; } + HloInstruction* inst = interval.buffer->instruction(); + HloModule* module = inst->GetModule(); + + // Don't intra-program prefetch a cross program prefetch + if (inst->opcode() == HloOpcode::kParameter && + absl::c_count(module->CrossProgramPrefetches(), + std::make_pair(inst->parameter_number(), + interval.buffer->index())) > 0) { + continue; + } + auto colocated_intervals = GetSortedColocatedIntervals(interval); if (AreIntervalsReservedInAlternateMemory(colocated_intervals)) { @@ -561,6 +572,52 @@ AlternateMemoryBestFitHeap::GetLiveAllocationAt( return nullptr; } +void AlternateMemoryBestFitHeap::AllocateCrossProgramPrefetchBuffer( + HloModule* module, absl::optional prefetch_candidate) { + if (!prefetch_candidate) { + return; + } + + ChunkCandidate chunk_candidate = FindChunkCandidate(*prefetch_candidate); + if (chunk_candidate.chunk.offset != 0 || + chunk_candidate.heap_size > available_heap_size()) { + LOG(WARNING) + << "Could not allocate preferred memory for cross program prefetch"; + return; + } + AddToPendingChunks(*prefetch_candidate, chunk_candidate); + + const HloValue* buffer = prefetch_candidate->buffer; + int64 parameter = buffer->instruction()->parameter_number(); + module->AddCrossProgramPrefetch(parameter, buffer->index()); + + allocation_sequence_list_->push_back({buffer, {}}); + MemorySpaceAssignment::AllocationSequence& allocations = + allocation_sequence_list_->back().sequence; + + allocations.push_back(absl::make_unique( + buffer->defining_position(), MemorySpace::kDefault, kDummyChunk, + prefetch_candidate->start, prefetch_candidate->end)); + + // Sort the uses by the use time. + const auto& instruction_schedule = hlo_live_range_.instruction_schedule(); + auto uses = buffer->uses(); + auto first_use = + absl::c_min_element(uses, [&](const HloUse& lhs, const HloUse& rhs) { + return instruction_schedule.at(lhs.instruction) < + instruction_schedule.at(rhs.instruction); + }); + int64 latest_prefetch_time = instruction_schedule.at(first_use->instruction); + + AddAsyncCopy(*allocations.back(), MemorySpace::kAlternate, + chunk_candidate.chunk, prefetch_candidate->start, + prefetch_candidate->end, latest_prefetch_time, &allocations); + absl::c_for_each(uses, [&](auto& use) { allocations.back()->AddUse(use); }); + + pending_chunks_.clear(); + pending_async_copies_.clear(); +} + void AlternateMemoryBestFitHeap::AddInputAndOutputRequiredAssignments() { // Go through the parameters and outputs and pin them to the corresponding // memory by adding a required assignment. @@ -1207,6 +1264,90 @@ MemorySpaceAssignment::GetMemoryBoundednessBufferIntervalCompare( }; } +namespace { + +bool LooksLikeAnActivation(const HloInstruction* inst) { + for (HloInstruction* user : inst->users()) { + switch (user->opcode()) { + case HloOpcode::kConvolution: + case HloOpcode::kDot: + if (user->operand(0) == inst) { + return true; + } + break; + case HloOpcode::kGather: + if (user->operand(1) == inst) { + return true; + } + break; + case HloOpcode::kFusion: + for (int i = 0; i < user->operand_count(); ++i) { + if (user->operand(i) == inst && + LooksLikeAnActivation(user->fused_parameter(i))) { + return true; + } + } + break; + default: + return true; + } + } + return false; +} + +bool IsCrossProgramPrefetchCandidate( + const HloValue& value, const MemorySpaceAssignment::Options& options) { + return value.instruction()->parent() == + value.instruction()->GetModule()->entry_computation() && + value.instruction()->opcode() == HloOpcode::kParameter && + value.index().size() == 1 && value.shape().IsArray() && + !value.uses().empty() && + options.size_fn(value) <= options.max_size_in_bytes && + absl::c_all_of(value.uses(), [&](const HloUse& use) { + const HloInstruction* gte = + use.instruction->operand(use.operand_number); + return gte->opcode() == HloOpcode::kGetTupleElement && + !LooksLikeAnActivation(gte); + }); +} + +absl::optional +FindCrossProgramPrefetchCandidate( + const HloAliasAnalysis& alias_analysis, const HloLiveRange& hlo_live_range, + const MemorySpaceAssignment::Options& options) { + std::vector candidates; + for (HloValue* value : alias_analysis.dataflow_analysis().values()) { + if (IsCrossProgramPrefetchCandidate(*value, options)) { + MemorySpaceAssignment::BufferInterval interval; + interval.buffer = value; + interval.size = options.size_fn(*value); + interval.start = 0; + interval.end = hlo_live_range.schedule_end_time(); + interval.need_allocation = true; + candidates.emplace_back(interval); + } + } + + // The buffer_interval_compare ought to do a good job picking the most + // appropriate buffer to cross program prefetch, but empirically, it makes + // worse choices than just picking the largest buffer. + // TODO(b/152421603): Investigate. + auto size_compare = [](const auto& x, const auto& y) { + return x.size < y.size; + }; + auto& compare = options.default_cross_program_prefetch_heuristic && + options.buffer_interval_compare + ? *options.buffer_interval_compare + : size_compare; + + auto best_candidate = absl::c_max_element(candidates, compare); + if (best_candidate == candidates.end()) { + return absl::nullopt; + } + return *best_candidate; +} +} // namespace + /*static*/ StatusOr> MemorySpaceAssignment::Run(HloModule* module, const HloLiveRange& hlo_live_range, @@ -1222,6 +1363,13 @@ MemorySpaceAssignment::Run(HloModule* module, &memory_space_assignment.allocation_sequence_list_, options, alias_analysis, hlo_live_range); + if (options.enable_cross_program_prefetch) { + absl::optional prefetch_candiate = + FindCrossProgramPrefetchCandidate(alias_analysis, hlo_live_range, + options); + algorithm->AllocateCrossProgramPrefetchBuffer(module, prefetch_candiate); + } + HeapSimulator::Options heap_simulator_options; heap_simulator_options.may_reuse_operand_buffers = false; TF_RETURN_IF_ERROR(HeapSimulator::Run(std::move(algorithm), *module, diff --git a/tensorflow/compiler/xla/service/memory_space_assignment.h b/tensorflow/compiler/xla/service/memory_space_assignment.h index 51ff5329482..4a24f60b6a9 100644 --- a/tensorflow/compiler/xla/service/memory_space_assignment.h +++ b/tensorflow/compiler/xla/service/memory_space_assignment.h @@ -319,6 +319,14 @@ class MemorySpaceAssignment { // If true, verifies the memory space assignment against overlapping // buffers. bool verify = false; + + // Enable prefetching buffers into preferred memory across program + // boundaries + bool enable_cross_program_prefetch = true; + + // If true, use buffer_interval_compare to determine which buffers to + // prefetch across program boundaries. + bool default_cross_program_prefetch_heuristic = false; }; // This class represents an allocation that might either be in the default or @@ -623,6 +631,12 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap { } } + // Allocates a buffer in preferred memory with whole program lifetime and + // enables prefetching prefech_candidate from default memory across program + // boundaries. + void AllocateCrossProgramPrefetchBuffer( + HloModule* module, absl::optional prefetch_candidate); + HeapSimulator::Result Finish() override; private: diff --git a/tensorflow/compiler/xla/service/memory_space_assignment_test.cc b/tensorflow/compiler/xla/service/memory_space_assignment_test.cc index 31967e94c46..85a5e7a87a2 100644 --- a/tensorflow/compiler/xla/service/memory_space_assignment_test.cc +++ b/tensorflow/compiler/xla/service/memory_space_assignment_test.cc @@ -3028,5 +3028,203 @@ TEST_F(AsynchronousCopyOrderingTest, Simple) { ordering.AddCopy({5, 14, alternate_mem_space}); } +TEST_P(MemorySpaceAssignmentTest, CrossProgramPrefetchTest) { + HloComputation::Builder builder(TestName()); + + constexpr int kBatch = 8; + constexpr int kFeature = 8; + constexpr int kOutput = 2; + + auto lhs_shape = ShapeUtil::MakeShape(F32, {kBatch, kFeature}); + auto rhs_shape = ShapeUtil::MakeShape(F32, {kFeature, kOutput}); + auto result_shape = ShapeUtil::MakeShape(F32, {kBatch, kOutput}); + auto tuple_shape = ShapeUtil::MakeTupleShape({lhs_shape, rhs_shape}); + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "p0")); + + auto lhs = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(lhs_shape, param, 0)); + auto rhs = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(rhs_shape, param, 1)); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + auto dot = builder.AddInstruction(HloInstruction::CreateDot( + result_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2))); + + auto module = CreateNewVerifiedModule(); + HloComputation* computation = module->AddEntryComputation(builder.Build()); + + HloSchedule schedule(module.get()); + schedule.set_sequence(computation, {param, lhs, rhs, dot}); + TF_CHECK_OK(module->set_schedule(schedule)); + + AssignMemorySpace(module.get()); + + auto cross_program_prefetches = module->CrossProgramPrefetches(); + EXPECT_EQ(cross_program_prefetches.size(), 1); + if (!cross_program_prefetches.empty()) { + EXPECT_EQ(cross_program_prefetches[0].first, 0); + EXPECT_EQ(cross_program_prefetches[0].second, ShapeIndex({1})); + } +} + +TEST_P(MemorySpaceAssignmentTest, CrossProgramPrefetchNestedTupleTest) { + HloComputation::Builder builder(TestName()); + + constexpr int kBatch = 8; + constexpr int kFeature = 8; + constexpr int kOutput = 2; + + auto lhs_shape = ShapeUtil::MakeShape(F32, {kBatch, kFeature}); + auto rhs_shape = ShapeUtil::MakeShape(F32, {kFeature, kOutput}); + auto result_shape = ShapeUtil::MakeShape(F32, {kBatch, kOutput}); + auto tuple_shape = ShapeUtil::MakeTupleShape({lhs_shape, rhs_shape}); + auto tuple_tuple_shape = ShapeUtil::MakeTupleShape({tuple_shape}); + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_tuple_shape, "p0")); + + auto gte = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(tuple_shape, param, 0)); + + auto lhs = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(lhs_shape, gte, 0)); + auto rhs = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(rhs_shape, gte, 1)); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + auto dot = builder.AddInstruction(HloInstruction::CreateDot( + result_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2))); + + auto module = CreateNewVerifiedModule(); + HloComputation* computation = module->AddEntryComputation(builder.Build()); + + HloSchedule schedule(module.get()); + schedule.set_sequence(computation, {param, gte, lhs, rhs, dot}); + TF_CHECK_OK(module->set_schedule(schedule)); + + AssignMemorySpace(module.get()); + + auto cross_program_prefetches = module->CrossProgramPrefetches(); + EXPECT_EQ(cross_program_prefetches.size(), 0); +} + +TEST_P(MemorySpaceAssignmentTest, CrossProgramPrefetchUnusedParamTest) { + HloComputation::Builder builder(TestName()); + + constexpr int kFeature = 8; + constexpr int kOutput = 2; + + auto rhs_shape = ShapeUtil::MakeShape(F32, {kFeature, kOutput}); + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, rhs_shape, "p0")); + + auto module = CreateNewVerifiedModule(); + HloComputation* computation = module->AddEntryComputation(builder.Build()); + + HloSchedule schedule(module.get()); + schedule.set_sequence(computation, {param}); + TF_CHECK_OK(module->set_schedule(schedule)); + + AssignMemorySpace(module.get()); + + auto cross_program_prefetches = module->CrossProgramPrefetches(); + EXPECT_EQ(cross_program_prefetches.size(), 0); +} + +TEST_P(MemorySpaceAssignmentTest, CrossProgramPrefetchTooBigTest) { + HloComputation::Builder builder(TestName()); + + constexpr int kBatch = 8; + constexpr int kFeature = 8; + constexpr int kOutput = 8; + + auto lhs_shape = ShapeUtil::MakeShape(F32, {kBatch, kFeature}); + auto rhs_shape = ShapeUtil::MakeShape(F32, {kFeature, kOutput}); + auto result_shape = ShapeUtil::MakeShape(F32, {kBatch, kOutput}); + auto tuple_shape = ShapeUtil::MakeTupleShape({lhs_shape, rhs_shape}); + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "p0")); + + auto lhs = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(lhs_shape, param, 0)); + auto rhs = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(rhs_shape, param, 1)); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + auto dot = builder.AddInstruction(HloInstruction::CreateDot( + result_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2))); + + auto module = CreateNewVerifiedModule(); + HloComputation* computation = module->AddEntryComputation(builder.Build()); + + HloSchedule schedule(module.get()); + schedule.set_sequence(computation, {param, lhs, rhs, dot}); + TF_CHECK_OK(module->set_schedule(schedule)); + + AssignMemorySpace(module.get()); + + auto cross_program_prefetches = module->CrossProgramPrefetches(); + EXPECT_EQ(cross_program_prefetches.size(), 0); +} + +TEST_P(MemorySpaceAssignmentTest, CrossProgramPrefetchFusionTest) { + HloComputation::Builder builder(TestName()); + + constexpr int kBatch = 2; + constexpr int kFeature = 2; + constexpr int kOutput = 2; + + auto lhs_shape = ShapeUtil::MakeShape(F32, {kBatch, kFeature}); + auto rhs_shape = ShapeUtil::MakeShape(F32, {kFeature, kOutput}); + auto result_shape = ShapeUtil::MakeShape(F32, {kBatch, kOutput}); + auto tuple_shape = ShapeUtil::MakeTupleShape({lhs_shape, rhs_shape}); + + auto module = CreateNewVerifiedModule(); + HloComputation::Builder fusion_builder("fusion"); + { + HloInstruction* param = fusion_builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "p0")); + auto lhs = fusion_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(lhs_shape, param, 0)); + auto rhs = fusion_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(rhs_shape, param, 1)); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + auto dot = fusion_builder.AddInstruction(HloInstruction::CreateDot( + result_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2))); + (void)dot; + } + HloComputation* fusion_computation = + module->AddEmbeddedComputation(fusion_builder.Build()); + + auto activations = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR2({{0.0, 1.0}, {2.0, 3.0}}))); + auto weights = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR2({{0.0, 1.0}, {2.0, 3.0}}))); + HloInstruction* tuple = builder.AddInstruction( + HloInstruction::CreateTuple({activations, weights})); + HloInstruction* fusion = builder.AddInstruction(HloInstruction::CreateFusion( + result_shape, HloInstruction::FusionKind::kCustom, {tuple}, + fusion_computation)); + + HloComputation* computation = module->AddEntryComputation(builder.Build()); + + HloSchedule schedule(module.get()); + schedule.set_sequence(computation, {activations, weights, tuple, fusion}); + TF_CHECK_OK(module->set_schedule(schedule)); + + AssignMemorySpace(module.get()); + + auto cross_program_prefetches = module->CrossProgramPrefetches(); + EXPECT_EQ(cross_program_prefetches.size(), 0); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/BUILD b/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/BUILD index ab02cfae96b..3489018973d 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/BUILD +++ b/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/BUILD @@ -27,6 +27,7 @@ cc_library( srcs = ["conv_emitter.cc"], hdrs = ["conv_emitter.h"], deps = [ + ":conv_emitter_transforms", "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", @@ -39,6 +40,22 @@ cc_library( ], ) +cc_library( + name = "conv_emitter_transforms", + srcs = ["conv_emitter_transforms.cc"], + hdrs = ["conv_emitter_transforms.h"], + deps = [ + "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/types:span", + "@llvm-project//llvm:support", + "@llvm-project//mlir:Affine", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:TransformUtils", + ], +) + tf_cc_test( name = "conv_emitter_test", srcs = ["conv_emitter_test.cc"], diff --git a/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/conv_emitter.cc b/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/conv_emitter.cc index 5ec8d3bb334..c17d686f7dc 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/conv_emitter.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/conv_emitter.cc @@ -38,6 +38,7 @@ limitations under the License. #include "mlir/Transforms/LoopUtils.h" // from @llvm-project #include "mlir/Transforms/RegionUtils.h" // from @llvm-project #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" +#include "tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/conv_emitter_transforms.h" #include "tensorflow/compiler/xla/window_util.h" namespace xla { @@ -109,48 +110,6 @@ ShapeInfo GetShapeInfo( return shape_info; } -bool IsSimpleLoop(mlir::AffineForOp loop) { - return loop.getLowerBoundMap().isSingleConstant() && - loop.getLowerBoundMap().getSingleConstantResult() == 0 && - loop.getStep() == 1 && loop.getUpperBoundMap().getNumResults() == 1 && - std::next(loop.region().begin()) == loop.region().end(); -} - -struct BoundAffineMap { - mlir::AffineMap affine_map; - std::vector operands; -}; - -BoundAffineMap GetBoundAffineMapFrom(mlir::Operation* op) { - if (auto load = mlir::dyn_cast(op)) { - return {load.getAffineMap(), - std::vector(load.getMapOperands().begin(), - load.getMapOperands().end())}; - } else if (auto store = mlir::dyn_cast(op)) { - return {store.getAffineMap(), - std::vector(store.getMapOperands().begin(), - store.getMapOperands().end())}; - } else { - CHECK(false); - } -} - -mlir::Operation* CloneWithNewAffineMap(mlir::Operation* op, - BoundAffineMap new_affine, - mlir::OpBuilder builder) { - if (auto load = mlir::dyn_cast(op)) { - return builder.create( - builder.getUnknownLoc(), load.getMemRef(), new_affine.affine_map, - new_affine.operands); - } else if (auto store = mlir::dyn_cast(op)) { - return builder.create( - builder.getUnknownLoc(), store.getValueToStore(), store.getMemRef(), - new_affine.affine_map, new_affine.operands); - } else { - CHECK(false); - } -} - void SetMemRef(mlir::Operation* op, mlir::Value memref) { if (auto load = mlir::dyn_cast(op)) { load.setMemRef(memref); @@ -161,127 +120,6 @@ void SetMemRef(mlir::Operation* op, mlir::Value memref) { } } -std::vector CreateNestedSimpleLoops( - absl::Span upper_bounds, mlir::OpBuilder builder) { - std::vector loops; - loops.reserve(upper_bounds.size()); - for (int64_t dim : upper_bounds) { - auto loop = - builder.create(builder.getUnknownLoc(), 0, dim); - loops.push_back(loop); - builder = loop.getBodyBuilder(); - } - return loops; -} - -void SetBoundForSimpleLoop(mlir::AffineForOp loop, mlir::AffineExpr new_bound, - mlir::OpBuilder builder) { - CHECK(IsSimpleLoop(loop)); - - loop.setUpperBoundMap(mlir::AffineMap::get( - loop.getUpperBoundMap().getNumDims(), - loop.getUpperBoundMap().getNumSymbols(), {new_bound})); -} - -// Tile a loop with trip count N by `size`. For now, N has to be a multiple of -// size, but later this constraint will be removed. -// -// The major loop (with trip count N / size) stays as-is, while the minor loop -// (with trip count `size`) will take over the body of `target`, and be placed -// as the new body of `target`. -// -// `target` has to be within the same "perfectly nested loop group" as `loop`. -// See the documentation for mlir::getPerfectlyNestedLoops. -// -// Example: -// Before tiling `loop` with tile size X: -// for (loop in N) -// for (unrelated_loop in ...) -// for (target in ...) -// // pass loop into affine maps -// After: -// for (loop in N / X) -// for (unrelated_loop in ...) -// for (target in ...) -// for (tiled_loop in X) -// // rewrite all affine exprs from loop to `loop * X + tiled_loop`. -// -// Design note: -// TileLoop is different from mlir::tile. At the moment, mlir::tile is not well -// documented about the exact tiling semantics, but the observed behavior is: -// for (i from 0 to N) -// for (unrelated_loop in ...) -// for (target in ...) -// // pass i into affine maps -// => -// for (i from 0 to N, step = X) -// for (unrelated_loop in ...) -// for (target in ...) -// for (j from i to min(i + X, N), step = 1) -// // pass j into affine maps -// -// There are two differences between mlir::tile and TileLoop: -// * TileLoop always puts the tiling logic "stepping" logic into AffineExprs. -// With that all index calculation is done in AffineExprs and easier to -// analyze in a single place. -// * TileLoop doesn't plan to use use max() and min() to resolve the issue when -// N % X != 0. max() and min() are not representable in AffineExprs. -// TODO(timshen): support the case where N % X != 0. -// -// TODO(timshen): consider the possibility to reuse mlir::tile's logic to -// achieve the same goal. -mlir::AffineForOp TileLoop(mlir::AffineForOp loop, int64_t size, - mlir::AffineForOp target) { - CHECK(IsSimpleLoop(loop)); - CHECK(IsSimpleLoop(target)); - { - llvm::SmallVector all_loops; - getPerfectlyNestedLoops(all_loops, loop); - CHECK(absl::c_linear_search(all_loops, target)); - } - - auto builder = target.getBodyBuilder(); - - auto inner_loop = - builder.create(builder.getUnknownLoc(), 0, size); - { - auto& inner_operations = inner_loop.getBody()->getOperations(); - auto& target_operations = target.getBody()->getOperations(); - - inner_operations.splice(inner_operations.begin(), target_operations, - target_operations.begin(), - std::prev(target_operations.end(), 2)); - - mlir::AffineExpr length = loop.getUpperBoundMap().getResult(0); - CHECK_EQ(0, length.cast().getValue() % size); - SetBoundForSimpleLoop(loop, length.ceilDiv(size), builder); - } - - for (auto& use : - llvm::make_early_inc_range(loop.getInductionVar().getUses())) { - mlir::Operation* owner = use.getOwner(); - BoundAffineMap affine_map = GetBoundAffineMapFrom(owner); - unsigned new_dim = affine_map.operands.size(); - affine_map.operands.push_back(inner_loop.getInductionVar()); - std::vector replacements; - for (int i = 0; i < affine_map.affine_map.getNumDims(); i++) { - if (affine_map.operands[i] == loop.getInductionVar()) { - replacements.push_back(builder.getAffineDimExpr(i) * size + - builder.getAffineDimExpr(new_dim)); - } else { - replacements.push_back(builder.getAffineDimExpr(i)); - } - } - affine_map.affine_map = affine_map.affine_map.replaceDimsAndSymbols( - replacements, {}, affine_map.operands.size(), 0); - auto new_op = - CloneWithNewAffineMap(owner, affine_map, mlir::OpBuilder(owner)); - owner->replaceAllUsesWith(new_op); - owner->erase(); - } - return inner_loop; -} - // Hoist operations out of `where`. [begin_op, end_op) must be the first // operations of their parent loop, and `where` must be an ancestor of that // parent loop. @@ -387,21 +225,6 @@ mlir::Operation* HoistAndFix(mlir::Operation* op, mlir::AffineForOp where) { return HoistAndFix(op->getIterator(), std::next(op->getIterator()), where); } -// Sinks a segment of perfectly nested loops to the bottom. It implements this -// by rotating the loop nest by rotate_amount. -void SinkPerfectlyNestedLoops(absl::Span loops, - int rotate_amount) { - CHECK_GE(rotate_amount, 0); - std::vector permutation(loops.size()); - std::iota(permutation.begin(), permutation.end(), unsigned(0)); - std::rotate(permutation.begin(), - permutation.begin() + loops.size() - rotate_amount, - permutation.end()); - mlir::interchangeLoops( - llvm::ArrayRef(loops.begin(), loops.end()), - permutation); -} - struct InitialMlirConvAnchors { std::vector cartesian_product_loops; std::vector reduction_loops; diff --git a/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/conv_emitter_transforms.cc b/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/conv_emitter_transforms.cc new file mode 100644 index 00000000000..045d06c9c86 --- /dev/null +++ b/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/conv_emitter_transforms.cc @@ -0,0 +1,150 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/conv_emitter_transforms.h" + +#include "absl/algorithm/container.h" +#include "llvm/ADT/StringRef.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project +#include "mlir/Transforms/LoopUtils.h" // from @llvm-project +#include "tensorflow/core/platform/logging.h" + +namespace xla { +namespace mlir_gpu { + +BoundAffineMap GetBoundAffineMapFrom(mlir::Operation* op) { + if (auto load = mlir::dyn_cast(op)) { + return {load.getAffineMap(), + std::vector(load.getMapOperands().begin(), + load.getMapOperands().end())}; + } else if (auto store = mlir::dyn_cast(op)) { + return {store.getAffineMap(), + std::vector(store.getMapOperands().begin(), + store.getMapOperands().end())}; + } else { + CHECK(false); + } +} + +mlir::Operation* CloneWithNewAffineMap(mlir::Operation* op, + BoundAffineMap new_affine, + mlir::OpBuilder builder) { + if (auto load = mlir::dyn_cast(op)) { + return builder.create( + builder.getUnknownLoc(), load.getMemRef(), new_affine.affine_map, + new_affine.operands); + } else if (auto store = mlir::dyn_cast(op)) { + return builder.create( + builder.getUnknownLoc(), store.getValueToStore(), store.getMemRef(), + new_affine.affine_map, new_affine.operands); + } else { + CHECK(false); + } +} + +bool IsSimpleLoop(mlir::AffineForOp loop) { + return loop.getLowerBoundMap().isSingleConstant() && + loop.getLowerBoundMap().getSingleConstantResult() == 0 && + loop.getStep() == 1 && loop.getUpperBoundMap().getNumResults() == 1 && + std::next(loop.region().begin()) == loop.region().end(); +} + +std::vector CreateNestedSimpleLoops( + absl::Span upper_bounds, mlir::OpBuilder builder) { + std::vector loops; + loops.reserve(upper_bounds.size()); + for (int64_t dim : upper_bounds) { + auto loop = + builder.create(builder.getUnknownLoc(), 0, dim); + loops.push_back(loop); + builder = loop.getBodyBuilder(); + } + return loops; +} + +void SetBoundForSimpleLoop(mlir::AffineForOp loop, mlir::AffineExpr new_bound, + mlir::OpBuilder builder) { + CHECK(IsSimpleLoop(loop)); + + loop.setUpperBoundMap(mlir::AffineMap::get( + loop.getUpperBoundMap().getNumDims(), + loop.getUpperBoundMap().getNumSymbols(), {new_bound})); +} + +mlir::AffineForOp TileLoop(mlir::AffineForOp loop, int64_t size, + mlir::AffineForOp target) { + CHECK(IsSimpleLoop(loop)); + CHECK(IsSimpleLoop(target)); + { + llvm::SmallVector all_loops; + getPerfectlyNestedLoops(all_loops, loop); + CHECK(absl::c_linear_search(all_loops, target)); + } + + auto builder = target.getBodyBuilder(); + + auto inner_loop = + builder.create(builder.getUnknownLoc(), 0, size); + { + auto& inner_operations = inner_loop.getBody()->getOperations(); + auto& target_operations = target.getBody()->getOperations(); + + inner_operations.splice(inner_operations.begin(), target_operations, + target_operations.begin(), + std::prev(target_operations.end(), 2)); + + mlir::AffineExpr length = loop.getUpperBoundMap().getResult(0); + CHECK_EQ(0, length.cast().getValue() % size); + SetBoundForSimpleLoop(loop, length.ceilDiv(size), builder); + } + + for (auto& use : + llvm::make_early_inc_range(loop.getInductionVar().getUses())) { + mlir::Operation* owner = use.getOwner(); + BoundAffineMap affine_map = GetBoundAffineMapFrom(owner); + unsigned new_dim = affine_map.operands.size(); + affine_map.operands.push_back(inner_loop.getInductionVar()); + std::vector replacements; + for (int i = 0; i < affine_map.affine_map.getNumDims(); i++) { + if (affine_map.operands[i] == loop.getInductionVar()) { + replacements.push_back(builder.getAffineDimExpr(i) * size + + builder.getAffineDimExpr(new_dim)); + } else { + replacements.push_back(builder.getAffineDimExpr(i)); + } + } + affine_map.affine_map = affine_map.affine_map.replaceDimsAndSymbols( + replacements, {}, affine_map.operands.size(), 0); + auto new_op = + CloneWithNewAffineMap(owner, affine_map, mlir::OpBuilder(owner)); + owner->replaceAllUsesWith(new_op); + owner->erase(); + } + return inner_loop; +} + +void SinkPerfectlyNestedLoops(llvm::MutableArrayRef loops, + int rotate_amount) { + CHECK_GE(rotate_amount, 0); + std::vector permutation(loops.size()); + std::iota(permutation.begin(), permutation.end(), unsigned(0)); + std::rotate(permutation.begin(), + permutation.begin() + loops.size() - rotate_amount, + permutation.end()); + mlir::permuteLoops(loops, permutation); +} + +} // namespace mlir_gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/conv_emitter_transforms.h b/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/conv_emitter_transforms.h new file mode 100644 index 00000000000..ce4955378c2 --- /dev/null +++ b/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/conv_emitter_transforms.h @@ -0,0 +1,102 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_MLIR_GPU_EXPERIMENTAL_CONV_EMITTER_CONV_EMITTER_TRANSFORMS_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_MLIR_GPU_EXPERIMENTAL_CONV_EMITTER_CONV_EMITTER_TRANSFORMS_H_ + +#include "absl/types/span.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace mlir_gpu { + +struct BoundAffineMap { + mlir::AffineMap affine_map; + std::vector operands; +}; + +BoundAffineMap GetBoundAffineMapFrom(mlir::Operation* op); +mlir::Operation* CloneWithNewAffineMap(mlir::Operation* op, + BoundAffineMap new_affine, + mlir::OpBuilder builder); + +bool IsSimpleLoop(mlir::AffineForOp loop); +std::vector CreateNestedSimpleLoops( + absl::Span upper_bounds, mlir::OpBuilder builder); +void SetBoundForSimpleLoop(mlir::AffineForOp loop, mlir::AffineExpr new_bound, + mlir::OpBuilder builder); + +// Tile a loop with trip count N by `size`. For now, N has to be a multiple of +// size, but later this constraint will be removed. +// +// The major loop (with trip count N / size) stays as-is, while the minor loop +// (with trip count `size`) will take over the body of `target`, and be placed +// as the new body of `target`. +// +// `target` has to be within the same "perfectly nested loop group" as `loop`. +// See the documentation for mlir::getPerfectlyNestedLoops. +// +// Example: +// Before tiling `loop` with tile size X: +// for (loop in N) +// for (unrelated_loop in ...) +// for (target in ...) +// // pass loop into affine maps +// After: +// for (loop in N / X) +// for (unrelated_loop in ...) +// for (target in ...) +// for (tiled_loop in X) +// // rewrite all affine exprs from loop to `loop * X + tiled_loop`. +// +// Design note: +// TileLoop is different from mlir::tile. At the moment, mlir::tile is not well +// documented about the exact tiling semantics, but the observed behavior is: +// for (i from 0 to N) +// for (unrelated_loop in ...) +// for (target in ...) +// // pass i into affine maps +// => +// for (i from 0 to N, step = X) +// for (unrelated_loop in ...) +// for (target in ...) +// for (j from i to min(i + X, N), step = 1) +// // pass j into affine maps +// +// There are two differences between mlir::tile and TileLoop: +// * TileLoop always puts the tiling logic "stepping" logic into AffineExprs. +// With that all index calculation is done in AffineExprs and easier to +// analyze in a single place. +// * TileLoop doesn't plan to use use max() and min() to resolve the issue when +// N % X != 0. max() and min() are not representable in AffineExprs. +// TODO(timshen): support the case where N % X != 0. +// +// TODO(timshen): consider the possibility to reuse mlir::tile's logic to +// achieve the same goal. +mlir::AffineForOp TileLoop(mlir::AffineForOp loop, int64_t size, + mlir::AffineForOp target); + +// Sinks a segment of perfectly nested loops to the bottom. It implements this +// by rotating the loop nest by rotate_amount. +void SinkPerfectlyNestedLoops(llvm::MutableArrayRef loops, + int rotate_amount); + +} // namespace mlir_gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_MLIR_GPU_EXPERIMENTAL_CONV_EMITTER_CONV_EMITTER_TRANSFORMS_H_ diff --git a/tensorflow/compiler/xla/service/transfer_manager.cc b/tensorflow/compiler/xla/service/transfer_manager.cc index 3f6019a059a..4658aebd571 100644 --- a/tensorflow/compiler/xla/service/transfer_manager.cc +++ b/tensorflow/compiler/xla/service/transfer_manager.cc @@ -20,6 +20,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" +#include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" @@ -285,6 +286,26 @@ Status TransferManager::WriteRootTupleIndexTable( stream, elements, device_buffer.on_device_shape(), &device_memory); } +Status TransferManager::WriteRootTupleIndexTable( + se::Stream* stream, const ShapeTree& buffer_tree) { + TF_RET_CHECK(buffer_tree.shape().IsTuple()); + if (ShapeUtil::TupleElementCount(buffer_tree.shape()) == 0) { + return Status::OK(); + } + se::DeviceMemoryBase device_memory = + buffer_tree.element({}).AsDeviceMemoryBase(); + TF_RET_CHECK(GetByteSizeRequirement(buffer_tree.shape()) == + device_memory.size()); + + std::vector elements; + for (int64 i = 0; i < ShapeUtil::TupleElementCount(buffer_tree.shape()); + ++i) { + elements.push_back(buffer_tree.element({i}).AsDeviceMemoryBase()); + } + return WriteSingleTupleIndexTable(stream, elements, buffer_tree.shape(), + &device_memory); +} + Status TransferManager::TransferBufferFromDevice( se::Stream* stream, const se::DeviceMemoryBase& source, int64 size, void* destination) { diff --git a/tensorflow/compiler/xla/service/transfer_manager.h b/tensorflow/compiler/xla/service/transfer_manager.h index 40fda188fe3..e5fa8ebae53 100644 --- a/tensorflow/compiler/xla/service/transfer_manager.h +++ b/tensorflow/compiler/xla/service/transfer_manager.h @@ -22,6 +22,7 @@ limitations under the License. #include "absl/types/span.h" #include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/service/executable.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" @@ -210,6 +211,9 @@ class TransferManager { // rather than writing all subbuffers. This method is always asynchronous. Status WriteRootTupleIndexTable(se::Stream* stream, const ShapedBuffer& device_buffer); + Status WriteRootTupleIndexTable( + se::Stream* stream, + const ShapeTree& buffer_tree); // Determines the byte size requirement for the given shape on the underlying // architecture. This will be used to allocate an appropriately sized memory diff --git a/tensorflow/compiler/xla/service/triangular_solve_expander.cc b/tensorflow/compiler/xla/service/triangular_solve_expander.cc index a19f17996be..cc483c310e8 100644 --- a/tensorflow/compiler/xla/service/triangular_solve_expander.cc +++ b/tensorflow/compiler/xla/service/triangular_solve_expander.cc @@ -320,10 +320,7 @@ XlaOp SolveWithInvertedDiagonalBlocks(XlaOp a, XlaOp b, XlaOp inv_diag_blocks, end = {k, std::min(i * block_size, n)}; } - if (!left_side) { - std::swap(end[0], end[1]); - } - if (transpose_a) { + if (!left_side ^ transpose_a) { std::swap(start[0], start[1]); std::swap(end[0], end[1]); } @@ -337,16 +334,12 @@ XlaOp SolveWithInvertedDiagonalBlocks(XlaOp a, XlaOp b, XlaOp inv_diag_blocks, } XlaOp x_update; - auto zero = Zero(builder, S32); - auto start_index = ConstantR0WithType(builder, S32, j * block_size); - std::vector update_starts = {start_index, zero}; if (left_side) { x_update = BatchDot(inv_block, transpose_a, remainder, false, precision); } else { x_update = BatchDot(remainder, false, inv_block, transpose_a, precision); - std::swap(update_starts[0], update_starts[1]); } if (i == 0) { diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier.cc b/tensorflow/compiler/xla/service/while_loop_simplifier.cc index cec954645cc..1b29da0660a 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier.cc +++ b/tensorflow/compiler/xla/service/while_loop_simplifier.cc @@ -14,6 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/while_loop_simplifier.h" + +#include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/strings/str_cat.h" @@ -21,8 +23,10 @@ limitations under the License. #include "absl/types/optional.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/service/call_inliner.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_query.h" #include "tensorflow/compiler/xla/service/pattern_matcher.h" #include "tensorflow/compiler/xla/service/while_loop_analysis.h" @@ -1010,6 +1014,35 @@ StatusOr WhileLoopSimplifier::Run(HloModule* module) { continue; } + // Do not simplify the loop away when there is a side-effectful op, + // otherwise the infeed op may not inherit the data dependency from + // the while loop. + // + // Example: while_body (param_a) { + // param_a = parameter(0) + // infeed2 = infeed() + // } + // + // infeed1 = ... + // while = while(infeed1), body=while_body // infeed2 has implicit + // dependency on infeed1. + // + // After simplification: + // + // infeed1 = ... + // infeed2 = infeed() // no dependency between infeed1 and infeed2. infeed1 + // // can be scheduled after infeed2. + // + bool has_side_effects = absl::c_any_of( + while_op->called_computations(), [](const HloComputation* computation) { + return computation->HasSideEffect(); + }); + if (has_side_effects) { + VLOG(2) << "Not attempting to simplify while loop because it contains a " + "side-effecting node: " + << while_op->ToShortString(); + continue; + } TF_ASSIGN_OR_RETURN(bool result, TryPropagateConstant(while_op)); changed |= result; diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc index cff0fd458e5..b5f9d0ce9de 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc @@ -209,8 +209,8 @@ TEST_F(WhileLoopSimplifierTest, LoopWithRecvNotSimplified) { EXPECT_FALSE(WhileLoopSimplifier().Run(m.get()).ValueOrDie()); } -// We can simplify loops whose bodies contain infeed or other side-effecting -// instructions other than send/recv. +// We can't simplify loops whose bodies contain infeed or other side-effecting +// instructions. TEST_F(WhileLoopSimplifierTest, LoopWithInfeedSimplified) { auto m = MakeModuleWithSimpleLoop(/*num_iters=*/1); HloComputation* computation = m->entry_computation(); @@ -220,8 +220,7 @@ TEST_F(WhileLoopSimplifierTest, LoopWithInfeedSimplified) { auto token = while_body->AddInstruction(HloInstruction::CreateToken()); while_body->AddInstruction(HloInstruction::CreateInfeed( ShapeUtil::MakeShape(F32, {1}), token, "config")); - EXPECT_TRUE(WhileLoopSimplifier().Run(m.get()).ValueOrDie()); - EXPECT_THAT(m->entry_computation()->root_instruction(), op::Tuple()); + EXPECT_FALSE(WhileLoopSimplifier().Run(m.get()).ValueOrDie()); } // We don't simplify trip-count-1 loops whose *conditions* contain infeed or @@ -445,47 +444,6 @@ TEST_F(WhileLoopSimplifierTest, RemoveUnusedLoopOperands) { op::GetTupleElement(op::Parameter(0), /*tuple_index=*/1))); } -// Check that we can remove unused loop operands even if the loop contains a -// side-effecting instruction. -TEST_F(WhileLoopSimplifierTest, - RemoveUnusedLoopOperandsDespiteSideEffectingOps) { - const string hlo_string = R"( - HloModule RemoveUnusedOperands - body { - loop_var = (s32[]) parameter(0) - gte0 = s32[] get-tuple-element(loop_var), index=0 - token0 = token[] after-all() - unused = ((s32[], pred[]), token[]) infeed(token0) - ROOT tuple = (s32[]) tuple(gte0) - } - cond { - loop_var = (s32[]) parameter(0) - ROOT constant = pred[] constant(true) - } - ENTRY RemoveUnusedOperands { - x = s32[] parameter(0) - tuple.1 = (s32[]) tuple(s32[] x) - ROOT while = (s32[]) while((s32[]) tuple.1), - condition=cond, body=body - } - )"; - - auto m = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); - EXPECT_TRUE(WhileLoopSimplifier().Run(m.get()).ValueOrDie()); - - // The original while instruction is still left in the module as a dead - // instruction, find a while instruction with a different name as the new - // while instruction. - const auto& instrs = m->entry_computation()->instructions(); - HloInstruction* new_while_op = - *absl::c_find_if(instrs, [&](const HloInstruction* instr) { - return (instr->opcode() == HloOpcode::kWhile && - instr->name() != "while"); - }); - EXPECT_TRUE(ShapeUtil::IsEmptyTuple(new_while_op->shape())) - << new_while_op->shape().ToString(); -} - TEST_F(WhileLoopSimplifierTest, LoopWithNonTupleBodyShapeNotSimplified) { const string hlo_string = R"( HloModule BodyHasNonTupleRoot diff --git a/tensorflow/compiler/xla/tests/collective_ops_test.cc b/tensorflow/compiler/xla/tests/collective_ops_test.cc index 464865506f7..3aacf065156 100644 --- a/tensorflow/compiler/xla/tests/collective_ops_test.cc +++ b/tensorflow/compiler/xla/tests/collective_ops_test.cc @@ -218,6 +218,88 @@ XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_half) { TestAllOps(); } +XLA_TEST_F(CollectiveOpsTest, AllReduceAnd_Pred) { + // Test with equal elements. + TestTwoReplicasOneOperand( + "and", + /*input_value=*/LiteralUtil::CreateR1({true, false}), + /*expected_value=*/LiteralUtil::CreateR1({true, false})); + + // Test with {true, false}. + const char* hlo_module = R"( + HloModule test + + apply_op { + x = pred[] parameter(0) + y = pred[] parameter(1) + ROOT apply_op = pred[] and(x, y) + } + + ENTRY test_computation { + id = u32[] replica-id() + c = u32[] constant(0) + p = pred[] compare(id, c), direction=EQ + p2 = pred[1] bitcast(p) + crs = pred[1] all-reduce(p2), replica_groups={}, to_apply=apply_op + copy = pred[1] copy(crs) + ROOT out = pred[1] bitcast(copy) + } + )"; + + auto config = GetModuleConfigForTest(); + config.set_replica_count(2); + auto module = ParseAndReturnVerifiedModule(hlo_module, config).ValueOrDie(); + TF_ASSERT_OK_AND_ASSIGN(std::vector results, + ExecuteReplicated(std::move(module), {}, + /*num_replicas=*/2, + /*use_threads=*/true)); + for (int replica_idx = 0; replica_idx < 2; replica_idx++) { + EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1({false}), + results[replica_idx])); + } +} + +XLA_TEST_F(CollectiveOpsTest, AllReduceOr_Pred) { + // Test with equal elements. + TestTwoReplicasOneOperand( + "or", + /*input_value=*/LiteralUtil::CreateR1({true, false}), + /*expected_value=*/LiteralUtil::CreateR1({true, false})); + + // Test with {true, false}. + const char* hlo_module = R"( + HloModule test + + apply_op { + x = pred[] parameter(0) + y = pred[] parameter(1) + ROOT apply_op = pred[] or(x, y) + } + + ENTRY test_computation { + id = u32[] replica-id() + c = u32[] constant(0) + p = pred[] compare(id, c), direction=EQ + p2 = pred[1] bitcast(p) + crs = pred[1] all-reduce(p2), replica_groups={}, to_apply=apply_op + copy = pred[1] copy(crs) + ROOT out = pred[1] bitcast(copy) + } + )"; + + auto config = GetModuleConfigForTest(); + config.set_replica_count(2); + auto module = ParseAndReturnVerifiedModule(hlo_module, config).ValueOrDie(); + TF_ASSERT_OK_AND_ASSIGN(std::vector results, + ExecuteReplicated(std::move(module), {}, + /*num_replicas=*/2, + /*use_threads=*/true)); + for (int replica_idx = 0; replica_idx < 2; replica_idx++) { + EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1({true}), + results[replica_idx])); + } +} + // Tries all-to-all operations across all 2^kNumDevices - 1 combinations of // devices in sequence. XLA_TEST_F(CollectiveOpsTest, AllReduce_AllCombinations) { diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc index 1a1dda80f18..64d586a9514 100755 --- a/tensorflow/compiler/xla/tests/hlo_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc @@ -312,6 +312,22 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal( reference_preprocessor); } +::testing::AssertionResult HloTestBase::Run(std::unique_ptr module, + bool run_hlo_passes) { + const auto fake_arguments = + MakeFakeArguments(module.get()).ConsumeValueOrDie(); + const auto change = hlo_verifier_->Run(module.get()); + if (!change.ok()) { + return ::testing::AssertionFailure() << change.status(); + } + + const auto output = + test_runner_.Execute(std::move(module), fake_arguments, run_hlo_passes); + return output.ok() + ? ::testing::AssertionSuccess() + : ::testing::AssertionFailure() << output.status().error_message(); +} + ::testing::AssertionResult HloTestBase::RunAndCompare( string_view hlo_string, const absl::optional& error, const std::function& reference_preprocessor) { diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h index eebe26ecde5..0b1801ebe23 100755 --- a/tensorflow/compiler/xla/tests/hlo_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_test_base.h @@ -203,6 +203,11 @@ class HloTestBase : public ::testing::Test { const std::function& reference_preprocessor = nullptr) TF_MUST_USE_RESULT; + // Executes an hlo module with fake inputs and checks that the execution is + // successful. + ::testing::AssertionResult Run(std::unique_ptr module, + bool run_hlo_passes) TF_MUST_USE_RESULT; + // Convenient wrappers for executing and comparing an hlo module with fake // input. Module can be passed in directly, or parsed from an hlo_string, // or loaded from a file. diff --git a/tensorflow/compiler/xla/tests/triangular_solve_test.cc b/tensorflow/compiler/xla/tests/triangular_solve_test.cc index f2a95ab126a..f3358f65ce3 100644 --- a/tensorflow/compiler/xla/tests/triangular_solve_test.cc +++ b/tensorflow/compiler/xla/tests/triangular_solve_test.cc @@ -458,7 +458,7 @@ XLA_TEST_P(TriangularSolveParametricTest, Random) { Array2D avals(spec.m, spec.m); avals.FillRandom(1.0); for (int i = 0; i < spec.m; ++i) { - avals(i, i) += 10; + avals(i, i) += 30; } std::pair bdims = spec.left_side ? std::make_pair(spec.m, spec.n) @@ -481,13 +481,13 @@ XLA_TEST_P(TriangularSolveParametricTest, Random) { } ComputeAndCompareR2(&builder, bvals, {a_data.get(), b_data.get()}, - ErrorSpec(1e-2, 1e-2)); + ErrorSpec(3e-2, 3e-2)); } std::vector TriangularSolveTests() { std::vector specs; - for (int m : {5, 10}) { - for (int n : {5, 10}) { + for (int m : {5, 10, 150}) { + for (int n : {5, 10, 150}) { for (bool left_side : {false, true}) { for (bool lower : {false, true}) { for (TriangularSolveOptions::Transpose transpose_a : diff --git a/tensorflow/compiler/xla/util.cc b/tensorflow/compiler/xla/util.cc index 6711779cd2b..1fbce96625b 100644 --- a/tensorflow/compiler/xla/util.cc +++ b/tensorflow/compiler/xla/util.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" #include "absl/container/inlined_vector.h" #include "absl/strings/match.h" @@ -28,6 +29,7 @@ limitations under the License. #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "absl/strings/str_split.h" +#include "absl/types/optional.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/bfloat16/bfloat16.h" #include "tensorflow/core/lib/core/errors.h" @@ -301,6 +303,40 @@ absl::InlinedVector, 8> CommonFactors( return bounds; } +ConvertedDimensionNumbers ConvertDimensionNumbers( + absl::Span from_dimensions, absl::Span from_sizes, + absl::Span to_sizes) { + ConvertedDimensionNumbers dimensions; + auto common_factors = CommonFactors(from_sizes, to_sizes); + for (int64 i = 0; i < common_factors.size() - 1; ++i) { + bool any_present = false; + bool all_present = true; + for (int64 d = common_factors[i].first; d < common_factors[i + 1].first; + ++d) { + const bool present = absl::c_linear_search(from_dimensions, d); + any_present |= present; + all_present &= present; + } + if (all_present) { + for (int64 d = common_factors[i].second; d < common_factors[i + 1].second; + ++d) { + dimensions.to_dimensions.push_back(d); + } + for (int64 d = common_factors[i].first; d < common_factors[i + 1].first; + ++d) { + dimensions.transformed_from_dimensions.push_back(d); + } + } else if (any_present) { + for (int64 d = common_factors[i].first; d < common_factors[i + 1].first; + ++d) { + if (absl::c_linear_search(from_dimensions, d)) { + dimensions.untransformed_from_dimensions.push_back(d); + } + } + } + } + return dimensions; +} string SanitizeFileName(string file_name) { for (char& c : file_name) { if (c == '/' || c == '\\' || c == '[' || c == ']' || c == ' ') { diff --git a/tensorflow/compiler/xla/util.h b/tensorflow/compiler/xla/util.h index 3ef41249d24..44a5bf4ea33 100644 --- a/tensorflow/compiler/xla/util.h +++ b/tensorflow/compiler/xla/util.h @@ -29,6 +29,7 @@ limitations under the License. #include "absl/container/inlined_vector.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/status.h" @@ -506,6 +507,18 @@ int64 Product(absl::Span xs); absl::InlinedVector, 8> CommonFactors( absl::Span a, absl::Span b); +struct ConvertedDimensionNumbers { + DimensionVector transformed_from_dimensions; + DimensionVector untransformed_from_dimensions; + DimensionVector to_dimensions; +}; + +// Convert and unsorted list of dimensions from one shapes dimension sizes to +// another shapes dimensions sizes. +ConvertedDimensionNumbers ConvertDimensionNumbers( + absl::Span from_dimensions, absl::Span from_sizes, + absl::Span to_sizes); + // Removes illegal characters from filenames. string SanitizeFileName(string file_name); diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 28fbcaa4c18..4fd816dae4e 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -199,7 +199,6 @@ COMMON_PROTO_SRCS = [ "protobuf/tensor_bundle.proto", "protobuf/saver.proto", "protobuf/verifier_config.proto", - "protobuf/trace_events.proto", ] EXAMPLE_PROTO_SRCS = [ @@ -2551,6 +2550,8 @@ filegroup( "common_runtime/executor_factory.h", "common_runtime/function_optimization_registry.h", "common_runtime/graph_optimizer.h", + "common_runtime/graph_view.h", + "common_runtime/immutable_executor_state.h", "common_runtime/input_colocation_exemption_registry.h", "common_runtime/isolate_placer_inspection_required_ops_pass.h", "common_runtime/local_device.h", @@ -2614,7 +2615,9 @@ tf_cuda_library( "common_runtime/function_optimization_registry.cc", "common_runtime/graph_optimizer.cc", "common_runtime/graph_runner.cc", + "common_runtime/graph_view.cc", "common_runtime/hierarchical_tree_broadcaster.cc", + "common_runtime/immutable_executor_state.cc", "common_runtime/input_colocation_exemption_registry.cc", "common_runtime/inspecting_placer.cc", "common_runtime/isolate_placer_inspection_required_ops_pass.cc", diff --git a/tensorflow/core/api_def/base_api/api_def_LinSpace.pbtxt b/tensorflow/core/api_def/base_api/api_def_LinSpace.pbtxt index f7068106627..06db7468a1c 100644 --- a/tensorflow/core/api_def/base_api/api_def_LinSpace.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_LinSpace.pbtxt @@ -1,5 +1,6 @@ op { graph_op_name: "LinSpace" + visibility: HIDDEN in_arg { name: "start" description: <requested_size; MarkFree(h); + // TraceMe needs to be added after MarkFree and before InsertFreeChunkIntoBin + // for correct memory stats. + AddTraceMe("MemoryDeallocation", ptr); // Consider coalescing it. if (timing_counter_) { @@ -614,8 +616,6 @@ void BFCAllocator::DeallocateRawInternal(void* ptr) { if (VLOG_IS_ON(4)) { LOG(INFO) << "F: " << RenderOccupancy(); } - - AddTraceMe("MemoryDeallocation", -requested_bytes); } // Merges h1 and h2 when Chunk(h1)->next is h2 and Chunk(h2)->prev is c1. diff --git a/tensorflow/core/common_runtime/bfc_allocator.h b/tensorflow/core/common_runtime/bfc_allocator.h index a41ca5a1066..94506bb3b7e 100644 --- a/tensorflow/core/common_runtime/bfc_allocator.h +++ b/tensorflow/core/common_runtime/bfc_allocator.h @@ -116,8 +116,9 @@ class BFCAllocator : public Allocator { TF_EXCLUSIVE_LOCKS_REQUIRED(lock_); // Add TraceMe (in memory allocation and deallocation) for memory stats - // profiling. The requested_bytes can be negative if it's a deallocation. - void AddTraceMe(absl::string_view traceme_name, int64 requested_bytes) + // profiling. The chunk_ptr is passed to get information such as address, + // chunk size and requested_size. + void AddTraceMe(absl::string_view traceme_name, const void* chunk_ptr) TF_EXCLUSIVE_LOCKS_REQUIRED(lock_); // A ChunkHandle is an index into the chunks_ vector in BFCAllocator diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc index 029582c04fd..2f5fbd2353c 100644 --- a/tensorflow/core/common_runtime/direct_session.cc +++ b/tensorflow/core/common_runtime/direct_session.cc @@ -1337,10 +1337,11 @@ Status DirectSession::CreateExecutors( device_mgr_.get(), options_.env, &options_.config, graph_def_version, func_info->flib_def.get(), optimizer_opts, thread_pools_[0].first, /*parent=*/nullptr, custom_kernel_creator, session_metadata, - [](const int64, const DeviceMgr* device_mgr, Rendezvous** r) { - *r = new IntraProcessRendezvous(device_mgr); - return Status::OK(); - })); + Rendezvous::Factory{ + [](const int64, const DeviceMgr* device_mgr, Rendezvous** r) { + *r = new IntraProcessRendezvous(device_mgr); + return Status::OK(); + }})); GraphOptimizer optimizer(optimizer_opts); for (auto iter = graphs.begin(); iter != graphs.end(); ++iter) { diff --git a/tensorflow/core/common_runtime/eager/context.cc b/tensorflow/core/common_runtime/eager/context.cc index 49403c080f6..33221e51218 100644 --- a/tensorflow/core/common_runtime/eager/context.cc +++ b/tensorflow/core/common_runtime/eager/context.cc @@ -128,11 +128,11 @@ void EagerContext::ResetPFLR(const DeviceMgr* device_mgr, Env* env, thread::ThreadPool* thread_pool, DistributedFunctionLibraryRuntime* cluster_flr, const CustomKernelCreator* custom_kernel_creator) { - Rendezvous::Factory rendezvous_factory = + Rendezvous::Factory rendezvous_factory{ [this](const int64 step_id, const DeviceMgr*, Rendezvous** r) { *r = CreateRendezvous(step_id); return Status::OK(); - }; + }}; if (lazy_copy_function_remote_inputs_) { pflr_.reset(new eager::EagerProcessFunctionLibraryRuntime( device_mgr, env, config, graph_def_version, lib_def, optimizer_options, @@ -1102,6 +1102,7 @@ Status EagerContext::UpdateRemoteMaster( if (rendezvous_ != nullptr) rendezvous_->Unref(); rendezvous_ = r; remote_eager_workers_ = std::move(remote_eager_workers); + pflr_->InitializeDeviceSet(); InitPrioritizedDeviceTypeList(); default_executor_.ClearError(); diff --git a/tensorflow/core/common_runtime/eager/copy_to_device_node.h b/tensorflow/core/common_runtime/eager/copy_to_device_node.h index 1a1459a9f1c..793513c5c5f 100644 --- a/tensorflow/core/common_runtime/eager/copy_to_device_node.h +++ b/tensorflow/core/common_runtime/eager/copy_to_device_node.h @@ -50,8 +50,8 @@ class CopyToDeviceNode : public EagerNode { Status Run() override { tensorflow::Tensor tensor; - MEMDEBUG_CACHE_OP(MEMDEBUG_CACHE_VAL ? MEMDEBUG_CACHE_VAL - : "eager::CopyToDeviceNode"); + auto op_annotation = ScopedMemoryDebugAnnotation( + pending_op_name ? pending_op_name : "eager::CopyToDeviceNode"); TF_RETURN_IF_ERROR(src_->CopyToDevice(ctx_, dstd_, &tensor)); if (!async_ && mirror_) { return dst_->AddLocalMirror(std::move(tensor), dstd_); diff --git a/tensorflow/core/common_runtime/eager/eager_operation.cc b/tensorflow/core/common_runtime/eager/eager_operation.cc index 94b85a190c1..81d0528c8a2 100644 --- a/tensorflow/core/common_runtime/eager/eager_operation.cc +++ b/tensorflow/core/common_runtime/eager/eager_operation.cc @@ -57,10 +57,16 @@ Status EagerOperation::Reset( cancellation_manager_ = nullptr; executor_ = executor ? executor : &ctx_.Executor(); remote_func_params_ = remote_func_params; -#ifdef TENSORFLOW_MEM_DEBUG op_name_ = op; -#endif - return SetDeviceName(raw_device_name, true); + if (raw_device_name != nullptr && strlen(raw_device_name) > 0) { + return SetDeviceName(raw_device_name); + } else { + raw_device_name_.clear(); + device_name_.clear(); + device_parsed_name_.Clear(); + device_ = kVariantDeviceNull; + return Status::OK(); + } } Status EagerOperation::MaybeInferSingleInputAttrs(TensorHandle* handle) { @@ -130,7 +136,7 @@ Status EagerOperation::InferInputListAttrs(int num_inputs) { return Status::OK(); } -Status EagerOperation::SetDeviceName(const char* device, const bool reset) { +Status EagerOperation::SetDeviceName(const char* device) { if (device != nullptr && strlen(device) > 0) { if (device != raw_device_name_) { if (!DeviceNameUtils::ParseFullName(device, &device_parsed_name_)) { @@ -152,11 +158,6 @@ Status EagerOperation::SetDeviceName(const char* device, const bool reset) { device_ = kVariantDeviceNull; } } - } else if (reset) { - raw_device_name_.clear(); - device_name_.clear(); - device_parsed_name_.Clear(); - device_ = kVariantDeviceNull; } return Status::OK(); } diff --git a/tensorflow/core/common_runtime/eager/eager_operation.h b/tensorflow/core/common_runtime/eager/eager_operation.h index 4b46fc5c709..1ba55ead83d 100644 --- a/tensorflow/core/common_runtime/eager/eager_operation.h +++ b/tensorflow/core/common_runtime/eager/eager_operation.h @@ -98,7 +98,7 @@ class EagerOperation { const DeviceNameUtils::ParsedName& GetDeviceParsedName() const { return device_parsed_name_; } - Status SetDeviceName(const char* device, const bool reset = false); + Status SetDeviceName(const char* device); // Indicates whether the op is assigned to a device that is local to the // current host. @@ -121,10 +121,9 @@ class EagerOperation { return remote_func_params_; } -#ifdef TENSORFLOW_MEM_DEBUG + // Op name recorded for memory debugging purpose. const char* op_name() const { return op_name_; } const char* op_name_ = nullptr; -#endif Status MaybeInferSingleInputAttrs(TensorHandle* handle); Status InferInputListAttrs(int num_inputs); diff --git a/tensorflow/core/common_runtime/eager/execute.cc b/tensorflow/core/common_runtime/eager/execute.cc index 3d4cf6ae8fc..a1696e5f2d1 100644 --- a/tensorflow/core/common_runtime/eager/execute.cc +++ b/tensorflow/core/common_runtime/eager/execute.cc @@ -373,7 +373,7 @@ Status MustCompileWithXLA(const EagerOperation* op, const EagerContext& ctx, // running without an explicitly requested device. Status EagerLocalExecute(EagerOperation* op, TensorHandle** retvals, int* num_retvals) { - MEMDEBUG_CACHE_OP(op->op_name()); + auto op_annotation = ScopedMemoryDebugAnnotation(op->op_name()); profiler::TraceMe activity( [&] { return absl::StrCat("EagerLocalExecute: ", op->Name()); }, profiler::TraceMeLevel::kInfo); @@ -583,11 +583,11 @@ Status EagerLocalExecute(EagerOperation* op, TensorHandle** retvals, if (executor.Async()) { const DataTypeVector& output_dtypes = kernel->output_dtypes(); for (int i = 0; i < num_outputs; ++i) { - TF_RETURN_IF_ERROR(TensorHandle::CreateEmptyLocalHandle( + retvals[i] = TensorHandle::CreateEmptyLocalHandle( /* d= */ ctx.CanonicalDevice(kernel->OutputDevice(i)), /* op_device= */ kernel->device(), /* resource_device= */ kernel->OutputResourceDevice(i), - output_dtypes[i], &ctx, &retvals[i])); + output_dtypes[i], &ctx); } auto node = absl::make_unique( &ctx, op->Inputs(), op->remote_func_params(), std::move(kernel), @@ -773,18 +773,8 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals, // remote device here. We just need to know that it is remote. If we need // to copy this tensor to this process, the remote end will know the // correct device of this handle. - Status status = TensorHandle::CreateUnshapedRemoteHandle( - id, i, remote_task, output_dtypes[i], op_device, &ctx, &retvals[i]); - if (!status.ok()) { - for (int j = 0; j < i; ++j) { - retvals[j]->PoisonRemote( - errors::Internal( - "Failed to construct unshaped remote tensor handle at index ", - i, " for op ", op->Name()), - op_device, ctx.GetContextViewId()); - } - return status; - } + retvals[i] = TensorHandle::CreateUnshapedRemoteHandle( + id, i, remote_task, output_dtypes[i], op_device, &ctx); } if (ctx.LazyCopyFunctionRemoteInputs()) { @@ -1056,12 +1046,11 @@ Status EagerKernelExecute( for (int i = 0; i < retvals.size(); ++i) { if (retvals[i] == nullptr) { - TF_RETURN_IF_ERROR(TensorHandle::CreateLocalHandle( + retvals[i] = TensorHandle::CreateLocalHandle( std::move(outputs[i]), /* d= */ ctx->CanonicalDevice(kernel->OutputDevice(i)), /* op_device= */ kernel->device(), - /* resource_device= */ kernel->OutputResourceDevice(i), ctx, - &retvals[i])); + /* resource_device= */ kernel->OutputResourceDevice(i), ctx); } else { DCHECK_EQ(kernel->device(), retvals[i]->op_device()); DCHECK_EQ(ctx->CanonicalDevice(kernel->OutputDevice(i)), @@ -1100,8 +1089,8 @@ Status LocalEagerCopyToDevice(TensorHandle* h, EagerContext* ctx, h->Ref(); *result = h; } else { - TF_RETURN_IF_ERROR(TensorHandle::CreateEmptyLocalHandle( - d, dstd, h->resource_device(), h->dtype, ctx, result)); + *result = TensorHandle::CreateEmptyLocalHandle( + d, dstd, h->resource_device(), h->dtype, ctx); } Status s; @@ -1169,9 +1158,9 @@ Status EagerCopyToDevice(TensorHandle* h, EagerContext* ctx, h->Ref(); *result = h; } else { - TF_RETURN_IF_ERROR(TensorHandle::CreateEmptyLocalHandle( + *result = TensorHandle::CreateEmptyLocalHandle( /* d= */ d, /* op_device= */ device, - /*resource_device=*/nullptr, h->dtype, ctx, result)); + /*resource_device=*/nullptr, h->dtype, ctx); } } else { if (mirror) { @@ -1194,8 +1183,8 @@ Status EagerCopyToDevice(TensorHandle* h, EagerContext* ctx, h->Ref(); *result = h; } else { - TF_RETURN_IF_ERROR(TensorHandle::CreateUnshapedRemoteHandle( - recv_op_id, 0, remote_task, h->dtype, device, ctx, result)); + *result = TensorHandle::CreateUnshapedRemoteHandle( + recv_op_id, 0, remote_task, h->dtype, device, ctx); } } diff --git a/tensorflow/core/common_runtime/eager/process_function_library_runtime.cc b/tensorflow/core/common_runtime/eager/process_function_library_runtime.cc index 2051a23f14b..c073dc1fd88 100644 --- a/tensorflow/core/common_runtime/eager/process_function_library_runtime.cc +++ b/tensorflow/core/common_runtime/eager/process_function_library_runtime.cc @@ -31,12 +31,6 @@ void EagerProcessFunctionLibraryRuntime::RunRemoteDevice( FunctionLibraryRuntime::Handle local_handle, gtl::ArraySlice args, std::vector* rets, FunctionLibraryRuntime::DoneCallback done) const { - if (!rets->empty()) { - done( - errors::Unimplemented("Remote outputs are not supported by " - "EagerClusterFunctionLibraryRuntime yet.")); - return; - } parent_->Run(opts, local_handle, args, rets, std::move(done)); } @@ -50,8 +44,8 @@ void EagerProcessFunctionLibraryRuntime::Run( std::move(done)); } auto* cleanup_items = new std::vector>; - done = - ApplyCleanUpToDoneCallback(cleanup_items, done, /*rendezvous=*/nullptr); + done = ApplyCleanUpToDoneCallback(cleanup_items, done, opts.step_id, + /*rendezvous=*/nullptr); auto get_component_args = [&args](const ComponentFunctionData& comp_data, InternalArgs* comp_args) -> Status { diff --git a/tensorflow/core/common_runtime/eager/tensor_handle.cc b/tensorflow/core/common_runtime/eager/tensor_handle.cc index dc805d091bf..e0d2a54728a 100644 --- a/tensorflow/core/common_runtime/eager/tensor_handle.cc +++ b/tensorflow/core/common_runtime/eager/tensor_handle.cc @@ -106,40 +106,36 @@ Status TensorHandle::GetResourceAllowedDevices(std::vector* result) { return GetResourceHandleInfoImpl(get_resource_info); } -Status TensorHandle::CreateLocalHandle(const tensorflow::Tensor& t, - TensorHandle** h) { +TensorHandle* TensorHandle::CreateLocalHandle(const tensorflow::Tensor& t) { // TODO(b/136608821): Move away from nullptr tensorflow::Tensor tensor = t; return CreateLocalHandle(std::move(tensor), /*d=*/nullptr, /*op_device=*/nullptr, - /*ctx=*/nullptr, h); + /*ctx=*/nullptr); } -Status TensorHandle::CreateLocalHandle(tensorflow::Tensor&& t, Device* d, - Device* op_device, EagerContext* ctx, - TensorHandle** h) { - return CreateLocalHandle(std::move(t), d, op_device, nullptr, ctx, h); +TensorHandle* TensorHandle::CreateLocalHandle(tensorflow::Tensor&& t, Device* d, + Device* op_device, + EagerContext* ctx) { + return CreateLocalHandle(std::move(t), d, op_device, nullptr, ctx); } -Status TensorHandle::CreateLocalHandle(tensorflow::Tensor&& t, Device* d, - Device* op_device, - Device* resource_device, - EagerContext* ctx, TensorHandle** h) { +TensorHandle* TensorHandle::CreateLocalHandle(tensorflow::Tensor&& t, Device* d, + Device* op_device, + Device* resource_device, + EagerContext* ctx) { if (t.dtype() == DT_RESOURCE && t.NumElements() > 0) { - *h = new TensorHandle(std::move(t), d, op_device, ctx); + return new TensorHandle(std::move(t), d, op_device, ctx); } else { - *h = new TensorHandle(std::move(t), d, op_device, resource_device, ctx); + return new TensorHandle(std::move(t), d, op_device, resource_device, ctx); } - - return Status::OK(); } -Status TensorHandle::CreateLocalHandle(tensorflow::Tensor&& t, CustomDevice* d, - EagerContext* ctx, TensorHandle** h) { - *h = new TensorHandle(std::move(t), d, ctx); - - return Status::OK(); +TensorHandle* TensorHandle::CreateLocalHandle(tensorflow::Tensor&& t, + CustomDevice* d, + EagerContext* ctx) { + return new TensorHandle(std::move(t), d, ctx); } TensorHandle::TensorHandle(tensorflow::Tensor&& t, Device* d, Device* op_device, @@ -190,13 +186,11 @@ TensorHandle::TensorHandle(tensorflow::Tensor&& t, CustomDevice* d, << " tensor: " << t.DeviceSafeDebugString(); } -Status TensorHandle::CreateEmptyLocalHandle(Device* d, Device* op_device, - Device* resource_device, - DataType dtype, EagerContext* ctx, - TensorHandle** h) { - *h = new TensorHandle(d, op_device, resource_device, dtype, ctx); - - return Status::OK(); +TensorHandle* TensorHandle::CreateEmptyLocalHandle(Device* d, Device* op_device, + Device* resource_device, + DataType dtype, + EagerContext* ctx) { + return new TensorHandle(d, op_device, resource_device, dtype, ctx); } TensorHandle::TensorHandle(Device* d, Device* op_device, @@ -214,14 +208,10 @@ TensorHandle::TensorHandle(Device* d, Device* op_device, } #if !defined(IS_MOBILE_PLATFORM) -Status TensorHandle::CreateUnshapedRemoteHandle(int64 op_id, int32 output_num, - const string& remote_task, - DataType dtype, Device* d, - EagerContext* ctx, - TensorHandle** h) { - *h = new TensorHandle(op_id, output_num, remote_task, dtype, d, ctx); - - return Status::OK(); +TensorHandle* TensorHandle::CreateUnshapedRemoteHandle( + int64 op_id, int32 output_num, const string& remote_task, DataType dtype, + Device* d, EagerContext* ctx) { + return new TensorHandle(op_id, output_num, remote_task, dtype, d, ctx); } TensorHandle::TensorHandle(int64 op_id, int32 output_num, @@ -239,13 +229,11 @@ TensorHandle::TensorHandle(int64 op_id, int32 output_num, << " device: " << VariantDeviceDebugString(device_); } -Status TensorHandle::CreateLazyRemoteHandle(int64 op_id, int32 output_num, - DataType dtype, Device* d, - EagerContext* ctx, - TensorHandle** h) { - *h = new TensorHandle(op_id, output_num, dtype, d, ctx); - - return Status::OK(); +TensorHandle* TensorHandle::CreateLazyRemoteHandle(int64 op_id, + int32 output_num, + DataType dtype, Device* d, + EagerContext* ctx) { + return new TensorHandle(op_id, output_num, dtype, d, ctx); } TensorHandle::TensorHandle(int64 op_id, int32 output_num, DataType dtype, diff --git a/tensorflow/core/common_runtime/eager/tensor_handle.h b/tensorflow/core/common_runtime/eager/tensor_handle.h index 030976f32b8..a67345b1156 100644 --- a/tensorflow/core/common_runtime/eager/tensor_handle.h +++ b/tensorflow/core/common_runtime/eager/tensor_handle.h @@ -77,27 +77,27 @@ class TensorHandle : public core::RefCounted { public: // TensorHandle with no assigned device - static Status CreateLocalHandle(const tensorflow::Tensor& t, - TensorHandle** h); - static Status CreateLocalHandle(tensorflow::Tensor&& t, Device* d, - Device* op_device, EagerContext* ctx, - TensorHandle** h); - static Status CreateLocalHandle(tensorflow::Tensor&& t, Device* d, - Device* op_device, Device* resource_device, - EagerContext* ctx, TensorHandle** h); - static Status CreateLocalHandle(tensorflow::Tensor&& t, CustomDevice* d, - EagerContext* ctx, TensorHandle** h); - static Status CreateEmptyLocalHandle(Device* d, Device* op_device, - Device* resource_device, DataType dtype, - EagerContext* ctx, TensorHandle** h); + static TensorHandle* CreateLocalHandle(const tensorflow::Tensor& t); + static TensorHandle* CreateLocalHandle(tensorflow::Tensor&& t, Device* d, + Device* op_device, EagerContext* ctx); + static TensorHandle* CreateLocalHandle(tensorflow::Tensor&& t, Device* d, + Device* op_device, + Device* resource_device, + EagerContext* ctx); + static TensorHandle* CreateLocalHandle(tensorflow::Tensor&& t, + CustomDevice* d, EagerContext* ctx); + static TensorHandle* CreateEmptyLocalHandle(Device* d, Device* op_device, + Device* resource_device, + DataType dtype, + EagerContext* ctx); #if !defined(IS_MOBILE_PLATFORM) - static Status CreateUnshapedRemoteHandle(int64 op_id, int32 output_num, - const string& remote_task, - DataType dtype, Device* d, - EagerContext* ctx, TensorHandle** h); - static Status CreateLazyRemoteHandle(int64 op_id, int32 output_num, - DataType dtype, Device* d, - EagerContext* ctx, TensorHandle** h); + static TensorHandle* CreateUnshapedRemoteHandle(int64 op_id, int32 output_num, + const string& remote_task, + DataType dtype, Device* d, + EagerContext* ctx); + static TensorHandle* CreateLazyRemoteHandle(int64 op_id, int32 output_num, + DataType dtype, Device* d, + EagerContext* ctx); #endif // IS_MOBILE_PLATFORM ~TensorHandle() override { DVLOG(3) << "Deleting TensorHandle " << this; } diff --git a/tensorflow/core/common_runtime/eager/tensor_handle_test.cc b/tensorflow/core/common_runtime/eager/tensor_handle_test.cc index 6c62334281c..9485dfb4764 100644 --- a/tensorflow/core/common_runtime/eager/tensor_handle_test.cc +++ b/tensorflow/core/common_runtime/eager/tensor_handle_test.cc @@ -38,15 +38,10 @@ TEST(TensorHandle_ShapeTest, AsyncShape) { tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT, tensorflow::ContextMirroringPolicy::MIRRORING_NONE, false, false, &device_mgr, false, nullptr, nullptr, nullptr); - TensorHandle* sync_th; - EXPECT_TRUE(TensorHandle::CreateLocalHandle(std::move(t), nullptr, nullptr, - ctx, &sync_th) - .ok()); - TensorHandle* async_th; - EXPECT_TRUE(TensorHandle::CreateEmptyLocalHandle(nullptr, nullptr, nullptr, - DataType::DT_UINT16, ctx, - &async_th) - .ok()); + TensorHandle* sync_th = + TensorHandle::CreateLocalHandle(std::move(t), nullptr, nullptr, ctx); + TensorHandle* async_th = TensorHandle::CreateEmptyLocalHandle( + nullptr, nullptr, nullptr, DataType::DT_UINT16, ctx); EXPECT_TRUE(async_th->CopyInferenceShape(sync_th).ok()); diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc index 946654a1605..43972589c17 100644 --- a/tensorflow/core/common_runtime/executor.cc +++ b/tensorflow/core/common_runtime/executor.cc @@ -16,17 +16,14 @@ limitations under the License. #include "tensorflow/core/common_runtime/executor.h" #include -#include #include -#include -#include #include #include "absl/memory/memory.h" -#include "absl/strings/string_view.h" #include "tensorflow/core/common_runtime/costmodel_manager.h" #include "tensorflow/core/common_runtime/executor_factory.h" -#include "tensorflow/core/common_runtime/metrics.h" +#include "tensorflow/core/common_runtime/graph_view.h" +#include "tensorflow/core/common_runtime/immutable_executor_state.h" #include "tensorflow/core/common_runtime/pending_counts.h" #include "tensorflow/core/common_runtime/renamed_device.h" #include "tensorflow/core/common_runtime/step_stats_collector.h" @@ -49,15 +46,11 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/notification.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/lib/gtl/flatmap.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/gtl/inlined_vector.h" #include "tensorflow/core/lib/gtl/manual_constructor.h" #include "tensorflow/core/lib/hash/hash.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/context.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" @@ -78,10 +71,6 @@ namespace { // 1-D, 0 element tensor. static const Tensor* const kEmptyTensor = new Tensor; -bool IsInitializationOp(const Node* node) { - return node->op_def().allows_uninitialized_input(); -} - // Helper routines for collecting step stats. namespace nodestats { inline int64 NowInNsec() { return EnvTime::NowNanos(); } @@ -124,19 +113,6 @@ void SetMemory(NodeExecStatsInterface* stats, OpKernelContext* ctx) { } // namespace nodestats class ExecutorImpl; -class GraphView; - -struct EdgeInfo { - int dst_id; - int output_slot : 31; - // true if this is the last info for output_slot in the EdgeInfo list. - bool is_last : 1; - int input_slot; -}; - -struct ControlEdgeInfo { - int dst_id; -}; // Time the execution of kernels (in CPU cycles). Used to dynamically identify // inexpensive kernels which can be dispatched inline. @@ -148,222 +124,19 @@ struct KernelTimer { } }; -// Compact structure representing a graph node and its associated kernel. -// -// Each NodeItem is an element of exactly one GraphView. -struct NodeItem { - NodeItem() {} - - // The index of this node's item in its GraphView. - int node_id = -1; - - // Cached attributes of this node for fast lookup. - bool kernel_is_async : 1; // True iff kernel->AsAsync() != nullptr - bool is_merge : 1; // True iff IsMerge(node) - bool is_enter : 1; // True iff IsEnter(node) - bool is_constant_enter : 1; // True iff IsEnter(node) and - // node->GetAttr("is_constant") == true. - bool is_exit : 1; // True iff IsExit(node) - bool is_control_trigger : 1; // True iff IsControlTrigger(node) - bool is_source : 1; // True iff IsSource(node) - // True iff IsEnter(node) || IsExit(node) || IsNextIteration(node) - bool is_enter_exit_or_next_iter : 1; - bool is_transfer_node : 1; // True iff IsTransferNode(node) - bool is_initialization_op : 1; // True iff IsInitializationOp(node) - bool is_recv_or_switch : 1; // True iff IsRecv(node) || IsSwitch(node) - bool is_next_iteration : 1; // True iff IsNextIteration(node) - bool is_noop : 1; // True iff item->kernel->type_string_view() == "NoOp") - bool - is_any_consumer_merge_or_control_trigger : 1; // True iff the destination - // of any output edge is a - // merge or control trigger - // node. - - // The kernel for this node. - OpKernel* kernel = nullptr; - - // If the kernel is a Const op, this containts points to the constant tensor. - const Tensor* const_tensor = nullptr; - - // Cached values of node->num_inputs() and node->num_outputs(), to - // avoid levels of indirection. - int num_inputs; - int num_outputs; - - // ExecutorImpl::tensors_[input_start] is the 1st positional input - // for this node. - int input_start = 0; - - // Number of output edges, excluding control edges. - int32 num_output_edges; - - // Number of output control edges. - int32 num_output_control_edges; - - // If non-null, contains an array of num_outputs bools, where the ith bool - // is true if and only if the ith output is consumed by another node. - std::unique_ptr outputs_required; - - gtl::MutableArraySlice mutable_output_edges() { - return gtl::MutableArraySlice(output_edge_base(), - num_output_edges); - } - - gtl::ArraySlice output_edges() const { - return gtl::ArraySlice(output_edge_base(), num_output_edges); - } - - gtl::ArraySlice output_control_edges() const { - return gtl::ArraySlice(output_control_edge_base(), - num_output_control_edges); - } - - DataType input_type(int i) const { - DCHECK_LT(i, num_inputs); - return static_cast(input_type_base()[i]); - } - DataType output_type(int i) const { - DCHECK_LT(i, num_outputs); - return static_cast(output_type_base()[i]); - } - - // Return array of per-output allocator attributes. - const AllocatorAttributes* output_attrs() const { return output_attr_base(); } - - // Return array of expected input index from which each output should - // be forwarded: - // kNeverForward (-2) for DO NOT FORWARD (must allocate). - // kNoReservation (-1) for no expected forwarding. - // 0... for forward from that input. - const int* forward_from() const { return forward_from_base(); } - - string DebugString() const { - string ret = strings::StrCat("{name:'", kernel->name(), "' id:", node_id); - if (is_source) { - strings::StrAppend(&ret, " source}"); - } else { - strings::StrAppend(&ret, " def:{", SummarizeNodeDef(kernel->def()), "}}"); - } - return ret; - } - - private: - friend class GraphView; - - // Variable length section starts immediately after *this - // (uint8 is enough for DataType). - // EdgeInfo out_edges[num_out_edges]; - // AllocatorAttributes output_attr[num_outputs]; - // int forward_from[num_outputs]; - // uint8 input_type[num_inputs]; - // uint8 output_type[num_outputs]; - - // Return pointer to variable length section. - char* var() const { - return const_cast(reinterpret_cast(this) + - sizeof(NodeItem)); - } - - EdgeInfo* output_edge_base() const { - return reinterpret_cast(var()); - } - - ControlEdgeInfo* output_control_edge_base() const { - return reinterpret_cast(var() + sizeof(EdgeInfo) * - num_output_edges); - } - - AllocatorAttributes* output_attr_base() const { - return reinterpret_cast( - var() + sizeof(EdgeInfo) * num_output_edges + - sizeof(ControlEdgeInfo) * num_output_control_edges); - } - int* forward_from_base() const { - return reinterpret_cast(var() + sizeof(EdgeInfo) * num_output_edges + - sizeof(ControlEdgeInfo) * - num_output_control_edges + - sizeof(AllocatorAttributes) * num_outputs); - } - uint8* input_type_base() const { - return reinterpret_cast( - var() + sizeof(EdgeInfo) * num_output_edges + - sizeof(ControlEdgeInfo) * num_output_control_edges + - sizeof(AllocatorAttributes) * num_outputs + sizeof(int) * num_outputs); - } - uint8* output_type_base() const { - return reinterpret_cast( - var() + sizeof(EdgeInfo) * num_output_edges + - sizeof(ControlEdgeInfo) * num_output_control_edges + - sizeof(AllocatorAttributes) * num_outputs + sizeof(int) * num_outputs + - sizeof(uint8) * num_inputs); - } - - TF_DISALLOW_COPY_AND_ASSIGN(NodeItem); -}; - typedef gtl::InlinedVector TensorValueVec; typedef gtl::InlinedVector AllocatorAttributeVec; -// Immutable view of a Graph organized for efficient execution. -class GraphView { - public: - GraphView() : space_(nullptr) {} - ~GraphView(); - - Status Initialize(const Graph* g); - Status SetAllocAttrs(const Graph* g, const Device* device); - void SetScopedAllocatorAttrs(const std::vector& sa_nodes); - - NodeItem* node(int32 id) const { - DCHECK_GE(id, 0); - DCHECK_LT(id, num_nodes_); - uint32 offset = node_offsets_[id]; - return ((offset == kuint32max) - ? nullptr - : reinterpret_cast(space_ + node_offsets_[id])); - } - - int32 num_nodes() const { return num_nodes_; } - - private: - char* InitializeNode(char* ptr, const Node* n); - size_t NodeItemBytes(const Node* n); - - int32 num_nodes_ = 0; - uint32* node_offsets_ = nullptr; // array of size "num_nodes_" - // node_offsets_[id] holds the byte offset for node w/ "id" in space_ - - char* space_; // NodeItem objects are allocated here - - TF_DISALLOW_COPY_AND_ASSIGN(GraphView); -}; - class ExecutorImpl : public Executor { public: - explicit ExecutorImpl(const LocalExecutorParams& p) : params_(p), gview_() { - CHECK(p.create_kernel != nullptr); - CHECK(p.delete_kernel != nullptr); + explicit ExecutorImpl(const LocalExecutorParams& p) : immutable_state_(p) {} + + Status Initialize(const Graph& graph) { + TF_RETURN_IF_ERROR(immutable_state_.Initialize(graph)); + kernel_stats_.Initialize(immutable_state_.graph_view()); + return Status::OK(); } - ~ExecutorImpl() override { - for (int32 i = 0; i < gview_.num_nodes(); i++) { - NodeItem* item = gview_.node(i); - if (item != nullptr) { - params_.delete_kernel(item->kernel); - } - } - for (auto fiter : frame_info_) { - delete fiter.second; - } - } - - Status Initialize(const Graph& graph); - - // Process all Nodes in the current graph, attempting to infer the - // memory allocation attributes to be used wherever they may allocate - // a tensor buffer. - Status SetAllocAttrs(); - void RunAsync(const Args& args, DoneCallback done) override; private: @@ -432,636 +205,20 @@ class ExecutorImpl : public Executor { std::unique_ptr cost_estimates_; }; - struct ControlFlowInfo { - gtl::FlatSet unique_frame_names; - std::vector frame_names; - }; - - struct FrameInfo { - FrameInfo() - : input_count(0), - total_inputs(0), - pending_counts(nullptr), - nodes(nullptr) {} - - // The total number of inputs to a frame. - int input_count; - - // The total number of input tensors of a frame. - // == sum(nodes[*].num_inputs()) where nodes are the nodes in the frame. - int total_inputs; - - // Used to determine the next place to allocate space in the - // pending_counts data structure we'll eventually construct - PendingCounts::Layout pending_counts_layout; - - // Each frame has its own PendingCounts only for the nodes in the frame. - PendingCounts* pending_counts; // Owned - - // The nodes in a frame. Used only for debugging. - std::vector* nodes; // Owned - - ~FrameInfo() { - delete pending_counts; - delete nodes; - } - }; - - static Status BuildControlFlowInfo(const Graph* graph, - ControlFlowInfo* cf_info); - void InitializePending(const Graph* graph, const ControlFlowInfo& cf_info); - - FrameInfo* EnsureFrameInfo(const string& fname) { - auto slot = &frame_info_[fname]; - if (*slot == nullptr) { - *slot = new FrameInfo; - } - return *slot; - } - - // Owned. - LocalExecutorParams params_; - GraphView gview_; - std::vector pending_ids_; - mutable KernelStats kernel_stats_; - - // Root nodes (with no in edges) that should form the initial ready queue - std::vector root_nodes_; - - // Mapping from frame name to static information about the frame. - // TODO(yuanbyu): We could cache it along with the graph so to avoid - // the overhead of constructing it for each executor instance. - gtl::FlatMap frame_info_; - - // Shallow copies of the constant tensors used in the graph. - std::vector const_tensors_; + ImmutableExecutorState immutable_state_; + KernelStats kernel_stats_; TF_DISALLOW_COPY_AND_ASSIGN(ExecutorImpl); }; -// Infer memory allocation attributes of a node n's output, -// based on its use node dst. Note that dst might not be directly -// connected to n by a single edge, but might be a downstream -// consumer of n's output by reference. *attr is updated with any -// necessary attributes. -Status InferAllocAttr(const Node* n, const Node* dst, - const DeviceNameUtils::ParsedName& local_dev_name, - AllocatorAttributes* attr); - -GraphView::~GraphView() { - static_assert(std::is_trivially_destructible::value, - "Update code if AllocatorAttributes gains a destructor"); - static_assert(std::is_trivially_destructible::value, - "Update code if EdgeInfo gains a destructor"); - for (int i = 0; i < num_nodes_; i++) { - NodeItem* n = node(i); - if (n != nullptr) { - n->NodeItem::~NodeItem(); - // Memory for "n" itself is held in space_ & gets cleaned up below - } - } - delete[] node_offsets_; - delete[] space_; -} - -typedef std::tuple OutputAndControlEdges; - -static OutputAndControlEdges CountOutputEdges(const Node* n) { - DCHECK_LE(n->out_edges().size(), kint32max); - int32 num_output_edges = 0; - int32 num_output_control_edges = 0; - for (auto e : n->out_edges()) { - if (IsSink(e->dst())) continue; - if (e->IsControlEdge()) { - ++num_output_control_edges; - } else { - ++num_output_edges; - } - } - return OutputAndControlEdges(num_output_edges, num_output_control_edges); -} - -size_t GraphView::NodeItemBytes(const Node* n) { - int32 num_output_edges; - int32 num_output_control_edges; - std::tie(num_output_edges, num_output_control_edges) = CountOutputEdges(n); - const int num_inputs = n->num_inputs(); - const int num_outputs = n->num_outputs(); - - // Compute number of bytes needed for NodeItem and variable length data. - // We do not subtract sizeof(var) since num_inputs/num_outputs might - // both be zero. - const size_t raw_bytes = - sizeof(NodeItem) // Fixed - + num_output_edges * sizeof(EdgeInfo) // output_edges[...] - + num_output_control_edges * // - sizeof(ControlEdgeInfo) // output_control_edges[...] - + num_outputs * sizeof(AllocatorAttributes) // output_attr[...] - + num_outputs * sizeof(int) // forward_from[num_outputs] - + num_inputs * sizeof(uint8) // input_type[num_inputs] - + num_outputs * sizeof(uint8); // output_type[num_outputs] - static constexpr size_t kItemAlignment = sizeof(NodeItem*); - static_assert(kItemAlignment % alignof(NodeItem) == 0, - "NodeItem must be aligned with kItemAlignment"); - static_assert(kItemAlignment % alignof(EdgeInfo) == 0, - "EdgeInfo must be aligned with kItemAlignment"); - static_assert(kItemAlignment % alignof(ControlEdgeInfo) == 0, - "ControlEdgeInfo must be aligned with kItemAlignment"); - static_assert(kItemAlignment % alignof(AllocatorAttributes) == 0, - "AllocatorAttributes must be aligned with kItemAlignment"); - static_assert(sizeof(NodeItem) % alignof(EdgeInfo) == 0, - "NodeItem must be aligned with EdgeInfo"); - static_assert(sizeof(NodeItem) % alignof(AllocatorAttributes) == 0, - "NodeItem must be aligned with AllocatorAttributes"); - static_assert(sizeof(EdgeInfo) % alignof(AllocatorAttributes) == 0, - "EdgeInfo must be aligned with AllocatorAttributes"); - const size_t bytes = - ((raw_bytes + kItemAlignment - 1) / kItemAlignment) * kItemAlignment; - return bytes; -} - -char* GraphView::InitializeNode(char* ptr, const Node* n) { - const int id = n->id(); - CHECK(node_offsets_[id] == kuint32max); // Initial value in constructor - - const size_t bytes = NodeItemBytes(n); - constexpr size_t kItemAlignment = sizeof(NodeItem*); - CHECK_EQ(reinterpret_cast(ptr) % kItemAlignment, 0); - NodeItem* item = reinterpret_cast(ptr); - - // We store a 32-bit offset relative to the beginning of space_, so that we - // only need an array of 32-bit values to map from node id to the NodeItem*, - // (versus 64 bits on most machines if we just stored an array of NodeItem* - // pointers). Casting to int64 is needed on 32bit CPU to avoid comparing - // values as "int" vs "size_t" in CHECK_LE. - CHECK_LE(static_cast(ptr - space_), kuint32max); - const uint32 offset = static_cast(ptr - space_); - node_offsets_[id] = offset; - ptr += bytes; - - int32 num_output_edges; - int32 num_output_control_edges; - std::tie(num_output_edges, num_output_control_edges) = CountOutputEdges(n); - const int num_inputs = n->num_inputs(); - const int num_outputs = n->num_outputs(); - - new (item) NodeItem(); - item->num_inputs = num_inputs; - item->num_outputs = num_outputs; - item->num_output_edges = num_output_edges; - item->num_output_control_edges = num_output_control_edges; - - // Fill output edges. - // Keep track of the last EdgeInfo in the EdgeInfo array that references - // a given output slot. For all but the last, we need to do a copy of the - // Tensor when propagating results downstream in the graph, but for the - // last one, we can just do a move of the Tensor object to propagate it. - gtl::InlinedVector last_indices(num_outputs, nullptr); - EdgeInfo* dst_edge = item->output_edge_base(); - for (auto e : n->out_edges()) { - if (e->IsControlEdge()) continue; - dst_edge->dst_id = e->dst()->id(); - CHECK_LE(e->src_output(), 0x3FFFFFFF); // Must fit in 31 bits - dst_edge->output_slot = e->src_output(); - dst_edge->is_last = false; - const int output_slot = dst_edge->output_slot; - if (output_slot >= 0) { - last_indices[output_slot] = dst_edge; - } - // NOTE: The `input_slot` will be rewritten to the frame-wide offset later - // in `ExecutorImpl::Initialize()`. - dst_edge->input_slot = e->dst_input(); - dst_edge++; - } - for (EdgeInfo* edge_info : last_indices) { - if (edge_info != nullptr) { - edge_info->is_last = true; - } - } - ControlEdgeInfo* dst_control_edge = item->output_control_edge_base(); - for (auto e : n->out_edges()) { - if (!e->IsControlEdge() || IsSink(e->dst())) continue; - dst_control_edge->dst_id = e->dst()->id(); - dst_control_edge++; - } - - AllocatorAttributes* output_attrs = item->output_attr_base(); - for (int i = 0; i < num_outputs; i++) { - new (&output_attrs[i]) AllocatorAttributes(); - } - - DCHECK_LT(DataType_MAX, 255); // Must fit in uint8 - uint8* input_types = item->input_type_base(); - for (int i = 0; i < num_inputs; i++) { - input_types[i] = static_cast(n->input_type(i)); - DCHECK_EQ(item->input_type(i), n->input_type(i)); - } - - // Check ScopedAllocatorAttrs and forward_from. Also assign output_types. - { - std::vector forward_input; - Status fwd_status = - GetNodeAttr(n->attrs(), "_forward_input", &forward_input); - std::vector scoped_allocator_attrs; - Status sa_status = - GetNodeAttr(n->attrs(), "_scoped_allocator", &scoped_allocator_attrs); - - int* forward_from = item->forward_from_base(); - uint8* output_types = item->output_type_base(); - for (int i = 0; i < num_outputs; ++i) { - output_types[i] = static_cast(n->output_type(i)); - DCHECK_EQ(item->output_type(i), n->output_type(i)); - - forward_from[i] = OpKernelContext::Params::kNoReservation; - if (sa_status.ok()) { - for (int j = 0; j < scoped_allocator_attrs.size(); j += 2) { - if (scoped_allocator_attrs[j] == i) { - // This output slot must be explicitly allocated from a - // ScopedAllocator. - forward_from[i] = OpKernelContext::Params::kNeverForward; - DCHECK_EQ(output_attrs[i].scope_id, 0); - output_attrs[i].scope_id = scoped_allocator_attrs[j + 1]; - } - } - } - if (fwd_status.ok() && - forward_from[i] == OpKernelContext::Params::kNoReservation) { - DCHECK_EQ(forward_input.size() % 2, 0); - for (int j = 0; j < forward_input.size(); j += 2) { - if (forward_input[j + 1] == i) { - DCHECK_EQ(forward_from[i], OpKernelContext::Params::kNoReservation); - forward_from[i] = forward_input[j]; - break; - } - } - } - } - } - - return ptr; -} - -Status GraphView::Initialize(const Graph* g) { - CHECK(node_offsets_ == nullptr); - const int num_nodes = g->num_node_ids(); - num_nodes_ = num_nodes; - size_t total_bytes = 0; - for (const Node* n : g->nodes()) { - if (n->out_edges().size() > kint32max) { - return errors::InvalidArgument( - "The executor cannot handle nodes with more than ", kint32max, - " output edges. Node ", n->name(), " had ", n->out_edges().size(), - " output edges."); - } - total_bytes += NodeItemBytes(n); - } - - node_offsets_ = new uint32[num_nodes]; - for (int i = 0; i < num_nodes; i++) { - node_offsets_[i] = kuint32max; - } - - space_ = new char[total_bytes]; // NodeItem objects are allocated here - char* ptr = space_; - for (const Node* n : g->nodes()) { - ptr = InitializeNode(ptr, n); - } - CHECK_EQ(ptr, space_ + total_bytes); - return Status::OK(); -} - -void GetMaxPendingCounts(const Node* n, size_t* max_pending, - size_t* max_dead_count) { - const size_t num_in_edges = n->in_edges().size(); - size_t initial_count; - if (IsMerge(n)) { - // merge waits all control inputs so we initialize the pending - // count to be the number of control edges. - int32 num_control_edges = 0; - for (const Edge* edge : n->in_edges()) { - if (edge->IsControlEdge()) { - num_control_edges++; - } - } - // Use bit 0 to indicate if we are waiting for a ready live data input. - initial_count = 1 + (num_control_edges << 1); - } else { - initial_count = num_in_edges; - } - - *max_pending = initial_count; - *max_dead_count = num_in_edges; -} - -Status ExecutorImpl::Initialize(const Graph& graph) { - TF_RETURN_IF_ERROR(gview_.Initialize(&graph)); - - // Build the information about frames in this subgraph. - ControlFlowInfo cf_info; - TF_RETURN_IF_ERROR(BuildControlFlowInfo(&graph, &cf_info)); - - for (auto& it : cf_info.unique_frame_names) { - EnsureFrameInfo(it)->nodes = new std::vector; - } - - pending_ids_.resize(gview_.num_nodes()); - - // Preprocess every node in the graph to create an instance of op - // kernel for each node. - for (const Node* n : graph.nodes()) { - if (IsSink(n)) continue; - const int id = n->id(); - const string& frame_name = cf_info.frame_names[id]; - FrameInfo* frame_info = EnsureFrameInfo(frame_name); - - NodeItem* item = gview_.node(id); - item->node_id = id; - - item->input_start = frame_info->total_inputs; - frame_info->total_inputs += n->num_inputs(); - - Status s = params_.create_kernel(n->properties(), &item->kernel); - if (!s.ok()) { - item->kernel = nullptr; - s = AttachDef(s, *n); - return s; - } - CHECK(item->kernel); - item->kernel_is_async = (item->kernel->AsAsync() != nullptr); - item->is_merge = IsMerge(n); - item->is_any_consumer_merge_or_control_trigger = false; - for (const Node* consumer : n->out_nodes()) { - if (IsMerge(consumer) || IsControlTrigger(consumer)) { - item->is_any_consumer_merge_or_control_trigger = true; - break; - } - } - const Tensor* const_tensor = item->kernel->const_tensor(); - if (const_tensor) { - // Hold onto a shallow copy of the constant tensor in `*this` so that the - // reference count does not drop to 1. This prevents the constant tensor - // from being forwarded, and its buffer reused. - const_tensors_.emplace_back(*const_tensor); - } - item->const_tensor = const_tensor; - item->is_noop = (item->kernel->type_string_view() == "NoOp"); - item->is_enter = IsEnter(n); - if (item->is_enter) { - bool is_constant_enter; - TF_RETURN_IF_ERROR( - GetNodeAttr(n->attrs(), "is_constant", &is_constant_enter)); - item->is_constant_enter = is_constant_enter; - } else { - item->is_constant_enter = false; - } - item->is_exit = IsExit(n); - item->is_control_trigger = IsControlTrigger(n); - item->is_source = IsSource(n); - item->is_enter_exit_or_next_iter = - (IsEnter(n) || IsExit(n) || IsNextIteration(n)); - item->is_transfer_node = IsTransferNode(n); - item->is_initialization_op = IsInitializationOp(n); - item->is_recv_or_switch = IsRecv(n) || IsSwitch(n); - item->is_next_iteration = IsNextIteration(n); - - // Compute the maximum values we'll store for this node in the - // pending counts data structure, and allocate a handle in - // that frame's pending counts data structure that has enough - // space to store these maximal count values. - size_t max_pending, max_dead; - GetMaxPendingCounts(n, &max_pending, &max_dead); - pending_ids_[id] = - frame_info->pending_counts_layout.CreateHandle(max_pending, max_dead); - - // See if this node is a root node, and if so, add item to root_nodes_. - if (n->in_edges().empty()) { - root_nodes_.push_back(item); - } - - // Initialize static information about the frames in the graph. - frame_info->nodes->push_back(item); - if (item->is_enter) { - string enter_name; - TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "frame_name", &enter_name)); - EnsureFrameInfo(enter_name)->input_count++; - } - - // Record information about whether each output of the op is used. - std::unique_ptr outputs_required(new bool[n->num_outputs()]); - std::fill(&outputs_required[0], &outputs_required[n->num_outputs()], false); - int32 unused_outputs = n->num_outputs(); - for (const Edge* e : n->out_edges()) { - if (IsSink(e->dst())) continue; - if (e->src_output() >= 0) { - if (!outputs_required[e->src_output()]) { - --unused_outputs; - outputs_required[e->src_output()] = true; - } - } - } - if (unused_outputs > 0) { - for (int i = 0; i < n->num_outputs(); ++i) { - if (!outputs_required[i]) { - metrics::RecordUnusedOutput(n->type_string()); - } - } - item->outputs_required = std::move(outputs_required); - } - } - - // Rewrite each `EdgeInfo::input_slot` member to refer directly to the input - // location. - for (const Node* n : graph.nodes()) { - if (IsSink(n)) continue; - const int id = n->id(); - NodeItem* item = gview_.node(id); - - for (EdgeInfo& e : item->mutable_output_edges()) { - const int dst_id = e.dst_id; - NodeItem* dst_item = gview_.node(dst_id); - e.input_slot += dst_item->input_start; - } - } - - // Initialize PendingCounts only after pending_ids_[node.id] is initialized - // for all nodes. - InitializePending(&graph, cf_info); - kernel_stats_.Initialize(gview_); - return gview_.SetAllocAttrs(&graph, params_.device); -} - -// If a Node has been marked to use a ScopedAllocator x for output i, then -// sc_attr will contain the subsequence (i, x) at an even offset. This function -// extracts and transfers that ScopedAllocator id to alloc_attr. For now, we -// only allow one ScopedAllocator use per Node. -bool ExtractScopedAllocatorAttr(const std::vector& sc_attr, - int output_index, - AllocatorAttributes* alloc_attr) { - DCHECK_LE(2, sc_attr.size()); - for (int i = 0; i < sc_attr.size(); i += 2) { - if (sc_attr[i] == output_index) { - CHECK_EQ(alloc_attr->scope_id, 0); - alloc_attr->scope_id = sc_attr[i + 1]; - return true; - } - } - return false; -} - -void GraphView::SetScopedAllocatorAttrs( - const std::vector& sa_nodes) { - for (const Node* sa : sa_nodes) { - NodeItem* sa_item = node(sa->id()); - AllocatorAttributes* sa_attrs = sa_item->output_attr_base(); - // Control edges out of the ScopedAllocator should be use instances, but may - // include a few other nodes. - for (const auto& e : sa->out_edges()) { - if (IsSink(e->dst()) || !e->IsControlEdge()) { - continue; - } - Node* use_node = e->dst(); - NodeItem* item = node(use_node->id()); - AllocatorAttributes* use_attrs = item->output_attr_base(); - std::vector scoped_allocator_attrs; - Status s = GetNodeAttr(use_node->attrs(), "_scoped_allocator", - &scoped_allocator_attrs); - if (!s.ok()) { - VLOG(2) << "Failed to find expected ScopedAllocator attr on " - << use_node->name(); - continue; - } - // There can be more than one output using ScopedAllocation, but this - // analysis assumes they use the same ScopedAllocator. - for (const auto& e : use_node->out_edges()) { - if (IsSink(e->dst()) || !e->IsControlEdge()) { - AllocatorAttributes attr; - if (ExtractScopedAllocatorAttr(scoped_allocator_attrs, - e->src_output(), &attr)) { - // Set the scope_id on this use instance node. - (use_attrs + e->src_output())->Merge(attr); - // Propagate the other attributes of this node back to the SA node. - attr = *(use_attrs + e->src_output()); - attr.scope_id = 0; - sa_attrs->Merge(attr); - } - } - } - } - } -} - -Status GraphView::SetAllocAttrs(const Graph* g, const Device* device) { - Status s; - DeviceNameUtils::ParsedName local_dev_name = device->parsed_name(); - - std::vector scoped_allocator_instances; - for (const Node* n : g->nodes()) { - NodeItem* item = node(n->id()); - AllocatorAttributes* attrs = item->output_attr_base(); - if (IsScopedAllocator(n)) { - scoped_allocator_instances.push_back(n); - } - - // Examine the out edges of each node looking for special use - // cases that may affect memory allocation attributes. - for (const auto& e : n->out_edges()) { - if (!e->IsControlEdge()) { - AllocatorAttributes attr; - s = InferAllocAttr(n, e->dst(), local_dev_name, &attr); - if (!s.ok()) return s; - if (attr.value != 0 || attr.scope_id != 0) { - attrs[e->src_output()].Merge(attr); - } - } - } - - for (int out = 0; out < n->num_outputs(); out++) { - const OpKernel* op_kernel = item->kernel; - DCHECK_LT(out, op_kernel->output_memory_types().size()); - bool on_host = op_kernel->output_memory_types()[out] == HOST_MEMORY; - if (on_host) { - AllocatorAttributes h; - h.set_on_host(on_host); - attrs[out].Merge(h); - } - } - } - SetScopedAllocatorAttrs(scoped_allocator_instances); - return s; -} - -Status InferAllocAttr(const Node* n, const Node* dst, - const DeviceNameUtils::ParsedName& local_dev_name, - AllocatorAttributes* attr) { - Status s; - // Note that it's possible for *n to be a Recv and *dst to be a Send, - // so these two cases are not mutually exclusive. - if (IsRecv(n)) { - string src_name; - s = GetNodeAttr(n->attrs(), "send_device", &src_name); - if (!s.ok()) return s; - DeviceNameUtils::ParsedName parsed_src_name; - if (!DeviceNameUtils::ParseFullName(src_name, &parsed_src_name)) { - s = errors::Internal("Bad send_device attr '", src_name, "' in node ", - n->name()); - return s; - } - if (!DeviceNameUtils::IsSameAddressSpace(parsed_src_name, local_dev_name)) { - // Value is going to be the sink of an RPC. - attr->set_nic_compatible(true); - VLOG(2) << "node " << n->name() << " is the sink of an RPC in"; - } else if ((local_dev_name.type == "CPU" || n->IsHostRecv()) && - parsed_src_name.type != "CPU") { - // Value is going to be the sink of a local DMA from GPU to CPU (or - // other types of accelerators). - attr->set_gpu_compatible(true); - VLOG(2) << "node " << n->name() << " is the sink of a gpu->cpu copy"; - } else { - VLOG(2) << "default alloc case local type " << local_dev_name.type - << " remote type " << parsed_src_name.type; - } - } - if (IsSend(dst)) { - string dst_name; - s = GetNodeAttr(dst->attrs(), "recv_device", &dst_name); - if (!s.ok()) return s; - DeviceNameUtils::ParsedName parsed_dst_name; - if (!DeviceNameUtils::ParseFullName(dst_name, &parsed_dst_name)) { - s = errors::Internal("Bad recv_device attr '", dst_name, "' in node ", - n->name()); - return s; - } - if (!DeviceNameUtils::IsSameAddressSpace(parsed_dst_name, local_dev_name)) { - // Value is going to be the source of an RPC. - attr->set_nic_compatible(true); - VLOG(2) << "node " << n->name() << " is the source of an RPC out"; - } else if ((local_dev_name.type == "CPU" || dst->IsHostSend()) && - parsed_dst_name.type != "CPU") { - // Value is going to be the source of a local DMA from CPU to GPU (or - // other types of accelerators). - // Note that this does not cover the case where the allocation of the - // output tensor is not generated by the src: n. - attr->set_gpu_compatible(true); - VLOG(2) << "node " << n->name() << " is the source of a cpu->gpu copy"; - } else { - VLOG(2) << "default alloc case local type " << local_dev_name.type - << " remote type " << parsed_dst_name.type; - } - } - if (n->IsCollective()) { - // We'll make the sweeping assumption that any collective op is going - // to be involved in network i/o. - attr->set_nic_compatible(true); - } - return s; -} - // The state associated with one invocation of ExecutorImpl::Run. // ExecutorState dispatches nodes when they become ready and keeps // track of how many predecessors of a node have not done (pending_). class ExecutorState { public: - ExecutorState(const Executor::Args& args, ExecutorImpl* impl); + ExecutorState(const Executor::Args& args, + const ImmutableExecutorState& immutable_state_, + ExecutorImpl::KernelStats* kernel_stats_); ~ExecutorState(); void RunAsync(Executor::DoneCallback done); @@ -1194,8 +351,8 @@ class ExecutorState { // The state of an iteration. // One copy per iteration. For iteration k, i-th node's j-th input is in - // input_tensors[k][impl_->nodes[i].input_start + j]. An entry is either - // a tensor pointer (pass-by-reference) or a tensor (pass-by-value). + // input_tensors[k][immutable_state_.nodes[i].input_start + j]. An entry is + // either a tensor pointer (pass-by-reference) or a tensor (pass-by-value). // // NOTE: No need to protect input_tensors[i] by any locks because it // is resized once. Each element of tensors_ is written once by the @@ -1239,8 +396,9 @@ class ExecutorState { }; struct FrameState { - explicit FrameState(const ExecutorImpl* impl, int parallel_iters) - : executor(impl), + explicit FrameState(const ImmutableExecutorState& immutable_state, + int parallel_iters) + : immutable_state(immutable_state), max_parallel_iterations(parallel_iters), num_outstanding_iterations(1), iterations(parallel_iters + 1), @@ -1274,8 +432,8 @@ class ExecutorState { // This frame state is mostly initialized lazily on demand so we // don't introduce unnecessary overhead. - // The executor the frame is in. - const ExecutorImpl* executor = nullptr; + // The immutable state of the executor the frame is in. + const ImmutableExecutorState& immutable_state; // The name of this frame, which is the concatenation of its parent // frame name, the iteration of the parent frame when this frame was @@ -1341,13 +499,13 @@ class ExecutorState { mutex mu; void InitializeFrameInfo(const string& enter_name) { - auto it_frame_info = executor->frame_info_.find(enter_name); - DCHECK(it_frame_info != executor->frame_info_.end()); - ExecutorImpl::FrameInfo* finfo = it_frame_info->second; - pending_counts = finfo->pending_counts; + const ImmutableExecutorState::FrameInfo* finfo = + immutable_state.get_frame_info(enter_name); + DCHECK_NE(finfo, nullptr); + pending_counts = finfo->pending_counts.get(); total_input_tensors = finfo->total_inputs; num_pending_inputs = finfo->input_count; - nodes = finfo->nodes; + nodes = finfo->nodes.get(); } inline IterationState* GetIteration(int64 iter) @@ -1461,9 +619,11 @@ class ExecutorState { // A tagged node: . struct TaggedNode { const NodeItem* node_item; - FrameState* input_frame = nullptr; - int64 input_iter = -1; - bool is_dead = false; + FrameState* input_frame; // = nullptr; + int64 input_iter; // = -1; + bool is_dead; // = false; + + TaggedNode() {} TaggedNode(const NodeItem* node_item, FrameState* in_frame, int64 in_iter, bool dead) @@ -1533,7 +693,8 @@ class ExecutorState { // instead of a pointer? (avoids having to delete). checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache_; CallFrameInterface* call_frame_; - const ExecutorImpl* impl_; + const ImmutableExecutorState& immutable_state_; + ExecutorImpl::KernelStats* const kernel_stats_; CancellationManager* cancellation_manager_; // If not null, use this device to schedule intra-op operation std::unique_ptr user_device_; @@ -1664,7 +825,9 @@ class ExecutorState { } }; -ExecutorState::ExecutorState(const Executor::Args& args, ExecutorImpl* impl) +ExecutorState::ExecutorState(const Executor::Args& args, + const ImmutableExecutorState& immutable_state, + ExecutorImpl::KernelStats* kernel_stats) : vlog_(VLOG_IS_ON(1)), log_memory_(LogMemory::IsEnabled()), step_id_(args.step_id), @@ -1672,7 +835,7 @@ ExecutorState::ExecutorState(const Executor::Args& args, ExecutorImpl* impl) collective_executor_(args.collective_executor), session_state_(args.session_state), session_handle_(args.session_handle), - session_metadata_(impl->params_.session_metadata), + session_metadata_(immutable_state.params().session_metadata), tensor_store_(args.tensor_store), step_container_(args.step_container), stats_collector_(args.stats_collector), @@ -1681,14 +844,15 @@ ExecutorState::ExecutorState(const Executor::Args& args, ExecutorImpl* impl) context_(ContextKind::kThread), slice_reader_cache_(new checkpoint::TensorSliceReaderCacheWrapper), call_frame_(args.call_frame), - impl_(impl), + immutable_state_(immutable_state), + kernel_stats_(kernel_stats), cancellation_manager_(args.cancellation_manager), runner_(args.runner), sync_on_finish_(args.sync_on_finish), run_all_kernels_inline_(args.run_all_kernels_inline), num_outstanding_ops_(0) { if (args.user_intra_op_threadpool != nullptr) { - Device* device = impl_->params_.device; + Device* device = immutable_state_.params().device; user_device_ = RenamedDevice::NewRenamedDevice( device->name(), device, false, false, args.user_intra_op_threadpool); } @@ -1696,7 +860,7 @@ ExecutorState::ExecutorState(const Executor::Args& args, ExecutorImpl* impl) // We start the entire execution in iteration 0 of the root frame // so let us create the root frame and the state for iteration 0. // We assume root_frame_->frame_name.empty(). - root_frame_ = new FrameState(impl_, 1); + root_frame_ = new FrameState(immutable_state_, 1); root_frame_->frame_id = 0; // must be 0 root_frame_->InitializeFrameInfo(root_frame_->frame_name); @@ -1718,94 +882,11 @@ ExecutorState::~ExecutorState() { delete slice_reader_cache_; } -Status ExecutorImpl::BuildControlFlowInfo(const Graph* g, - ControlFlowInfo* cf_info) { - const int num_nodes = g->num_node_ids(); - cf_info->frame_names.resize(num_nodes); - std::vector parent_nodes; - parent_nodes.resize(num_nodes); - std::vector visited; - visited.resize(num_nodes); - - string frame_name; - std::deque ready; - - // Initialize with the root nodes. - for (Node* n : g->nodes()) { - if (n->in_edges().empty()) { - visited[n->id()] = true; - cf_info->unique_frame_names.insert(frame_name); - ready.push_back(n); - } - } - - while (!ready.empty()) { - Node* curr_node = ready.front(); - int curr_id = curr_node->id(); - ready.pop_front(); - - Node* parent = nullptr; - if (IsEnter(curr_node)) { - // Enter a child frame. - TF_RETURN_IF_ERROR( - GetNodeAttr(curr_node->attrs(), "frame_name", &frame_name)); - parent = curr_node; - } else if (IsExit(curr_node)) { - // Exit to the parent frame. - parent = parent_nodes[curr_id]; - frame_name = cf_info->frame_names[parent->id()]; - parent = parent_nodes[parent->id()]; - } else { - parent = parent_nodes[curr_id]; - frame_name = cf_info->frame_names[curr_id]; - } - - for (const Edge* out_edge : curr_node->out_edges()) { - Node* out = out_edge->dst(); - if (IsSink(out)) continue; - const int out_id = out->id(); - - // Add to ready queue if not visited. - bool is_visited = visited[out_id]; - if (!is_visited) { - ready.push_back(out); - visited[out_id] = true; - - // Process the node 'out'. - cf_info->frame_names[out_id] = frame_name; - parent_nodes[out_id] = parent; - cf_info->unique_frame_names.insert(frame_name); - } - } - } - - return Status::OK(); -} - -void ExecutorImpl::InitializePending(const Graph* graph, - const ControlFlowInfo& cf_info) { - for (auto& it : cf_info.unique_frame_names) { - FrameInfo* finfo = EnsureFrameInfo(it); - PendingCounts* counts = new PendingCounts(finfo->pending_counts_layout); - DCHECK_EQ(finfo->pending_counts, nullptr); - finfo->pending_counts = counts; - } - for (const Node* n : graph->nodes()) { - if (IsSink(n)) continue; - const int id = n->id(); - const string& name = cf_info.frame_names[id]; - size_t max_pending, max_dead; - GetMaxPendingCounts(n, &max_pending, &max_dead); - PendingCounts* counts = EnsureFrameInfo(name)->pending_counts; - counts->set_initial_count(pending_ids_[id], max_pending); - } -} - void ExecutorState::RunAsync(Executor::DoneCallback done) { TaggedNodeSeq ready; // Ask the device to fill in the device context map. - Device* device = impl_->params_.device; + Device* device = immutable_state_.params().device; const Status get_context_status = device->TryGetDeviceContext(&device_context_); if (!get_context_status.ok()) { @@ -1815,8 +896,8 @@ void ExecutorState::RunAsync(Executor::DoneCallback done) { } // Initialize the ready queue. - ready.reserve(impl_->root_nodes_.size()); - for (const NodeItem* item : impl_->root_nodes_) { + ready.reserve(immutable_state_.root_nodes().size()); + for (const NodeItem* item : immutable_state_.root_nodes()) { DCHECK_EQ(item->num_inputs, 0); ready.push_back(TaggedNode{item, root_frame_, 0, false}); } @@ -1906,8 +987,8 @@ Status ExecutorState::ProcessSync(const NodeItem& item, nodestats::SetOpStart(stats); OpKernel* op_kernel = item.kernel; - Device* device = impl_->params_.device; - const bool is_expensive = impl_->kernel_stats_.IsExpensive(item); + Device* device = immutable_state_.params().device; + const bool is_expensive = kernel_stats_->IsExpensive(item); if (TF_PREDICT_FALSE(MightTrace(event_collector_, is_expensive))) { tracing::ScopedRegion region(tracing::EventCategory::kCompute, @@ -1926,7 +1007,7 @@ Status ExecutorState::ProcessSync(const NodeItem& item, if (is_expensive) { KernelTimer timer; device->Compute(op_kernel, &ctx); - impl_->kernel_stats_.UpdateCostEstimate(item, timer.ElapsedCycles()); + kernel_stats_->UpdateCostEstimate(item, timer.ElapsedCycles()); } else { device->Compute(op_kernel, &ctx); } @@ -1948,7 +1029,7 @@ void ExecutorState::ProcessAsync(const NodeItem& item, new AsyncState(params, tagged_node, &item, first_input, stats); auto done = [this, state]() { - Device* device = impl_->params_.device; + Device* device = immutable_state_.params().device; NodeExecStatsInterface* stats = state->stats; // Shorthand Entry* first_input = state->first_input; // Shorthand @@ -1987,9 +1068,9 @@ void ExecutorState::ProcessAsync(const NodeItem& item, return async_kernel->TraceString( &state->ctx, /*verbose=*/profiler::TfOpDetailsEnabled()); }, - profiler::GetTFTraceMeLevel(impl_->kernel_stats_.IsExpensive(item))); - impl_->params_.device->ComputeAsync(async_kernel, &state->ctx, - std::move(done)); + profiler::GetTFTraceMeLevel(kernel_stats_->IsExpensive(item))); + immutable_state_.params().device->ComputeAsync(async_kernel, &state->ctx, + std::move(done)); } } @@ -2028,7 +1109,7 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_nsec) { OpKernelContext::Params params; params.step_id = step_id_; // Override device's threadpool if user provides an intra_op_threadpool - Device* device = impl_->params_.device; + Device* device = immutable_state_.params().device; if (user_device_) { params.device = user_device_.get(); } else { @@ -2043,7 +1124,7 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_nsec) { params.tensor_store = tensor_store_; params.cancellation_manager = cancellation_manager_; params.call_frame = call_frame_; - params.function_library = impl_->params_.function_library; + params.function_library = immutable_state_.params().function_library; params.resource_manager = device->resource_manager(); params.step_container = step_container_; params.slice_reader_cache = slice_reader_cache_; @@ -2093,7 +1174,7 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_nsec) { if (vlog_ && VLOG_IS_ON(1)) { mutex_lock l(input_frame->mu); input_frame->GetIteration(input_iter) - ->mark_started(impl_->pending_ids_[id]); + ->mark_started(immutable_state_.pending_ids()[id]); } params.track_allocations = false; @@ -2431,7 +1512,7 @@ void ExecutorState::PropagateOutputs(const TaggedNode& tagged_node, mutex_lock l(input_frame->mu); output_frame->ActivateNodes(item, is_dead, output_iter, outputs, ready); is_frame_done = input_frame->DecrementOutstandingOpsLocked( - &impl_->gview_, input_iter, ready); + &immutable_state_.graph_view(), input_iter, ready); } else if (item->is_enter) { FindOrCreateChildFrame(input_frame, input_iter, *item, &output_frame); output_iter = 0; @@ -2445,8 +1526,8 @@ void ExecutorState::PropagateOutputs(const TaggedNode& tagged_node, } output_frame->num_pending_inputs--; } - is_frame_done = - input_frame->DecrementOutstandingOps(&impl_->gview_, input_iter, ready); + is_frame_done = input_frame->DecrementOutstandingOps( + &immutable_state_.graph_view(), input_iter, ready); } else if (item->is_exit) { if (is_dead) { mutex_lock l(input_frame->mu); @@ -2455,7 +1536,7 @@ void ExecutorState::PropagateOutputs(const TaggedNode& tagged_node, input_frame->dead_exits.push_back(item); } is_frame_done = input_frame->DecrementOutstandingOpsLocked( - &impl_->gview_, input_iter, ready); + &immutable_state_.graph_view(), input_iter, ready); } else { output_frame = input_frame->parent_frame; output_iter = input_frame->parent_iter; @@ -2463,8 +1544,8 @@ void ExecutorState::PropagateOutputs(const TaggedNode& tagged_node, mutex_lock l(output_frame->mu); output_frame->ActivateNodes(item, is_dead, output_iter, outputs, ready); } - is_frame_done = input_frame->DecrementOutstandingOps(&impl_->gview_, - input_iter, ready); + is_frame_done = input_frame->DecrementOutstandingOps( + &immutable_state_.graph_view(), input_iter, ready); } } else { DCHECK(item->is_next_iteration); @@ -2482,7 +1563,8 @@ void ExecutorState::PropagateOutputs(const TaggedNode& tagged_node, } else { // If this is a new iteration, start it. if (input_iter == input_frame->iteration_count) { - input_frame->IncrementIteration(&impl_->gview_, ready); + input_frame->IncrementIteration(&immutable_state_.graph_view(), + ready); } output_iter = input_iter + 1; } @@ -2493,7 +1575,7 @@ void ExecutorState::PropagateOutputs(const TaggedNode& tagged_node, output_frame->ActivateNodes(item, is_dead, output_iter, outputs, ready); } is_frame_done = input_frame->DecrementOutstandingOpsLocked( - &impl_->gview_, input_iter, ready); + &immutable_state_.graph_view(), input_iter, ready); } // At this point, this node is completely done. We also know if the @@ -2516,7 +1598,7 @@ bool ExecutorState::NodeDone(const Status& s, TaggedNodeSeq* ready, nodestats::SetAllEnd(stats); if (stats) { if (stats_collector_) { - stats->Done(impl_->params_.device->name()); + stats->Done(immutable_state_.params().device->name()); } else { delete stats; } @@ -2543,7 +1625,7 @@ bool ExecutorState::NodeDone(const Status& s, TaggedNodeSeq* ready, TRACEPRINTF("StartAbort: %s", s.ToString().c_str()); if (cancellation_manager_) { // only log when the abort happens during the actual run time. - auto device_name = impl_->params_.device->name(); + auto device_name = immutable_state_.params().device->name(); // Use VLOG instead of LOG(warning) because error status is expected when // the executor is run under the grappler optimization phase or when // iterating through a tf.data input pipeline. @@ -2611,7 +1693,7 @@ void ExecutorState::ScheduleReady(TaggedNodeSeq* ready, } else { for (auto& tagged_node : *ready) { const NodeItem& item = *tagged_node.node_item; - if (tagged_node.is_dead || !impl_->kernel_stats_.IsExpensive(item)) { + if (tagged_node.is_dead || !kernel_stats_->IsExpensive(item)) { // Inline this inexpensive node. inline_ready->push_back(tagged_node); } else { @@ -2645,7 +1727,8 @@ inline void ExecutorState::MaybeMarkCompleted(FrameState* frame, int64 iter, // add better optional debugging support. if (vlog_ && VLOG_IS_ON(1)) { mutex_lock l(frame->mu); - frame->GetIteration(iter)->mark_completed(impl_->pending_ids_[node_id]); + frame->GetIteration(iter)->mark_completed( + immutable_state_.pending_ids()[node_id]); } } @@ -2665,7 +1748,7 @@ const Tensor* ExecutorState::GetTensorValueForDump(const Entry& input) { void ExecutorState::DumpPendingNodeState( const int node_id, const Entry* input_vector, const bool show_nodes_with_no_ready_inputs) { - const NodeItem& node_item = *impl_->gview_.node(node_id); + const NodeItem& node_item = *immutable_state_.graph_view().node(node_id); const int input_base = node_item.input_start; if (!show_nodes_with_no_ready_inputs) { bool has_ready_input = false; @@ -2698,7 +1781,7 @@ void ExecutorState::DumpPendingNodeState( void ExecutorState::DumpActiveNodeState(const int node_id, const Entry* input_vector) { - const NodeItem& node_item = *impl_->gview_.node(node_id); + const NodeItem& node_item = *immutable_state_.graph_view().node(node_id); LOG(WARNING) << " Active Node: " << node_item.DebugString(); const int input_base = node_item.input_start; for (int i = 0; i < node_item.num_inputs; ++i) { @@ -2720,7 +1803,8 @@ void ExecutorState::DumpIterationState(const FrameState* frame, const std::vector* nodes = frame->nodes; // Dump any waiting nodes that are holding on to tensors. for (const NodeItem* node : *nodes) { - PendingCounts::Handle pending_id = impl_->pending_ids_[node->node_id]; + PendingCounts::Handle pending_id = + immutable_state_.pending_ids()[node->node_id]; if (iteration->node_state(pending_id) == PendingCounts::PENDING_NOTREADY || iteration->node_state(pending_id) == PendingCounts::PENDING_READY) { DumpPendingNodeState(node->node_id, iteration->input_tensors, false); @@ -2728,7 +1812,8 @@ void ExecutorState::DumpIterationState(const FrameState* frame, } // Then the active nodes. for (const NodeItem* node : *nodes) { - PendingCounts::Handle pending_id = impl_->pending_ids_[node->node_id]; + PendingCounts::Handle pending_id = + immutable_state_.pending_ids()[node->node_id]; if (iteration->node_state(pending_id) == PendingCounts::STARTED) { DumpActiveNodeState(node->node_id, iteration->input_tensors); } @@ -2791,7 +1876,7 @@ void ExecutorState::Finish() { mu_.unlock(); int64 step_id = step_id_; CHECK(done_cb != nullptr); - Device* device = impl_->params_.device; + Device* device = immutable_state_.params().device; // There are several potential race conditions below. To name a few: // 1. Even if the device's status is OK at the precise moment when @@ -2906,7 +1991,7 @@ void ExecutorState::FindOrCreateChildFrame(FrameState* frame, int64 iter, DCHECK(found_parallel_iters) << "Could not find \"parallel_iterations\" attr in node " << node_item.kernel->name(); - FrameState* temp = new FrameState(impl_, parallel_iters); + FrameState* temp = new FrameState(immutable_state_, parallel_iters); temp->frame_name = child_name; temp->frame_id = Hash64(child_name); temp->parent_frame = frame; @@ -2963,8 +2048,9 @@ void ExecutorState::DeleteFrame(FrameState* frame, TaggedNodeSeq* ready) { }; for (const EdgeInfo& e : item->output_edges()) { - const NodeItem& dst_item = *impl_->gview_.node(e.dst_id); - const auto dst_pending_id = impl_->pending_ids_[e.dst_id]; + const NodeItem& dst_item = + *immutable_state_.graph_view().node(e.dst_id); + const auto dst_pending_id = immutable_state_.pending_ids()[e.dst_id]; bool dst_dead = true; bool dst_ready; @@ -2982,8 +2068,9 @@ void ExecutorState::DeleteFrame(FrameState* frame, TaggedNodeSeq* ready) { } for (const ControlEdgeInfo& e : item->output_control_edges()) { - const NodeItem& dst_item = *impl_->gview_.node(e.dst_id); - const auto dst_pending_id = impl_->pending_ids_[e.dst_id]; + const NodeItem& dst_item = + *immutable_state_.graph_view().node(e.dst_id); + const auto dst_pending_id = immutable_state_.pending_ids()[e.dst_id]; bool dst_dead; bool dst_ready; @@ -3019,7 +2106,8 @@ void ExecutorState::CleanupFramesIterations(FrameState* frame, int64 iter, { mutex_lock frame_lock(frame->mu); frame->GetIteration(iter)->outstanding_frame_count--; - is_frame_done = frame->CleanupIterations(&impl_->gview_, iter, ready); + is_frame_done = + frame->CleanupIterations(&immutable_state_.graph_view(), iter, ready); } if (is_frame_done) { FrameState* parent_frame = frame->parent_frame; @@ -3041,24 +2129,32 @@ void ExecutorState::FrameState::ActivateNodesFastPath(const NodeItem* item, // If we know that none of the item's edge destinations require special // handling (i.e. none of the nodes is a merge or control trigger node), we // can take a fast path that avoids accessing the destination NodeItem. - const GraphView& gview = executor->gview_; + const GraphView& gview = immutable_state.graph_view(); IterationState* iter_state = GetIteration(iter); - auto maybe_add_to_ready = [&](int dst_id, - PendingCounts::AdjustResult adjust_result) { - // Add dst to the ready queue if it's ready - if (!adjust_result.any_pending) { - const NodeItem* dst_item = gview.node(dst_id); - ready->emplace_back(dst_item, this, iter, adjust_result.any_dead); - iter_state->outstanding_ops++; - } - }; +// Add dst to the ready queue if it's ready +// +// NOTE(mrry): Use a macro here instead of a lambda, because this method is +// performance-critical and we need to ensure that the code is inlined. +#define MAYBE_ADD_TO_READY(dst_id, adjust_result) \ + do { \ + if (!adjust_result.any_pending) { \ + const NodeItem* dst_item = gview.node(dst_id); \ + TaggedNode& t = ready->emplace_back(); \ + t.node_item = dst_item; \ + t.input_frame = this; \ + t.input_iter = iter; \ + t.is_dead = adjust_result.any_dead; \ + iter_state->outstanding_ops++; \ + } \ + } while (0); Entry* input_tensors = iter_state->input_tensors; for (const EdgeInfo& e : item->output_edges()) { const int dst_id = e.dst_id; - const PendingCounts::Handle dst_pending_id = executor->pending_ids_[dst_id]; + const PendingCounts::Handle dst_pending_id = + immutable_state.pending_ids()[dst_id]; const int src_slot = e.output_slot; const bool increment_dead = @@ -3071,16 +2167,18 @@ void ExecutorState::FrameState::ActivateNodesFastPath(const NodeItem* item, } else { input_tensors[dst_loc] = (*outputs)[src_slot]; } - maybe_add_to_ready(dst_id, adjust_result); + MAYBE_ADD_TO_READY(dst_id, adjust_result); } for (const ControlEdgeInfo& e : item->output_control_edges()) { const int dst_id = e.dst_id; - const PendingCounts::Handle dst_pending_id = executor->pending_ids_[dst_id]; + const PendingCounts::Handle dst_pending_id = + immutable_state.pending_ids()[dst_id]; const PendingCounts::AdjustResult adjust_result = iter_state->adjust_for_activation(dst_pending_id, is_dead); - maybe_add_to_ready(dst_id, adjust_result); + MAYBE_ADD_TO_READY(dst_id, adjust_result); } +#undef MAYBE_ADD_TO_READY } void ExecutorState::FrameState::ActivateNodesSlowPath(const NodeItem* item, @@ -3091,7 +2189,7 @@ void ExecutorState::FrameState::ActivateNodesSlowPath(const NodeItem* item, // If any of the edge destinations is a merge or a control trigger node, // we need to read each destination NodeItem to determine what action // to take. - const GraphView& gview = executor->gview_; + const GraphView& gview = immutable_state.graph_view(); IterationState* iter_state = GetIteration(iter); auto maybe_add_to_ready = [&](int dst_id, const NodeItem* dst_item, @@ -3109,7 +2207,8 @@ void ExecutorState::FrameState::ActivateNodesSlowPath(const NodeItem* item, for (const EdgeInfo& e : item->output_edges()) { const int dst_id = e.dst_id; const NodeItem* dst_item = gview.node(dst_id); - const PendingCounts::Handle dst_pending_id = executor->pending_ids_[dst_id]; + const PendingCounts::Handle dst_pending_id = + immutable_state.pending_ids()[dst_id]; const int src_slot = e.output_slot; bool dst_dead = false; @@ -3170,7 +2269,8 @@ void ExecutorState::FrameState::ActivateNodesSlowPath(const NodeItem* item, for (const ControlEdgeInfo& e : item->output_control_edges()) { const int dst_id = e.dst_id; const NodeItem* dst_item = gview.node(dst_id); - const PendingCounts::Handle dst_pending_id = executor->pending_ids_[dst_id]; + const PendingCounts::Handle dst_pending_id = + immutable_state.pending_ids()[dst_id]; bool dst_dead; bool dst_ready; @@ -3302,7 +2402,8 @@ bool ExecutorState::FrameState::CleanupIterations(const GraphView* gview, } void ExecutorImpl::RunAsync(const Args& args, DoneCallback done) { - (new ExecutorState(args, this))->RunAsync(std::move(done)); + (new ExecutorState(args, immutable_state_, &kernel_stats_)) + ->RunAsync(std::move(done)); } } // namespace diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc index 9f4774d21ca..c0d69cf6d93 100644 --- a/tensorflow/core/common_runtime/function_test.cc +++ b/tensorflow/core/common_runtime/function_test.cc @@ -164,10 +164,11 @@ class FunctionLibraryRuntimeTest : public ::testing::Test { TF_GRAPH_DEF_VERSION, lib_def_.get(), opts, /*thread_pool=*/nullptr, /*parent=*/nullptr, /*custom_kernel_creator=*/nullptr, /*session_metadata=*/nullptr, - [](const int64, const DeviceMgr* device_mgr, Rendezvous** r) { - *r = new IntraProcessRendezvous(device_mgr); - return Status::OK(); - })); + Rendezvous::Factory{ + [](const int64, const DeviceMgr* device_mgr, Rendezvous** r) { + *r = new IntraProcessRendezvous(device_mgr); + return Status::OK(); + }})); flr0_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:0"); flr1_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:1"); flr2_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:2"); diff --git a/tensorflow/core/common_runtime/function_threadpool_test.cc b/tensorflow/core/common_runtime/function_threadpool_test.cc index 3b14bd4a0f2..3b1f90e7198 100644 --- a/tensorflow/core/common_runtime/function_threadpool_test.cc +++ b/tensorflow/core/common_runtime/function_threadpool_test.cc @@ -68,10 +68,11 @@ class FunctionLibraryRuntimeTest : public ::testing::Test { TF_GRAPH_DEF_VERSION, lib_def_.get(), opts, default_thread_pool, /*parent=*/nullptr, /*custom_kernel_creator=*/nullptr, /*session_metadata=*/nullptr, - [](const int64, const DeviceMgr* device_mgr, Rendezvous** r) { - *r = new IntraProcessRendezvous(device_mgr); - return Status::OK(); - })); + Rendezvous::Factory{ + [](const int64, const DeviceMgr* device_mgr, Rendezvous** r) { + *r = new IntraProcessRendezvous(device_mgr); + return Status::OK(); + }})); flr0_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:0"); } diff --git a/tensorflow/core/common_runtime/gpu/gpu_device.cc b/tensorflow/core/common_runtime/gpu/gpu_device.cc index dcc40c3d3de..da6a2eadea2 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_device.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_device.cc @@ -340,7 +340,7 @@ Status BaseGPUDevice::InitScratchBuffers() { if (!scratch_) { DCHECK(stream_); size_t scratch_buffer_size = Eigen::kGpuScratchSize + sizeof(unsigned int); - MEMDEBUG_CACHE_OP("ScratchBuffer"); + auto op_annotation = ScopedMemoryDebugAnnotation("ScratchBuffer"); void* scratch_buffer = gpu_allocator_->AllocateRaw( Allocator::kAllocatorAlignment, scratch_buffer_size); if (scratch_buffer == nullptr) { @@ -498,8 +498,8 @@ void BaseGPUDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) { } } ScopedActivateExecutorContext scoped_activation{stream->parent()}; - MEMDEBUG_CACHE_OP(op_kernel->name().c_str()); - MEMDEBUG_CACHE_STEPID(context->step_id()); + auto op_annotation = ScopedMemoryDebugAnnotation( + op_kernel->name_view().data(), context->step_id()); op_kernel->Compute(context); if (context->status().ok()) { if (sync_every_op_) { @@ -612,7 +612,7 @@ Status BaseGPUDevice::MaybeCopyTensorToGPU( Status BaseGPUDevice::MakeTensorFromProto(const TensorProto& tensor_proto, const AllocatorAttributes alloc_attrs, Tensor* tensor) { - MEMDEBUG_CACHE_OP( + auto op_annotation = ScopedMemoryDebugAnnotation( (pending_op_name != nullptr ? pending_op_name : "MakeTensorFromProto")); AllocatorAttributes attr; attr.set_on_host(true); diff --git a/tensorflow/core/common_runtime/graph_view.cc b/tensorflow/core/common_runtime/graph_view.cc new file mode 100644 index 00000000000..7db0781551d --- /dev/null +++ b/tensorflow/core/common_runtime/graph_view.cc @@ -0,0 +1,442 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/common_runtime/graph_view.h" + +#include +#include +#include +#include +#include +#include + +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/graph/edgeset.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/util/device_name_utils.h" + +namespace tensorflow { + +string NodeItem::DebugString() const { + string ret = strings::StrCat("{name:'", kernel->name(), "' id:", node_id); + if (is_source) { + strings::StrAppend(&ret, " source}"); + } else { + strings::StrAppend(&ret, " def:{", SummarizeNodeDef(kernel->def()), "}}"); + } + return ret; +} + +GraphView::~GraphView() { + static_assert(std::is_trivially_destructible::value, + "Update code if AllocatorAttributes gains a destructor"); + static_assert(std::is_trivially_destructible::value, + "Update code if EdgeInfo gains a destructor"); + for (int i = 0; i < num_nodes_; i++) { + NodeItem* n = node(i); + if (n != nullptr) { + n->NodeItem::~NodeItem(); + // Memory for "n" itself is held in space_ & gets cleaned up below + } + } + delete[] node_offsets_; + delete[] space_; +} + +namespace { +typedef std::tuple OutputAndControlEdges; + +OutputAndControlEdges CountOutputEdges(const Node* n) { + DCHECK_LE(n->out_edges().size(), kint32max); + int32 num_output_edges = 0; + int32 num_output_control_edges = 0; + for (auto e : n->out_edges()) { + if (IsSink(e->dst())) continue; + if (e->IsControlEdge()) { + ++num_output_control_edges; + } else { + ++num_output_edges; + } + } + return OutputAndControlEdges(num_output_edges, num_output_control_edges); +} +} // namespace + +size_t GraphView::NodeItemBytes(const Node* n) { + int32 num_output_edges; + int32 num_output_control_edges; + std::tie(num_output_edges, num_output_control_edges) = CountOutputEdges(n); + const int num_inputs = n->num_inputs(); + const int num_outputs = n->num_outputs(); + + // Compute number of bytes needed for NodeItem and variable length data. + // We do not subtract sizeof(var) since num_inputs/num_outputs might + // both be zero. + const size_t raw_bytes = + sizeof(NodeItem) // Fixed + + num_output_edges * sizeof(EdgeInfo) // output_edges[...] + + num_output_control_edges * // + sizeof(ControlEdgeInfo) // output_control_edges[...] + + num_outputs * sizeof(AllocatorAttributes) // output_attr[...] + + num_outputs * sizeof(int) // forward_from[num_outputs] + + num_inputs * sizeof(uint8) // input_type[num_inputs] + + num_outputs * sizeof(uint8); // output_type[num_outputs] + static constexpr size_t kItemAlignment = sizeof(NodeItem*); + static_assert(kItemAlignment % alignof(NodeItem) == 0, + "NodeItem must be aligned with kItemAlignment"); + static_assert(kItemAlignment % alignof(EdgeInfo) == 0, + "EdgeInfo must be aligned with kItemAlignment"); + static_assert(kItemAlignment % alignof(ControlEdgeInfo) == 0, + "ControlEdgeInfo must be aligned with kItemAlignment"); + static_assert(kItemAlignment % alignof(AllocatorAttributes) == 0, + "AllocatorAttributes must be aligned with kItemAlignment"); + static_assert(sizeof(NodeItem) % alignof(EdgeInfo) == 0, + "NodeItem must be aligned with EdgeInfo"); + static_assert(sizeof(NodeItem) % alignof(AllocatorAttributes) == 0, + "NodeItem must be aligned with AllocatorAttributes"); + static_assert(sizeof(EdgeInfo) % alignof(AllocatorAttributes) == 0, + "EdgeInfo must be aligned with AllocatorAttributes"); + const size_t bytes = + ((raw_bytes + kItemAlignment - 1) / kItemAlignment) * kItemAlignment; + return bytes; +} + +char* GraphView::InitializeNode(char* ptr, const Node* n) { + const int id = n->id(); + CHECK(node_offsets_[id] == kuint32max); // Initial value in constructor + + const size_t bytes = NodeItemBytes(n); + constexpr size_t kItemAlignment = sizeof(NodeItem*); + CHECK_EQ(reinterpret_cast(ptr) % kItemAlignment, 0); + NodeItem* item = reinterpret_cast(ptr); + + // We store a 32-bit offset relative to the beginning of space_, so that we + // only need an array of 32-bit values to map from node id to the NodeItem*, + // (versus 64 bits on most machines if we just stored an array of NodeItem* + // pointers). Casting to int64 is needed on 32bit CPU to avoid comparing + // values as "int" vs "size_t" in CHECK_LE. + CHECK_LE(static_cast(ptr - space_), kuint32max); + const uint32 offset = static_cast(ptr - space_); + node_offsets_[id] = offset; + ptr += bytes; + + int32 num_output_edges; + int32 num_output_control_edges; + std::tie(num_output_edges, num_output_control_edges) = CountOutputEdges(n); + const int num_inputs = n->num_inputs(); + const int num_outputs = n->num_outputs(); + + new (item) NodeItem(); + item->num_inputs = num_inputs; + item->num_outputs = num_outputs; + item->num_output_edges = num_output_edges; + item->num_output_control_edges = num_output_control_edges; + + // Fill output edges. + // Keep track of the last EdgeInfo in the EdgeInfo array that references + // a given output slot. For all but the last, we need to do a copy of the + // Tensor when propagating results downstream in the graph, but for the + // last one, we can just do a move of the Tensor object to propagate it. + gtl::InlinedVector last_indices(num_outputs, nullptr); + EdgeInfo* dst_edge = item->output_edge_base(); + for (auto e : n->out_edges()) { + if (e->IsControlEdge()) continue; + dst_edge->dst_id = e->dst()->id(); + CHECK_LE(e->src_output(), 0x3FFFFFFF); // Must fit in 31 bits + dst_edge->output_slot = e->src_output(); + dst_edge->is_last = false; + const int output_slot = dst_edge->output_slot; + if (output_slot >= 0) { + last_indices[output_slot] = dst_edge; + } + // NOTE: The `input_slot` will be rewritten to the frame-wide offset later + // in `ExecutorImpl::Initialize()`. + dst_edge->input_slot = e->dst_input(); + dst_edge++; + } + for (EdgeInfo* edge_info : last_indices) { + if (edge_info != nullptr) { + edge_info->is_last = true; + } + } + ControlEdgeInfo* dst_control_edge = item->output_control_edge_base(); + for (auto e : n->out_edges()) { + if (!e->IsControlEdge() || IsSink(e->dst())) continue; + dst_control_edge->dst_id = e->dst()->id(); + dst_control_edge++; + } + + AllocatorAttributes* output_attrs = item->output_attr_base(); + for (int i = 0; i < num_outputs; i++) { + new (&output_attrs[i]) AllocatorAttributes(); + } + + DCHECK_LT(DataType_MAX, 255); // Must fit in uint8 + uint8* input_types = item->input_type_base(); + for (int i = 0; i < num_inputs; i++) { + input_types[i] = static_cast(n->input_type(i)); + DCHECK_EQ(item->input_type(i), n->input_type(i)); + } + + // Check ScopedAllocatorAttrs and forward_from. Also assign output_types. + { + std::vector forward_input; + Status fwd_status = + GetNodeAttr(n->attrs(), "_forward_input", &forward_input); + std::vector scoped_allocator_attrs; + Status sa_status = + GetNodeAttr(n->attrs(), "_scoped_allocator", &scoped_allocator_attrs); + + int* forward_from = item->forward_from_base(); + uint8* output_types = item->output_type_base(); + for (int i = 0; i < num_outputs; ++i) { + output_types[i] = static_cast(n->output_type(i)); + DCHECK_EQ(item->output_type(i), n->output_type(i)); + + forward_from[i] = OpKernelContext::Params::kNoReservation; + if (sa_status.ok()) { + for (int j = 0; j < scoped_allocator_attrs.size(); j += 2) { + if (scoped_allocator_attrs[j] == i) { + // This output slot must be explicitly allocated from a + // ScopedAllocator. + forward_from[i] = OpKernelContext::Params::kNeverForward; + DCHECK_EQ(output_attrs[i].scope_id, 0); + output_attrs[i].scope_id = scoped_allocator_attrs[j + 1]; + } + } + } + if (fwd_status.ok() && + forward_from[i] == OpKernelContext::Params::kNoReservation) { + DCHECK_EQ(forward_input.size() % 2, 0); + for (int j = 0; j < forward_input.size(); j += 2) { + if (forward_input[j + 1] == i) { + DCHECK_EQ(forward_from[i], OpKernelContext::Params::kNoReservation); + forward_from[i] = forward_input[j]; + break; + } + } + } + } + } + + return ptr; +} + +Status GraphView::Initialize(const Graph* g) { + CHECK(node_offsets_ == nullptr); + const int num_nodes = g->num_node_ids(); + num_nodes_ = num_nodes; + size_t total_bytes = 0; + for (const Node* n : g->nodes()) { + if (n->out_edges().size() > kint32max) { + return errors::InvalidArgument( + "The executor cannot handle nodes with more than ", kint32max, + " output edges. Node ", n->name(), " had ", n->out_edges().size(), + " output edges."); + } + total_bytes += NodeItemBytes(n); + } + + node_offsets_ = new uint32[num_nodes]; + for (int i = 0; i < num_nodes; i++) { + node_offsets_[i] = kuint32max; + } + + space_ = new char[total_bytes]; // NodeItem objects are allocated here + char* ptr = space_; + for (const Node* n : g->nodes()) { + ptr = InitializeNode(ptr, n); + } + CHECK_EQ(ptr, space_ + total_bytes); + return Status::OK(); +} + +namespace { +// If a Node has been marked to use a ScopedAllocator x for output i, then +// sc_attr will contain the subsequence (i, x) at an even offset. This function +// extracts and transfers that ScopedAllocator id to alloc_attr. For now, we +// only allow one ScopedAllocator use per Node. +bool ExtractScopedAllocatorAttr(const std::vector& sc_attr, + int output_index, + AllocatorAttributes* alloc_attr) { + DCHECK_LE(2, sc_attr.size()); + for (int i = 0; i < sc_attr.size(); i += 2) { + if (sc_attr[i] == output_index) { + CHECK_EQ(alloc_attr->scope_id, 0); + alloc_attr->scope_id = sc_attr[i + 1]; + return true; + } + } + return false; +} +} // namespace + +void GraphView::SetScopedAllocatorAttrs( + const std::vector& sa_nodes) { + for (const Node* sa : sa_nodes) { + NodeItem* sa_item = node(sa->id()); + AllocatorAttributes* sa_attrs = sa_item->output_attr_base(); + // Control edges out of the ScopedAllocator should be use instances, but may + // include a few other nodes. + for (const auto& e : sa->out_edges()) { + if (IsSink(e->dst()) || !e->IsControlEdge()) { + continue; + } + Node* use_node = e->dst(); + NodeItem* item = node(use_node->id()); + AllocatorAttributes* use_attrs = item->output_attr_base(); + std::vector scoped_allocator_attrs; + Status s = GetNodeAttr(use_node->attrs(), "_scoped_allocator", + &scoped_allocator_attrs); + if (!s.ok()) { + VLOG(2) << "Failed to find expected ScopedAllocator attr on " + << use_node->name(); + continue; + } + // There can be more than one output using ScopedAllocation, but this + // analysis assumes they use the same ScopedAllocator. + for (const auto& e : use_node->out_edges()) { + if (IsSink(e->dst()) || !e->IsControlEdge()) { + AllocatorAttributes attr; + if (ExtractScopedAllocatorAttr(scoped_allocator_attrs, + e->src_output(), &attr)) { + // Set the scope_id on this use instance node. + (use_attrs + e->src_output())->Merge(attr); + // Propagate the other attributes of this node back to the SA node. + attr = *(use_attrs + e->src_output()); + attr.scope_id = 0; + sa_attrs->Merge(attr); + } + } + } + } + } +} + +namespace { +Status InferAllocAttr(const Node* n, const Node* dst, + const DeviceNameUtils::ParsedName& local_dev_name, + AllocatorAttributes* attr) { + Status s; + // Note that it's possible for *n to be a Recv and *dst to be a Send, + // so these two cases are not mutually exclusive. + if (IsRecv(n)) { + string src_name; + s = GetNodeAttr(n->attrs(), "send_device", &src_name); + if (!s.ok()) return s; + DeviceNameUtils::ParsedName parsed_src_name; + if (!DeviceNameUtils::ParseFullName(src_name, &parsed_src_name)) { + s = errors::Internal("Bad send_device attr '", src_name, "' in node ", + n->name()); + return s; + } + if (!DeviceNameUtils::IsSameAddressSpace(parsed_src_name, local_dev_name)) { + // Value is going to be the sink of an RPC. + attr->set_nic_compatible(true); + VLOG(2) << "node " << n->name() << " is the sink of an RPC in"; + } else if ((local_dev_name.type == "CPU" || n->IsHostRecv()) && + parsed_src_name.type != "CPU") { + // Value is going to be the sink of a local DMA from GPU to CPU (or + // other types of accelerators). + attr->set_gpu_compatible(true); + VLOG(2) << "node " << n->name() << " is the sink of a gpu->cpu copy"; + } else { + VLOG(2) << "default alloc case local type " << local_dev_name.type + << " remote type " << parsed_src_name.type; + } + } + if (IsSend(dst)) { + string dst_name; + s = GetNodeAttr(dst->attrs(), "recv_device", &dst_name); + if (!s.ok()) return s; + DeviceNameUtils::ParsedName parsed_dst_name; + if (!DeviceNameUtils::ParseFullName(dst_name, &parsed_dst_name)) { + s = errors::Internal("Bad recv_device attr '", dst_name, "' in node ", + n->name()); + return s; + } + if (!DeviceNameUtils::IsSameAddressSpace(parsed_dst_name, local_dev_name)) { + // Value is going to be the source of an RPC. + attr->set_nic_compatible(true); + VLOG(2) << "node " << n->name() << " is the source of an RPC out"; + } else if ((local_dev_name.type == "CPU" || dst->IsHostSend()) && + parsed_dst_name.type != "CPU") { + // Value is going to be the source of a local DMA from CPU to GPU (or + // other types of accelerators). + // Note that this does not cover the case where the allocation of the + // output tensor is not generated by the src: n. + attr->set_gpu_compatible(true); + VLOG(2) << "node " << n->name() << " is the source of a cpu->gpu copy"; + } else { + VLOG(2) << "default alloc case local type " << local_dev_name.type + << " remote type " << parsed_dst_name.type; + } + } + if (n->IsCollective()) { + // We'll make the sweeping assumption that any collective op is going + // to be involved in network i/o. + attr->set_nic_compatible(true); + } + return s; +} +} // namespace + +Status GraphView::SetAllocAttrs(const Graph* g, const Device* device) { + Status s; + DeviceNameUtils::ParsedName local_dev_name = device->parsed_name(); + + std::vector scoped_allocator_instances; + for (const Node* n : g->nodes()) { + NodeItem* item = node(n->id()); + AllocatorAttributes* attrs = item->output_attr_base(); + if (IsScopedAllocator(n)) { + scoped_allocator_instances.push_back(n); + } + + // Examine the out edges of each node looking for special use + // cases that may affect memory allocation attributes. + for (const auto& e : n->out_edges()) { + if (!e->IsControlEdge()) { + AllocatorAttributes attr; + s = InferAllocAttr(n, e->dst(), local_dev_name, &attr); + if (!s.ok()) return s; + if (attr.value != 0 || attr.scope_id != 0) { + attrs[e->src_output()].Merge(attr); + } + } + } + + for (int out = 0; out < n->num_outputs(); out++) { + const OpKernel* op_kernel = item->kernel; + DCHECK_LT(out, op_kernel->output_memory_types().size()); + bool on_host = op_kernel->output_memory_types()[out] == HOST_MEMORY; + if (on_host) { + AllocatorAttributes h; + h.set_on_host(on_host); + attrs[out].Merge(h); + } + } + } + SetScopedAllocatorAttrs(scoped_allocator_instances); + return s; +} + +} // namespace tensorflow diff --git a/tensorflow/core/common_runtime/graph_view.h b/tensorflow/core/common_runtime/graph_view.h new file mode 100644 index 00000000000..b0bc0f4b6de --- /dev/null +++ b/tensorflow/core/common_runtime/graph_view.h @@ -0,0 +1,240 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_VIEW_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_VIEW_H_ + +#include +#include + +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +class Device; +class Graph; +class Node; +class OpKernel; +class Tensor; + +// Represents a single data edge in a `NodeItem`. +struct EdgeInfo { + // The node ID of the destination in the containing `GraphView`. + int dst_id; + // The index of the output that produces values on this edge. + int output_slot : 31; + // true if this is the last info for output_slot in the EdgeInfo list. + bool is_last : 1; + // The index of the input that consumes values on this edge. + int input_slot; +}; + +// Represents a single control edge in a `NodeItem`. +struct ControlEdgeInfo { + // The node ID of the destination in the containing `GraphView`. + int dst_id; +}; + +// Compact structure representing a graph node and its associated kernel. +// +// Each NodeItem is an element of exactly one GraphView. +struct NodeItem { + // The index of this node's item in its GraphView. + int node_id = -1; + + // Cached attributes of this node for fast lookup. + bool kernel_is_async : 1; // True iff kernel->AsAsync() != nullptr + bool is_merge : 1; // True iff IsMerge(node) + bool is_enter : 1; // True iff IsEnter(node) + bool is_constant_enter : 1; // True iff IsEnter(node) and + // node->GetAttr("is_constant") == true. + bool is_exit : 1; // True iff IsExit(node) + bool is_control_trigger : 1; // True iff IsControlTrigger(node) + bool is_source : 1; // True iff IsSource(node) + // True iff IsEnter(node) || IsExit(node) || IsNextIteration(node) + bool is_enter_exit_or_next_iter : 1; + bool is_transfer_node : 1; // True iff IsTransferNode(node) + bool is_initialization_op : 1; // True iff IsInitializationOp(node) + bool is_recv_or_switch : 1; // True iff IsRecv(node) || IsSwitch(node) + bool is_next_iteration : 1; // True iff IsNextIteration(node) + bool is_noop : 1; // True iff item->kernel->type_string_view() == "NoOp") + bool + is_any_consumer_merge_or_control_trigger : 1; // True iff the destination + // of any output edge is a + // merge or control trigger + // node. + + // The kernel for this node. + OpKernel* kernel = nullptr; + + // If the kernel is a Const op, this containts points to the constant tensor. + const Tensor* const_tensor = nullptr; + + // Cached values of node->num_inputs() and node->num_outputs(), to + // avoid levels of indirection. + int num_inputs; + int num_outputs; + + // ExecutorImpl::tensors_[input_start] is the 1st positional input + // for this node. + int input_start = 0; + + // Number of output edges, excluding control edges. + int32 num_output_edges; + + // Number of output control edges. + int32 num_output_control_edges; + + // If non-null, contains an array of num_outputs bools, where the ith bool + // is true if and only if the ith output is consumed by another node. + std::unique_ptr outputs_required; + + gtl::MutableArraySlice mutable_output_edges() { + return gtl::MutableArraySlice(output_edge_base(), + num_output_edges); + } + + gtl::ArraySlice output_edges() const { + return gtl::ArraySlice(output_edge_base(), num_output_edges); + } + + gtl::ArraySlice output_control_edges() const { + return gtl::ArraySlice(output_control_edge_base(), + num_output_control_edges); + } + + DataType input_type(int i) const { + DCHECK_LT(i, num_inputs); + return static_cast(input_type_base()[i]); + } + DataType output_type(int i) const { + DCHECK_LT(i, num_outputs); + return static_cast(output_type_base()[i]); + } + + // Return array of per-output allocator attributes. + const AllocatorAttributes* output_attrs() const { return output_attr_base(); } + + // Return array of expected input index from which each output should + // be forwarded: + // kNeverForward (-2) for DO NOT FORWARD (must allocate). + // kNoReservation (-1) for no expected forwarding. + // 0... for forward from that input. + const int* forward_from() const { return forward_from_base(); } + + string DebugString() const; + + private: + friend class GraphView; + + NodeItem() {} + + // Variable length section starts immediately after *this + // (uint8 is enough for DataType). + // EdgeInfo out_edges[num_output_edges]; + // ControlEdgeInfo out_control_edges[num_output_control_edges]; + // AllocatorAttributes output_attr[num_outputs]; + // int forward_from[num_outputs]; + // uint8 input_type[num_inputs]; + // uint8 output_type[num_outputs]; + + // Return pointer to variable length section. + char* var() const { + return const_cast(reinterpret_cast(this) + + sizeof(NodeItem)); + } + + EdgeInfo* output_edge_base() const { + return reinterpret_cast(var()); + } + + ControlEdgeInfo* output_control_edge_base() const { + return reinterpret_cast(var() + sizeof(EdgeInfo) * + num_output_edges); + } + + AllocatorAttributes* output_attr_base() const { + return reinterpret_cast( + var() + sizeof(EdgeInfo) * num_output_edges + + sizeof(ControlEdgeInfo) * num_output_control_edges); + } + int* forward_from_base() const { + return reinterpret_cast(var() + sizeof(EdgeInfo) * num_output_edges + + sizeof(ControlEdgeInfo) * + num_output_control_edges + + sizeof(AllocatorAttributes) * num_outputs); + } + uint8* input_type_base() const { + return reinterpret_cast( + var() + sizeof(EdgeInfo) * num_output_edges + + sizeof(ControlEdgeInfo) * num_output_control_edges + + sizeof(AllocatorAttributes) * num_outputs + sizeof(int) * num_outputs); + } + uint8* output_type_base() const { + return reinterpret_cast( + var() + sizeof(EdgeInfo) * num_output_edges + + sizeof(ControlEdgeInfo) * num_output_control_edges + + sizeof(AllocatorAttributes) * num_outputs + sizeof(int) * num_outputs + + sizeof(uint8) * num_inputs); + } + + TF_DISALLOW_COPY_AND_ASSIGN(NodeItem); +}; + +// Immutable view of a Graph organized for efficient execution. +// +// TODO(b/152651962): Add independent unit tests for this class. +class GraphView { + public: + GraphView() : space_(nullptr) {} + ~GraphView(); + + Status Initialize(const Graph* g); + Status SetAllocAttrs(const Graph* g, const Device* device); + void SetScopedAllocatorAttrs(const std::vector& sa_nodes); + + NodeItem* node(int32 id) const { + DCHECK_GE(id, 0); + DCHECK_LT(id, num_nodes_); + uint32 offset = node_offsets_[id]; + return ((offset == kuint32max) + ? nullptr + : reinterpret_cast(space_ + node_offsets_[id])); + } + + int32 num_nodes() const { return num_nodes_; } + + private: + char* InitializeNode(char* ptr, const Node* n); + size_t NodeItemBytes(const Node* n); + + int32 num_nodes_ = 0; + uint32* node_offsets_ = nullptr; // array of size "num_nodes_" + // node_offsets_[id] holds the byte offset for node w/ "id" in space_ + + char* space_; // NodeItem objects are allocated here + + TF_DISALLOW_COPY_AND_ASSIGN(GraphView); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_VIEW_H_ diff --git a/tensorflow/core/common_runtime/hierarchical_tree_broadcaster.cc b/tensorflow/core/common_runtime/hierarchical_tree_broadcaster.cc index 66f77bd403e..df4cd4bffbb 100644 --- a/tensorflow/core/common_runtime/hierarchical_tree_broadcaster.cc +++ b/tensorflow/core/common_runtime/hierarchical_tree_broadcaster.cc @@ -409,7 +409,8 @@ void HierarchicalTreeBroadcaster::DispatchSend(int subdiv, int dst_rank, int src_rank, const Tensor* src_tensor, const StatusCallback& done) { - MEMDEBUG_CACHE_OP(col_ctx_->op_ctx->op_kernel().name().c_str()); + auto op_annotation = ScopedMemoryDebugAnnotation( + col_ctx_->op_ctx->op_kernel().name_view().data()); string send_buf_key = BroadcastBufKey(col_ctx_->exec_key, subdiv, src_rank, dst_rank); int dst_idx = diff --git a/tensorflow/core/common_runtime/immutable_executor_state.cc b/tensorflow/core/common_runtime/immutable_executor_state.cc new file mode 100644 index 00000000000..97c17aa287d --- /dev/null +++ b/tensorflow/core/common_runtime/immutable_executor_state.cc @@ -0,0 +1,319 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/common_runtime/immutable_executor_state.h" + +#include "absl/memory/memory.h" +#include "tensorflow/core/common_runtime/metrics.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/graph/edgeset.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/graph_node_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { + +namespace { +bool IsInitializationOp(const Node* node) { + return node->op_def().allows_uninitialized_input(); +} +} // namespace + +ImmutableExecutorState::~ImmutableExecutorState() { + for (int32 i = 0; i < gview_.num_nodes(); i++) { + NodeItem* item = gview_.node(i); + if (item != nullptr) { + params_.delete_kernel(item->kernel); + } + } + for (auto fiter : frame_info_) { + delete fiter.second; + } +} + +namespace { +void GetMaxPendingCounts(const Node* n, size_t* max_pending, + size_t* max_dead_count) { + const size_t num_in_edges = n->in_edges().size(); + size_t initial_count; + if (IsMerge(n)) { + // merge waits all control inputs so we initialize the pending + // count to be the number of control edges. + int32 num_control_edges = 0; + for (const Edge* edge : n->in_edges()) { + if (edge->IsControlEdge()) { + num_control_edges++; + } + } + // Use bit 0 to indicate if we are waiting for a ready live data input. + initial_count = 1 + (num_control_edges << 1); + } else { + initial_count = num_in_edges; + } + + *max_pending = initial_count; + *max_dead_count = num_in_edges; +} +} // namespace + +ImmutableExecutorState::FrameInfo* ImmutableExecutorState::EnsureFrameInfo( + const string& fname) { + auto slot = &frame_info_[fname]; + if (*slot == nullptr) { + *slot = new FrameInfo; + } + return *slot; +} + +Status ImmutableExecutorState::Initialize(const Graph& graph) { + TF_RETURN_IF_ERROR(gview_.Initialize(&graph)); + + // Build the information about frames in this subgraph. + ControlFlowInfo cf_info; + TF_RETURN_IF_ERROR(BuildControlFlowInfo(&graph, &cf_info)); + + for (auto& it : cf_info.unique_frame_names) { + EnsureFrameInfo(it)->nodes = + absl::make_unique>(); + } + + pending_ids_.resize(gview_.num_nodes()); + + // Preprocess every node in the graph to create an instance of op + // kernel for each node. + for (const Node* n : graph.nodes()) { + if (IsSink(n)) continue; + const int id = n->id(); + const string& frame_name = cf_info.frame_names[id]; + FrameInfo* frame_info = EnsureFrameInfo(frame_name); + + NodeItem* item = gview_.node(id); + item->node_id = id; + + item->input_start = frame_info->total_inputs; + frame_info->total_inputs += n->num_inputs(); + + Status s = params_.create_kernel(n->properties(), &item->kernel); + if (!s.ok()) { + item->kernel = nullptr; + s = AttachDef(s, *n); + return s; + } + CHECK(item->kernel); + item->kernel_is_async = (item->kernel->AsAsync() != nullptr); + item->is_merge = IsMerge(n); + item->is_any_consumer_merge_or_control_trigger = false; + for (const Node* consumer : n->out_nodes()) { + if (IsMerge(consumer) || IsControlTrigger(consumer)) { + item->is_any_consumer_merge_or_control_trigger = true; + break; + } + } + const Tensor* const_tensor = item->kernel->const_tensor(); + if (const_tensor) { + // Hold onto a shallow copy of the constant tensor in `*this` so that the + // reference count does not drop to 1. This prevents the constant tensor + // from being forwarded, and its buffer reused. + const_tensors_.emplace_back(*const_tensor); + } + item->const_tensor = const_tensor; + item->is_noop = (item->kernel->type_string_view() == "NoOp"); + item->is_enter = IsEnter(n); + if (item->is_enter) { + bool is_constant_enter; + TF_RETURN_IF_ERROR( + GetNodeAttr(n->attrs(), "is_constant", &is_constant_enter)); + item->is_constant_enter = is_constant_enter; + } else { + item->is_constant_enter = false; + } + item->is_exit = IsExit(n); + item->is_control_trigger = IsControlTrigger(n); + item->is_source = IsSource(n); + item->is_enter_exit_or_next_iter = + (IsEnter(n) || IsExit(n) || IsNextIteration(n)); + item->is_transfer_node = IsTransferNode(n); + item->is_initialization_op = IsInitializationOp(n); + item->is_recv_or_switch = IsRecv(n) || IsSwitch(n); + item->is_next_iteration = IsNextIteration(n); + + // Compute the maximum values we'll store for this node in the + // pending counts data structure, and allocate a handle in + // that frame's pending counts data structure that has enough + // space to store these maximal count values. + size_t max_pending, max_dead; + GetMaxPendingCounts(n, &max_pending, &max_dead); + pending_ids_[id] = + frame_info->pending_counts_layout.CreateHandle(max_pending, max_dead); + + // See if this node is a root node, and if so, add item to root_nodes_. + if (n->in_edges().empty()) { + root_nodes_.push_back(item); + } + + // Initialize static information about the frames in the graph. + frame_info->nodes->push_back(item); + if (item->is_enter) { + string enter_name; + TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "frame_name", &enter_name)); + EnsureFrameInfo(enter_name)->input_count++; + } + + // Record information about whether each output of the op is used. + std::unique_ptr outputs_required(new bool[n->num_outputs()]); + std::fill(&outputs_required[0], &outputs_required[n->num_outputs()], false); + int32 unused_outputs = n->num_outputs(); + for (const Edge* e : n->out_edges()) { + if (IsSink(e->dst())) continue; + if (e->src_output() >= 0) { + if (!outputs_required[e->src_output()]) { + --unused_outputs; + outputs_required[e->src_output()] = true; + } + } + } + if (unused_outputs > 0) { + for (int i = 0; i < n->num_outputs(); ++i) { + if (!outputs_required[i]) { + metrics::RecordUnusedOutput(n->type_string()); + } + } + item->outputs_required = std::move(outputs_required); + } + } + + // Rewrite each `EdgeInfo::input_slot` member to refer directly to the input + // location. + for (const Node* n : graph.nodes()) { + if (IsSink(n)) continue; + const int id = n->id(); + NodeItem* item = gview_.node(id); + + for (EdgeInfo& e : item->mutable_output_edges()) { + const int dst_id = e.dst_id; + NodeItem* dst_item = gview_.node(dst_id); + e.input_slot += dst_item->input_start; + } + } + + // Initialize PendingCounts only after pending_ids_[node.id] is initialized + // for all nodes. + InitializePending(&graph, cf_info); + return gview_.SetAllocAttrs(&graph, params_.device); +} + +namespace { +// If a Node has been marked to use a ScopedAllocator x for output i, then +// sc_attr will contain the subsequence (i, x) at an even offset. This function +// extracts and transfers that ScopedAllocator id to alloc_attr. For now, we +// only allow one ScopedAllocator use per Node. +bool ExtractScopedAllocatorAttr(const std::vector& sc_attr, + int output_index, + AllocatorAttributes* alloc_attr) { + DCHECK_LE(2, sc_attr.size()); + for (int i = 0; i < sc_attr.size(); i += 2) { + if (sc_attr[i] == output_index) { + CHECK_EQ(alloc_attr->scope_id, 0); + alloc_attr->scope_id = sc_attr[i + 1]; + return true; + } + } + return false; +} +} // namespace + +Status ImmutableExecutorState::BuildControlFlowInfo(const Graph* g, + ControlFlowInfo* cf_info) { + const int num_nodes = g->num_node_ids(); + cf_info->frame_names.resize(num_nodes); + std::vector parent_nodes; + parent_nodes.resize(num_nodes); + std::vector visited; + visited.resize(num_nodes); + + string frame_name; + std::deque ready; + + // Initialize with the root nodes. + for (Node* n : g->nodes()) { + if (n->in_edges().empty()) { + visited[n->id()] = true; + cf_info->unique_frame_names.insert(frame_name); + ready.push_back(n); + } + } + + while (!ready.empty()) { + Node* curr_node = ready.front(); + int curr_id = curr_node->id(); + ready.pop_front(); + + Node* parent = nullptr; + if (IsEnter(curr_node)) { + // Enter a child frame. + TF_RETURN_IF_ERROR( + GetNodeAttr(curr_node->attrs(), "frame_name", &frame_name)); + parent = curr_node; + } else if (IsExit(curr_node)) { + // Exit to the parent frame. + parent = parent_nodes[curr_id]; + frame_name = cf_info->frame_names[parent->id()]; + parent = parent_nodes[parent->id()]; + } else { + parent = parent_nodes[curr_id]; + frame_name = cf_info->frame_names[curr_id]; + } + + for (const Edge* out_edge : curr_node->out_edges()) { + Node* out = out_edge->dst(); + if (IsSink(out)) continue; + const int out_id = out->id(); + + // Add to ready queue if not visited. + bool is_visited = visited[out_id]; + if (!is_visited) { + ready.push_back(out); + visited[out_id] = true; + + // Process the node 'out'. + cf_info->frame_names[out_id] = frame_name; + parent_nodes[out_id] = parent; + cf_info->unique_frame_names.insert(frame_name); + } + } + } + + return Status::OK(); +} + +void ImmutableExecutorState::InitializePending(const Graph* graph, + const ControlFlowInfo& cf_info) { + for (auto& it : cf_info.unique_frame_names) { + FrameInfo* finfo = EnsureFrameInfo(it); + DCHECK_EQ(finfo->pending_counts, nullptr); + finfo->pending_counts = + absl::make_unique(finfo->pending_counts_layout); + } + for (const Node* n : graph->nodes()) { + if (IsSink(n)) continue; + const int id = n->id(); + const string& name = cf_info.frame_names[id]; + size_t max_pending, max_dead; + GetMaxPendingCounts(n, &max_pending, &max_dead); + auto& counts = EnsureFrameInfo(name)->pending_counts; + counts->set_initial_count(pending_ids_[id], max_pending); + } +} +} // namespace tensorflow diff --git a/tensorflow/core/common_runtime/immutable_executor_state.h b/tensorflow/core/common_runtime/immutable_executor_state.h new file mode 100644 index 00000000000..c9c23e55a21 --- /dev/null +++ b/tensorflow/core/common_runtime/immutable_executor_state.h @@ -0,0 +1,126 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_IMMUTABLE_EXECUTOR_STATE_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_IMMUTABLE_EXECUTOR_STATE_H_ + +#include +#include +#include + +#include "tensorflow/core/common_runtime/executor.h" +#include "tensorflow/core/common_runtime/graph_view.h" +#include "tensorflow/core/common_runtime/pending_counts.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/lib/gtl/flatset.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +class Graph; + +// Represents the state of an executor (graph and control flow information) +// that is immutable throughout execution. +// +// TODO(b/152651962): Add independent unit tests for this class. +class ImmutableExecutorState { + public: + struct FrameInfo { + FrameInfo() + : input_count(0), + total_inputs(0), + pending_counts(nullptr), + nodes(nullptr) {} + + // The total number of inputs to a frame. + int input_count; + + // The total number of input tensors of a frame. + // == sum(nodes[*].num_inputs()) where nodes are the nodes in the frame. + int total_inputs; + + // Used to determine the next place to allocate space in the + // pending_counts data structure we'll eventually construct + PendingCounts::Layout pending_counts_layout; + + // Each frame has its own PendingCounts only for the nodes in the frame. + std::unique_ptr pending_counts; + + // The nodes in a frame. Used only for debugging. + std::unique_ptr> nodes; + }; + + explicit ImmutableExecutorState(const LocalExecutorParams& p) + : params_(p), gview_() {} + ~ImmutableExecutorState(); + + Status Initialize(const Graph& graph); + + // Process all Nodes in the current graph, attempting to infer the + // memory allocation attributes to be used wherever they may allocate + // a tensor buffer. + Status SetAllocAttrs(); + + const LocalExecutorParams& params() const { return params_; } + const GraphView& graph_view() const { return gview_; } + const std::vector& pending_ids() const { + return pending_ids_; + } + const std::vector& root_nodes() const { return root_nodes_; } + + const FrameInfo* get_frame_info(const string& frame_name) const { + auto it_frame_info = frame_info_.find(frame_name); + if (it_frame_info == frame_info_.end()) { + return nullptr; + } else { + return it_frame_info->second; + } + } + + private: + struct ControlFlowInfo { + gtl::FlatSet unique_frame_names; + std::vector frame_names; + }; + + static Status BuildControlFlowInfo(const Graph* graph, + ControlFlowInfo* cf_info); + void InitializePending(const Graph* graph, const ControlFlowInfo& cf_info); + + FrameInfo* EnsureFrameInfo(const string& fname); + + // Owned. + LocalExecutorParams params_; + GraphView gview_; + std::vector pending_ids_; + + // Root nodes (with no in edges) that should form the initial ready queue + std::vector root_nodes_; + + // Mapping from frame name to static information about the frame. + // TODO(yuanbyu): We could cache it along with the graph so to avoid + // the overhead of constructing it for each executor instance. + gtl::FlatMap frame_info_; + + // Shallow copies of the constant tensors used in the graph. + std::vector const_tensors_; + + TF_DISALLOW_COPY_AND_ASSIGN(ImmutableExecutorState); +}; + +} // namespace tensorflow +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_IMMUTABLE_EXECUTOR_STATE_H_ diff --git a/tensorflow/core/common_runtime/metrics.cc b/tensorflow/core/common_runtime/metrics.cc index a26a678af13..a2065ff2bf1 100644 --- a/tensorflow/core/common_runtime/metrics.cc +++ b/tensorflow/core/common_runtime/metrics.cc @@ -45,13 +45,13 @@ auto* graph_run_time_usecs_histogram = monitoring::Sampler<0>::New( auto* graph_run_input_tensor_bytes = monitoring::Sampler<0>::New( {"/tensorflow/core/graph_run_input_tensor_bytes", "The size of input tensors in bytes."}, - // Power of 2 with bucket count 14 (256G) - {monitoring::Buckets::Exponential(1, 4, 20)}); + // Power of 2 with bucket count 14 (256MB) + {monitoring::Buckets::Exponential(1, 4, 14)}); auto* graph_run_output_tensor_bytes = monitoring::Sampler<0>::New( {"/tensorflow/core/graph_run_output_tensor_bytes", "The size of output tensors in bytes."}, - // Power of 2 with bucket count 14 (256G) + // Power of 2 with bucket count 14 (256MB) {monitoring::Buckets::Exponential(1, 4, 14)}); auto* graph_unused_outputs = monitoring::Counter<1>::New( @@ -72,8 +72,8 @@ auto* tf_data_bytes_fetched_counter = monitoring::Counter<0>::New( auto* tf_data_getnext_duration_counter = monitoring::Sampler<0>::New( {"/tensorflow/data/getnext_duration", "Microseconds spent fetching an element from tf.data Dataset iterator."}, - // Power of 2 with bucket count 14 (256G) - {monitoring::Buckets::Exponential(1, 4, 20)}); + // Power of 2 with bucket count 10 (1024 ms) + {monitoring::Buckets::Exponential(1, 2, 10)}); auto* tf_data_elements_counter = monitoring::Counter<1>::New( "/tensorflow/data/elements", "tf.data elements", "name"); diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.cc b/tensorflow/core/common_runtime/process_function_library_runtime.cc index 9f9924b6ff2..0d5f042612e 100644 --- a/tensorflow/core/common_runtime/process_function_library_runtime.cc +++ b/tensorflow/core/common_runtime/process_function_library_runtime.cc @@ -110,14 +110,7 @@ ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime( session_metadata_, this); } - DeviceMgr const* all_devices = device_mgr_; - if (parent_ != nullptr && parent_->remote_device_mgr() != nullptr) { - all_devices = parent_->remote_device_mgr(); - } - - for (auto d : all_devices->ListDevices()) { - device_set_.AddDevice(d); - } + InitializeDeviceSet(); } /* static */ @@ -214,6 +207,18 @@ Status ProcessFunctionLibraryRuntime::GetDeviceContext( "function executions"); } +void ProcessFunctionLibraryRuntime::InitializeDeviceSet() { + DeviceMgr const* all_devices = device_mgr_; + if (parent_ != nullptr && parent_->remote_device_mgr() != nullptr) { + all_devices = parent_->remote_device_mgr(); + } + + device_set_.reset(new DeviceSet); + for (auto d : all_devices->ListDevices()) { + device_set_->AddDevice(d); + } +} + FunctionLibraryRuntime* ProcessFunctionLibraryRuntime::GetFLR( const string& device_name) const { Device* device = nullptr; @@ -225,7 +230,8 @@ FunctionLibraryRuntime* ProcessFunctionLibraryRuntime::GetFLR( } const auto& iter = flr_map_->find(device); if (iter == flr_map_->end()) { - LOG(ERROR) << "Could not find device: " << device_name; + VLOG(1) << "Could not find device: " << device_name + << "in the local process."; return nullptr; } return iter->second.get(); @@ -678,7 +684,7 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice( TF_RETURN_IF_ERROR( SetArgShape(options.input_resource_dtypes_and_shapes, arg_nodes)); TF_RETURN_IF_ERROR(PinArgsAndRets( - options.input_devices, options.output_devices, device_set_, arg_nodes, + options.input_devices, options.output_devices, *device_set_, arg_nodes, ret_nodes, options.config_proto.allow_soft_placement() ? default_device : nullptr)); @@ -691,7 +697,7 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice( bool control_rets_updated = false; TF_RETURN_IF_ERROR(FunctionOptimizationPassRegistry::Global().Run( - device_set_, options.config_proto, &graph, &data->lib_def_, + *device_set_, options.config_proto, &graph, &data->lib_def_, &control_ret_node_names, &control_rets_updated)); if (control_rets_updated) { @@ -714,7 +720,7 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice( optimization_options.session_options = &session_options; optimization_options.graph = &graph; optimization_options.flib_def = &data->lib_def_; - optimization_options.device_set = &device_set_; + optimization_options.device_set = device_set_.get(); optimization_options.is_function_graph = true; DumpGraph("Before running PRE_PLACEMENT passes", graph.get()); @@ -725,7 +731,7 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice( // exceptions/warnings in case where nested function call options are ignored. DumpGraph("Before calling Placer", graph.get()); Placer placer(graph.get(), function_name, optimization_options.flib_def, - &device_set_, default_device, + device_set_.get(), default_device, options.config_proto.allow_soft_placement(), options.config_proto.log_device_placement()); TF_RETURN_IF_ERROR(placer.Run()); @@ -741,7 +747,7 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice( DumpGraph("Before running graph optimization fn", graph.get()); Status status = options.optimize_graph_fn( std::move(ret_node_names), std::move(control_ret_node_names), - &data->lib_def_, device_set_, cpu_device, &graph); + &data->lib_def_, *device_set_, cpu_device, &graph); if (!status.ok()) { LOG(WARNING) << "Ignoring multi-device function optimization failure: " << status.ToString(); @@ -765,7 +771,7 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice( std::unordered_map> subgraphs; TF_RETURN_IF_ERROR( - PartitionFunctionGraph(device_set_, std::move(graph), &subgraphs)); + PartitionFunctionGraph(*device_set_, std::move(graph), &subgraphs)); for (const auto& pair : subgraphs) { DumpGraph(strings::StrCat("Before running POST_PARTITIONING passes (", @@ -841,7 +847,7 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice( const string& target = pair.first; const string& device_type = - device_set_.FindDeviceByName(target)->device_type(); + device_set_->FindDeviceByName(target)->device_type(); Graph* subgraph = pair.second.get(); status->Update(UpdateArgAndRetvalMetadata( @@ -1258,12 +1264,18 @@ Status ProcessFunctionLibraryRuntime::ReleaseHandle( FunctionLibraryRuntime::DoneCallback ProcessFunctionLibraryRuntime::ApplyCleanUpToDoneCallback( std::vector>* items, - FunctionLibraryRuntime::DoneCallback done, - const Rendezvous* rendezvous) const { + FunctionLibraryRuntime::DoneCallback done, const int64 step_id, + const Rendezvous* created_rendezvous) const { return - [this, items, done = std::move(done), rendezvous](const Status& status) { - if (rendezvous) { - rendezvous->Unref(); + [this, items, done = std::move(done), step_id, + created_rendezvous](const Status& status) { + if (created_rendezvous) { + DCHECK(rendezvous_factory_); + created_rendezvous->Unref(); + Status s = rendezvous_factory_.CleanUp(step_id); + if (!s.ok()) { + LOG(ERROR) << s; + } } auto* local_status = new Status(status); CleanUp(items, [local_status, done](const Status& cleanup_status) { @@ -1281,15 +1293,16 @@ void ProcessFunctionLibraryRuntime::Run( std::vector* rets, FunctionLibraryRuntime::DoneCallback done) const { FunctionLibraryRuntime::Options new_opts = opts; - Rendezvous* rendezvous = nullptr; + Rendezvous* created_rendezvous = nullptr; if (!opts.rendezvous) { if (rendezvous_factory_) { - Status s = rendezvous_factory_(opts.step_id, device_mgr_, &rendezvous); + Status s = + rendezvous_factory_(opts.step_id, device_mgr_, &created_rendezvous); if (!s.ok()) { done(s); return; } - new_opts.rendezvous = rendezvous; + new_opts.rendezvous = created_rendezvous; } else { done( errors::FailedPrecondition("The caller does not provide a rendezvous " @@ -1301,7 +1314,8 @@ void ProcessFunctionLibraryRuntime::Run( } auto* cleanup_items = new std::vector>; - done = ApplyCleanUpToDoneCallback(cleanup_items, std::move(done), rendezvous); + done = ApplyCleanUpToDoneCallback(cleanup_items, std::move(done), + new_opts.step_id, created_rendezvous); bool multi_device; { tf_shared_lock l(mu_); diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.h b/tensorflow/core/common_runtime/process_function_library_runtime.h index 545615a1bea..f8550fd8bea 100644 --- a/tensorflow/core/common_runtime/process_function_library_runtime.h +++ b/tensorflow/core/common_runtime/process_function_library_runtime.h @@ -71,7 +71,7 @@ class ProcessFunctionLibraryRuntime { DistributedFunctionLibraryRuntime* parent = nullptr, const CustomKernelCreator* custom_kernel_creator = nullptr, const SessionMetadata* session_metadata = nullptr, - Rendezvous::Factory rendezvous_factory = nullptr); + Rendezvous::Factory rendezvous_factory = Rendezvous::Factory()); virtual ~ProcessFunctionLibraryRuntime() { // Deleting the FunctionLibraryRuntime map will delete the function handles @@ -191,7 +191,10 @@ class ProcessFunctionLibraryRuntime { const DeviceMgr* device_mgr() { return device_mgr_; } - const DeviceSet* device_set() { return &device_set_; } + const DeviceSet* device_set() { return device_set_.get(); } + + // Initialize the set of local and remote devices for op device selection. + void InitializeDeviceSet(); const ConfigProto* config() const { return config_ ? &(*config_) : nullptr; } @@ -294,7 +297,7 @@ class ProcessFunctionLibraryRuntime { FunctionLibraryRuntime::DoneCallback ApplyCleanUpToDoneCallback( std::vector>* items, - FunctionLibraryRuntime::DoneCallback done, + FunctionLibraryRuntime::DoneCallback done, const int64 step_id, const Rendezvous* rendezvous) const; DistributedFunctionLibraryRuntime* const parent_; @@ -422,7 +425,7 @@ class ProcessFunctionLibraryRuntime { Env* const env_; const absl::optional config_; const DeviceMgr* const device_mgr_; - DeviceSet device_set_; + std::unique_ptr device_set_; const FunctionLibraryDefinition* lib_def_; thread::ThreadPool* default_thread_pool_; diff --git a/tensorflow/core/common_runtime/process_function_library_runtime_test.cc b/tensorflow/core/common_runtime/process_function_library_runtime_test.cc index 0b2d231e500..d53861a4d25 100644 --- a/tensorflow/core/common_runtime/process_function_library_runtime_test.cc +++ b/tensorflow/core/common_runtime/process_function_library_runtime_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/process_function_library_runtime.h" #include +#include #include #include "tensorflow/core/common_runtime/device_factory.h" @@ -122,10 +123,24 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test { TF_GRAPH_DEF_VERSION, lib_def_.get(), opts, /*thread_pool=*/nullptr, cluster_flr_.get(), /*custom_kernel_creator=*/nullptr, session_metadata, - [](const int64, const DeviceMgr* device_mgr, Rendezvous** r) { - *r = new IntraProcessRendezvous(device_mgr); - return Status::OK(); - })); + Rendezvous::Factory{ + [this](const int64 step_id, const DeviceMgr* device_mgr, + Rendezvous** r) { + *r = new IntraProcessRendezvous(device_mgr); + if (rendezvous_ref_counts_.find(step_id) != + rendezvous_ref_counts_.end()) { + rendezvous_ref_counts_[step_id]++; + } else { + rendezvous_ref_counts_[step_id] = 1; + } + return Status::OK(); + }, + [this](const int64 step_id) { + CHECK(rendezvous_ref_counts_.find(step_id) != + rendezvous_ref_counts_.end()); + rendezvous_ref_counts_[step_id]--; + return Status::OK(); + }})); } Status Instantiate( @@ -289,6 +304,9 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test { std::unique_ptr lib_def_; std::unique_ptr cluster_flr_; std::unique_ptr proc_flr_; + + // To ensure that we are cleaning up the rendezvous properly. + std::unordered_map rendezvous_ref_counts_; }; TEST_F(ProcessFunctionLibraryRuntimeTest, GetFLRNull) { @@ -362,6 +380,9 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, SingleCallFindDevice) { test::ExpectTensorEqual( y, test::AsTensor({"/job:a/replica:0/task:0/device:CPU:0"}, TensorShape({}))); + EXPECT_EQ(1, rendezvous_ref_counts_.size()); + EXPECT_EQ(opts.step_id, rendezvous_ref_counts_.begin()->first); + EXPECT_EQ(0, rendezvous_ref_counts_.begin()->second); } TEST_F(ProcessFunctionLibraryRuntimeTest, MultipleCallsSameDeviceXTimes) { diff --git a/tensorflow/core/common_runtime/rendezvous_mgr.cc b/tensorflow/core/common_runtime/rendezvous_mgr.cc index 6ed7df2cc1e..3df124e934b 100644 --- a/tensorflow/core/common_runtime/rendezvous_mgr.cc +++ b/tensorflow/core/common_runtime/rendezvous_mgr.cc @@ -74,7 +74,7 @@ void SameWorkerRecvDone(const DeviceMgr* device_mgr, return; } - MEMDEBUG_CACHE_OP("SameWorkerRecvDone"); + auto op_annotation = ScopedMemoryDebugAnnotation("SameWorkerRecvDone"); AllocatorAttributes attr = recv_args.alloc_attrs; attr.set_gpu_compatible(send_args.alloc_attrs.gpu_compatible() || recv_args.alloc_attrs.gpu_compatible()); @@ -112,7 +112,7 @@ void IntraProcessRecvAsyncImpl(const DeviceMgr* device_mgr, RendezvousInterface::DoneCallback done) { VLOG(1) << "IntraProcessRendezvous Recv " << local << " " << parsed.FullKey(); - MEMDEBUG_CACHE_OP("RecvAsync"); + auto op_annotation = ScopedMemoryDebugAnnotation("RecvAsync"); // Recv the tensor from local_. local->RecvAsync( parsed, recv_args, diff --git a/tensorflow/core/data/service/BUILD b/tensorflow/core/data/service/BUILD index 46665846acc..bbb5dfe3a6f 100644 --- a/tensorflow/core/data/service/BUILD +++ b/tensorflow/core/data/service/BUILD @@ -6,6 +6,7 @@ load( ) load( "//tensorflow:tensorflow.bzl", + "cc_header_only_library", "tf_cc_test", ) @@ -247,6 +248,16 @@ cc_library( ], ) +# This needs to be cc_header_only_library - tf_pybind_cc_library_wrapper +# does not pull in the server_lib.h header. +cc_header_only_library( + name = "server_lib_headers_lib", + features = ["-parse_headers"], + deps = [ + ":server_lib", + ], +) + cc_library( name = "server_lib", srcs = ["server_lib.cc"], diff --git a/tensorflow/core/data/service/python/BUILD b/tensorflow/core/data/service/python/BUILD new file mode 100644 index 00000000000..19bcaa3b952 --- /dev/null +++ b/tensorflow/core/data/service/python/BUILD @@ -0,0 +1,42 @@ +load("//tensorflow:tensorflow.bzl", "tf_python_pybind_extension") +load("//tensorflow:tensorflow.bzl", "tf_py_test") + +package( + default_visibility = ["//tensorflow:internal"], + licenses = ["notice"], # Apache 2.0 +) + +tf_python_pybind_extension( + name = "_pywrap_server_lib", + srcs = ["server_lib_wrapper.cc"], + module_name = "_pywrap_server_lib", + deps = [ + "//tensorflow/core:lib", + "//tensorflow/core/data/service:server_lib_headers_lib", + "//tensorflow/python:pybind11_lib", + "//tensorflow/python:pybind11_status", + "//third_party/python_runtime:headers", + "@com_github_grpc_grpc//:grpc++_public_hdrs", + "@pybind11", + ], +) + +py_library( + name = "server_lib", + srcs = ["server_lib.py"], + visibility = [ + "//visibility:public", + ], + deps = [ + ":_pywrap_server_lib", + ], +) + +tf_py_test( + name = "server_lib_test", + srcs = ["server_lib_test.py"], + deps = [ + ":server_lib", + "//tensorflow/python:platform_test", + ], +) diff --git a/tensorflow/core/data/service/python/server_lib.py b/tensorflow/core/data/service/python/server_lib.py new file mode 100644 index 00000000000..d3636123e0f --- /dev/null +++ b/tensorflow/core/data/service/python/server_lib.py @@ -0,0 +1,95 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""A Python interface for creating dataset servers.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=invalid-import-order,g-bad-import-order, unused-import +from tensorflow.python import pywrap_tensorflow +from tensorflow.core.data.service.python import _pywrap_server_lib + + +class MasterServer(object): + """An in-process tf.data service master, for use in testing.""" + + def __init__(self, protocol): + """Creates and starts a new tf.data master server. + + The server will choose an available port. Use `target()` to get the string + for connecting to the server. + + Args: + protocol: A string representing the type of protocol to use when creating + channels. For no security, use "grpc". For local credentials, use + "grpc+local", and make sure your binary links in + `data/service:local_credentials`. + """ + self._server = _pywrap_server_lib.TF_DATA_NewMasterServer(0, protocol) + + @property + def target(self): + """Returns the target for connecting to this server. + + The returned string will be in the form protocol://address:port, e.g. + "grpc://localhost:1000". + """ + return _pywrap_server_lib.TF_DATA_ServerTarget(self._server) + + def __del__(self): + """Shuts down and deletes the server. + + This method will block until all outstanding rpcs have completed and the + server has been shut down. + """ + _pywrap_server_lib.TF_DATA_DeleteServer(self._server) + + +class WorkerServer(object): + """An in-process tf.data service worker, for use in testing.""" + + def __init__(self, protocol, master_address): + """Creates and starts a new tf.data worker server. + + The server will choose an available port. Use `target()` to get the string + for connecting to the server. + + Args: + protocol: A string representing the type of protocol to use when creating + channels. For no security, use "grpc". For local credentials, use + "grpc+local", and make sure your binary links in + `data/service:local_credentials`. + master_address: The address of the tf.data master server to register with. + """ + self._server = _pywrap_server_lib.TF_DATA_NewWorkerServer( + 0, protocol, master_address) + + @property + def target(self): + """Returns the target for connecting to this server. + + The returned string will be in the form protocol://address:port, e.g. + "grpc://localhost:1000". + """ + return _pywrap_server_lib.TF_DATA_ServerTarget(self._server) + + def __del__(self): + """Shuts down and deletes the server. + + This method will block until all outstanding rpcs have completed and the + server has been shut down. + """ + _pywrap_server_lib.TF_DATA_DeleteServer(self._server) diff --git a/tensorflow/core/data/service/python/server_lib_test.py b/tensorflow/core/data/service/python/server_lib_test.py new file mode 100644 index 00000000000..6e9d6b9c043 --- /dev/null +++ b/tensorflow/core/data/service/python/server_lib_test.py @@ -0,0 +1,42 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tf.data service server lib.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.core.data.service.python import server_lib + +from tensorflow.python.platform import test + +PROTOCOL = "grpc" + + +class ServerLibTest(test.TestCase): + + def testStartMaster(self): + master = server_lib.MasterServer(PROTOCOL) + self.assertRegex(master.target, PROTOCOL + "://.*:.*") + + def testStartWorker(self): + master = server_lib.MasterServer(PROTOCOL) + worker = server_lib.WorkerServer(PROTOCOL, + master.target[len(PROTOCOL + "://"):]) + self.assertRegex(worker.target, PROTOCOL + "://.*:.*") + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/core/data/service/python/server_lib_wrapper.cc b/tensorflow/core/data/service/python/server_lib_wrapper.cc new file mode 100644 index 00000000000..e273eb5b6a9 --- /dev/null +++ b/tensorflow/core/data/service/python/server_lib_wrapper.cc @@ -0,0 +1,65 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "Python.h" +#include "include/pybind11/chrono.h" +#include "include/pybind11/complex.h" +#include "include/pybind11/functional.h" +#include "include/pybind11/pybind11.h" +#include "include/pybind11/pytypes.h" +#include "include/pybind11/stl.h" +#include "tensorflow/core/data/service/server_lib.h" +#include "tensorflow/python/lib/core/pybind11_lib.h" +#include "tensorflow/python/lib/core/pybind11_status.h" + +namespace py = pybind11; + +PYBIND11_MODULE(_pywrap_server_lib, m) { + py::class_(m, "GrpcDataServer"); + + m.def( + "TF_DATA_NewMasterServer", + [](int port, std::string protocol) + -> std::unique_ptr { + std::unique_ptr server; + tensorflow::Status status = + tensorflow::data::NewMasterServer(port, protocol, &server); + tensorflow::MaybeRaiseFromStatus(status); + server->Start(); + return server; + }, + py::return_value_policy::reference); + + m.def( + "TF_DATA_NewWorkerServer", + [](int port, std::string protocol, std::string master_address) + -> std::unique_ptr { + std::unique_ptr server; + tensorflow::Status status = tensorflow::data::NewWorkerServer( + port, protocol, master_address, &server); + tensorflow::MaybeRaiseFromStatus(status); + server->Start(); + return server; + }, + py::return_value_policy::reference); + m.def( + "TF_DATA_ServerTarget", + [](tensorflow::data::GrpcDataServer* server) -> std::string { + return server->Target(); + }, + py::return_value_policy::copy); + m.def("TF_DATA_DeleteServer", + [](tensorflow::data::GrpcDataServer* server) { server->Stop(); }); +}; diff --git a/tensorflow/core/data/standalone.cc b/tensorflow/core/data/standalone.cc index cf2d921e976..f1c48bd8fed 100644 --- a/tensorflow/core/data/standalone.cc +++ b/tensorflow/core/data/standalone.cc @@ -61,13 +61,14 @@ Status Dataset::FromGraph(Params params, const GraphDef& graph_def, /*thread_pool=*/nullptr, /*parent=*/nullptr, /*custom_kernel_creator=*/nullptr, /*session_metadata=*/nullptr, - [](const int64, const DeviceMgr* device_mgr, Rendezvous** r) { - *r = new IntraProcessRendezvous(device_mgr); - return Status::OK(); - }); + Rendezvous::Factory{ + [](const int64, const DeviceMgr* device_mgr, Rendezvous** r) { + *r = new IntraProcessRendezvous(device_mgr); + return Status::OK(); + }}); string fetch_node = ""; - for (auto node : graph_def.node()) { + for (const auto& node : graph_def.node()) { if (node.op() == "_Retval") { fetch_node = node.input(0); } diff --git a/tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc b/tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc index 4c811831d4f..d921e9c2cf1 100644 --- a/tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc +++ b/tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc @@ -272,10 +272,9 @@ void BaseRemoteRendezvous::SameWorkerRecvDone( return; } - MEMDEBUG_CACHE_STEPID(0); // Note that it would be nice to cache the step_id here, but it's not // available. - MEMDEBUG_CACHE_OP("SameWorkerRecvDone"); + auto op_annotation = ScopedMemoryDebugAnnotation("SameWorkerRecvDone", 0); AllocatorAttributes attr = recv_args.alloc_attrs; attr.set_gpu_compatible(send_args.alloc_attrs.gpu_compatible() || recv_args.alloc_attrs.gpu_compatible()); @@ -324,8 +323,7 @@ void BaseRemoteRendezvous::RecvAsync(const ParsedKey& parsed, DCHECK(is_initialized()) << "RecvAsync called when uninitialized (key: " << parsed.FullKey() << ")."; - MEMDEBUG_CACHE_OP("RecvAsync"); - MEMDEBUG_CACHE_STEPID(0); + auto op_annotation = ScopedMemoryDebugAnnotation("RecvAsync", 0); // Are src and dst in the same worker? if (IsSameWorker(parsed.src, parsed.dst)) { // Recv the tensor from local_. diff --git a/tensorflow/core/distributed_runtime/collective_param_resolver_distributed_test.cc b/tensorflow/core/distributed_runtime/collective_param_resolver_distributed_test.cc index cd6a3d53b7d..13e61e55ee0 100644 --- a/tensorflow/core/distributed_runtime/collective_param_resolver_distributed_test.cc +++ b/tensorflow/core/distributed_runtime/collective_param_resolver_distributed_test.cc @@ -315,7 +315,7 @@ TEST_F(DeviceResDistTest, Workers2Devices2) { ValidateCollectiveParams(num_workers, num_devices); } -#ifndef GOOGLE_CUDA +#if !GOOGLE_CUDA && !TENSORFLOW_USE_ROCM namespace { // A mock NcclReducer for testing group runtime details initialization with CPU // builds. The only meaningful function in this class is diff --git a/tensorflow/core/distributed_runtime/collective_rma_distributed.cc b/tensorflow/core/distributed_runtime/collective_rma_distributed.cc index c7d218258e8..5afd679dc9f 100644 --- a/tensorflow/core/distributed_runtime/collective_rma_distributed.cc +++ b/tensorflow/core/distributed_runtime/collective_rma_distributed.cc @@ -129,7 +129,7 @@ void CollectiveRemoteAccessDistributed::RecvFromPeer( } AllocatorAttributes cpu_attr; cpu_attr.set_gpu_compatible(true); - MEMDEBUG_CACHE_OP( + auto op_annotation = ScopedMemoryDebugAnnotation( "CollectiveRemoteAccessDistributed::RecvFromPeer" "::recv_buf_callback"); Tensor* cpu_tensor = new Tensor(cpu_dev->GetAllocator(cpu_attr), diff --git a/tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.cc b/tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.cc index dfa35086659..f2f63a8fab5 100644 --- a/tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.cc +++ b/tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.cc @@ -120,15 +120,11 @@ void EagerClusterFunctionLibraryRuntime::Run( const FunctionLibraryRuntime::Options& opts, FunctionLibraryRuntime::LocalHandle handle, gtl::ArraySlice args, std::vector* rets, FunctionLibraryRuntime::DoneCallback done) { - FunctionLibraryRuntime::Options opts_copy = opts; - if (!opts_copy.op_id.has_value()) { - opts_copy.op_id = ctx_->RemoteMgr()->NextOpId(); - } - std::vector function_args; - for (const auto& tensor : args) { - function_args.push_back(tensor); - } - Run(opts_copy, handle, function_args, rets, std::move(done)); + std::vector function_args; + for (const auto& tensor : args) { + function_args.push_back(tensor); + } + Run(opts, handle, function_args, rets, std::move(done)); } void EagerClusterFunctionLibraryRuntime::Run( @@ -165,11 +161,6 @@ void EagerClusterFunctionLibraryRuntime::Run( EagerOperation* op = function_data->op.get(); - if (!opts.op_id.has_value()) { - done( - errors::Internal("op_id is not set for remote function: ", op->Name())); - } - eager::EnqueueRequest* request = new eager::EnqueueRequest; request->set_context_id(context_id_); eager::Operation* remote_op = request->add_queue()->mutable_operation(); @@ -187,7 +178,11 @@ void EagerClusterFunctionLibraryRuntime::Run( // The remote component function should use the same op_id as its parent // multi-device function's in order to get the global unique op_id generated // by the master context. - remote_op->set_id(opts.op_id.value()); + if (opts.op_id.has_value()) { + remote_op->set_id(opts.op_id.value()); + } else { + remote_op->set_id(kInvalidRemoteOpId); + } remote_op->set_is_function(true); remote_op->set_is_component_function(true); remote_op->set_func_step_id(opts.step_id); @@ -203,15 +198,39 @@ void EagerClusterFunctionLibraryRuntime::Run( // disabled, Run() returns when the remote function execution completes, which // might be blocked by a non-enqueued function execution. EnqueueResponse* response = new EnqueueResponse; - eager_client->EnqueueAsync(request, response, - [op, request, response, done](const Status& s) { - for (auto handle : op->Inputs()) { - handle->Unref(); - } - done(s); - delete request; - delete response; - }); + eager_client->EnqueueAsync( + request, response, + [op, request, response, rets, done = std::move(done)](const Status& s) { + Status status = s; + auto cleanup = gtl::MakeCleanup([request, response, &status, &done] { + done(status); + delete request; + delete response; + }); + + for (auto handle : op->Inputs()) { + handle->Unref(); + } + if (!status.ok()) { + return; + } + if (response->queue_response_size() != 1) { + status.Update(errors::Internal( + "Expect that the size of response queue equals 1, but got: ", + response->queue_response_size())); + return; + } + for (const auto& tensor_proto : response->queue_response(0).tensor()) { + Tensor t; + if (t.FromProto(tensor_proto)) { + rets->push_back(std::move(t)); + } else { + status.Update(errors::Internal("Could not convert tensor proto: ", + tensor_proto.DebugString())); + return; + } + } + }); } void EagerClusterFunctionLibraryRuntime::CleanUp( diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc index ee2ea755bfa..cf28e2680d8 100644 --- a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc +++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc @@ -327,6 +327,13 @@ Status EagerServiceImpl::CreateMasterContext( return Status::OK(); } +Status TensorHandleProto(TensorHandle* handle, TensorProto* proto) { + const tensorflow::Tensor* t = nullptr; + TF_RETURN_IF_ERROR(handle->Tensor(&t)); + t->AsProtoTensorContent(proto); + return Status::OK(); +} + Status TensorHandleShape(TensorHandle* handle, TensorShapeProto* proto) { const tensorflow::Tensor* t = nullptr; @@ -378,8 +385,8 @@ Status EagerServiceImpl::ExecuteOp(const Operation& operation, return errors::InvalidArgument("Invalid TensorProto: ", input.tensor().DebugString()); } else { - TF_RETURN_IF_ERROR(TensorHandle::CreateLocalHandle( - std::move(tensor), nullptr, nullptr, eager_context, &handle)); + handle = TensorHandle::CreateLocalHandle(std::move(tensor), nullptr, + nullptr, eager_context); op->AddInput(handle); } } @@ -412,12 +419,21 @@ Status EagerServiceImpl::ExecuteOp(const Operation& operation, VLOG(3) << "ServerContext: Calling EagerExecute for op " << operation.id(); TF_RETURN_IF_ERROR(EagerExecute(op.get(), retvals.data(), &num_retvals)); - eager_context->RemoteMgr()->AddOperationOutputs( - absl::MakeSpan(retvals.data(), num_retvals), operation.id()); - - for (int i = 0; i < num_retvals; i++) { - TF_RETURN_IF_ERROR( - TensorHandleShape(retvals[i], queue_response->add_shape())); + if (operation.id() == kInvalidRemoteOpId) { + // Copy the output tensors back along with the response, since the op id + // is invalid which cannot be added to RemoteMgr. + for (int i = 0; i < num_retvals; i++) { + TF_RETURN_IF_ERROR( + TensorHandleProto(retvals[i], queue_response->add_tensor())); + retvals[i]->Unref(); + } + } else { + eager_context->RemoteMgr()->AddOperationOutputs( + absl::MakeSpan(retvals.data(), num_retvals), operation.id()); + for (int i = 0; i < num_retvals; i++) { + TF_RETURN_IF_ERROR( + TensorHandleShape(retvals[i], queue_response->add_shape())); + } } return Status::OK(); @@ -558,9 +574,8 @@ Status EagerServiceImpl::SendTensor(const SendTensorOp& send_tensor, return errors::InvalidArgument("Unable to parse tensor proto"); } - TensorHandle* tensor_handle = nullptr; - TF_RETURN_IF_ERROR(TensorHandle::CreateLocalHandle( - std::move(tensor), nullptr, nullptr, eager_context, &tensor_handle)); + TensorHandle* tensor_handle = TensorHandle::CreateLocalHandle( + std::move(tensor), nullptr, nullptr, eager_context); TensorHandle* copied_handle = nullptr; Device* device; TF_RETURN_IF_ERROR(eager_context->FindDeviceFromName( diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc b/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc index 73bc42be0c5..2006d0a4d5c 100644 --- a/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc +++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc @@ -560,13 +560,8 @@ class FunctionWithRemoteInputsTest : public EagerServiceImplTest { /*thread_pool=*/nullptr, eager_cluster_flr_.get()); } - void CheckOutputsAndClose(const int64 op_id) { - const tensorflow::Tensor* t = nullptr; - tensorflow::TensorHandle* tensor_handle; - TF_ASSERT_OK(eager_service_impl_.GetTensorHandle( - context_id_, RemoteTensorHandleInternal(2, 0), &tensor_handle)); - TF_ASSERT_OK(tensor_handle->Tensor(&t)); - auto actual = t->flat(); + void CheckOutputTensorAndClose(const Tensor& tensor) { + auto actual = tensor.flat(); EXPECT_EQ(4, actual.size()); EXPECT_EQ(7, actual(0)); EXPECT_EQ(10, actual(1)); @@ -581,6 +576,15 @@ class FunctionWithRemoteInputsTest : public EagerServiceImplTest { &close_context_response)); } + void CheckOutputsAndClose(const int64 op_id) { + const tensorflow::Tensor* t = nullptr; + tensorflow::TensorHandle* tensor_handle; + TF_ASSERT_OK(eager_service_impl_.GetTensorHandle( + context_id_, RemoteTensorHandleInternal(2, 0), &tensor_handle)); + TF_ASSERT_OK(tensor_handle->Tensor(&t)); + CheckOutputTensorAndClose(*t); + } + protected: const string local_device_ = "/job:localhost/replica:0/task:0/device:CPU:0"; const string remote_device_ = "/job:localhost/replica:0/task:1/device:CPU:0"; @@ -649,8 +653,9 @@ TEST_F(FunctionWithRemoteInputsTest, EagerPFLRTest) { CheckOutputsAndClose(op_id); } -// Test executes a remote function with a local tensor input. -TEST_F(FunctionWithRemoteInputsTest, EagerClusterFLRTestWithLocalTensorInput) { +// Test executes a remote function with local input and output tensors. +TEST_F(FunctionWithRemoteInputsTest, + EagerClusterFLRTestWithLocalInputAndOutput) { Init(); // Instantiate MatMulFunction on remote_device. FunctionLibraryRuntime::Handle handle; @@ -681,11 +686,9 @@ TEST_F(FunctionWithRemoteInputsTest, EagerClusterFLRTestWithLocalTensorInput) { context_id_, RemoteTensorHandleInternal(1, 0), &tensor_handle)); TF_ASSERT_OK(tensor_handle->Tensor(&input_tensor)); - // Send input_tensor to the remote device and execute MatMulFunction on the - // remote device. + // Send input_tensor to the remote device, execute MatMulFunction on the + // remote device, and send the output back. FunctionLibraryRuntime::Options opts; - const uint64 op_id = 2; - opts.op_id = op_id; Notification execute_done; std::vector inputs = {*input_tensor}; std::vector outputs; @@ -696,7 +699,8 @@ TEST_F(FunctionWithRemoteInputsTest, EagerClusterFLRTestWithLocalTensorInput) { }); execute_done.WaitForNotification(); TF_ASSERT_OK(status); - CheckOutputsAndClose(op_id); + EXPECT_EQ(outputs.size(), 1); + CheckOutputTensorAndClose(outputs.at(0)); } // Test executes a remote function through KernelAndDeviceFunc. diff --git a/tensorflow/core/distributed_runtime/eager/remote_mgr.cc b/tensorflow/core/distributed_runtime/eager/remote_mgr.cc index 54fb10e721d..ef3d42de037 100644 --- a/tensorflow/core/distributed_runtime/eager/remote_mgr.cc +++ b/tensorflow/core/distributed_runtime/eager/remote_mgr.cc @@ -162,8 +162,8 @@ Status RemoteMgr::DeserializeRemoteTensorHandle(const RemoteTensorHandle& in, in.op_device().empty() ? in.device() : in.op_device(); TF_RETURN_IF_ERROR( parent_->FindDeviceFromName(device_name.c_str(), &device)); - TF_RETURN_IF_ERROR(TensorHandle::CreateLazyRemoteHandle( - in.op_id(), in.output_num(), in.dtype(), device, parent_, out)); + *out = TensorHandle::CreateLazyRemoteHandle(in.op_id(), in.output_num(), + in.dtype(), device, parent_); TensorHandle::ResourceHandleInfo resource_handle_info; std::vector* dtypes_and_shapes = &resource_handle_info.dtypes_and_shapes; diff --git a/tensorflow/core/distributed_runtime/eager/remote_mgr.h b/tensorflow/core/distributed_runtime/eager/remote_mgr.h index d075345a027..54c987d4daa 100644 --- a/tensorflow/core/distributed_runtime/eager/remote_mgr.h +++ b/tensorflow/core/distributed_runtime/eager/remote_mgr.h @@ -26,6 +26,8 @@ limitations under the License. namespace tensorflow { namespace eager { +const int64 kInvalidRemoteOpId = -1; + // This class manages the states required to setup an eager cluster. // TODO(fishx): Move remote state from context to this class. class RemoteMgr { diff --git a/tensorflow/core/distributed_runtime/eager/remote_mgr_test.cc b/tensorflow/core/distributed_runtime/eager/remote_mgr_test.cc index 90213c978ed..eb2f2aea632 100644 --- a/tensorflow/core/distributed_runtime/eager/remote_mgr_test.cc +++ b/tensorflow/core/distributed_runtime/eager/remote_mgr_test.cc @@ -70,9 +70,8 @@ TEST_F(RemoteMgrTest, SerializeLocalTensorHandleWithRemoteMirror) { RemoteMgr remote_mgr(false, ctx_); Tensor t(DT_FLOAT, TensorShape({0})); - TensorHandle* handle; - TF_ASSERT_OK(TensorHandle::CreateLocalHandle(std::move(t), local_device_, - local_device_, ctx_, &handle)); + TensorHandle* handle = TensorHandle::CreateLocalHandle( + std::move(t), local_device_, local_device_, ctx_); const uint64 op_id = 2; const int output_num = 3; TF_ASSERT_OK(handle->AddUnshapedRemoteMirror(remote_device_, op_id, @@ -91,10 +90,9 @@ TEST_F(RemoteMgrTest, SerializeRemoteTensorHandle) { const uint64 op_id = 3; const int output_num = 1; - TensorHandle* handle; - TF_ASSERT_OK(TensorHandle::CreateUnshapedRemoteHandle( + TensorHandle* handle = TensorHandle::CreateUnshapedRemoteHandle( op_id, output_num, - /*remote_task=*/"", DT_FLOAT, remote_device_, ctx_, &handle)); + /*remote_task=*/"", DT_FLOAT, remote_device_, ctx_); RemoteTensorHandle remote_handle; TF_ASSERT_OK(remote_mgr.SerializeRemoteTensorHandle( handle, &remote_handle, remote_device_, remote_device_->name())); diff --git a/tensorflow/core/distributed_runtime/graph_mgr.cc b/tensorflow/core/distributed_runtime/graph_mgr.cc index 8239dbcc72d..70704a27736 100644 --- a/tensorflow/core/distributed_runtime/graph_mgr.cc +++ b/tensorflow/core/distributed_runtime/graph_mgr.cc @@ -67,7 +67,7 @@ GraphMgr::GraphMgr(const WorkerEnv* worker_env, DeviceMgr* device_mgr) } GraphMgr::~GraphMgr() { - for (auto p : table_) p.second->Unref(); + for (const auto& p : table_) p.second->Unref(); } GraphMgr::Item::~Item() { @@ -141,13 +141,18 @@ Status GraphMgr::InitItem( gdef.versions().producer(), item->lib_def.get(), graph_options.optimizer_options(), worker_env_->compute_pool, cluster_flr, /*custom_kernel_creator=*/nullptr, /*session_metadata=*/nullptr, - [this, session](const int64 step_id, const DeviceMgr*, - Rendezvous** r) -> Status { - auto* remote_r = this->worker_env_->rendezvous_mgr->Find(step_id); - TF_RETURN_IF_ERROR(remote_r->Initialize(session)); - *r = remote_r; - return Status::OK(); - })); + Rendezvous::Factory{ + [this, session](const int64 step_id, const DeviceMgr*, + Rendezvous** r) -> Status { + auto* remote_r = this->worker_env_->rendezvous_mgr->Find(step_id); + TF_RETURN_IF_ERROR(remote_r->Initialize(session)); + *r = remote_r; + return Status::OK(); + }, + [this](const int64 step_id) { + this->worker_env_->rendezvous_mgr->Cleanup(step_id); + return Status::OK(); + }})); // Constructs the graph out of "gdef". Graph graph(OpRegistry::Global()); diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc index a12b392f83a..7ad05008a3b 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc @@ -669,7 +669,8 @@ void GrpcWorker::RecvBufAsync(CallOptions* opts, const RecvBufRequest* request, AllocatorAttributes cpu_attr; cpu_attr.set_gpu_compatible(true); cpu_attr.set_nic_compatible(true); - MEMDEBUG_CACHE_OP("GrpcWorker::RecvBufAsync::consumer_callback"); + auto op_annotation = ScopedMemoryDebugAnnotation( + "GrpcWorker::RecvBufAsync::consumer_callback"); Tensor* cpu_tensor = new Tensor(cpu_dev->GetAllocator(cpu_attr), hook->prod_value->dtype(), hook->prod_value->shape()); diff --git a/tensorflow/core/framework/allocator.cc b/tensorflow/core/framework/allocator.cc index 6757a9b593e..51cc27426b1 100644 --- a/tensorflow/core/framework/allocator.cc +++ b/tensorflow/core/framework/allocator.cc @@ -27,10 +27,8 @@ limitations under the License. namespace tensorflow { -#ifdef TENSORFLOW_MEM_DEBUG thread_local const char* pending_op_name = nullptr; -thread_local uint64 pending_step_id = 0; -#endif +thread_local int64 pending_step_id = 0; string AllocatorStats::DebugString() const { return strings::Printf( diff --git a/tensorflow/core/framework/allocator.h b/tensorflow/core/framework/allocator.h index 2e239a4d6de..46cb8a6cae1 100644 --- a/tensorflow/core/framework/allocator.h +++ b/tensorflow/core/framework/allocator.h @@ -62,29 +62,32 @@ struct AllocationAttributes { TF_DISALLOW_COPY_AND_ASSIGN(AllocationAttributes); }; -// If defined, the runtime will cache Op names in thread-local memory -// and some allocators will try to tag allocations with the requesting Op. -#ifdef TENSORFLOW_MEM_DEBUG +// The runtime will cache Op names in thread-local memory and some allocators +// will try to tag allocations with the requesting Op. extern thread_local const char* pending_op_name; -extern thread_local uint64 pending_step_id; -#define MEMDEBUG_CACHE_OP(N) \ - do { \ - pending_op_name = (N); \ - } while (0) -#define MEMDEBUG_CACHE_STEPID(N) \ - do { \ - pending_step_id = (N); \ - } while (0) -#define MEMDEBUG_CACHE_VAL pending_op_name -#else -#define MEMDEBUG_CACHE_OP(N) \ - do { \ - } while (0) -#define MEMDEBUG_CACHE_STEPID(N) \ - do { \ - } while (0) -#define MEMDEBUG_CACHE_VAL nullptr -#endif +extern thread_local int64 pending_step_id; + +// Wrapper class of pending_op_name and pending_step_id for RAII. +class ScopedMemoryDebugAnnotation { + public: + explicit ScopedMemoryDebugAnnotation(const char* op_name) { + last_op_name_ = pending_op_name; + pending_op_name = op_name; + } + + explicit ScopedMemoryDebugAnnotation(const char* op_name, int64 step_id) { + last_op_name_ = pending_op_name; + pending_op_name = op_name; + pending_step_id = step_id; + } + + ~ScopedMemoryDebugAnnotation() { pending_op_name = last_op_name_; } + + private: + // Stores the previous value of pending_op_name in case the annotations are + // nested. + const char* last_op_name_ = nullptr; +}; // Runtime statistics collected by an allocator. Exactly the same as // stream_executor::AllocatorStats, but independently defined to preserve the @@ -114,7 +117,7 @@ struct AllocatorStats { bytes_reserved(0), peak_bytes_reserved(0) {} - string DebugString() const; + std::string DebugString() const; }; // Allocator is an abstract interface for allocating and deallocating @@ -127,7 +130,7 @@ class Allocator { virtual ~Allocator(); // Return a string identifying this allocator - virtual string Name() = 0; + virtual std::string Name() = 0; // Return an uninitialized block of memory that is "num_bytes" bytes // in size. The returned pointer is guaranteed to be aligned to a @@ -242,7 +245,7 @@ class AllocatorWrapper : public Allocator { // Returns the wrapped allocator to which all calls are delegated. Allocator* wrapped() const { return wrapped_; } - string Name() override { return wrapped_->Name(); } + std::string Name() override { return wrapped_->Name(); } void* AllocateRaw(size_t alignment, size_t num_bytes) override { return wrapped_->AllocateRaw(alignment, num_bytes); @@ -336,7 +339,7 @@ struct AllocatorAttributes { int32 scope_id = 0; // Returns a human readable representation of this. - string DebugString() const; + std::string DebugString() const; }; // Returns a trivial implementation of Allocator, which is a process singleton. diff --git a/tensorflow/core/framework/common_shape_fns.cc b/tensorflow/core/framework/common_shape_fns.cc index 7903c5795e7..113adbdd432 100644 --- a/tensorflow/core/framework/common_shape_fns.cc +++ b/tensorflow/core/framework/common_shape_fns.cc @@ -822,6 +822,78 @@ Status Conv3DShape(shape_inference::InferenceContext* c) { return Status::OK(); } +Status Conv2DBackpropInputShape(shape_inference::InferenceContext* c) { + string data_format_str; + if (!c->GetAttr("data_format", &data_format_str).ok()) { + data_format_str = "NHWC"; + } + TensorFormat data_format; + if (!FormatFromString(data_format_str, &data_format)) { + return errors::InvalidArgument("Invalid data format string: ", + data_format_str); + } + + // For the rest of this function, output_grad_* describes out_backprop and + // input_grad_* describes in_backprop. + ShapeHandle output_grad_shape = c->input(2); + TF_RETURN_IF_ERROR(c->WithRank(output_grad_shape, 4, &output_grad_shape)); + ShapeHandle filter_shape = c->input(1); + TF_RETURN_IF_ERROR(c->WithRank(filter_shape, 4, &filter_shape)); + + DimensionHandle batch_size_dim; + DimensionHandle output_grad_depth_dim; + gtl::InlinedVector output_grad_spatial_dims(2); + TF_RETURN_IF_ERROR(DimensionsFromShape( + output_grad_shape, data_format, &batch_size_dim, + absl::MakeSpan(output_grad_spatial_dims), &output_grad_depth_dim, c)); + DimensionHandle unused; + TF_RETURN_IF_ERROR( + c->Merge(output_grad_depth_dim, c->Dim(filter_shape, 3), &unused)); + + ShapeHandle specified_input_grad_shape; + TF_RETURN_IF_ERROR( + c->MakeShapeFromShapeTensor(0, &specified_input_grad_shape)); + if (c->Rank(specified_input_grad_shape) == InferenceContext::kUnknownRank) { + TF_RETURN_IF_ERROR(c->WithRank(specified_input_grad_shape, 4, + &specified_input_grad_shape)); + } + + // input_grad_depth_dim doesn't equal c->Dim(filter_shape,2) when the number + // of groups is larger than 1. If input_sizes is a 4D shape, we collect + // input_grad_depth_dim from input_sizes; otherwise we compute it as + // c->Dim(filter_shape,2). + DimensionHandle input_grad_depth_dim; + gtl::InlinedVector specified_input_grad_spatial_dims(2); + int specified_input_grad_rank = c->Rank(specified_input_grad_shape); + if (specified_input_grad_rank == 4) { + DimensionHandle specified_batch_size_dim; + TF_RETURN_IF_ERROR(DimensionsFromShape( + specified_input_grad_shape, data_format, &specified_batch_size_dim, + absl::MakeSpan(specified_input_grad_spatial_dims), + &input_grad_depth_dim, c)); + TF_RETURN_IF_ERROR( + c->Merge(specified_batch_size_dim, batch_size_dim, &unused)); + } else if (specified_input_grad_rank == 2) { + specified_input_grad_spatial_dims[0] = + c->Dim(specified_input_grad_shape, 0); + specified_input_grad_spatial_dims[1] = + c->Dim(specified_input_grad_shape, 1); + input_grad_depth_dim = c->Dim(filter_shape, 2); + } else { + return errors::InvalidArgument( + "Conv2DBackpropInput requires input_sizes to contain 4 values or 2 " + "values, but got: ", + specified_input_grad_rank); + } + + ShapeHandle input_grad_shape; + TF_RETURN_IF_ERROR(ShapeFromDimensions( + batch_size_dim, specified_input_grad_spatial_dims, input_grad_depth_dim, + data_format, c, &input_grad_shape)); + c->set_output(0, input_grad_shape); + return Status::OK(); +} + namespace { Status DepthwiseConv2DNativeShapeImpl(shape_inference::InferenceContext* c, diff --git a/tensorflow/core/framework/common_shape_fns.h b/tensorflow/core/framework/common_shape_fns.h index 252e56309ca..e1984abab7e 100644 --- a/tensorflow/core/framework/common_shape_fns.h +++ b/tensorflow/core/framework/common_shape_fns.h @@ -138,6 +138,9 @@ Status DepthwiseConv2DNativeShapeWithExplicitPadding( // explicit padding. Status DepthwiseConv2DNativeShape(shape_inference::InferenceContext* c); +// Shape function for Conv2DBackpropInput. +Status Conv2DBackpropInputShape(shape_inference::InferenceContext* c); + // Shape function for AvgPool-like operations. Status AvgPoolShape(shape_inference::InferenceContext* c); diff --git a/tensorflow/core/framework/memory_types.cc b/tensorflow/core/framework/memory_types.cc index 5393b162e80..d27ef1da61d 100644 --- a/tensorflow/core/framework/memory_types.cc +++ b/tensorflow/core/framework/memory_types.cc @@ -104,60 +104,51 @@ Status MemoryTypesForNode(const OpRegistryInterface* op_registry, return it != ndef.attr().end() && it->second.b(); }(); - // For functions (which have no KernelDef) and their gradients, we can only - // best-effort derive the memory type from the data type. For now, we assume - // int32 is always on host memory and other types are always on device memory. - // TODO(zhifengc,phawkins): We should do type inference over function bodies - // to derive the correct input/output memory types. We should also split - // host-memory and non host-memory arguments into separate type lists. - if (!status.ok() || IsFunctionCallOp(ndef.op())) { - if (device_type.type_string() == "TPU" || has_xla_compile) { - // Here we assume that if tf.function() is called within - // "with tf.device('/device:TPU:0')", the whole function will be compiled - // and executed on TPU. This is true today, but when we implement auto - // clustering on function body, this will no longer be true. For example, - // we might want to place string arguments on host. - for (const auto& t : inp_dtypes) - inp_mtypes->push_back(MTypeFromDTypeIntsOnDevice(t)); - for (const auto& t : out_dtypes) - out_mtypes->push_back(MTypeFromDTypeIntsOnDevice(t)); - } else { - for (const auto& t : inp_dtypes) inp_mtypes->push_back(MTypeFromDType(t)); - for (const auto& t : out_dtypes) out_mtypes->push_back(MTypeFromDType(t)); + bool has_kernel_def = status.ok() && !IsFunctionCallOp(ndef.op()); + auto host_memory_required = [&](const DataType& dt) { + bool int32_on_device = + has_kernel_def || device_type.type_string() == "TPU" || has_xla_compile; + return DataTypeAlwaysOnHost(dt) || (dt == DT_INT32 && !int32_on_device); + }; + + if (has_kernel_def) { + // Gets the input/output names and their corresponding endpoint ranges. + NameRangeMap inp_names; + NameRangeMap out_names; + TF_RETURN_IF_ERROR( + NameRangesForNode(ndef, *op_def, &inp_names, &out_names)); + + // Now that we know the size, fill with the default 'DEVICE_MEMORY'. + inp_mtypes->resize(GetTotal(inp_names), DEVICE_MEMORY); + out_mtypes->resize(GetTotal(out_names), DEVICE_MEMORY); + + // Fills in host memory types based on the kernel def. + const auto& from_proto = kdef->host_memory_arg(); + std::vector host_memory_args(from_proto.begin(), from_proto.end()); + MemoryTypesHelper(inp_names, &host_memory_args, inp_mtypes); + MemoryTypesHelper(out_names, &host_memory_args, out_mtypes); + if (!host_memory_args.empty()) { + return errors::InvalidArgument( + "HostMemory args '", absl::StrJoin(host_memory_args, "', '"), + "' not found in OpDef: ", SummarizeOpDef(*op_def)); } - return Status::OK(); - } - - // Gets the input/output names and their corresponding endpoint ranges. - NameRangeMap inp_names; - NameRangeMap out_names; - TF_RETURN_IF_ERROR(NameRangesForNode(ndef, *op_def, &inp_names, &out_names)); - - // Now that we know the size, fill with the default 'DEVICE_MEMORY'. - inp_mtypes->resize(GetTotal(inp_names), DEVICE_MEMORY); - out_mtypes->resize(GetTotal(out_names), DEVICE_MEMORY); - - // Fills in host memory types based on the kernel def. - const auto& from_proto = kdef->host_memory_arg(); - std::vector host_memory_args(from_proto.begin(), from_proto.end()); - MemoryTypesHelper(inp_names, &host_memory_args, inp_mtypes); - MemoryTypesHelper(out_names, &host_memory_args, out_mtypes); - if (!host_memory_args.empty()) { - return errors::InvalidArgument( - "HostMemory args '", absl::StrJoin(host_memory_args, "', '"), - "' not found in OpDef: ", SummarizeOpDef(*op_def)); + } else { + // Set all the datatype to DEVICE_MEMORY by default, later on change it to + // HOST_MEMORY where it is required by the datatype. + inp_mtypes->resize(inp_dtypes.size(), DEVICE_MEMORY); + out_mtypes->resize(out_dtypes.size(), DEVICE_MEMORY); } CHECK_LE(inp_mtypes->size(), inp_dtypes.size()); CHECK_LE(out_mtypes->size(), out_dtypes.size()); // Mark e.g. all resource and string types as host memory. for (int i = 0; i < inp_mtypes->size(); ++i) { - if (DataTypeAlwaysOnHost(inp_dtypes[i])) { + if (host_memory_required(inp_dtypes[i])) { (*inp_mtypes)[i] = HOST_MEMORY; } } for (int i = 0; i < out_mtypes->size(); ++i) { - if (DataTypeAlwaysOnHost(out_dtypes[i])) { + if (host_memory_required(out_dtypes[i])) { (*out_mtypes)[i] = HOST_MEMORY; } } diff --git a/tensorflow/core/framework/op_kernel.cc b/tensorflow/core/framework/op_kernel.cc index 40e075ba737..c365716f209 100644 --- a/tensorflow/core/framework/op_kernel.cc +++ b/tensorflow/core/framework/op_kernel.cc @@ -387,45 +387,28 @@ void OpKernelContext::SetStatus(const Status& status) { } Status OpKernelContext::input(StringPiece name, const Tensor** tensor) { - int start, stop; - TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop)); - if (stop != start + 1) { - return errors::InvalidArgument("OpKernel used list-valued input name '", - name, - "' when single-valued input was " - "expected"); - } - if (input_is_ref(start)) { + int index; + TF_RETURN_IF_ERROR(get_input_index(name, &index)); + if (input_is_ref(index)) { return errors::InvalidArgument("OpKernel used ref input name '", name, "' when non-ref input was expected"); } - *tensor = (*params_->inputs)[start].tensor; + *tensor = (*params_->inputs)[index].tensor; return Status::OK(); } Status OpKernelContext::input_dtype(StringPiece name, DataType* dtype) const { - int start, stop; - TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop)); - if (stop != start + 1) { - return errors::InvalidArgument("OpKernel used list-valued input name '", - name, - "' when single-valued input was " - "expected"); - } - const TensorValue& value((*params_->inputs)[start]); + int index; + TF_RETURN_IF_ERROR(get_input_index(name, &index)); + const TensorValue& value((*params_->inputs)[index]); *dtype = value.dtype(); return Status::OK(); } Status OpKernelContext::input_ref_mutex(StringPiece name, mutex** out_mutex) { - int start, stop; - TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop)); - if (stop != start + 1) { - return errors::InvalidArgument("OpKernel used list-valued input name '", - name, - "' when single-valued input was expected"); - } - *out_mutex = input_ref_mutex(start); + int index; + TF_RETURN_IF_ERROR(get_input_index(name, &index)); + *out_mutex = input_ref_mutex(index); return Status::OK(); } @@ -497,23 +480,9 @@ bool OpKernelContext::forward_input_to_output_with_shape( Status OpKernelContext::forward_input_to_output_with_shape( StringPiece input_name, StringPiece output_name, const TensorShape& output_shape, Tensor** output) { - int input_index, output_index, stop; - TF_RETURN_IF_ERROR( - params_->op_kernel->InputRange(input_name, &input_index, &stop)); - if (stop != input_index + 1) { - return errors::InvalidArgument("OpKernel used list-valued input name '", - input_name, - "' when single-valued input was " - "expected"); - } - TF_RETURN_IF_ERROR( - params_->op_kernel->OutputRange(output_name, &output_index, &stop)); - if (stop != output_index + 1) { - return errors::InvalidArgument("OpKernel used list-valued output name '", - output_name, - "' when single-valued output was " - "expected"); - } + int input_index, output_index; + TF_RETURN_IF_ERROR(get_input_index(input_name, &input_index)); + TF_RETURN_IF_ERROR(get_output_index(output_name, &output_index)); if (!forward_input_to_output_with_shape(input_index, output_index, output_shape, output)) { return errors::FailedPrecondition("OpKernel could not forward input '", @@ -621,23 +590,18 @@ void OpKernelContext::delete_ref_input(int index, bool lock_held) { Status OpKernelContext::mutable_input(StringPiece name, Tensor* tensor, bool lock_held) { - int start, stop; - TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop)); - if (stop != start + 1) { - return errors::InvalidArgument("OpKernel used list-valued input name '", - name, - "' when single-valued input was expected"); - } - if (!input_is_ref(start)) { + int index; + TF_RETURN_IF_ERROR(get_input_index(name, &index)); + if (!input_is_ref(index)) { return errors::InvalidArgument("OpKernel used non-ref input name '", name, "' when ref input was expected"); } // return a copy of the Ref acquired while holding the mutex if (lock_held) { - *tensor = *(*params_->inputs)[start].tensor; + *tensor = *(*params_->inputs)[index].tensor; } else { - tf_shared_lock l(*input_ref_mutex(start)); - *tensor = *(*params_->inputs)[start].tensor; + tf_shared_lock l(*input_ref_mutex(index)); + *tensor = *(*params_->inputs)[index].tensor; } return Status::OK(); } @@ -645,18 +609,13 @@ Status OpKernelContext::mutable_input(StringPiece name, Tensor* tensor, Status OpKernelContext::replace_ref_input(StringPiece name, const Tensor& tensor, bool lock_held) { - int start, stop; - TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop)); - if (stop != start + 1) { - return errors::InvalidArgument("OpKernel used list-valued input name '", - name, - "' when single-valued input was expected"); - } - if (!input_is_ref(start)) { + int index; + TF_RETURN_IF_ERROR(get_input_index(name, &index)); + if (!input_is_ref(index)) { return errors::InvalidArgument("OpKernel used immutable input name '", name, "' when ref input was expected"); } - replace_ref_input(start, tensor, lock_held); + replace_ref_input(index, tensor, lock_held); return Status::OK(); } @@ -744,8 +703,8 @@ Status OpKernelContext::allocate_tensor( DataType type, const TensorShape& shape, Tensor* out_tensor, AllocatorAttributes attr, const AllocationAttributes& allocation_attr) { Allocator* a = get_allocator(attr); - MEMDEBUG_CACHE_OP(op_kernel().name().c_str()); - MEMDEBUG_CACHE_STEPID(step_id()); + auto op_annotation = + ScopedMemoryDebugAnnotation(op_kernel().name_view().data(), step_id()); Tensor new_tensor(a, type, shape, AllocationAttributes(allocation_attr.no_retry_on_failure, /* allocation_will_be_logged= */ true, @@ -888,7 +847,22 @@ Status OpKernelContext::allocate_persistent(DataType type, return s; } -Status OpKernelContext::set_output(StringPiece name, const Tensor& tensor) { +Status OpKernelContext::get_input_index(StringPiece name, + int* out_index) const { + int start, stop; + TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop)); + if (stop != start + 1) { + return errors::InvalidArgument("OpKernel used list-valued input name '", + name, + "' when single-valued input was " + "expected"); + } + *out_index = start; + return Status::OK(); +} + +Status OpKernelContext::get_output_index(StringPiece name, + int* out_index) const { int start, stop; TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop)); if (stop != start + 1) { @@ -897,22 +871,31 @@ Status OpKernelContext::set_output(StringPiece name, const Tensor& tensor) { "' when single-valued output was " "expected"); } - set_output(start, tensor); + *out_index = start; return Status::OK(); } -void OpKernelContext::set_output(int index, const Tensor& tensor) { - CHECK_GE(index, 0); - CHECK_LT(index, outputs_.size()); - const DataType type = params_->op_kernel->output_type(index); - CHECK(!IsRefType(type)); - CHECK_EQ(mutable_output(index), nullptr); +Status OpKernelContext::set_output(StringPiece name, const Tensor& tensor) { + int index; + TF_RETURN_IF_ERROR(get_output_index(name, &index)); + set_output(index, tensor); + return Status::OK(); +} +Status OpKernelContext::set_output(StringPiece name, Tensor&& tensor) { + int index; + TF_RETURN_IF_ERROR(get_output_index(name, &index)); + set_output(index, std::move(tensor)); + return Status::OK(); +} + +bool OpKernelContext::maybe_set_output_by_allocate_and_copy( + int index, const Tensor& tensor) { bool allocate_and_copy = false; const bool never_forward = (params_->forward_from_array != nullptr && params_->forward_from_array[index] == Params::kNeverForward); - if (never_forward) { + if (TF_PREDICT_FALSE(never_forward)) { maybe_initialize_scope_id_set(); if (allocated_scope_ids_->find(output_alloc_attr(index).scope_id) == allocated_scope_ids_->end()) { @@ -929,7 +912,7 @@ void OpKernelContext::set_output(int index, const Tensor& tensor) { } } - if (allocate_and_copy) { + if (TF_PREDICT_FALSE(allocate_and_copy)) { // This output was marked to not be forwarded either during graph // construction or grappler passes. Force an allocation and copy input to // output. @@ -939,31 +922,59 @@ void OpKernelContext::set_output(int index, const Tensor& tensor) { << params_->forward_from_array[index] << " alloc_attr.scope_id " << output_alloc_attr(index).scope_id; auto new_tensor = MakeUnique(); - Status s = allocate_tensor(type, tensor.shape(), new_tensor.get(), + Status s = allocate_tensor(tensor.dtype(), tensor.shape(), new_tensor.get(), output_alloc_attr(index)); TF_CHECK_OK(s); device()->CopyTensorInSameDevice(&tensor, new_tensor.get(), op_device_context(), [](const Status&) {}); outputs_[index] = TensorValue(new_tensor.release()); - } else { + } + return allocate_and_copy; +} + +void OpKernelContext::maybe_track_allocations_for_set_output( + const Tensor& tensor) { + if (TF_PREDICT_FALSE(track_allocations()) && tensor.TotalBytes() > 0) { + DCHECK(tracking_state_); + mutex_lock l(tracking_state_->stats_mu); + const auto it = std::find_if( + tracking_state_->temp_tensor_buffer_and_size.begin(), + tracking_state_->temp_tensor_buffer_and_size.end(), + [&tensor](const std::pair& e) { + return e.first == + static_cast(tensor.tensor_data().data()); + }); + if (it != tracking_state_->temp_tensor_buffer_and_size.end()) { + tracking_state_->temp_memory_allocated -= it->second; + tracking_state_->temp_tensor_buffer_and_size.erase(it); + } + } +} + +void OpKernelContext::set_output(int index, const Tensor& tensor) { + CHECK_GE(index, 0); + CHECK_LT(index, outputs_.size()); + const DataType type = params_->op_kernel->output_type(index); + CHECK(!IsRefType(type)); + CHECK_EQ(outputs_[index].tensor, nullptr); + if (TF_PREDICT_TRUE(!maybe_set_output_by_allocate_and_copy(index, tensor))) { // Input can be forwarded to output; incref on `tensor` and set output at // `index` to this tensor. outputs_[index] = TensorValue(new Tensor(tensor)); - if (track_allocations() && tensor.TotalBytes() > 0) { - DCHECK(tracking_state_); - mutex_lock l(tracking_state_->stats_mu); - const auto it = std::find_if( - tracking_state_->temp_tensor_buffer_and_size.begin(), - tracking_state_->temp_tensor_buffer_and_size.end(), - [&tensor](const std::pair& e) { - return e.first == - static_cast(tensor.tensor_data().data()); - }); - if (it != tracking_state_->temp_tensor_buffer_and_size.end()) { - tracking_state_->temp_memory_allocated -= it->second; - tracking_state_->temp_tensor_buffer_and_size.erase(it); - } - } + maybe_track_allocations_for_set_output(*outputs_[index].tensor); + } +} + +void OpKernelContext::set_output(int index, Tensor&& tensor) { + CHECK_GE(index, 0); + CHECK_LT(index, outputs_.size()); + const DataType type = params_->op_kernel->output_type(index); + CHECK(!IsRefType(type)); + CHECK_EQ(outputs_[index].tensor, nullptr); + if (TF_PREDICT_TRUE(!maybe_set_output_by_allocate_and_copy(index, tensor))) { + // Input can be forwarded to output; set output at `index` to this tensor. + outputs_[index] = TensorValue(new Tensor(std::move(tensor))); + maybe_track_allocations_for_set_output(*outputs_[index].tensor); } } @@ -977,28 +988,16 @@ void OpKernelContext::set_output_ref(int index, mutex* mu, Status OpKernelContext::set_output_ref(StringPiece name, mutex* mu, Tensor* tensor_for_ref) { - int start, stop; - TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop)); - if (stop != start + 1) { - return errors::InvalidArgument("OpKernel used list-valued output name '", - name, - "' when single-valued output was " - "expected"); - } - set_output_ref(start, mu, tensor_for_ref); + int index; + TF_RETURN_IF_ERROR(get_output_index(name, &index)); + set_output_ref(index, mu, tensor_for_ref); return Status::OK(); } Status OpKernelContext::mutable_output(StringPiece name, Tensor** tensor) { - int start, stop; - TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop)); - if (stop != start + 1) { - return errors::InvalidArgument("OpKernel used list-valued output name '", - name, - "' when single-valued output was " - "expected"); - } - *tensor = mutable_output(start); + int index; + TF_RETURN_IF_ERROR(get_output_index(name, &index)); + *tensor = mutable_output(index); return Status::OK(); } diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h index 1644eff9319..2f140316b3a 100644 --- a/tensorflow/core/framework/op_kernel.h +++ b/tensorflow/core/framework/op_kernel.h @@ -492,6 +492,7 @@ class OpOutputList { DataType expected_output_dtype(int i) const; Status allocate(int i, const TensorShape& shape, Tensor** output); void set(int i, const Tensor& tensor); + void set(int i, Tensor&& tensor); void set_ref(int i, mutex* mu, Tensor* tensor_for_ref); int size() const { return stop_ - start_; } Iterator begin() const { return Iterator(this, 0); } @@ -1031,6 +1032,9 @@ class OpKernelContext { // REQUIRES: 'tensor' must have the same MemoryType as // output_memory_types[index]. See comment above. Status set_output(StringPiece name, const Tensor& tensor); + Status set_output(StringPiece name, Tensor&& tensor); + void set_output(int index, const Tensor& tensor); + void set_output(int index, Tensor&& tensor); // To output a reference. Caller retains ownership of mu and tensor_for_ref, // and they must outlive all uses within the step. See comment above. @@ -1198,7 +1202,6 @@ class OpKernelContext { // The following functions all have versions that return Status // to capture error conditions, and are strongly preferred. Tensor* mutable_output(int index); - void set_output(int index, const Tensor& tensor); mutex* input_ref_mutex(int index); void set_output_ref(int index, mutex* mu, Tensor* tensor_for_ref); TensorValue release_output(int index); @@ -1274,6 +1277,16 @@ class OpKernelContext { Tensor* out_tensor, AllocatorAttributes allocator_attr, const AllocationAttributes& allocation_attr); + // Helpers for `set_output()`. + + // Returns `true` if the tensor was copied into an allocated output. + bool maybe_set_output_by_allocate_and_copy(int index, const Tensor& tensor); + + void maybe_track_allocations_for_set_output(const Tensor& tensor); + + Status get_input_index(StringPiece name, int* out_index) const; + Status get_output_index(StringPiece name, int* out_index) const; + // Initialize the allocated_scope_ids_ set the first time this method is // called. void maybe_initialize_scope_id_set(); @@ -1704,6 +1717,12 @@ inline void OpOutputList::set(int i, const Tensor& tensor) { ctx_->set_output(start_ + i, tensor); } +inline void OpOutputList::set(int i, Tensor&& tensor) { + DCHECK_GE(i, 0); + DCHECK_LT(i, stop_ - start_); + ctx_->set_output(start_ + i, std::move(tensor)); +} + inline void OpOutputList::set_ref(int i, mutex* mu, Tensor* tensor_for_ref) { DCHECK_GE(i, 0); DCHECK_LT(i, stop_ - start_); diff --git a/tensorflow/core/framework/rendezvous.h b/tensorflow/core/framework/rendezvous.h index e5283cdd13c..ccd6d102b5e 100644 --- a/tensorflow/core/framework/rendezvous.h +++ b/tensorflow/core/framework/rendezvous.h @@ -129,8 +129,43 @@ class RendezvousInterface { // threads with no clear owner. class Rendezvous : public RendezvousInterface, public core::RefCounted { public: - using Factory = - std::function; + class Factory { + public: + // Default to a factory that evaluates to false. + Factory() : valid_(false) {} + + Factory(std::function + create_fn, + std::function cleanup_fn) + : valid_(true), + create_fn_(std::move(create_fn)), + cleanup_fn_(std::move(cleanup_fn)) {} + + // If no clean up fn is provided, just put in a dummy. + // For backwards compatibility. + explicit Factory( + std::function + create_fn) + : valid_(true), + create_fn_(std::move(create_fn)), + cleanup_fn_([](const int64 step_id) { return Status::OK(); }) {} + + explicit operator bool() const { return valid_; } + + Status operator()(const int64 step_id, const DeviceMgr* device_mgr, + Rendezvous** rendez) const { + return create_fn_(step_id, device_mgr, rendez); + } + + Status CleanUp(const int64 step_id) const { return cleanup_fn_(step_id); } + + private: + bool valid_; + std::function + create_fn_; + std::function cleanup_fn_; + }; + // Constructs a rendezvous key for the tensor of "name" sent from // "src_device" to "dst_device". The tensor is generated in the frame // and iteration specified by "frame_iter". diff --git a/tensorflow/core/framework/resource_handle.h b/tensorflow/core/framework/resource_handle.h index ac6aef1b19c..88c9f9da190 100644 --- a/tensorflow/core/framework/resource_handle.h +++ b/tensorflow/core/framework/resource_handle.h @@ -38,23 +38,23 @@ class ResourceHandle { ~ResourceHandle(); // Unique name for the device containing the resource. - const string& device() const { return device_; } + const std::string& device() const { return device_; } // Names of the devices containing the resource. const std::vector& allowed_devices() const { return allowed_devices_; } - void set_device(const string& device) { device_ = device; } + void set_device(const std::string& device) { device_ = device; } void set_allowed_devices(const std::vector& devices) { allowed_devices_ = devices; } // Container in which this resource is placed. - const string& container() const { return container_; } - void set_container(const string& container) { container_ = container; } + const std::string& container() const { return container_; } + void set_container(const std::string& container) { container_ = container; } // Unique name of this resource. - const string& name() const { return name_; } - void set_name(const string& name) { name_ = name; } + const std::string& name() const { return name_; } + void set_name(const std::string& name) { name_ = name; } // Hash code for the type of the resource. Is only valid in the same device // and in the same execution. @@ -63,8 +63,10 @@ class ResourceHandle { // For debug-only, the name of the type pointed to by this handle, if // available. - const string& maybe_type_name() const { return maybe_type_name_; } - void set_maybe_type_name(const string& value) { maybe_type_name_ = value; } + const std::string& maybe_type_name() const { return maybe_type_name_; } + void set_maybe_type_name(const std::string& value) { + maybe_type_name_ = value; + } // Data types and shapes for the underlying resource. std::vector dtypes_and_shapes() const { @@ -80,10 +82,10 @@ class ResourceHandle { void FromProto(const ResourceHandleProto& proto); // Serialization via ResourceHandleProto - string SerializeAsString() const; - bool ParseFromString(const string& s); + std::string SerializeAsString() const; + bool ParseFromString(const std::string& s); - string DebugString() const; + std::string DebugString() const; // GUID for anonymous resources. Resources with this shared_name will have // their shared_name replaced with a GUID at creation time @@ -93,19 +95,19 @@ class ResourceHandle { public: // The default device containing the resource, where the ResourceHandle is // initially created. - string device_; + std::string device_; // A set of devices containing the resource. If empty, the resource only // exists on device_. Can be represented in wildcard patterns. std::vector allowed_devices_; - string container_; - string name_; + std::string container_; + std::string name_; uint64 hash_code_ = 0; - string maybe_type_name_; + std::string maybe_type_name_; std::vector dtypes_and_shapes_; }; // For backwards compatibility for when this was a proto -string ProtoDebugString(const ResourceHandle& handle); +std::string ProtoDebugString(const ResourceHandle& handle); // Encodes a list of ResourceHandle protos in the given StringListEncoder. void EncodeResourceHandleList(const ResourceHandle* p, int64 n, diff --git a/tensorflow/core/framework/shape_inference_testutil.h b/tensorflow/core/framework/shape_inference_testutil.h index 095a672b044..40a6d53d223 100644 --- a/tensorflow/core/framework/shape_inference_testutil.h +++ b/tensorflow/core/framework/shape_inference_testutil.h @@ -86,11 +86,11 @@ class ShapeInferenceTestutil { .error_message()) #define INFER_ERROR(error_substring, op, i) \ { \ - string error_message = \ + std::string error_message = \ ::tensorflow::shape_inference::ShapeInferenceTestutil::InferShapes( \ op, i, "e") \ .error_message(); \ - const string& substring = error_substring; \ + const std::string& substring = error_substring; \ EXPECT_NE("", error_message); \ EXPECT_TRUE(absl::StrContains(error_message, substring)) \ << "Expected to see '" << substring << "' in '" << error_message \ diff --git a/tensorflow/core/framework/tensor.h b/tensorflow/core/framework/tensor.h index 11910766ba8..54541be0b4f 100644 --- a/tensorflow/core/framework/tensor.h +++ b/tensorflow/core/framework/tensor.h @@ -571,19 +571,19 @@ class Tensor { int64 begin) const; /// Render the first `max_entries` values in `*this` into a string. - string SummarizeValue(int64 max_entries, bool print_v2 = false) const; + std::string SummarizeValue(int64 max_entries, bool print_v2 = false) const; /// A human-readable summary of the tensor suitable for debugging. // `num_values` is the number of actual data values in the tensor // included in the message. If the tensor might be resident in // GPU/TPU memory use DeviceSafeDebugString instead. - string DebugString(int num_values) const; - string DebugString() const { return DebugString(3); } + std::string DebugString(int num_values) const; + std::string DebugString() const { return DebugString(3); } // Variant of DebugString() that should be used for possibly non-CPU tensors. // If the tensor is not resident on CPU, we can't read its values as // DebugString() does. - string DeviceSafeDebugString() const; + std::string DeviceSafeDebugString() const; /// Fill in the `TensorDescription` proto with metadata about the /// tensor that is useful for monitoring and debugging. diff --git a/tensorflow/core/framework/tensor_interface.h b/tensorflow/core/framework/tensor_interface.h index f5d7bf53370..93883308662 100644 --- a/tensorflow/core/framework/tensor_interface.h +++ b/tensorflow/core/framework/tensor_interface.h @@ -17,8 +17,11 @@ limitations under the License. #define TENSORFLOW_CORE_FRAMEWORK_TENSOR_INTERFACE_H_ #include "tensorflow/c/tf_datatype.h" -#include "tensorflow/c/tf_status.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/platform/casts.h" +#include "tensorflow/core/platform/status.h" + +namespace tensorflow { // Abstract interface to a Tensor. // @@ -49,12 +52,10 @@ class AbstractTensorInterface { virtual bool CanMove() const = 0; }; -namespace tensorflow { - class TensorInterface : public AbstractTensorInterface { public: TensorInterface() {} - explicit TensorInterface(Tensor t) : tensor_(std::move(t)) {} + explicit TensorInterface(tensorflow::Tensor t) : tensor_(std::move(t)) {} ~TensorInterface() override {} TF_DataType Type() const override; @@ -66,14 +67,23 @@ class TensorInterface : public AbstractTensorInterface { bool IsAligned() const override; bool CanMove() const override; - Status ToTensor(Tensor* dst) const; + Status ToTensor(tensorflow::Tensor* dst) const; Status BitcastFrom(const TensorInterface& from, TF_DataType type, const int64_t* new_dims, int num_new_dims); + // TODO(gjn): This is not a very generic interface, but is needed for specific + // use cases. + tensorflow::Tensor& Tensor() { return tensor_; } + private: - Tensor tensor_; + tensorflow::Tensor tensor_; }; +inline Tensor& TensorFromInterface( + const std::unique_ptr& tensor) { + return down_cast(tensor.get())->Tensor(); +} + } // namespace tensorflow #endif // TENSORFLOW_CORE_FRAMEWORK_TENSOR_INTERFACE_H_ diff --git a/tensorflow/core/framework/tensor_shape.h b/tensorflow/core/framework/tensor_shape.h index 4a53a253586..ac1bef12370 100644 --- a/tensorflow/core/framework/tensor_shape.h +++ b/tensorflow/core/framework/tensor_shape.h @@ -69,8 +69,8 @@ class TensorShapeRep { int64 num_elements() const { return num_elements_; } /// For error messages. - string DebugString() const; - static string DebugString(const TensorShapeProto& proto); + std::string DebugString() const; + static std::string DebugString(const TensorShapeProto& proto); void DumpRep() const; // XXX @@ -397,7 +397,8 @@ class TensorShapeUtils { static Status MakeShape(gtl::ArraySlice shape, PartialTensorShape* out); - static string ShapeListString(const gtl::ArraySlice& shapes); + static std::string ShapeListString( + const gtl::ArraySlice& shapes); /// \brief Returns true iff `shape` starts with `prefix`. static bool StartsWith(const TensorShape& shape, const TensorShape& prefix); @@ -462,7 +463,7 @@ class PartialTensorShape : public TensorShapeBase { /// common predicates on a partially known tensor shape. class PartialTensorShapeUtils { public: - static string PartialShapeListString( + static std::string PartialShapeListString( const gtl::ArraySlice& shapes); static bool AreIdentical(const gtl::ArraySlice& shapes0, diff --git a/tensorflow/core/framework/types.h b/tensorflow/core/framework/types.h index e09ea268cd7..61575a7b735 100644 --- a/tensorflow/core/framework/types.h +++ b/tensorflow/core/framework/types.h @@ -59,14 +59,14 @@ class DeviceType { explicit DeviceType(StringPiece type) : type_(type.data(), type.size()) {} const char* type() const { return type_.c_str(); } - const string& type_string() const { return type_; } + const std::string& type_string() const { return type_; } bool operator<(const DeviceType& other) const; bool operator==(const DeviceType& other) const; bool operator!=(const DeviceType& other) const { return !(*this == other); } private: - string type_; + std::string type_; }; std::ostream& operator<<(std::ostream& os, const DeviceType& d); @@ -110,10 +110,10 @@ typedef gtl::InlinedVector, 4> PrioritizedDeviceTypeVector; // Convert the enums to strings for errors: -string DataTypeString(DataType dtype); -string DeviceTypeString(const DeviceType& device_type); -string DataTypeSliceString(const DataTypeSlice dtypes); -inline string DataTypeVectorString(const DataTypeVector& dtypes) { +std::string DataTypeString(DataType dtype); +std::string DeviceTypeString(const DeviceType& device_type); +std::string DataTypeSliceString(const DataTypeSlice dtypes); +inline std::string DataTypeVectorString(const DataTypeVector& dtypes) { return DataTypeSliceString(dtypes); } diff --git a/tensorflow/core/graph/default_device.h b/tensorflow/core/graph/default_device.h index f0f53c91f47..011b7c11cf2 100644 --- a/tensorflow/core/graph/default_device.h +++ b/tensorflow/core/graph/default_device.h @@ -26,7 +26,7 @@ namespace graph { // Sets the default device for all nodes in graph_def to "device", // only if not already set. -inline void SetDefaultDevice(const string& device, GraphDef* graph_def) { +inline void SetDefaultDevice(const std::string& device, GraphDef* graph_def) { for (int i = 0; i < graph_def->node_size(); ++i) { auto node = graph_def->mutable_node(i); if (node->device().empty()) { diff --git a/tensorflow/core/graph/graph.h b/tensorflow/core/graph/graph.h index 235d944bd60..cdb2d123eaf 100644 --- a/tensorflow/core/graph/graph.h +++ b/tensorflow/core/graph/graph.h @@ -41,6 +41,7 @@ limitations under the License. #include #include +#include "absl/types/optional.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_def_util.h" @@ -678,6 +679,10 @@ class Graph { // Builds a node name to node pointer index for all nodes in the graph. std::unordered_map BuildNodeNameIndex() const; + absl::optional>& GetConstArgIndicesCache() const { + return const_arg_indices_cache_; + } + // TODO(josh11b): uint64 hash() const; private: @@ -751,6 +756,10 @@ class Graph { // AddWhileContext() or Node::while_ctx(), but this manages the lifetime. std::map while_ctxs_; + // Cache of the indices of the arguments which need to be constant for the XLA + // compilation. + mutable absl::optional> const_arg_indices_cache_; + TF_DISALLOW_COPY_AND_ASSIGN(Graph); }; diff --git a/tensorflow/core/graph/graph_constructor.cc b/tensorflow/core/graph/graph_constructor.cc index 6bb1772e02d..bc3c0309ab9 100644 --- a/tensorflow/core/graph/graph_constructor.cc +++ b/tensorflow/core/graph/graph_constructor.cc @@ -821,56 +821,10 @@ Status GraphConstructor::ValidateShape(Node* node) { } s = refiner_->SetShape(node, i, h); if (!s.ok()) { - // If the output shape is incompatible with what is inferred - // by the graph for a very specific whitelist of ops, then we - // ignore this output shape. This can happen if there is a - // bug in the shape function for some operation, and the - // serialized graph def has the incorrect shape set when - // running on a newer binary with the fixed shape function. - // This is an escape hatch that allows us to correct shape - // functions that are not critical to correct execution but - // would cause graphs to fail if imported after correcting. - const string& op = node->type_string(); - const std::vector whitelist = { - // To be removed after 2017/03/08. - "RandomShuffleQueue", - "PaddingFIFOQueue", - "FIFOQueue", - "PriorityQueue", - "QueueSize", - "Stack", - "Barrier", - "BarrierReadySize", - "BarrierIncompleteSize", - "HashTable", - "MutableHashTable", - "MutableHashTableOfTensors", - "Mutex", - "CuckooTable", - "IndexTable", - "WholeFileReader", - "TextLineReader", - "FixedLengthRecordReader", - "TFRecordReader", - "IdentityReader", - "RefSwitch", - "RefEnter", - "RefNextIteration", - "RefMerge", - "RefIdentity", - "LMDBReader", - // To be removed after 2017/04/24. - "ConditionalAccumulator", - "SparseConditionalAccumulator", - "Table", - }; - if (std::find(whitelist.begin(), whitelist.end(), op) == - whitelist.end()) { - return errors::InvalidArgument( - "Node '", node->name(), "' has an ", kAttrName, - " attribute inconsistent with the GraphDef for output #", i, ": ", - s.error_message()); - } + return errors::InvalidArgument( + "Node '", node->name(), "' has an ", kAttrName, + " attribute inconsistent with the GraphDef for output #", i, ": ", + s.error_message()); } } node->ClearAttr(kAttrName); diff --git a/tensorflow/core/graph/graph_constructor_test.cc b/tensorflow/core/graph/graph_constructor_test.cc index 8475032665e..89cb7d82b2b 100644 --- a/tensorflow/core/graph/graph_constructor_test.cc +++ b/tensorflow/core/graph/graph_constructor_test.cc @@ -1172,31 +1172,6 @@ node { << s; } -TEST_F(GraphConstructorTest, ImportGraphDef_ShapeWhitelist) { - // Barrier's shape is an output vector of 2, but the graph says it's a vector - // of 1. This is currently whitelisted. - GraphDef def; - bool parsed = protobuf::TextFormat::ParseFromString( - R"EOF( - node { - name: "A" - op: "Barrier" - attr { - key: "_output_shapes" - value { list { shape {} } } - } - attr { - key: "component_types" - value { list { type: DT_FLOAT } } - } - } - )EOF", - &def); - ASSERT_TRUE(parsed); - Status s = ImportGraphDef(ImportGraphDefOptions(), def, &graph_, nullptr); - EXPECT_EQ(Status::OK(), s) << s; -} - TEST_F(GraphConstructorTest, ImportGraphDef_InputMap) { ShapeRefiner refiner(TF_GRAPH_DEF_VERSION, graph_.op_registry()); diff --git a/tensorflow/core/grappler/costs/graph_properties.cc b/tensorflow/core/grappler/costs/graph_properties.cc index e3688cc0a6d..1f9bae4f852 100644 --- a/tensorflow/core/grappler/costs/graph_properties.cc +++ b/tensorflow/core/grappler/costs/graph_properties.cc @@ -48,6 +48,11 @@ using shape_inference::ShapeAndType; using shape_inference::ShapeHandle; using TensorVector = gtl::InlinedVector; +// A large value for UnknownDim from Const used as a dim value in shape. +// Some ops treat "-1" specially, different from UnknownDim: +// e.g., shape input to Reshape op. +const int64 kUnknownDimFromConst = INT64_MAX; + template struct HashHandle { std::size_t operator()(const Handle& h) const { return h.Handle(); } @@ -353,15 +358,34 @@ void VerboseLogUnknownDimensionSources( } } -bool IsShapeFullyDefinedIntegerVectorOrScalar( - InferenceContext* ic, const ShapeHandle& shape, - const ShapeHandle& tensor_as_shape, const DataType& dtype) { - if (!ic->FullyDefined(shape) || ic->Rank(shape) > 1 || - !ic->FullyDefined(tensor_as_shape) || - (dtype != DT_INT32 && dtype != DT_INT64)) { - return false; +// Helper function to convert kUnknownDimFromConst into UnknownDim. +std::vector ReplaceUnknownDimFromConstWithUnknownDim( + InferenceContext* ic, const std::vector& shapes) { + std::vector converted_shapes(shapes.size()); + for (int i = 0; i < shapes.size(); i++) { + const auto& shape = shapes[i]; + if (!ic->RankKnown(shape)) { + converted_shapes[i] = shape; + continue; + } + bool just_copy = true; + std::vector dims; + for (int32 i = 0; i < ic->Rank(shape); ++i) { + DimensionHandle dim = ic->Dim(shape, i); + if (ic->ValueKnown(dim) && ic->Value(dim) == kUnknownDimFromConst) { + just_copy = false; + dims.push_back(ic->UnknownDim()); + } else { + dims.push_back(dim); + } + } + if (just_copy) { + converted_shapes[i] = shape; + continue; + } + converted_shapes[i] = ic->MakeShape(dims); } - return true; + return converted_shapes; } // Returned tensor's shape is like `shape`, and its values and dtype are from @@ -414,6 +438,32 @@ NodeDef MakeConstNodeDefFromShape(InferenceContext* ic, } // namespace +// Note that tensor_as_shape input should not include kUnknownDimFromConst. +// This function check kUnknownDimFromConst, but will log WARNING. +// If checking input_tensors_as_shape_to_propgate or output_tensors_as_shape, +// which may include kUnknownDimFromConst, run +// convert it using ReplaceUnknownDimFromConstWithUnknownDim() before. +bool IsShapeFullyDefinedIntegerVectorOrScalar( + InferenceContext* ic, const ShapeHandle& shape, + const ShapeHandle& tensor_as_shape, const DataType& dtype) { + if (!ic->FullyDefined(shape) || ic->Rank(shape) > 1 || + !ic->FullyDefined(tensor_as_shape) || + (dtype != DT_INT32 && dtype != DT_INT64)) { + return false; + } + // Also check whether any dim in tensor_as_shape is kUnknownDimFromConst. + for (int32 i = 0; i < ic->Rank(tensor_as_shape); ++i) { + DimensionHandle dim = ic->Dim(tensor_as_shape, i); + if (ic->Value(dim) == kUnknownDimFromConst) { + LOG(WARNING) << "IsShapeFullyDefinedIntegerVectorOrScalar(): " + << "tensor_as_shape input includes kUnknownDimFromConst -- " + << ic->DebugString(tensor_as_shape); + return false; + } + } + return true; +} + // Queue of nodes to process. Nodes can be enqueued in any order, but will be // dequeued in (roughly) topological order. Propagating shapes following a // topological ordering isn't required for correctness but helps speed things up @@ -569,10 +619,91 @@ class SymbolicShapeRefiner { // Additional info for propagating tensor values and tensor shapes. std::vector input_tensor_protos; std::vector output_tensor_protos; + // This is the same to inference_context->input_tensors_as_shapes, except + // that some UnknownDims (-1) can be kUnknownDimFromConst. + std::vector input_tensors_as_shapes_to_propagate; std::vector output_tensors_as_shapes; // Output shapes incompatible between annotation and shape inference. bool shape_incompatible = false; + + // Similar to DebugString() in InferenceContext, but prints out + // kUnknownDimFromConst properly. + std::string StringifyShapeHandle(ShapeHandle s) { + auto* ic = inference_context.get(); + if (ic->RankKnown(s)) { + std::vector vals; + for (int i = 0; i < ic->Rank(s); i++) { + DimensionHandle d = ic->Dim(s, i); + if (ic->ValueKnown(d) && ic->Value(d) == kUnknownDimFromConst) { + vals.push_back("?(Const)"); + } else { + vals.push_back(ic->DebugString(d)); + } + } + return strings::StrCat("[", absl::StrJoin(vals, ","), "]"); + } else { + return "?"; + } + } + + std::string DebugString(const NodeDef& node) { + std::string output; + auto* ic = inference_context.get(); + absl::StrAppend( + &output, node.name(), " [", node.op(), "] has ", ic->num_inputs(), + (ic->num_inputs() > 1 ? " inputs and " : " input and "), + ic->num_outputs(), (ic->num_outputs() > 1 ? " outputs" : " output")); + if (op_data->is_function_op) { + absl::StrAppend(&output, " (function op)"); + } + absl::StrAppend(&output, ": \n"); + + for (int i = 0; i < ic->num_inputs(); i++) { + absl::StrAppend(&output, " input [", i, "] ", node.input(i), + " -- type: ", DataTypeString(input_types.at(i)), + ", shape: ", ic->DebugString(ic->input(i)), + ", tensor: "); + Tensor t1; + if (input_tensor_protos.size() > i && + input_tensor_protos.at(i) != nullptr && + t1.FromProto(*input_tensor_protos.at(i))) { + absl::StrAppend(&output, t1.DebugString(), ", tensor_as_shape: "); + } else { + absl::StrAppend(&output, " null, tensor_as_shape: "); + } + if (input_tensors_as_shapes_to_propagate.size() > i) { + absl::StrAppend( + &output, + StringifyShapeHandle(input_tensors_as_shapes_to_propagate.at(i)), + "\n"); + } else { + absl::StrAppend(&output, " null\n"); + } + } + for (int i = 0; i < ic->num_outputs(); i++) { + absl::StrAppend(&output, " output [", i, + "] -- type: ", DataTypeString(output_types.at(i)), + ", shape: ", ic->DebugString(ic->output(i)), + ", tensor: "); + Tensor t2; + if (output_tensor_protos.size() > i && + output_tensor_protos.at(i) != nullptr && + t2.FromProto(*output_tensor_protos.at(i))) { + absl::StrAppend(&output, t2.DebugString(), ", tensor_as_shape: "); + } else { + absl::StrAppend(&output, " null, tensor_as_shape: "); + } + if (output_tensors_as_shapes.size() > i) { + absl::StrAppend(&output, + StringifyShapeHandle(output_tensors_as_shapes.at(i)), + "\n"); + } else { + absl::StrAppend(&output, " null\n"); + } + } + return output; + } }; NodeContext* GetNodeContext(const NodeDef* node) { @@ -770,7 +901,7 @@ class SymbolicShapeRefiner { // Check if the shapes of the nodes in the fan-in of this node have changed, // and if they have, update the node input shapes. InferenceContext* ic = ctx->inference_context.get(); - std::vector input_tensors_as_shapes(ic->num_inputs()); + ctx->input_tensors_as_shapes_to_propagate.resize(ic->num_inputs()); ctx->input_tensor_protos.resize(ic->num_inputs(), nullptr); for (int dst_input = 0; dst_input < ic->num_inputs(); ++dst_input) { @@ -797,7 +928,7 @@ class SymbolicShapeRefiner { // output_tensor_protos to input_tensor_protos and input_tensors, and // output_tensors_as_shapes to input_tensors_as_shapes. if (src_ctx->output_tensors_as_shapes.size() > src_output) { - input_tensors_as_shapes[dst_input] = + ctx->input_tensors_as_shapes_to_propagate[dst_input] = src_ctx->output_tensors_as_shapes[src_output]; } @@ -805,41 +936,6 @@ class SymbolicShapeRefiner { const auto* tensor_proto = src_ctx->output_tensor_protos[src_output]; if (tensor_proto != nullptr) { ctx->input_tensor_protos[dst_input] = tensor_proto; - - if (!ic->FullyDefined(input_tensors_as_shapes[dst_input])) { - // Tensorflow uses '-1' to encode unknown shape or dimension: - // - // -1 : unknown shape - // [-1] : vector of unknown size - // [-1, -1] : matrix of unknown size - // - // For example `tf.reshape(x, [-1])` will reshape an arbitrary - // tensor x to a vector. - // - // It's possible that the same Const with -1 is used in many - // places, but that doesn't mean the resultant shapes are - // identical. e.g., x1 = Reshape(x, c) and y1 = Reshape(y, c), - // where c is [-1]. In this case, shape inference yields both x1 and - // y1 as rank 1, size unknown, but still the shapes of x1 and y1 - // can be different. (even if we use different Const([-1]) for x1 - // and x2, graph optimizer may merge them to single Const through - // duplicate removal.) - // If we reuse output_tensors_as_shapes to input_tensors_as_shapes - // by copying ShapeHandle, they share the same Shape object, and - // SymbolicShapeManager, later in InferStatically(), assigns the - // same symbolic dim value (unique value < -1); in the above - // Reshape example, the shapes of x1 and y1 become, for example, - // [-278] and graph optimizer may yield incorrect output 'cause it - // assumes x1 and y1 have the same shape. - // To prevent this, we re-create a ShapeHandle from the Const - // tensor, instead of reusing output_tensors_as_shapes (so that - // ShapeHandles of the const fanouts have the same values, - // but different Shape objects -- SymbolicShapeManager assigns - // different symbol id to each fanout shape). - // TODO(dyoon): clean up the way values are propagated. - MaybeTensorProtoToShape(ic, *tensor_proto, - &input_tensors_as_shapes[dst_input]); - } } } @@ -879,10 +975,12 @@ class SymbolicShapeRefiner { return Status::OK(); } + // Convert all kUnknownDimFromConst to -1 for shape inference. + ic->set_input_tensors_as_shapes(ReplaceUnknownDimFromConstWithUnknownDim( + ic, ctx->input_tensors_as_shapes_to_propagate)); // Notice: UpdateFunction only uses input_tensors_as_shapes, so for function // nodes, we dont' perform the conversion from TensorProtos to Tensors for // constant inputs here. - ic->set_input_tensors_as_shapes(input_tensors_as_shapes); // Properly handle function nodes. if (ctx->op_data && ctx->op_data->is_function_op) { @@ -1218,10 +1316,18 @@ class SymbolicShapeRefiner { } if (c->output_tensors_as_shapes.size() > i && ic->FullyDefined(c->output_tensors_as_shapes[i])) { - continue; + bool no_unknown_dim_from_const = true; + for (int32 j = 0; j < ic->Rank(c->output_tensors_as_shapes[i]); ++j) { + const auto dim = ic->Dim(c->output_tensors_as_shapes[i], j); + if (ic->ValueKnown(dim) && ic->Value(dim) == kUnknownDimFromConst) { + no_unknown_dim_from_const = false; + break; + } + } + if (no_unknown_dim_from_const) { + continue; + } } - - // Unknown for output[i]. return false; } } @@ -1231,7 +1337,6 @@ class SymbolicShapeRefiner { // Returns true if all the output shapes are known. bool AllOutputShapesKnown(NodeContext* c) { InferenceContext* ic = c->inference_context.get(); - // LOG(INFO) << ic->DebugString(); // Checks if all the output shapes are fully defined. for (int i = 0; i < ic->num_outputs(); i++) { if (!ic->FullyDefined(ic->output(i))) { @@ -1479,6 +1584,13 @@ class SymbolicShapeRefiner { // TODO(bsteiner) We should still propagate the shapes to the ports that // aren't fed in the case of a ShapeN node. + // Note that when propagating tensors_as_shapes, we use + // c->input_tensors_as_shapes_to_progate instead of + // ic->input_tensors_as_shapes. The former uses kUnknownDimFromConst if + // UnknownDim is from Const tensor, and it is propagated through shape + // inference. Before calling shape functions, we convert it to UnknownDim, + // but instantiate a new UnknownDim to prevent incorrect symbolic shape + // inference through UnknownDim from Const. InferenceContext* ic = c->inference_context.get(); if (!is_fed) { if (IsConstant(node)) { @@ -1531,7 +1643,7 @@ class SymbolicShapeRefiner { bool valid = true; ShapeHandle result; for (int i = 0; i < ic->num_inputs() - 1; ++i) { - ShapeHandle input = ic->input_tensors_as_shapes()[i]; + ShapeHandle input = c->input_tensors_as_shapes_to_propagate[i]; if (!ic->RankKnown(input)) { valid = false; break; @@ -1563,7 +1675,8 @@ class SymbolicShapeRefiner { } else { // Don't have tensor value, but use input_tensors_as_shapes, if // possible. - const ShapeHandle& shape_handle = ic->input_tensors_as_shapes()[i]; + const ShapeHandle& shape_handle = + c->input_tensors_as_shapes_to_propagate[i]; if (ic->RankKnown(shape_handle) && ic->Rank(shape_handle) >= 1 && ic->ValueKnown(ic->Dim(shape_handle, 0))) { dims.push_back(ic->Dim(shape_handle, 0)); @@ -1578,13 +1691,14 @@ class SymbolicShapeRefiner { } } else if (IsIdentity(node) || IsIdentityNSingleInput(node)) { c->output_tensors_as_shapes.resize(1); - c->output_tensors_as_shapes[0] = ic->input_tensors_as_shapes()[0]; + c->output_tensors_as_shapes[0] = + c->input_tensors_as_shapes_to_propagate[0]; if (c->input_tensor_protos[0] != nullptr) { c->output_tensor_protos.resize(1); c->output_tensor_protos[0] = c->input_tensor_protos[0]; } } else if (IsSlice(node)) { - ShapeHandle input = ic->input_tensors_as_shapes()[0]; + ShapeHandle input = c->input_tensors_as_shapes_to_propagate[0]; bool valid = ic->RankKnown(input); const Tensor* slice_offset = ic->input_tensor(1); valid &= slice_offset != nullptr && slice_offset->NumElements() == 1; @@ -1608,7 +1722,7 @@ class SymbolicShapeRefiner { c->output_tensors_as_shapes[0] = result; } } else if (IsStridedSlice(node)) { - ShapeHandle input = ic->input_tensors_as_shapes()[0]; + ShapeHandle input = c->input_tensors_as_shapes_to_propagate[0]; bool valid = ic->RankKnown(input); const Tensor* slice_begin = ic->input_tensor(1); valid &= slice_begin != nullptr && slice_begin->NumElements() == 1; @@ -1773,8 +1887,11 @@ class SymbolicShapeRefiner { int64 value = tensor.dtype() == DT_INT32 ? tensor.flat()(i) : tensor.flat()(i); has_values_smaller_than_minus_1 |= (value < -1); - dims.push_back(value < 0 ? ic->UnknownDim() : ic->MakeDim(value)); + // Mark this as UnknownDim from Const. + dims.push_back(value < 0 ? ic->MakeDim(kUnknownDimFromConst) + : ic->MakeDim(value)); } + if (!has_values_smaller_than_minus_1) { *tensors_as_shapes = ic->MakeShape(dims); return true; @@ -2345,6 +2462,9 @@ Status GraphProperties::InferStatically(bool assume_valid_feeds, for (int i = 0; i < ic->num_outputs(); ++i) { shape_manager->AsTensorProperties(ic->output(i), ctx->output_types[i], &output_properties[i]); + auto converted_output_tensors_as_shapes = + ReplaceUnknownDimFromConstWithUnknownDim( + ic, ctx->output_tensors_as_shapes); if (include_output_tensor_values) { // Export tensor value to output_properties.value. if (IsConstant(node)) { @@ -2355,12 +2475,13 @@ Status GraphProperties::InferStatically(bool assume_valid_feeds, ctx->output_tensor_protos[i] != nullptr) { *output_properties[i].mutable_value() = *ctx->output_tensor_protos[i]; - } else if (ctx->output_tensors_as_shapes.size() > i && + } else if (converted_output_tensors_as_shapes.size() > i && IsShapeFullyDefinedIntegerVectorOrScalar( - ic, ic->output(i), ctx->output_tensors_as_shapes[i], + ic, ic->output(i), + converted_output_tensors_as_shapes[i], ctx->output_types[i])) { *output_properties[i].mutable_value() = MakeTensorProtoFromShape( - ic, ic->output(i), ctx->output_tensors_as_shapes[i], + ic, ic->output(i), converted_output_tensors_as_shapes[i], ctx->output_types[i]); } } diff --git a/tensorflow/core/grappler/costs/graph_properties.h b/tensorflow/core/grappler/costs/graph_properties.h index b18d8e2a505..37c41a3dba5 100644 --- a/tensorflow/core/grappler/costs/graph_properties.h +++ b/tensorflow/core/grappler/costs/graph_properties.h @@ -211,6 +211,12 @@ class GraphProperties { std::unordered_set incompatible_shape_nodes_; }; +// Helper function for GraphProperties. +bool IsShapeFullyDefinedIntegerVectorOrScalar( + shape_inference::InferenceContext* ic, + const shape_inference::ShapeHandle& shape, + const shape_inference::ShapeHandle& tensor_as_shape, const DataType& dtype); + } // end namespace grappler } // end namespace tensorflow diff --git a/tensorflow/core/grappler/costs/graph_properties_test.cc b/tensorflow/core/grappler/costs/graph_properties_test.cc index 5a18179a33c..135fc521668 100644 --- a/tensorflow/core/grappler/costs/graph_properties_test.cc +++ b/tensorflow/core/grappler/costs/graph_properties_test.cc @@ -38,6 +38,10 @@ namespace tensorflow { namespace grappler { namespace { +using shape_inference::InferenceContext; +using shape_inference::ShapeAndType; +using shape_inference::ShapeHandle; + const char kTestDataPath[] = "core/grappler/costs/graph_properties_testdata"; REGISTER_OP("TestOpWithNoInferenceFn") @@ -163,12 +167,9 @@ TEST_F(GraphPropertiesTest, StaticProperties) { EXPECT_EQ(1, in_prop.shape().dim(1).size()); const auto out_props = properties.GetOutputProperties(node.name()); EXPECT_EQ(1, out_props.size()); - string in_prop_str; - ::tensorflow::protobuf::TextFormat::PrintToString(in_prop, &in_prop_str); - string out_prop_str; - ::tensorflow::protobuf::TextFormat::PrintToString(out_props[0], - &out_prop_str); - EXPECT_EQ(in_prop_str, out_prop_str); + EXPECT_EQ(in_prop.dtype(), out_props[0].dtype()); + EXPECT_EQ(in_prop.shape().DebugString(), + out_props[0].shape().DebugString()); } } } @@ -1728,7 +1729,9 @@ TEST_F(GraphPropertiesTest, FedNodes) { const auto out_props = properties.GetOutputProperties(node.name()); EXPECT_EQ(1, out_props.size()); const OpInfo::TensorProperties& out_prop = out_props[0]; - EXPECT_EQ(in_prop.DebugString(), out_prop.DebugString()); + EXPECT_EQ(in_prop.dtype(), out_prop.dtype()); + EXPECT_EQ(in_prop.shape().DebugString(), + out_prop.shape().DebugString()); } } } @@ -2064,6 +2067,160 @@ TEST_F(GraphPropertiesTest, ShapeAnnotatedFunctionOp) { EXPECT_EQ("float: [1,2,3,4]", PropToString(out_prop0)); } } + +TEST_F(GraphPropertiesTest, + SymbolicShapeInferenceWithReshapeOpsSharingShapeVector) { + GrapplerItem item; + // This graph creates a shape vector [-1, 10] from Concat(Const, Const) + // used for two reshape ops. One reshape op is segment_ids input to + // UnsortedSegmentSum op, which applies MergePrefix from its shape function. + // segment_ids has a shape [-1, 10] (from reshape), but MergePrefix with + // data input ([10, 10, 10, 10]) makes -1, or unknown dim, 10, with + // SymbolicShapeRefiner. + // This dim value (10), however, should not affect the other reshape op, even + // though it shares the shape input; -1 in the shape input of Reshape op is + // a special case of computed output dim, not unknown dim. + // data and num_segments are inputs to UnsortedSegmenetSum. + + TF_CHECK_OK(NodeDefBuilder("data", "Placeholder") + .Attr("dtype", DT_FLOAT) + .Attr("shape", TensorShape({10, 10, 10, 10})) + .Finalize(item.graph.add_node())); + Tensor num_segments(DT_INT32, TensorShape({})); + // Build semgent_ids input to UnsortedSegmentSum from Const ops, ConcatV2, + // and Reshape ops. tensors_as_shape from Const ops are propagated to ConcatV2 + // output to form shape vector [-1, 10] to Reshape. + test::FillIota(&num_segments, 3); + TF_CHECK_OK(NodeDefBuilder("num_segments", "Const") + .Attr("dtype", DT_INT32) + .Attr("value", num_segments) + .Finalize(item.graph.add_node())); + Tensor minus_one(DT_INT32, TensorShape({1})); + test::FillIota(&minus_one, -1); + TF_CHECK_OK(NodeDefBuilder("minus_one", "Const") + .Attr("dtype", DT_INT32) + .Attr("value", minus_one) + .Finalize(item.graph.add_node())); + Tensor plus_ten(DT_INT32, TensorShape({1})); + test::FillIota(&plus_ten, 10); + TF_CHECK_OK(NodeDefBuilder("plus_ten", "Const") + .Attr("dtype", DT_INT32) + .Attr("value", plus_ten) + .Finalize(item.graph.add_node())); + Tensor axis(DT_INT32, TensorShape({})); + test::FillIota(&axis, -1); + TF_CHECK_OK(NodeDefBuilder("axis", "Const") + .Attr("dtype", DT_INT32) + .Attr("value", axis) + .Finalize(item.graph.add_node())); + std::vector inputs(2); + inputs[0] = NodeDefBuilder::NodeOut{"minus_one", 0, DT_INT32}; + inputs[1] = NodeDefBuilder::NodeOut{"plus_ten", 0, DT_INT32}; + TF_CHECK_OK(NodeDefBuilder("concat", "ConcatV2") + .Input(inputs) + .Input("axis", 0, DT_INT32) + .Attr("N", 2) + .Attr("T", DT_INT32) + .Attr("Tidx", DT_INT32) + .Finalize(item.graph.add_node())); + TF_CHECK_OK(NodeDefBuilder("segment_ids_", "Placeholder") + .Attr("dtype", DT_FLOAT) + .Finalize(item.graph.add_node())); + TF_CHECK_OK(NodeDefBuilder("segment_ids_shape_before_reshape", "Shape") + .Input("segment_ids_", 0, DT_FLOAT) + .Attr("T", DT_FLOAT) + .Attr("out_type", DT_INT32) + .Finalize(item.graph.add_node())); + TF_CHECK_OK(NodeDefBuilder("segment_ids", "Reshape") + .Input("segment_ids_", 0, DT_FLOAT) + .Input("concat", 0, DT_INT32) + .Attr("T", DT_FLOAT) + .Attr("Tshape", DT_INT32) + .Finalize(item.graph.add_node())); + // Shape function of UnsortedSegmentSum applies MergePrefix to data and + // segment_ids (the latter being prefix). data shape is [10,10,10,10] and + // segment_ids shape is [-1, 10], but MergePrefix and symbolic shape inference + // assign 10 from data shape to the unknown dim in segment_ids. + TF_CHECK_OK(NodeDefBuilder("y", "UnsortedSegmentSum") + .Input("data", 0, DT_FLOAT) + .Input("segment_ids", 0, DT_INT32) + .Input("num_segments", 0, DT_INT32) + .Attr("T", DT_FLOAT) + .Attr("Tindices", DT_INT32) + .Attr("Tnumsegments", DT_INT32) + .Finalize(item.graph.add_node())); + // Note that y2=Reshape(x1) using the same shape vector as segment_ids, but + // y2 shape shouldn't be affected by symbolic shape inference w/ segment_ids. + TF_CHECK_OK(NodeDefBuilder("x1", "Placeholder") + .Attr("dtype", DT_FLOAT) + .Finalize(item.graph.add_node())); + TF_CHECK_OK(NodeDefBuilder("y1", "Reshape") + .Input("x1", 0, DT_FLOAT) + .Input("concat", 0, DT_INT32) + .Attr("T", DT_FLOAT) + .Attr("Tshape", DT_INT32) + .Finalize(item.graph.add_node())); + + GraphProperties properties(item); + TF_CHECK_OK(properties.InferStatically(true)); + const auto& y1_output_properties = properties.GetOutputProperties("y1"); + // y1=reshape(x1), but x1's shape in unknown, so y1 should be [-1, 10]. + // The first dimensino should not be 10. + EXPECT_EQ(y1_output_properties.size(), 1); + EXPECT_EQ(y1_output_properties[0].shape().dim_size(), 2); + EXPECT_LT(y1_output_properties[0].shape().dim(0).size(), 0); + EXPECT_EQ(y1_output_properties[0].shape().dim(1).size(), 10); +} + +TEST(HelperFunctions, IsShapeFullyDefinedIntegerVectorOrScalar) { + // Make a dummy InferenceContext. + NodeDef node_def; + OpRegistrationData op_reg_data; + OpDefBuilder b("dummy"); + CHECK(b.Finalize(&op_reg_data).ok()); + std::vector>> + input_handle_shapes_and_types; + InferenceContext ic(/*graph_def_version=*/0, node_def, op_reg_data.op_def, + /*input_shapes=*/{}, + /*input_tensors=*/{}, + /*input_tensors_as_shapes=*/{}, + std::move(input_handle_shapes_and_types)); + + // ShapeHandles for testing. + ShapeHandle fully_defined_vector = ic.MakeShape( + {ic.MakeDim(4), ic.MakeDim(5), ic.MakeDim(6), ic.MakeDim(7)}); + ShapeHandle vector_with_unknown = ic.MakeShape( + {ic.MakeDim(4), ic.MakeDim(5), ic.UnknownDim(), ic.MakeDim(7)}); + // INT64_MAX is used as unknown from Const. See kUnknownFromConst const in + // graph_properties.cc + ShapeHandle vector_with_unknown_from_const = ic.MakeShape( + {ic.MakeDim(4), ic.MakeDim(INT64_MAX), ic.MakeDim(6), ic.MakeDim(7)}); + ShapeHandle rank_1_vector = ic.MakeShape({ic.MakeDim(4)}); + + // Rank-1 shape and fully defined tensor_as_shape with INT32 or INT64. + EXPECT_TRUE(IsShapeFullyDefinedIntegerVectorOrScalar( + &ic, rank_1_vector, fully_defined_vector, DT_INT32)); + EXPECT_TRUE(IsShapeFullyDefinedIntegerVectorOrScalar( + &ic, rank_1_vector, fully_defined_vector, DT_INT64)); + + // Non-integer data type. + EXPECT_FALSE(IsShapeFullyDefinedIntegerVectorOrScalar( + &ic, rank_1_vector, fully_defined_vector, DT_FLOAT)); + + // tensor_as_shape including Unknown or UnknownFromConst. + EXPECT_FALSE(IsShapeFullyDefinedIntegerVectorOrScalar( + &ic, rank_1_vector, vector_with_unknown, DT_INT32)); + EXPECT_FALSE(IsShapeFullyDefinedIntegerVectorOrScalar( + &ic, rank_1_vector, vector_with_unknown_from_const, DT_INT32)); + EXPECT_FALSE(IsShapeFullyDefinedIntegerVectorOrScalar( + &ic, rank_1_vector, ic.UnknownShape(), DT_INT32)); + EXPECT_FALSE(IsShapeFullyDefinedIntegerVectorOrScalar( + &ic, ic.UnknownShape(), fully_defined_vector, DT_INT32)); + + // shape rank > 1. + EXPECT_FALSE(IsShapeFullyDefinedIntegerVectorOrScalar( + &ic, fully_defined_vector, vector_with_unknown_from_const, DT_INT32)); +} } // namespace } // namespace grappler } // namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/auto_mixed_precision_lists.h b/tensorflow/core/grappler/optimizers/auto_mixed_precision_lists.h index ac592003a94..7e7f487fa37 100644 --- a/tensorflow/core/grappler/optimizers/auto_mixed_precision_lists.h +++ b/tensorflow/core/grappler/optimizers/auto_mixed_precision_lists.h @@ -69,6 +69,7 @@ class AutoMixedPrecisionLists { "CudnnRNNBackpropV3", "CudnnRNNV2", "CudnnRNNV3", + "Einsum", "GRUBlockCell", "GRUBlockCellGrad", "LSTMBlockCell", diff --git a/tensorflow/core/grappler/optimizers/common_subgraph_elimination.cc b/tensorflow/core/grappler/optimizers/common_subgraph_elimination.cc index 8924e4c6bea..2b36296a273 100644 --- a/tensorflow/core/grappler/optimizers/common_subgraph_elimination.cc +++ b/tensorflow/core/grappler/optimizers/common_subgraph_elimination.cc @@ -251,6 +251,9 @@ Status CommonSubgraphElimination::DedupComputations(GraphDef* optimized_graph) { CanonicalizeNode(fanout); } } + if (fetch_nodes_known_) { + node->Clear(); + } duplicates.insert(i); stop = false; } diff --git a/tensorflow/core/grappler/optimizers/constant_folding_test.cc b/tensorflow/core/grappler/optimizers/constant_folding_test.cc index 569f6433edb..6c829bb353b 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding_test.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding_test.cc @@ -551,9 +551,9 @@ TEST_F(ConstantFoldingTest, ConstantPushDownBiasAdd) { TEST_F(ConstantFoldingTest, MulConvPushDownTest_Conv2D_ScalarConst) { for (string data_format : { "NHWC", -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM "NCHW" -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM }) { MulConvPushDownTest( /*input_shape=*/data_format == "NHWC" ? TensorShape{4, 10, 10, 3} @@ -569,9 +569,9 @@ TEST_F(ConstantFoldingTest, MulConvPushDownTest_Conv2D_ScalarConst) { TEST_F(ConstantFoldingTest, MulConvPushDownTest_Conv2D_SingletonConst) { for (string data_format : { "NHWC", -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM "NCHW" -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM }) { for (auto mul_const_input_shape : {TensorShape{1}, TensorShape{1, 1, 1, 1}}) { @@ -590,9 +590,9 @@ TEST_F(ConstantFoldingTest, MulConvPushDownTest_Conv2D_SingletonConst_ShapeMismatch) { for (string data_format : { "NHWC", -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM "NCHW" -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM }) { MulConvPushDownTest( /*input_shape=*/data_format == "NHWC" ? TensorShape{4, 10, 10, 3} @@ -608,9 +608,9 @@ TEST_F(ConstantFoldingTest, TEST_F(ConstantFoldingTest, MulConvPushDownTest_Conv2D_3x1x3Const) { for (auto data_format : { "NHWC", -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM "NCHW" -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM }) { MulConvPushDownTest( /*input_shape=*/{3, 3, 3, 3}, @@ -635,7 +635,7 @@ TEST_F(ConstantFoldingTest, MulConvPushDownTest_Conv2D_NHWC_VectorLikeConst) { } } -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM TEST_F(ConstantFoldingTest, MulConvPushDownTest_Conv2D_NCHW_VectorLikeConst) { for (auto mul_const_input_shape : {TensorShape{3}, TensorShape{3, 1, 1}, TensorShape{1, 3, 1, 1}}) { @@ -649,14 +649,14 @@ TEST_F(ConstantFoldingTest, MulConvPushDownTest_Conv2D_NCHW_VectorLikeConst) { /*expect_folded=*/false); } } -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM TEST_F(ConstantFoldingTest, MulConvPushDownTest_Conv2D_3x1Const) { for (auto data_format : { "NHWC", -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM "NCHW" -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM }) { MulConvPushDownTest( /*input_shape=*/{3, 3, 3, 3}, @@ -668,6 +668,9 @@ TEST_F(ConstantFoldingTest, MulConvPushDownTest_Conv2D_3x1Const) { } } +// This test fails on ROCm platform with two vaue miscompare +// TODO(rocm) : analysze and fix the cause of the failure and re-enable test +#ifndef TENSORFLOW_USE_ROCM TEST_F(ConstantFoldingTest, MulConvPushDownTest_Conv3D_NDHWC_1x1x3Const) { MulConvPushDownTest( /*input_shape=*/{3, 3, 3, 3, 3}, @@ -678,6 +681,7 @@ TEST_F(ConstantFoldingTest, MulConvPushDownTest_Conv3D_NDHWC_1x1x3Const) { /*data_format=*/"NDHWC", /*expect_folded=*/true); } +#endif TEST_F(ConstantFoldingTest, MulConvPushDownTest_Conv3D_NCDHW_3x1x1x1Const) { MulConvPushDownTest( diff --git a/tensorflow/core/grappler/utils.cc b/tensorflow/core/grappler/utils.cc index 9213b37ff29..c5c3fcd9665 100644 --- a/tensorflow/core/grappler/utils.cc +++ b/tensorflow/core/grappler/utils.cc @@ -74,6 +74,8 @@ bool IsShapeConsumer(const NodeDef& node) { NodeMap::NodeMap(GraphDef* graph) { CHECK(graph != nullptr); + nodes_.reserve(graph->node_size()); + outputs_.reserve(graph->node_size()); for (int i = 0; i < graph->node_size(); i++) { NodeDef* node = graph->mutable_node(i); const string& node_name = node->name(); diff --git a/tensorflow/core/kernels/batch_kernels.cc b/tensorflow/core/kernels/batch_kernels.cc index 0af40436490..c39abfac5fb 100644 --- a/tensorflow/core/kernels/batch_kernels.cc +++ b/tensorflow/core/kernels/batch_kernels.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/kernels/split_lib.h" #include "tensorflow/core/lib/gtl/cleanup.h" +#include "tensorflow/core/lib/monitoring/percentile_sampler.h" #include "tensorflow/core/lib/random/random.h" #include "tensorflow/core/platform/context.h" #include "tensorflow/core/platform/errors.h" @@ -34,6 +35,53 @@ limitations under the License. namespace tensorflow { +namespace { + +void RecordPaddingSize(int32 padding_size, const string& model_name) { + static tensorflow::monitoring::PercentileSamplerCell* cell = + tensorflow::monitoring::PercentileSampler<1>::New( + {"/tensorflow/serving/batching/padding_size", "model_name", + "Tracks the padding size distribution on batches by model_name (if " + "available)."}, + /*percentiles=*/{25.0, 50.0, 75.0, 90.0, 95.0, 99.0}, + /*max_samples=*/1024, tensorflow::monitoring::UnitOfMeasure::kNumber) + ->GetCell(model_name); + cell->Add(static_cast(padding_size)); +} + +void RecordInputBatchSize(int32 batch_size, const string& model_name) { + static tensorflow::monitoring::PercentileSamplerCell* cell = + tensorflow::monitoring::PercentileSampler<1>::New( + {"/tensorflow/serving/batching/input_batch_size", "model_name", + "Tracks the batch size distribution on the inputs by model_name (if " + "available)."}, + /*percentiles=*/{25.0, 50.0, 75.0, 90.0, 95.0, 99.0}, + /*max_samples=*/1024, tensorflow::monitoring::UnitOfMeasure::kNumber) + ->GetCell(model_name); + cell->Add(static_cast(batch_size)); +} + +void RecordBatchDelayMs(int64 batch_delay_ms, const string& model_name) { + static monitoring::PercentileSamplerCell* cell = + monitoring::PercentileSampler<1>::New( + {"/tensorflow/serving/batching/batch_delay_ms", "model_name", + "Tracks the batching delay for inputs by model_name (if " + "available)."}, + /*percentiles=*/{25.0, 50.0, 75.0, 90.0, 95.0, 99.0}, + /*max_samples=*/1024, monitoring::UnitOfMeasure::kTime) + ->GetCell(model_name); + cell->Add(static_cast(batch_delay_ms)); +} + +const string& GetModelName(OpKernelContext* ctx) { + static string* kModelNameUnset = new string("model_name_unset"); + if (!ctx->session_metadata()) return *kModelNameUnset; + if (ctx->session_metadata()->name().empty()) return *kModelNameUnset; + return ctx->session_metadata()->name(); +} + +} // namespace + typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; #ifdef TENSORFLOW_USE_SYCL @@ -246,6 +294,7 @@ class BatchResource : public ResourceBase { const string& batcher_queue_name, AsyncOpKernel::DoneCallback done_callback) { auto batch_components = MakeUnique(); + batch_components->start_time = EnvTime::NowNanos(); batch_components->guid = guid; batch_components->propagated_context = Context(ContextKind::kThread); OpInputList tensors; @@ -264,6 +313,7 @@ class BatchResource : public ResourceBase { } batch_components->inputs.push_back(tensor); } + RecordInputBatchSize(tensors[0].shape().dim_size(0), GetModelName(context)); OpInputList captured_tensors; const auto captured_status = context->input_list("captured_tensors", &captured_tensors); @@ -298,6 +348,8 @@ class BatchResource : public ResourceBase { AsyncOpKernel::DoneCallback done_callback; size_t size() const override { return inputs[0].shape().dim_size(0); } + + uint64 start_time; }; using Batcher = serving::SharedBatchScheduler; @@ -344,6 +396,7 @@ class BatchResource : public ResourceBase { const int padded_batch_size = RoundToLowestAllowedBatchSize(batch.size()); const int padding_amount = padded_batch_size - batch.size(); + RecordPaddingSize(padding_amount, GetModelName(context)); // All tasks should have the same number of input edges. const int num_inputs = batch.task(0).inputs.size(); @@ -526,6 +579,12 @@ class BatchResource : public ResourceBase { batch->task(batch->num_tasks() - 1).captured_inputs; args.insert(args.end(), captured_inputs.begin(), captured_inputs.end()); + uint64 current_time = EnvTime::NowNanos(); + const string& model_name = GetModelName(last_task_context); + for (int i = 0; i < batch->num_tasks(); ++i) { + RecordBatchDelayMs((current_time - batch->task(i).start_time) * 1e-6, + model_name); + } // Releases the cleanup method here, because the callback of the function // library runtime will handle it now. finally.release(); diff --git a/tensorflow/core/kernels/batching_util/adaptive_shared_batch_scheduler.h b/tensorflow/core/kernels/batching_util/adaptive_shared_batch_scheduler.h index 0aa848e3555..fedea93849c 100644 --- a/tensorflow/core/kernels/batching_util/adaptive_shared_batch_scheduler.h +++ b/tensorflow/core/kernels/batching_util/adaptive_shared_batch_scheduler.h @@ -32,6 +32,7 @@ limitations under the License. #include "tensorflow/core/platform/cpu_info.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/platform/threadpool_interface.h" #include "tensorflow/core/platform/types.h" namespace tensorflow { @@ -92,6 +93,10 @@ class AdaptiveSharedBatchScheduler // for num_batch_threads allows for large in_flight_batches_limit_, which // will harm latency for some time once load increases again. int64 num_batch_threads = port::MaxParallelism(); + // You can pass a ThreadPoolInterface directly rather than the above two + // parameters. If given, the above two parameers are ignored. Ownership of + // the threadpool is not transferred. + thread::ThreadPoolInterface* thread_pool = nullptr; // Lower bound for in_flight_batches_limit_. As discussed above, can be used // to minimize the damage caused by the random walk under low load. int64 min_in_flight_batches_limit = 1; @@ -356,8 +361,12 @@ AdaptiveSharedBatchScheduler::AdaptiveSharedBatchScheduler( rand_double_(0.0, 1.0) { std::random_device device; rand_engine_.seed(device()); - batch_thread_pool_.reset(new thread::ThreadPool( - GetEnv(), options.thread_pool_name, options.num_batch_threads)); + if (options.thread_pool == nullptr) { + batch_thread_pool_.reset(new thread::ThreadPool( + GetEnv(), options.thread_pool_name, options.num_batch_threads)); + } else { + batch_thread_pool_.reset(new thread::ThreadPool(options.thread_pool)); + } } template diff --git a/tensorflow/core/kernels/constant_op.cc b/tensorflow/core/kernels/constant_op.cc index 4f26aed641e..4c1fa5983f2 100644 --- a/tensorflow/core/kernels/constant_op.cc +++ b/tensorflow/core/kernels/constant_op.cc @@ -73,7 +73,7 @@ ConstantOp::ConstantOp(OpKernelConstruction* ctx) : OpKernel(ctx, StripTensorDataFromNodeDef(ctx), false), tensor_(ctx->output_type(0)) { const TensorProto* proto = nullptr; - MEMDEBUG_CACHE_OP(ctx->def().name().c_str()); + auto op_annotation = ScopedMemoryDebugAnnotation(name_view().data()); OP_REQUIRES_OK(ctx, ctx->GetAttr("value", &proto)); OP_REQUIRES_OK(ctx, ctx->device()->MakeTensorFromProto( *proto, AllocatorAttributes(), &tensor_)); diff --git a/tensorflow/core/kernels/control_flow_ops.cc b/tensorflow/core/kernels/control_flow_ops.cc index 1e26bddd7cb..c8e83b6f672 100644 --- a/tensorflow/core/kernels/control_flow_ops.cc +++ b/tensorflow/core/kernels/control_flow_ops.cc @@ -274,10 +274,14 @@ void MergeOp::Compute(OpKernelContext* context) { } else { context->set_output(0, context->input(i)); } - Tensor* value_index = nullptr; - OP_REQUIRES_OK( - context, context->allocate_output(1, TensorShape({}), &value_index)); - value_index->scalar()() = i; + // The value_index output is typically used only in gradient calculations, + // so we can avoid allocating in many inference workloads. + if (context->output_required(1)) { + Tensor* value_index = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(1, TensorShape({}), + &value_index)); + value_index->scalar()() = i; + } } } } diff --git a/tensorflow/core/kernels/conv_2d_gpu.h b/tensorflow/core/kernels/conv_2d_gpu.h index 22d7f939686..31abe9dfead 100644 --- a/tensorflow/core/kernels/conv_2d_gpu.h +++ b/tensorflow/core/kernels/conv_2d_gpu.h @@ -236,7 +236,7 @@ __global__ void SwapDimension1And2InTensor3UsingTiles( // One extra line in the inner dimension to avoid share memory bank conflict. // This is to mimic the following, but no constructor of T can be invoked. // __shared__ T shared_memory_tile[TileSizeI][TileSizeJ + 1]; -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_COMPILER_IS_HIP_CLANG __shared__ __align__( alignof(T)) char shared_mem_raw[TileSizeI * (TileSizeJ + 1) * sizeof(T)]; typedef T(*SharedMemoryTile)[TileSizeJ + 1]; diff --git a/tensorflow/core/kernels/conv_grad_input_ops.cc b/tensorflow/core/kernels/conv_grad_input_ops.cc index d479963556f..c92bd4b9582 100644 --- a/tensorflow/core/kernels/conv_grad_input_ops.cc +++ b/tensorflow/core/kernels/conv_grad_input_ops.cc @@ -450,14 +450,12 @@ class Conv2DBackpropInputOp : public OpKernel { const Tensor& input_sizes = context->input(0); const Tensor& filter = context->input(1); const Tensor& out_backprop = context->input(2); - OP_REQUIRES( - context, TensorShapeUtils::IsVector(input_sizes.shape()), - errors::InvalidArgument( - "Conv2DBackpropInput: input_sizes input must be 1-dim, not ", - input_sizes.dims())); + TensorShape input_shape; - OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape( - input_sizes.vec(), &input_shape)); + OP_REQUIRES_OK(context, + Conv2DBackpropComputeInputShape(input_sizes, filter.shape(), + out_backprop.shape(), + data_format_, &input_shape)); Tensor* in_backprop = nullptr; OP_REQUIRES_OK(context, @@ -549,14 +547,12 @@ class Conv2DCustomBackpropInputOp : public OpKernel { const Tensor& input_sizes = context->input(0); const Tensor& filter = context->input(1); const Tensor& out_backprop = context->input(2); - OP_REQUIRES( - context, TensorShapeUtils::IsVector(input_sizes.shape()), - errors::InvalidArgument( - "Conv2DBackpropInput: input_sizes input must be 1-dim, not ", - input_sizes.dims())); + TensorShape input_shape; - OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape( - input_sizes.vec(), &input_shape)); + OP_REQUIRES_OK(context, + Conv2DBackpropComputeInputShape(input_sizes, filter.shape(), + out_backprop.shape(), + data_format_, &input_shape)); ConvBackpropDimensions dims; OP_REQUIRES_OK(context, diff --git a/tensorflow/core/kernels/conv_grad_shape_utils.cc b/tensorflow/core/kernels/conv_grad_shape_utils.cc index 81c20ab0c7f..acb052968e1 100644 --- a/tensorflow/core/kernels/conv_grad_shape_utils.cc +++ b/tensorflow/core/kernels/conv_grad_shape_utils.cc @@ -166,4 +166,35 @@ Status ConvBackpropComputeDimensions(StringPiece label, int num_spatial_dims, dims); } +Status Conv2DBackpropComputeInputShape(const Tensor& input_sizes, + const TensorShape& filter_shape, + const TensorShape& out_backprop_shape, + const TensorFormat& data_format, + TensorShape* input_shape) { + if (!TensorShapeUtils::IsVector(input_sizes.shape())) { + return errors::InvalidArgument( + "Conv2DBackpropInput: input_sizes input must be 1-dim, not ", + input_sizes.dims()); + } + + if (input_sizes.dim_size(0) == 4) { + return TensorShapeUtils::MakeShape(input_sizes.vec(), input_shape); + } + + if (input_sizes.dim_size(0) == 2) { + const int batch_size = GetTensorDim(out_backprop_shape, data_format, 'N'); + const int output_height = input_sizes.vec()(0); + const int output_width = input_sizes.vec()(1); + const int output_depth = filter_shape.dim_size(2); + *input_shape = ShapeFromFormat(data_format, batch_size, output_height, + output_width, output_depth); + return Status::OK(); + } + + return errors::InvalidArgument( + "Conv2DBackpropInput requires input_sizes to " + "contain 4 values or 2 values, but got: ", + input_sizes.dim_size(0)); +} + } // namespace tensorflow diff --git a/tensorflow/core/kernels/conv_grad_shape_utils.h b/tensorflow/core/kernels/conv_grad_shape_utils.h index 3cc6a3077b3..90a76f4fc3e 100644 --- a/tensorflow/core/kernels/conv_grad_shape_utils.h +++ b/tensorflow/core/kernels/conv_grad_shape_utils.h @@ -83,6 +83,13 @@ Status ConvBackpropComputeDimensionsV2( const gtl::ArraySlice& dilations, const std::vector& strides, Padding padding, absl::Span explicit_paddings, TensorFormat data_format, ConvBackpropDimensions* dims); + +// Computes the shape of the in_backprop. +Status Conv2DBackpropComputeInputShape(const Tensor& input_sizes, + const TensorShape& filter_shape, + const TensorShape& out_backprop_shape, + const TensorFormat& data_format, + TensorShape* input_shape); } // namespace tensorflow #endif // TENSORFLOW_CORE_KERNELS_CONV_GRAD_SHAPE_UTILS_H_ diff --git a/tensorflow/core/kernels/data/dataset_test_base.cc b/tensorflow/core/kernels/data/dataset_test_base.cc index 1917d2a0db0..968bb509cb9 100644 --- a/tensorflow/core/kernels/data/dataset_test_base.cc +++ b/tensorflow/core/kernels/data/dataset_test_base.cc @@ -405,10 +405,11 @@ Status DatasetOpsTestBase::InitFunctionLibraryRuntime( TF_GRAPH_DEF_VERSION, lib_def_.get(), opts, thread_pool_.get(), /*parent=*/nullptr, /*custom_kernel_creator=*/nullptr, /*session_metadata=*/nullptr, - [](const int64, const DeviceMgr* device_mgr, Rendezvous** r) { - *r = new IntraProcessRendezvous(device_mgr); - return Status::OK(); - }); + Rendezvous::Factory{ + [](const int64, const DeviceMgr* device_mgr, Rendezvous** r) { + *r = new IntraProcessRendezvous(device_mgr); + return Status::OK(); + }}); flr_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:0"); if (thread_pool_ == nullptr) { runner_ = [](const std::function& fn) { fn(); }; @@ -548,8 +549,7 @@ Status DatasetOpsTestBase::AddDatasetInput( inputs->size(), " vs. ", input_types.size()); } bool is_ref = IsRefType(input_types[inputs->size()]); - std::unique_ptr input = - absl::make_unique(allocator_, dtype, shape); + auto input = absl::make_unique(allocator_, dtype, shape); if (is_ref) { DataType expected_dtype = RemoveRefType(input_types[inputs->size()]); diff --git a/tensorflow/core/kernels/data/experimental/BUILD b/tensorflow/core/kernels/data/experimental/BUILD index 82ae6c4d643..b1cb66da346 100644 --- a/tensorflow/core/kernels/data/experimental/BUILD +++ b/tensorflow/core/kernels/data/experimental/BUILD @@ -201,6 +201,7 @@ tf_kernel_library( tf_kernel_library( name = "lmdb_dataset_op", srcs = ["lmdb_dataset_op.cc"], + hdrs = ["lmdb_dataset_op.h"], deps = [ "//tensorflow/core:experimental_dataset_ops_op_lib", "//tensorflow/core:framework", @@ -210,6 +211,25 @@ tf_kernel_library( ], ) +tf_cc_test( + name = "lmdb_dataset_op_test", + size = "small", + srcs = ["lmdb_dataset_op_test.cc"], + data = ["//tensorflow/core:lmdb_testdata"], + deps = [ + ":lmdb_dataset_op", + "//tensorflow/core:experimental_dataset_ops_op_lib", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//tensorflow/core/kernels/data:dataset_test_base", + "//tensorflow/core/platform:env", + "//third_party/eigen3", + ], +) + tf_kernel_library( name = "map_and_batch_dataset_op", srcs = ["map_and_batch_dataset_op.cc"], diff --git a/tensorflow/core/kernels/data/experimental/lmdb_dataset_op.cc b/tensorflow/core/kernels/data/experimental/lmdb_dataset_op.cc index 1852bc51407..7cfa74e6516 100644 --- a/tensorflow/core/kernels/data/experimental/lmdb_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/lmdb_dataset_op.cc @@ -12,214 +12,220 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/core/kernels/data/experimental/lmdb_dataset_op.h" + #include +#include "lmdb.h" // NOLINT(build/include) #include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/lib/io/buffered_inputstream.h" #include "tensorflow/core/platform/file_system.h" -#include "lmdb.h" // NOLINT(build/include) - namespace tensorflow { namespace data { namespace experimental { -namespace { -class LMDBDatasetOp : public DatasetOpKernel { +/* static */ constexpr const char* const LMDBDatasetOp::kDatasetType; +/* static */ constexpr const char* const LMDBDatasetOp::kFileNames; +/* static */ constexpr const char* const LMDBDatasetOp::kOutputTypes; +/* static */ constexpr const char* const LMDBDatasetOp::kOutputShapes; + +class LMDBDatasetOp::Dataset : public DatasetBase { public: - using DatasetOpKernel::DatasetOpKernel; - void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override { - const Tensor* filenames_tensor; - OP_REQUIRES_OK(ctx, ctx->input("filenames", &filenames_tensor)); - OP_REQUIRES( - ctx, filenames_tensor->dims() <= 1, - errors::InvalidArgument("`filenames` must be a scalar or a vector.")); + Dataset(OpKernelContext* ctx, const std::vector& filenames) + : DatasetBase(DatasetContext(ctx)), filenames_(filenames) {} - std::vector filenames; - filenames.reserve(filenames_tensor->NumElements()); - for (int i = 0; i < filenames_tensor->NumElements(); ++i) { - filenames.push_back(filenames_tensor->flat()(i)); - } + std::unique_ptr MakeIteratorInternal( + const string& prefix) const override { + return absl::make_unique( + Iterator::Params{this, strings::StrCat(prefix, "::LMDB")}); + } - *output = new Dataset(ctx, filenames); + const DataTypeVector& output_dtypes() const override { + static DataTypeVector* dtypes = new DataTypeVector({DT_STRING, DT_STRING}); + return *dtypes; + } + + const std::vector& output_shapes() const override { + static std::vector* shapes = + new std::vector({{}, {}}); + return *shapes; + } + + string DebugString() const override { return "LMDBDatasetOp::Dataset"; } + + Status CheckExternalState() const override { return Status::OK(); } + + protected: + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { + Node* filenames = nullptr; + TF_RETURN_IF_ERROR(b->AddVector(filenames_, &filenames)); + TF_RETURN_IF_ERROR(b->AddDataset(this, {filenames}, output)); + return Status::OK(); } private: - class Dataset : public DatasetBase { + class Iterator : public DatasetIterator { public: - Dataset(OpKernelContext* ctx, const std::vector& filenames) - : DatasetBase(DatasetContext(ctx)), filenames_(filenames) {} + explicit Iterator(const Params& params) + : DatasetIterator(params) {} - std::unique_ptr MakeIteratorInternal( - const string& prefix) const override { - return absl::make_unique( - Iterator::Params{this, strings::StrCat(prefix, "::LMDB")}); + ~Iterator() override { + // Close any open database connections. + ResetStreamsLocked(); } - const DataTypeVector& output_dtypes() const override { - static DataTypeVector* dtypes = - new DataTypeVector({DT_STRING, DT_STRING}); - return *dtypes; + Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { + mutex_lock l(mu_); + do { + if (mdb_cursor_) { + out_tensors->emplace_back(ctx->allocator({}), DT_STRING, + TensorShape({})); + Tensor& key_tensor = out_tensors->back(); + key_tensor.scalar()() = string( + static_cast(mdb_key_.mv_data), mdb_key_.mv_size); + + out_tensors->emplace_back(ctx->allocator({}), DT_STRING, + TensorShape({})); + Tensor& value_tensor = out_tensors->back(); + value_tensor.scalar()() = string( + static_cast(mdb_value_.mv_data), mdb_value_.mv_size); + + int val; + val = mdb_cursor_get(mdb_cursor_, &mdb_key_, &mdb_value_, MDB_NEXT); + if (val != MDB_SUCCESS && val != MDB_NOTFOUND) { + return errors::InvalidArgument(mdb_strerror(val)); + } + if (val == MDB_NOTFOUND) { + ResetStreamsLocked(); + ++current_file_index_; + } + *end_of_sequence = false; + return Status::OK(); + } + if (current_file_index_ == dataset()->filenames_.size()) { + *end_of_sequence = true; + ResetStreamsLocked(); + return Status::OK(); + } + + TF_RETURN_IF_ERROR(SetupStreamsLocked(ctx->env())); + } while (true); } - const std::vector& output_shapes() const override { - static std::vector* shapes = - new std::vector({{}, {}}); - return *shapes; - } - - string DebugString() const override { return "LMDBDatasetOp::Dataset"; } - - Status CheckExternalState() const override { return Status::OK(); } - protected: - Status AsGraphDefInternal(SerializationContext* ctx, - DatasetGraphDefBuilder* b, - Node** output) const override { - Node* filenames = nullptr; - TF_RETURN_IF_ERROR(b->AddVector(filenames_, &filenames)); - TF_RETURN_IF_ERROR(b->AddDataset(this, {filenames}, output)); - return Status::OK(); + std::shared_ptr CreateNode( + IteratorContext* ctx, model::Node::Args args) const override { + return model::MakeSourceNode(std::move(args)); + } + + Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { + return errors::Unimplemented( + "Checkpointing is currently not supported for LMDBDataset."); + } + + Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { + return errors::Unimplemented( + "Checkpointing is currently not supported for LMDBDataset."); } private: - class Iterator : public DatasetIterator { - public: - explicit Iterator(const Params& params) - : DatasetIterator(params) {} - - Status GetNextInternal(IteratorContext* ctx, - std::vector* out_tensors, - bool* end_of_sequence) override { - mutex_lock l(mu_); - do { - if (mdb_cursor_) { - out_tensors->emplace_back(ctx->allocator({}), DT_STRING, - TensorShape({})); - Tensor& key_tensor = out_tensors->back(); - key_tensor.scalar()() = string( - static_cast(mdb_key_.mv_data), mdb_key_.mv_size); - - out_tensors->emplace_back(ctx->allocator({}), DT_STRING, - TensorShape({})); - Tensor& value_tensor = out_tensors->back(); - value_tensor.scalar()() = - string(static_cast(mdb_value_.mv_data), - mdb_value_.mv_size); - - int val; - val = mdb_cursor_get(mdb_cursor_, &mdb_key_, &mdb_value_, MDB_NEXT); - if (val != MDB_SUCCESS && val != MDB_NOTFOUND) { - return errors::InvalidArgument(mdb_strerror(val)); - } - if (val == MDB_NOTFOUND) { - ResetStreamsLocked(); - ++current_file_index_; - } - *end_of_sequence = false; - return Status::OK(); - } - if (current_file_index_ == dataset()->filenames_.size()) { - *end_of_sequence = true; - return Status::OK(); - } - - TF_RETURN_IF_ERROR(SetupStreamsLocked(ctx->env())); - } while (true); + Status SetupStreamsLocked(Env* env) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (current_file_index_ >= dataset()->filenames_.size()) { + return errors::InvalidArgument( + "current_file_index_:", current_file_index_, + " >= filenames_.size():", dataset()->filenames_.size()); } + const string& filename = dataset()->filenames_[current_file_index_]; - protected: - std::shared_ptr CreateNode( - IteratorContext* ctx, model::Node::Args args) const override { - return model::MakeSourceNode(std::move(args)); + int val = mdb_env_create(&mdb_env_); + if (val != MDB_SUCCESS) { + return errors::InvalidArgument(mdb_strerror(val)); } + int flags = MDB_RDONLY | MDB_NOTLS | MDB_NOLOCK; - Status SaveInternal(SerializationContext* ctx, - IteratorStateWriter* writer) override { - return errors::Unimplemented( - "Checkpointing is currently not supported for LMDBDataset."); + struct stat source_stat; + if (stat(filename.c_str(), &source_stat) == 0 && + (source_stat.st_mode & S_IFREG)) { + flags |= MDB_NOSUBDIR; } - - Status RestoreInternal(IteratorContext* ctx, - IteratorStateReader* reader) override { - return errors::Unimplemented( - "Checkpointing is currently not supported for LMDBDataset."); + val = mdb_env_open(mdb_env_, filename.c_str(), flags, 0664); + if (val != MDB_SUCCESS) { + return errors::InvalidArgument(mdb_strerror(val)); } - - private: - Status SetupStreamsLocked(Env* env) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { - if (current_file_index_ >= dataset()->filenames_.size()) { - return errors::InvalidArgument( - "current_file_index_:", current_file_index_, - " >= filenames_.size():", dataset()->filenames_.size()); - } - const string& filename = dataset()->filenames_[current_file_index_]; - - int val = mdb_env_create(&mdb_env_); - if (val != MDB_SUCCESS) { - return errors::InvalidArgument(mdb_strerror(val)); - } - int flags = MDB_RDONLY | MDB_NOTLS | MDB_NOLOCK; - - struct stat source_stat; - if (stat(filename.c_str(), &source_stat) == 0 && - (source_stat.st_mode & S_IFREG)) { - flags |= MDB_NOSUBDIR; - } - val = mdb_env_open(mdb_env_, filename.c_str(), flags, 0664); - if (val != MDB_SUCCESS) { - return errors::InvalidArgument(mdb_strerror(val)); - } - val = mdb_txn_begin(mdb_env_, nullptr, MDB_RDONLY, &mdb_txn_); - if (val != MDB_SUCCESS) { - return errors::InvalidArgument(mdb_strerror(val)); - } - val = mdb_dbi_open(mdb_txn_, nullptr, 0, &mdb_dbi_); - if (val != MDB_SUCCESS) { - return errors::InvalidArgument(mdb_strerror(val)); - } - val = mdb_cursor_open(mdb_txn_, mdb_dbi_, &mdb_cursor_); - if (val != MDB_SUCCESS) { - return errors::InvalidArgument(mdb_strerror(val)); - } - val = mdb_cursor_get(mdb_cursor_, &mdb_key_, &mdb_value_, MDB_FIRST); - if (val != MDB_SUCCESS && val != MDB_NOTFOUND) { - return errors::InvalidArgument(mdb_strerror(val)); - } - if (val == MDB_NOTFOUND) { - ResetStreamsLocked(); - } - return Status::OK(); + val = mdb_txn_begin(mdb_env_, nullptr, MDB_RDONLY, &mdb_txn_); + if (val != MDB_SUCCESS) { + return errors::InvalidArgument(mdb_strerror(val)); } - void ResetStreamsLocked() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { - if (mdb_env_ != nullptr) { - if (mdb_cursor_) { - mdb_cursor_close(mdb_cursor_); - mdb_cursor_ = nullptr; - } - mdb_dbi_close(mdb_env_, mdb_dbi_); - mdb_txn_abort(mdb_txn_); - mdb_env_close(mdb_env_); - mdb_txn_ = nullptr; - mdb_dbi_ = 0; - mdb_env_ = nullptr; - } + val = mdb_dbi_open(mdb_txn_, nullptr, 0, &mdb_dbi_); + if (val != MDB_SUCCESS) { + return errors::InvalidArgument(mdb_strerror(val)); } - mutex mu_; - size_t current_file_index_ TF_GUARDED_BY(mu_) = 0; - MDB_env* mdb_env_ TF_GUARDED_BY(mu_) = nullptr; - MDB_txn* mdb_txn_ TF_GUARDED_BY(mu_) = nullptr; - MDB_dbi mdb_dbi_ TF_GUARDED_BY(mu_) = 0; - MDB_cursor* mdb_cursor_ TF_GUARDED_BY(mu_) = nullptr; + val = mdb_cursor_open(mdb_txn_, mdb_dbi_, &mdb_cursor_); + if (val != MDB_SUCCESS) { + return errors::InvalidArgument(mdb_strerror(val)); + } + val = mdb_cursor_get(mdb_cursor_, &mdb_key_, &mdb_value_, MDB_FIRST); + if (val != MDB_SUCCESS && val != MDB_NOTFOUND) { + return errors::InvalidArgument(mdb_strerror(val)); + } + if (val == MDB_NOTFOUND) { + ResetStreamsLocked(); + } + return Status::OK(); + } + void ResetStreamsLocked() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (mdb_env_ != nullptr) { + if (mdb_cursor_) { + mdb_cursor_close(mdb_cursor_); + mdb_cursor_ = nullptr; + } + mdb_dbi_close(mdb_env_, mdb_dbi_); + mdb_txn_abort(mdb_txn_); + mdb_env_close(mdb_env_); + mdb_txn_ = nullptr; + mdb_dbi_ = 0; + mdb_env_ = nullptr; + } + } + mutex mu_; + size_t current_file_index_ TF_GUARDED_BY(mu_) = 0; + MDB_env* mdb_env_ TF_GUARDED_BY(mu_) = nullptr; + MDB_txn* mdb_txn_ TF_GUARDED_BY(mu_) = nullptr; + MDB_dbi mdb_dbi_ TF_GUARDED_BY(mu_) = 0; + MDB_cursor* mdb_cursor_ TF_GUARDED_BY(mu_) = nullptr; - MDB_val mdb_key_ TF_GUARDED_BY(mu_); - MDB_val mdb_value_ TF_GUARDED_BY(mu_); - }; - - const std::vector filenames_; + MDB_val mdb_key_ TF_GUARDED_BY(mu_); + MDB_val mdb_value_ TF_GUARDED_BY(mu_); }; + + const std::vector filenames_; }; +void LMDBDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase** output) { + const Tensor* filenames_tensor; + OP_REQUIRES_OK(ctx, ctx->input("filenames", &filenames_tensor)); + OP_REQUIRES( + ctx, filenames_tensor->dims() <= 1, + errors::InvalidArgument("`filenames` must be a scalar or a vector.")); + + std::vector filenames; + filenames.reserve(filenames_tensor->NumElements()); + for (int i = 0; i < filenames_tensor->NumElements(); ++i) { + filenames.push_back(filenames_tensor->flat()(i)); + } + + *output = new Dataset(ctx, filenames); +} + +namespace { + REGISTER_KERNEL_BUILDER(Name("LMDBDataset").Device(DEVICE_CPU), LMDBDatasetOp); REGISTER_KERNEL_BUILDER(Name("ExperimentalLMDBDataset").Device(DEVICE_CPU), LMDBDatasetOp); diff --git a/tensorflow/core/kernels/data/experimental/lmdb_dataset_op.h b/tensorflow/core/kernels/data/experimental/lmdb_dataset_op.h new file mode 100644 index 00000000000..f58473a7a86 --- /dev/null +++ b/tensorflow/core/kernels/data/experimental/lmdb_dataset_op.h @@ -0,0 +1,44 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_LMDB_DATASET_OP_H_ +#define TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_LMDB_DATASET_OP_H_ + +#include "tensorflow/core/framework/dataset.h" + +namespace tensorflow { +namespace data { +namespace experimental { + +class LMDBDatasetOp : public DatasetOpKernel { + public: + static constexpr const char* const kDatasetType = "LMDB"; + static constexpr const char* const kFileNames = "filenames"; + static constexpr const char* const kOutputTypes = "output_types"; + static constexpr const char* const kOutputShapes = "output_shapes"; + + using DatasetOpKernel::DatasetOpKernel; + + protected: + void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override; + + private: + class Dataset; +}; + +} // namespace experimental +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_LMDB_DATASET_OP_H_ diff --git a/tensorflow/core/kernels/data/experimental/lmdb_dataset_op_test.cc b/tensorflow/core/kernels/data/experimental/lmdb_dataset_op_test.cc new file mode 100644 index 00000000000..80705229d2c --- /dev/null +++ b/tensorflow/core/kernels/data/experimental/lmdb_dataset_op_test.cc @@ -0,0 +1,278 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/kernels/data/experimental/lmdb_dataset_op.h" + +#include "tensorflow/core/kernels/data/dataset_test_base.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace data { +namespace experimental { +namespace { + +constexpr char kNodeName[] = "lmdb_dataset"; +constexpr char kIteratorPrefix[] = "Iterator"; +constexpr char kDataFileName[] = "data.mdb"; +constexpr char kDataFileLoc[] = "core/lib/lmdb/testdata"; + +class LMDBDatasetParams : public DatasetParams { + public: + LMDBDatasetParams(std::vector filenames, + DataTypeVector output_dtypes, + std::vector output_shapes, + string node_name) + : DatasetParams(std::move(output_dtypes), std::move(output_shapes), + kNodeName), + filenames_(CreateTensor( + TensorShape({static_cast(filenames.size())}), filenames)) {} + + std::vector GetInputTensors() const override { return {filenames_}; } + + Status GetInputNames(std::vector* input_names) const override { + *input_names = {LMDBDatasetOp::kFileNames}; + return Status::OK(); + } + + Status GetAttributes(AttributeVector* attributes) const override { + *attributes = {{LMDBDatasetOp::kOutputTypes, output_dtypes_}, + {LMDBDatasetOp::kOutputShapes, output_shapes_}}; + return Status::OK(); + } + + string dataset_type() const override { return LMDBDatasetOp::kDatasetType; } + + private: + // Names of binary database files to read, boxed up inside a Tensor of + // strings + Tensor filenames_; +}; + +class LMDBDatasetOpTest : public DatasetOpsTestBase {}; + +// Copy our test data file to the current test program's temporary +// directory, and return the full path to the copied file. +// This copying is necessary because LMDB creates lock files adjacent +// to files that it reads. +tstring MaybeCopyDataFile() { + tstring src_loc = + io::JoinPath(testing::TensorFlowSrcRoot(), kDataFileLoc, kDataFileName); + tstring dest_loc = io::JoinPath(testing::TmpDir(), kDataFileName); + + FileSystem* fs; // Pointer to singleton + TF_EXPECT_OK(Env::Default()->GetFileSystemForFile(src_loc, &fs)); + + // FileSystem::FileExists currently returns Status::OK() if the file + // exists and errors::NotFound() if the file doesn't exist. There's no + // indication in the code or docs about whether other error codes may be + // added in the future, so we code defensively here. + Status exists_status = fs->FileExists(dest_loc); + if (exists_status.code() == error::NOT_FOUND) { + TF_EXPECT_OK(fs->CopyFile(src_loc, dest_loc)); + } else { + TF_EXPECT_OK(exists_status); + } + + return dest_loc; +} + +LMDBDatasetParams SingleValidInput() { + return {/*filenames=*/{MaybeCopyDataFile()}, + /*output_dtypes=*/{DT_STRING, DT_STRING}, + /*output_shapes=*/{PartialTensorShape({}), PartialTensorShape({})}, + /*node_name=*/kNodeName}; +} + +LMDBDatasetParams TwoValidInputs() { + return {/*filenames*/ {MaybeCopyDataFile(), MaybeCopyDataFile()}, + /*output_dtypes*/ {DT_STRING, DT_STRING}, + /*output_shapes*/ {PartialTensorShape({}), PartialTensorShape({})}, + /*node_name=*/kNodeName}; +} + +LMDBDatasetParams EmptyInput() { + return {/*filenames*/ {}, + /*output_dtypes*/ {DT_STRING, DT_STRING}, + /*output_shapes*/ {PartialTensorShape({}), PartialTensorShape({})}, + /*node_name=*/kNodeName}; +} + +LMDBDatasetParams InvalidPathAtStart() { + return {/*filenames*/ {"This is not a valid filename", MaybeCopyDataFile()}, + /*output_dtypes*/ {DT_STRING, DT_STRING}, + /*output_shapes*/ {PartialTensorShape({}), PartialTensorShape({})}, + /*node_name=*/kNodeName}; +} + +LMDBDatasetParams InvalidPathInMiddle() { + return {/*filenames*/ {MaybeCopyDataFile(), "This is not a valid filename", + MaybeCopyDataFile()}, + /*output_dtypes*/ {DT_STRING, DT_STRING}, + /*output_shapes*/ {PartialTensorShape({}), PartialTensorShape({})}, + /*node_name=*/kNodeName}; +} + +// The tensors we expect to see each time we read through the input data file. +std::vector> GetNextTestCases() { + const std::vector kFileOutput = CreateTensors( + TensorShape({}), + { + // Each call to GetNext() produces two scalar string tensors, but the + // test harness expects to receive a flat vector + {"0"}, {"a"}, // + {"1"}, {"b"}, // + {"2"}, {"c"}, // + {"3"}, {"d"}, // + {"4"}, {"e"}, // + {"5"}, {"f"}, // + {"6"}, {"g"}, // + {"7"}, {"h"}, // + {"8"}, {"i"}, // + {"9"}, {"j"}, // + }); + + // STL vectors don't have a "concatenate two vectors into a new vector" + // operation, so... + std::vector output_twice; + output_twice.insert(output_twice.end(), kFileOutput.cbegin(), + kFileOutput.cend()); + output_twice.insert(output_twice.end(), kFileOutput.cbegin(), + kFileOutput.cend()); + + return { + {/*dataset_params=*/SingleValidInput(), /*expected_outputs=*/kFileOutput}, + {/*dataset_params=*/TwoValidInputs(), /*expected_outputs=*/output_twice}, + {/*dataset_params=*/EmptyInput(), /*expected_outputs=*/{}}}; +} + +ITERATOR_GET_NEXT_TEST_P(LMDBDatasetOpTest, LMDBDatasetParams, + GetNextTestCases()); + +TEST_F(LMDBDatasetOpTest, InvalidPathAtStart) { + auto dataset_params = InvalidPathAtStart(); + TF_ASSERT_OK(Initialize(dataset_params)); + + // Errors about invalid files are only raised when attempting to read data. + bool end_of_sequence = false; + std::vector out_tensors; + std::vector next; + + Status get_next_status = + iterator_->GetNext(iterator_ctx_.get(), &next, &end_of_sequence); + + EXPECT_EQ(get_next_status.code(), error::INVALID_ARGUMENT); +} + +TEST_F(LMDBDatasetOpTest, InvalidPathInMiddle) { + auto dataset_params = InvalidPathInMiddle(); + TF_ASSERT_OK(Initialize(dataset_params)); + + bool end_of_sequence = false; + std::vector out_tensors; + std::vector next; + + // First 10 rows should be ok + for (int i = 0; i < 10; ++i) { + TF_ASSERT_OK( + iterator_->GetNext(iterator_ctx_.get(), &next, &end_of_sequence)); + EXPECT_FALSE(end_of_sequence); + } + + // Next read operation should raise an error + Status get_next_status = + iterator_->GetNext(iterator_ctx_.get(), &next, &end_of_sequence); + EXPECT_EQ(get_next_status.code(), error::INVALID_ARGUMENT); +} + +std::vector> +DatasetNodeNameTestCases() { + return {{/*dataset_params=*/SingleValidInput(), + /*expected_node_name=*/kNodeName}}; +} + +DATASET_NODE_NAME_TEST_P(LMDBDatasetOpTest, LMDBDatasetParams, + DatasetNodeNameTestCases()); + +std::vector> +DatasetTypeStringTestCases() { + return {{/*dataset_params=*/SingleValidInput(), + /*expected_dataset_type_string=*/name_utils::OpName( + LMDBDatasetOp::kDatasetType)}}; +} + +DATASET_TYPE_STRING_TEST_P(LMDBDatasetOpTest, LMDBDatasetParams, + DatasetTypeStringTestCases()); + +std::vector> +DatasetOutputDtypesTestCases() { + return {{/*dataset_params=*/SingleValidInput(), + /*expected_output_dtypes=*/{DT_STRING, DT_STRING}}}; +} + +DATASET_OUTPUT_DTYPES_TEST_P(LMDBDatasetOpTest, LMDBDatasetParams, + DatasetOutputDtypesTestCases()); + +std::vector> +DatasetOutputShapesTestCases() { + return {{/*dataset_params=*/SingleValidInput(), + /*expected_output_shapes=*/{PartialTensorShape({}), + PartialTensorShape({})}}}; +} + +DATASET_OUTPUT_SHAPES_TEST_P(LMDBDatasetOpTest, LMDBDatasetParams, + DatasetOutputShapesTestCases()); + +std::vector> CardinalityTestCases() { + return {{/*dataset_params=*/SingleValidInput(), + /*expected_cardinality=*/kUnknownCardinality}}; +} + +DATASET_CARDINALITY_TEST_P(LMDBDatasetOpTest, LMDBDatasetParams, + CardinalityTestCases()); + +std::vector> +IteratorOutputDtypesTestCases() { + return {{/*dataset_params=*/SingleValidInput(), + /*expected_output_dtypes=*/{DT_STRING, DT_STRING}}}; +} + +ITERATOR_OUTPUT_DTYPES_TEST_P(LMDBDatasetOpTest, LMDBDatasetParams, + IteratorOutputDtypesTestCases()); + +std::vector> +IteratorOutputShapesTestCases() { + return {{/*dataset_params=*/SingleValidInput(), + /*expected_output_shapes=*/{PartialTensorShape({}), + PartialTensorShape({})}}}; +} + +ITERATOR_OUTPUT_SHAPES_TEST_P(LMDBDatasetOpTest, LMDBDatasetParams, + IteratorOutputShapesTestCases()); + +std::vector> +IteratorOutputPrefixTestCases() { + return {{/*dataset_params=*/SingleValidInput(), + /*expected_iterator_prefix=*/name_utils::IteratorPrefix( + LMDBDatasetOp::kDatasetType, kIteratorPrefix)}}; +} + +ITERATOR_PREFIX_TEST_P(LMDBDatasetOpTest, LMDBDatasetParams, + IteratorOutputPrefixTestCases()); + +// No test of save and restore; save/restore functionality is not implemented +// for this dataset. + +} // namespace +} // namespace experimental +} // namespace data +} // namespace tensorflow diff --git a/tensorflow/core/kernels/data/shuffle_dataset_op.cc b/tensorflow/core/kernels/data/shuffle_dataset_op.cc index 26f68431203..e0312b5fe08 100644 --- a/tensorflow/core/kernels/data/shuffle_dataset_op.cc +++ b/tensorflow/core/kernels/data/shuffle_dataset_op.cc @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/core/lib/random/philox_random.h" #include "tensorflow/core/lib/random/random.h" #include "tensorflow/core/lib/random/random_distributions.h" +#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/stringprintf.h" namespace tensorflow { @@ -761,8 +762,16 @@ void ShuffleDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input, int64 count = 1; if (op_version_ == 2) { SeedGenerator* seed_generator = nullptr; - OP_REQUIRES_OK( - ctx, LookupResource(ctx, HandleFromInput(ctx, 2), &seed_generator)); + Status s = LookupResource(ctx, HandleFromInput(ctx, 2), &seed_generator); + if (errors::IsNotFound(s)) { + LOG(WARNING) << "Failed to find seed generator resource. Falling back to " + "using a non-deterministically-seeded seed generator."; + *output = + new ShuffleDatasetOp::Dataset(ctx, input, buffer_size, Seeds(0, 0), + count, reshuffle_each_iteration_); + return; + } + OP_REQUIRES_OK(ctx, s); // Create a fresh handle for the resource because the input handle can // become invalid after this op executes. diff --git a/tensorflow/core/kernels/function_ops.cc b/tensorflow/core/kernels/function_ops.cc index d6c157b18e5..70ef069837e 100644 --- a/tensorflow/core/kernels/function_ops.cc +++ b/tensorflow/core/kernels/function_ops.cc @@ -50,7 +50,7 @@ void ArgOp::Compute(OpKernelContext* ctx) { errors::InvalidArgument("Type mismatch: actual ", DataTypeString(val.dtype()), " vs. expect ", DataTypeString(dtype_))); - ctx->set_output(0, val); + ctx->set_output(0, std::move(val)); } RetvalOp::RetvalOp(OpKernelConstruction* ctx) : OpKernel(ctx) { @@ -279,7 +279,7 @@ class SymbolicGradientOp : public AsyncOpKernel { " tensor(s), but get ", rets->size(), " tensor(s) instead.")); } else { for (size_t i = 0; i < rets->size(); ++i) { - ctx->set_output(i, (*rets)[i]); + ctx->set_output(i, std::move((*rets)[i])); } } delete rets; @@ -413,7 +413,7 @@ void RemoteCallOp::ComputeAsync(OpKernelContext* ctx, DoneCallback done) { ctx->SetStatus(status); } else { for (size_t i = 0; i < rets->size(); ++i) { - ctx->set_output(i, (*rets)[i]); + ctx->set_output(i, std::move((*rets)[i])); } } delete rets; diff --git a/tensorflow/core/kernels/functional_ops.cc b/tensorflow/core/kernels/functional_ops.cc index eb6b5cdce3a..c8cebd0ff4d 100644 --- a/tensorflow/core/kernels/functional_ops.cc +++ b/tensorflow/core/kernels/functional_ops.cc @@ -123,22 +123,10 @@ class IfOp : public AsyncOpKernel { ~IfOp() override {} void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { - auto lib = ctx->function_library(); - OP_REQUIRES_ASYNC(ctx, lib != nullptr, - errors::Internal("No function library"), done); - - // TODO(b/37549631): Because this op has `SetIsStateful()` in its op - // registration, this kernel may be shared by multiple subgraphs, which have - // different associated `FunctionLibraryRuntime` objects and hence different - // `FHandle` namespaces. So we must call Instantiate() to make sure we get - // the correct function handles with respect to `lib`. Note the underlying - // `lib->Instantiate()` caches the created function handles, so calling - // `Instantiate()` repeatedly on the same `lib` and function is cheap. FHandle then_handle; FHandle else_handle; - OP_REQUIRES_OK_ASYNC(ctx, Instantiate(lib, then_func_, &then_handle), done); - OP_REQUIRES_OK_ASYNC(ctx, Instantiate(lib, else_func_, &else_handle), done); - + OP_REQUIRES_OK_ASYNC(ctx, GetHandles(ctx, &then_handle, &else_handle), + done); bool cond; OP_REQUIRES_OK(ctx, ToBool({ctx->input(0)}, &cond)); (new State(this, ctx, cond, then_handle, else_handle, done))->Start(); @@ -148,6 +136,10 @@ class IfOp : public AsyncOpKernel { NameAttrList then_func_; NameAttrList else_func_; + mutex mu_; + std::unordered_map> + handles_ GUARDED_BY(mu_); + class State { public: State(IfOp* kernel, OpKernelContext* ctx, bool cond, FHandle then_handle, @@ -203,6 +195,42 @@ class IfOp : public AsyncOpKernel { TensorVec args_; TensorVec rets_; }; + + Status GetHandles(OpKernelContext* ctx, FHandle* then_handle, + FHandle* else_handle) { + // TODO(b/37549631): Because this op has `SetIsStateful()` in its + // op registration, this kernel may be shared by multiple + // subgraphs, which have different associated + // `FunctionLibraryRuntime` objects and hence different `FHandle` + // namespaces. We currently work around this by caching the map + // from `FunctionLibraryRuntime*` to `FHandle` pairs for the two + // functions this op uses. + auto lib = ctx->function_library(); + if (lib == nullptr) return errors::Internal("No function library"); + *then_handle = kInvalidHandle; + *else_handle = kInvalidHandle; + { + tf_shared_lock l(mu_); + const auto iter = handles_.find(lib); + if (TF_PREDICT_TRUE(iter != handles_.end())) { + *then_handle = iter->second.first; + *else_handle = iter->second.second; + } + } + if (TF_PREDICT_FALSE(*then_handle == kInvalidHandle)) { + mutex_lock l(mu_); + const auto iter = handles_.find(lib); + if (TF_PREDICT_TRUE(iter != handles_.end())) { + *then_handle = iter->second.first; + *else_handle = iter->second.second; + } else { + TF_RETURN_IF_ERROR(Instantiate(lib, then_func_, then_handle)); + TF_RETURN_IF_ERROR(Instantiate(lib, else_func_, else_handle)); + handles_[lib] = {*then_handle, *else_handle}; + } + } + return Status::OK(); + } }; class CaseOp : public AsyncOpKernel { @@ -332,18 +360,10 @@ class WhileOp : public AsyncOpKernel { auto lib = ctx->function_library(); OP_REQUIRES_ASYNC(ctx, lib != nullptr, errors::Internal("No function library"), done); - - // TODO(b/37549631): Because this op has `SetIsStateful()` in its op - // registration, this kernel may be shared by multiple subgraphs, which have - // different associated `FunctionLibraryRuntime` objects and hence different - // `FHandle` namespaces. So we must call Instantiate() to make sure we get - // the correct function handles with respect to `lib`. Note the underlying - // `lib->Instantiate()` caches the created function handles, so calling - // `Instantiate()` repeatedly on the same `lib` and function is cheap. FHandle cond_handle; FHandle body_handle; - OP_REQUIRES_OK_ASYNC(ctx, Instantiate(lib, cond_func_, &cond_handle), done); - OP_REQUIRES_OK_ASYNC(ctx, Instantiate(lib, body_func_, &body_handle), done); + OP_REQUIRES_OK_ASYNC(ctx, GetHandles(ctx, &cond_handle, &body_handle), + done); (new State(this, ctx, cond_handle, body_handle, done))->Start(); } @@ -351,6 +371,10 @@ class WhileOp : public AsyncOpKernel { NameAttrList cond_func_; NameAttrList body_func_; + mutex mu_; + std::unordered_map> + handles_ GUARDED_BY(mu_); + class State { public: State(WhileOp* kernel, OpKernelContext* ctx, FHandle cond_handle, @@ -486,6 +510,42 @@ class WhileOp : public AsyncOpKernel { delete this; } }; + + Status GetHandles(OpKernelContext* ctx, FHandle* cond_handle, + FHandle* body_handle) { + // TODO(b/37549631): Because this op has `SetIsStateful()` in its + // op registration, this kernel may be shared by multiple + // subgraphs, which have different associated + // `FunctionLibraryRuntime` objects and hence different `FHandle` + // namespaces. We currently work around this by caching the map + // from `FunctionLibraryRuntime*` to `FHandle` pairs for the two + // functions this op uses. + auto lib = ctx->function_library(); + if (lib == nullptr) return errors::Internal("No function library"); + *cond_handle = kInvalidHandle; + *body_handle = kInvalidHandle; + { + tf_shared_lock l(mu_); + const auto iter = handles_.find(lib); + if (TF_PREDICT_TRUE(iter != handles_.end())) { + *cond_handle = iter->second.first; + *body_handle = iter->second.second; + } + } + if (TF_PREDICT_FALSE(*cond_handle == kInvalidHandle)) { + mutex_lock l(mu_); + const auto iter = handles_.find(lib); + if (TF_PREDICT_TRUE(iter != handles_.end())) { + *cond_handle = iter->second.first; + *body_handle = iter->second.second; + } else { + TF_RETURN_IF_ERROR(Instantiate(lib, cond_func_, cond_handle)); + TF_RETURN_IF_ERROR(Instantiate(lib, body_func_, body_handle)); + handles_[lib] = {*cond_handle, *body_handle}; + } + } + return Status::OK(); + } }; // TODO(drpng): remove these. REGISTER_KERNEL_BUILDER(Name("_While").Device(DEVICE_CPU), WhileOp); diff --git a/tensorflow/core/kernels/in_topk_op_gpu.cu.cc b/tensorflow/core/kernels/in_topk_op_gpu.cu.cc index 44c96f67b26..f701071cb8e 100644 --- a/tensorflow/core/kernels/in_topk_op_gpu.cu.cc +++ b/tensorflow/core/kernels/in_topk_op_gpu.cu.cc @@ -100,6 +100,12 @@ struct InTopKFunctor { errors::InvalidArgument( "Number of targets * number of classes must be less than INT_MAX")); + if (num_targets == 0 || num_classes == 0) { + // Result is empty, so shortcut the rest of the function to avoid + // launching kernels with empty input. + return; + } + // Temporary storage for a mask computed by `ComputePredictionMaskKernel`. Tensor predictions_mask; OP_REQUIRES_OK( diff --git a/tensorflow/core/kernels/ops_testutil.cc b/tensorflow/core/kernels/ops_testutil.cc index 614e184b0b2..87f70d3a3b3 100644 --- a/tensorflow/core/kernels/ops_testutil.cc +++ b/tensorflow/core/kernels/ops_testutil.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/core/framework/node_properties.h" #ifdef GOOGLE_CUDA #define EIGEN_USE_GPU #include "tensorflow/core/common_runtime/gpu/gpu_managed_allocator.h" @@ -137,11 +138,16 @@ Status OpsTestBase::InitOp() { } Status OpsTestBase::InitOpWithGraphVersion(int graph_def_version) { - Status status; - kernel_ = CreateOpKernel(device_type_, device_, allocator(), node_def_, - graph_def_version, &status); - if (kernel_ != nullptr) input_types_ = kernel_->input_types(); - return status; + std::shared_ptr props; + TF_RETURN_IF_ERROR(NodeProperties::CreateFromNodeDef( + node_def_, OpRegistry::Global(), &props)); + OpKernel* kernel; + TF_RETURN_IF_ERROR(CreateOpKernel( + device_type_, device_, allocator(), /*flib=*/nullptr, + device_->resource_manager(), props, graph_def_version, &kernel)); + kernel_.reset(kernel); + input_types_ = kernel_->input_types(); + return Status::OK(); } Status OpsTestBase::RunOpKernel() { diff --git a/tensorflow/core/kernels/sparse/csr_sparse_matrix_to_dense_op.cc b/tensorflow/core/kernels/sparse/csr_sparse_matrix_to_dense_op.cc index 9e5a11c4aeb..364c2c07bd8 100644 --- a/tensorflow/core/kernels/sparse/csr_sparse_matrix_to_dense_op.cc +++ b/tensorflow/core/kernels/sparse/csr_sparse_matrix_to_dense_op.cc @@ -224,10 +224,8 @@ REGISTER_CPU(complex128) REGISTER_GPU(float) REGISTER_GPU(double) -#if GOOGLE_CUDA REGISTER_GPU(complex64) REGISTER_GPU(complex128) -#endif #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM diff --git a/tensorflow/core/kernels/sparse/dense_to_csr_sparse_matrix_op.cc b/tensorflow/core/kernels/sparse/dense_to_csr_sparse_matrix_op.cc index fdcff6876c3..459bb219343 100644 --- a/tensorflow/core/kernels/sparse/dense_to_csr_sparse_matrix_op.cc +++ b/tensorflow/core/kernels/sparse/dense_to_csr_sparse_matrix_op.cc @@ -362,10 +362,8 @@ class DenseToCSRSparseMatrixGPUOp : public AsyncOpKernel { REGISTER_GPU(GPU, float) REGISTER_GPU(GPU, double) -#if GOOGLE_CUDA REGISTER_GPU(GPU, complex64) REGISTER_GPU(GPU, complex128) -#endif namespace functor { diff --git a/tensorflow/core/kernels/sparse/mat_mul_op.cc b/tensorflow/core/kernels/sparse/mat_mul_op.cc index a57d97b7a73..1a9186b7e4b 100644 --- a/tensorflow/core/kernels/sparse/mat_mul_op.cc +++ b/tensorflow/core/kernels/sparse/mat_mul_op.cc @@ -538,8 +538,13 @@ class CSRMatMulGPUOp : public CSRMatMulOp { OP_REQUIRES_OK(ctx, ctx->allocate_output(0, c_shape, &c_t)); const GPUDevice& d = ctx->eigen_device(); - - if (b_outer_dim == 1) { + bool use_matrix_vector_multiply = (b_outer_dim == 1); +#if TENSORFLOW_USE_ROCM + // ROCm hipsparse does not implement csrmv with transposed input a + use_matrix_vector_multiply = + use_matrix_vector_multiply && !this->transpose_a_; +#endif + if (use_matrix_vector_multiply) { // Call matrix-vector multiply if b is a vector. TTypes::ConstVec a_dense_shape_comp(a_dense_shape.data() + row_dim, 2); diff --git a/tensorflow/core/kernels/sparse/mul_op.cc b/tensorflow/core/kernels/sparse/mul_op.cc index f6cf369626c..33c3756ce58 100644 --- a/tensorflow/core/kernels/sparse/mul_op.cc +++ b/tensorflow/core/kernels/sparse/mul_op.cc @@ -107,10 +107,8 @@ class CSRMulOp : public OpKernel { REGISTER_GPU(float) REGISTER_GPU(double) -#if GOOGLE_CUDA REGISTER_GPU(complex64) REGISTER_GPU(complex128) -#endif #undef REGISTER_GPU diff --git a/tensorflow/core/kernels/sparse/sparse_matrix_components_op.cc b/tensorflow/core/kernels/sparse/sparse_matrix_components_op.cc index 9cbe88bde6c..59540f63846 100644 --- a/tensorflow/core/kernels/sparse/sparse_matrix_components_op.cc +++ b/tensorflow/core/kernels/sparse/sparse_matrix_components_op.cc @@ -120,10 +120,8 @@ REGISTER(CPU, complex128) REGISTER(GPU, float) REGISTER(GPU, double) -#if GOOGLE_CUDA REGISTER(GPU, complex64) REGISTER(GPU, complex128) -#endif #undef REGISTER @@ -141,10 +139,8 @@ namespace functor { DECLARE_GPU_SPEC(int32); DECLARE_GPU_SPEC(float); DECLARE_GPU_SPEC(double); -#if GOOGLE_CUDA DECLARE_GPU_SPEC(complex64); DECLARE_GPU_SPEC(complex128); -#endif #undef DECLARE_GPU_SPEC } // namespace functor diff --git a/tensorflow/core/kernels/sparse/sparse_tensor_to_csr_sparse_matrix_op.cc b/tensorflow/core/kernels/sparse/sparse_tensor_to_csr_sparse_matrix_op.cc index 893909ef2fa..e1a4b4194d2 100644 --- a/tensorflow/core/kernels/sparse/sparse_tensor_to_csr_sparse_matrix_op.cc +++ b/tensorflow/core/kernels/sparse/sparse_tensor_to_csr_sparse_matrix_op.cc @@ -328,10 +328,8 @@ extern template struct COOSparseMatrixToCSRSparseMatrix; REGISTER_GPU(float) REGISTER_GPU(double) -#if GOOGLE_CUDA REGISTER_GPU(complex64) REGISTER_GPU(complex128) -#endif #undef REGISTER_GPU diff --git a/tensorflow/core/kernels/unpack_op.cc b/tensorflow/core/kernels/unpack_op.cc index 4123b4b8225..ce02aa17225 100644 --- a/tensorflow/core/kernels/unpack_op.cc +++ b/tensorflow/core/kernels/unpack_op.cc @@ -107,6 +107,8 @@ class UnpackOp : public OpKernel { input.shaped({before_dim, axis_dim * after_dim}); for (int i = 0; i < num; ++i) { + if (!context->output_required(i)) continue; + Tensor* output; OP_REQUIRES_OK(context, context->allocate_output(i, output_shape, &output)); diff --git a/tensorflow/core/kernels/variable_ops.cc b/tensorflow/core/kernels/variable_ops.cc index 374be1ce4ec..42b262b70eb 100644 --- a/tensorflow/core/kernels/variable_ops.cc +++ b/tensorflow/core/kernels/variable_ops.cc @@ -50,15 +50,11 @@ class LegacyVar : public ResourceBase { VariableOp::VariableOp(OpKernelConstruction* context) : OpKernel(context) { OP_REQUIRES_OK(context, context->GetAttr("shape", &shape_)); dtype_ = RemoveRefType(context->output_type(0)); + OP_REQUIRES_OK(context, cinfo_.Init(context->resource_manager(), def(), + true /* use name() */)); } void VariableOp::Compute(OpKernelContext* ctx) { - mutex_lock l(init_mu_); - if (!initialized_) { - OP_REQUIRES_OK(ctx, cinfo_.Init(ctx->resource_manager(), def(), - true /* use name() */)); - initialized_ = true; - } auto creator = [this](LegacyVar** var) { *var = new LegacyVar(dtype_); (*var)->tensor()->set_shape(shape_); diff --git a/tensorflow/core/kernels/variable_ops.h b/tensorflow/core/kernels/variable_ops.h index 8a7578770fa..51252221f06 100644 --- a/tensorflow/core/kernels/variable_ops.h +++ b/tensorflow/core/kernels/variable_ops.h @@ -36,10 +36,7 @@ class VariableOp : public OpKernel { private: DataType dtype_; TensorShape shape_; - - mutex init_mu_; - ContainerInfo cinfo_ TF_GUARDED_BY(init_mu_); - bool initialized_ TF_GUARDED_BY(init_mu_){false}; + ContainerInfo cinfo_; TF_DISALLOW_COPY_AND_ASSIGN(VariableOp); }; diff --git a/tensorflow/core/lib/bfloat16/bfloat16.h b/tensorflow/core/lib/bfloat16/bfloat16.h index 89850ed7ed1..e3bbff69cb5 100644 --- a/tensorflow/core/lib/bfloat16/bfloat16.h +++ b/tensorflow/core/lib/bfloat16/bfloat16.h @@ -23,7 +23,7 @@ limitations under the License. #include "tensorflow/core/platform/byte_order.h" -#ifdef __CUDACC__ +#if defined(__CUDACC__) || (defined(__HIPCC__) && defined(__HIP__)) // All functions callable from CUDA code must be qualified with __device__ #define B16_DEVICE_FUNC __host__ __device__ diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc index e2fca5189b8..2e55bc6cd95 100644 --- a/tensorflow/core/ops/nn_ops.cc +++ b/tensorflow/core/ops/nn_ops.cc @@ -357,13 +357,7 @@ REGISTER_OP("Conv2DBackpropInput") .Attr(GetExplicitPaddingsAttrString()) .Attr(GetConvnetDataFormatAttrString()) .Attr("dilations: list(int) = [1, 1, 1, 1]") - .SetShapeFn([](InferenceContext* c) { - ShapeHandle s; - TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s)); - TF_RETURN_IF_ERROR(c->WithRank(s, 4, &s)); - c->set_output(0, s); - return Status::OK(); - }); + .SetShapeFn(shape_inference::Conv2DBackpropInputShape); // TODO(jeff): Instead of 'use_cudnn_for_gpu', maybe we should have a // more general string attribute ('kernel_impl'?) that can be used to diff --git a/tensorflow/core/ops/nn_ops_test.cc b/tensorflow/core/ops/nn_ops_test.cc index b53f7624d96..469a9015a17 100644 --- a/tensorflow/core/ops/nn_ops_test.cc +++ b/tensorflow/core/ops/nn_ops_test.cc @@ -320,6 +320,25 @@ TEST(NNOpsTest, FusedBatchNormGrad_ShapeFn) { "[d0_3|d2_0|d3_0|d4_0];[d0_3|d2_0|d3_0|d4_0];[0];[0]"); } +TEST(NNOpsTest, Conv2DBackpropInput_ShapeFn) { + ShapeInferenceTestOp op("Conv2DBackpropInput"); + + // Test rank error. + INFER_ERROR("input_sizes to contain 4 values or 2 values", op, + "[3];[?,?,?,?];[?,?,?,?]"); + INFER_ERROR("Shape must be rank 4 but is rank 3", op, + "[4];[?,?,?,?];[?,?,?]"); + + // When input_sizes is a 4D shape and the convolution is grouped, the channel + // size of the input grad doesn't always equal the input channel size of the + // filter. So, when input_sizes is a 4D shape, the channel size of the input + // grad is determined by the content of input_sizes. + INFER_OK(op, "[4];[?,?,2,?];[1,?,?,?]", "[d2_0,?,?,?]"); + // When input_sizes is a 2D shape, the channel size of the input grad always + // matches the filter shape. + INFER_OK(op, "[2];[?,?,2,?];[1,?,?,?]", "[d2_0,?,?,d1_2]"); +} + TEST(NNOpsTest, Conv3DBackpropInput_ShapeFn) { ShapeInferenceTestOp op("Conv3DBackpropInput"); diff --git a/tensorflow/core/platform/cloud/gcs_file_system.cc b/tensorflow/core/platform/cloud/gcs_file_system.cc index bc23cf14b03..dc58c2c5513 100644 --- a/tensorflow/core/platform/cloud/gcs_file_system.cc +++ b/tensorflow/core/platform/cloud/gcs_file_system.cc @@ -308,7 +308,8 @@ class BufferedGcsRandomAccessFile : public RandomAccessFile { : filename_(filename), read_fn_(std::move(read_fn)), buffer_size_(buffer_size), - buffer_start_(0) {} + buffer_start_(0), + buffer_end_is_past_eof_(false) {} Status Name(StringPiece* result) const override { *result = filename_; @@ -332,9 +333,9 @@ class BufferedGcsRandomAccessFile : public RandomAccessFile { memcpy(scratch, buffer_.data() + (offset - buffer_start_), copy_size); *result = StringPiece(scratch, copy_size); } - if (copy_size < n) { - // Try reading from the file regardless of previous read status. - // The file might have grown since the last read. + bool consumed_buffer_to_eof = + offset + copy_size >= buffer_end && buffer_end_is_past_eof_; + if (copy_size < n && !consumed_buffer_to_eof) { Status status = FillBuffer(offset + copy_size); if (!status.ok() && status.code() != errors::Code::OUT_OF_RANGE) { // Empty the buffer to avoid caching bad reads. @@ -347,9 +348,11 @@ class BufferedGcsRandomAccessFile : public RandomAccessFile { *result = StringPiece(scratch, copy_size); } if (copy_size < n) { + // Forget the end-of-file flag to allow for clients that poll on the + // same file. + buffer_end_is_past_eof_ = false; return errors::OutOfRange("EOF reached. Requested to read ", n, - " bytes from ", offset, " but only got ", - copy_size, " bytes."); + " bytes from ", offset, "."); } } return Status::OK(); @@ -363,6 +366,7 @@ class BufferedGcsRandomAccessFile : public RandomAccessFile { StringPiece str_piece; Status status = read_fn_(filename_, buffer_start_, buffer_size_, &str_piece, &(buffer_[0])); + buffer_end_is_past_eof_ = status.code() == errors::Code::OUT_OF_RANGE; buffer_.resize(str_piece.size()); return status; } @@ -383,6 +387,8 @@ class BufferedGcsRandomAccessFile : public RandomAccessFile { // Offset of buffer from start of the file. mutable uint64 buffer_start_ TF_GUARDED_BY(buffer_mutex_); + mutable bool buffer_end_is_past_eof_ TF_GUARDED_BY(buffer_mutex_); + mutable string buffer_ TF_GUARDED_BY(buffer_mutex_); }; diff --git a/tensorflow/core/platform/cloud/gcs_file_system_test.cc b/tensorflow/core/platform/cloud/gcs_file_system_test.cc index 21cee5d5ebd..802f18a31ae 100644 --- a/tensorflow/core/platform/cloud/gcs_file_system_test.cc +++ b/tensorflow/core/platform/cloud/gcs_file_system_test.cc @@ -247,6 +247,93 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_Buffered_ReadAtEOF) { EXPECT_EQ("", result); } +TEST(GcsFileSystemTest, NewRandomAccessFile_Buffered_CachedOutOfRange) { + // In this test, there is only one backend request since we cache the file + // size. + std::vector requests({new FakeHttpRequest( + "Uri: https://storage.googleapis.com/bucket/random_access.txt\n" + "Auth Token: fake_token\n" + "Range: 0-9\n" + "Timeouts: 5 1 20\n", + "012345678")}); + GcsFileSystem fs( + std::unique_ptr(new FakeAuthProvider), + std::unique_ptr( + new FakeHttpRequestFactory(&requests)), + std::unique_ptr(new FakeZoneProvider), 10 /* block size */, + 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */, + 0 /* stat cache max entries */, 0 /* matching paths cache max age */, + 0 /* matching paths cache max entries */, kTestRetryConfig, + kTestTimeoutConfig, *kAllowedLocationsDefault, + nullptr /* gcs additional header */); + + std::unique_ptr file; + TF_EXPECT_OK(fs.NewRandomAccessFile("gs://bucket/random_access.txt", &file)); + + StringPiece filename; + TF_EXPECT_OK(file->Name(&filename)); + EXPECT_EQ(filename, "gs://bucket/random_access.txt"); + + char scratch[5]; + StringPiece result; + + // Read the first chunk. Even though the backend response is out-of-range, + // we should get a OK status since we're just reading the first 5 bytes. + TF_EXPECT_OK(file->Read(0, sizeof(scratch), &result, scratch)); + EXPECT_EQ("01234", result); + + TF_EXPECT_OK(file->Read(4, sizeof(scratch), &result, scratch)); + EXPECT_EQ("45678", result); + + // Return the cached error once the user starts reading out of range. + EXPECT_EQ(errors::Code::OUT_OF_RANGE, + file->Read(5, sizeof(scratch), &result, scratch).code()); + EXPECT_EQ("5678", result); +} + +TEST(GcsFileSystemTest, NewRandomAccessFile_Buffered_CachedNotSequential) { + // In this test, the second read is seeking backwards, so it should trigger + // a backend request. + std::vector requests( + {new FakeHttpRequest( + "Uri: https://storage.googleapis.com/bucket/random_access.txt\n" + "Auth Token: fake_token\n" + "Range: 1-10\n" + "Timeouts: 5 1 20\n", + "12345678"), + new FakeHttpRequest( + "Uri: https://storage.googleapis.com/bucket/random_access.txt\n" + "Auth Token: fake_token\n" + "Range: 0-9\n" + "Timeouts: 5 1 20\n", + "012345678")}); + GcsFileSystem fs( + std::unique_ptr(new FakeAuthProvider), + std::unique_ptr( + new FakeHttpRequestFactory(&requests)), + std::unique_ptr(new FakeZoneProvider), 10 /* block size */, + 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */, + 0 /* stat cache max entries */, 0 /* matching paths cache max age */, + 0 /* matching paths cache max entries */, kTestRetryConfig, + kTestTimeoutConfig, *kAllowedLocationsDefault, + nullptr /* gcs additional header */); + + std::unique_ptr file; + TF_EXPECT_OK(fs.NewRandomAccessFile("gs://bucket/random_access.txt", &file)); + + StringPiece filename; + TF_EXPECT_OK(file->Name(&filename)); + EXPECT_EQ(filename, "gs://bucket/random_access.txt"); + + char scratch[5]; + StringPiece result; + + TF_EXPECT_OK(file->Read(1, sizeof(scratch), &result, scratch)); + EXPECT_EQ("12345", result); + TF_EXPECT_OK(file->Read(0, sizeof(scratch), &result, scratch)); + EXPECT_EQ("01234", result); +} + TEST(GcsFileSystemTest, NewRandomAccessFile_Buffered_Growing) { std::vector requests( {new FakeHttpRequest( @@ -282,7 +369,9 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_Buffered_Growing) { char scratch[10]; StringPiece result; - // Read the first chunk. + // Read the first chunk. Since the first read is out-of-range, + // we don't cache the out-of-range flag and each subsequent read triggers a + // backend call. EXPECT_EQ(errors::Code::OUT_OF_RANGE, file->Read(0, sizeof(scratch), &result, scratch).code()); EXPECT_EQ("012345678", result); diff --git a/tensorflow/core/platform/file_system.cc b/tensorflow/core/platform/file_system.cc index 7d8c6aefc33..9e96ceedbdc 100644 --- a/tensorflow/core/platform/file_system.cc +++ b/tensorflow/core/platform/file_system.cc @@ -16,9 +16,6 @@ limitations under the License. #include "tensorflow/core/platform/file_system.h" #include -#if defined(IS_MOBILE_PLATFORM) -#include -#endif #include #include @@ -26,12 +23,15 @@ limitations under the License. #include #include +#if defined(PLATFORM_POSIX) || defined(IS_MOBILE_PLATFORM) +#include +#else +#include "tensorflow/core/platform/regexp.h" +#endif // defined(PLATFORM_POSIX) || defined(IS_MOBILE_PLATFORM) + #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/platform.h" -#if !defined(IS_MOBILE_PLATFORM) -#include "tensorflow/core/platform/regexp.h" -#endif #include "tensorflow/core/platform/scanner.h" #include "tensorflow/core/platform/str_util.h" #include "tensorflow/core/platform/strcat.h" @@ -39,16 +39,20 @@ limitations under the License. namespace tensorflow { bool FileSystem::Match(const string& filename, const string& pattern) { -#if defined(IS_MOBILE_PLATFORM) +#if defined(PLATFORM_POSIX) || defined(IS_MOBILE_PLATFORM) // We avoid relying on RE2 on mobile platforms, because it incurs a // significant binary size increase. + // For POSIX platforms, there is no need to depend on RE2 if `fnmatch` can be + // used safely. return fnmatch(pattern.c_str(), filename.c_str(), FNM_PATHNAME) == 0; #else string regexp(pattern); - RE2::GlobalReplace(®exp, "\\*", "[^/]*"); - RE2::GlobalReplace(®exp, "\\?", "."); + regexp = str_util::StringReplace(regexp, "*", "[^/]*", true); + regexp = str_util::StringReplace(regexp, "?", ".", true); + regexp = str_util::StringReplace(regexp, "(", "\\(", true); + regexp = str_util::StringReplace(regexp, ")", "\\)", true); return RE2::FullMatch(filename, regexp); -#endif +#endif // defined(PLATFORM_POSIX) || defined(IS_MOBILE_PLATFORM) } string FileSystem::TranslateName(const string& name) const { diff --git a/tensorflow/core/platform/hadoop/hadoop_file_system.cc b/tensorflow/core/platform/hadoop/hadoop_file_system.cc index 610e233ef79..1e1062c88c0 100644 --- a/tensorflow/core/platform/hadoop/hadoop_file_system.cc +++ b/tensorflow/core/platform/hadoop/hadoop_file_system.cc @@ -306,9 +306,24 @@ class HDFSWritableFile : public WritableFile { } Status Append(StringPiece data) override { - if (libhdfs()->hdfsWrite(fs_, file_, data.data(), - static_cast(data.size())) == -1) { - return IOError(filename_, errno); + size_t cur_pos = 0, write_len = 0; + bool retry = false; + // max() - 2 can avoid OutOfMemoryError in JVM . + static const size_t max_len_once = + static_cast(std::numeric_limits::max() - 2); + while (cur_pos < data.size()) { + write_len = std::min(data.size() - cur_pos, max_len_once); + tSize w = libhdfs()->hdfsWrite(fs_, file_, data.data() + cur_pos, + static_cast(write_len)); + if (w == -1) { + if (!retry && (errno == EINTR || errno == EAGAIN)) { + retry = true; + } else { + return IOError(filename_, errno); + } + } else { + cur_pos += w; + } } return Status::OK(); } diff --git a/tensorflow/core/platform/hadoop/hadoop_file_system_test.cc b/tensorflow/core/platform/hadoop/hadoop_file_system_test.cc index 71cf0542d3c..ae5d09c806b 100644 --- a/tensorflow/core/platform/hadoop/hadoop_file_system_test.cc +++ b/tensorflow/core/platform/hadoop/hadoop_file_system_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/core/platform/path.h" #include "tensorflow/core/platform/str_util.h" #include "tensorflow/core/platform/test.h" +#include "third_party/hadoop/hdfs.h" namespace tensorflow { namespace { @@ -273,6 +274,23 @@ TEST_F(HadoopFileSystemTest, HarRootPath) { EXPECT_EQ("har://hdfs-root/user/j.doe/my_archive.har", nn); EXPECT_EQ("/", path); } + +TEST_F(HadoopFileSystemTest, WriteBigFile) { + const string fname = TmpDir("BigFile"); + const size_t file_len = + static_cast(std::numeric_limits::max()) + 1024; + // Fake a test string . + char* p = new char[file_len]; + for (size_t i = 0; i < file_len; ++i) { + *(p + i) = (i % 128); + } + string file_write_content(p, file_len); + TF_ASSERT_OK(WriteString(fname, file_write_content)); + string file_read_content; + TF_EXPECT_OK(ReadAll(fname, &file_read_content)); + EXPECT_EQ(file_write_content, file_read_content); + delete p; +} // NewAppendableFile() is not testable. Local filesystem maps to // ChecksumFileSystem in Hadoop, where appending is an unsupported operation. diff --git a/tensorflow/core/profiler/BUILD b/tensorflow/core/profiler/BUILD index 88020753072..bfb9a893765 100644 --- a/tensorflow/core/profiler/BUILD +++ b/tensorflow/core/profiler/BUILD @@ -21,13 +21,6 @@ package_group( ], ) -tf_proto_library( - name = "op_profile_proto", - srcs = ["op_profile.proto"], - cc_api_version = 2, - visibility = [":internal"], -) - tf_proto_library( name = "profiler_service_monitor_result_proto", srcs = ["profiler_service_monitor_result.proto"], diff --git a/tensorflow/core/profiler/convert/BUILD b/tensorflow/core/profiler/convert/BUILD index 027d03ba152..bc8b937174d 100644 --- a/tensorflow/core/profiler/convert/BUILD +++ b/tensorflow/core/profiler/convert/BUILD @@ -37,6 +37,7 @@ tf_cc_test( "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core/profiler/protobuf:op_metrics_proto_cc", + "//tensorflow/core/profiler/utils:op_metrics_db_utils", "//tensorflow/core/profiler/utils:time_utils", "//tensorflow/core/profiler/utils:xplane_builder", "//tensorflow/core/profiler/utils:xplane_schema", @@ -158,7 +159,7 @@ cc_library( hdrs = ["trace_events_to_json.h"], deps = [ "//tensorflow/core:lib", - "//tensorflow/core:protos_all_cc", + "//tensorflow/core/profiler/protobuf:trace_events_proto_cc", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@jsoncpp_git//:jsoncpp", @@ -171,9 +172,9 @@ tf_cc_test( deps = [ ":trace_events_to_json", "//tensorflow/core:lib", - "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", + "//tensorflow/core/profiler/protobuf:trace_events_proto_cc", "@jsoncpp_git//:jsoncpp", ], ) @@ -304,7 +305,7 @@ cc_library( hdrs = ["xplane_to_trace_events.h"], deps = [ "//tensorflow/core:lib", - "//tensorflow/core:protos_all_cc", + "//tensorflow/core/profiler/protobuf:trace_events_proto_cc", "//tensorflow/core/profiler/protobuf:xplane_proto_cc", "//tensorflow/core/profiler/utils:tf_xplane_visitor", "//tensorflow/core/profiler/utils:xplane_schema", diff --git a/tensorflow/core/profiler/convert/op_metrics_to_record.cc b/tensorflow/core/profiler/convert/op_metrics_to_record.cc index 480040778c5..b51c679776b 100644 --- a/tensorflow/core/profiler/convert/op_metrics_to_record.cc +++ b/tensorflow/core/profiler/convert/op_metrics_to_record.cc @@ -28,7 +28,7 @@ std::vector SortedOpMetricsDb(const OpMetricsDb& metrics_db, std::vector result; result.reserve(metrics_db.metrics_db_size()); for (const OpMetrics& metrics : metrics_db.metrics_db()) { - if (metrics.occurrences() > 0) result.push_back(&metrics); + result.push_back(&metrics); } auto comp = [](const OpMetrics* a, const OpMetrics* b) { diff --git a/tensorflow/core/profiler/convert/op_stats_to_tf_stats.cc b/tensorflow/core/profiler/convert/op_stats_to_tf_stats.cc index 4ce14f54d47..23561169c4e 100644 --- a/tensorflow/core/profiler/convert/op_stats_to_tf_stats.cc +++ b/tensorflow/core/profiler/convert/op_stats_to_tf_stats.cc @@ -57,7 +57,7 @@ TfStatsTable GenerateTfStatsTable(const OpMetricsDb& host_tf_metrics_db, } double total_device_time_us = PicosToMicros(total_device_time_ps); for (const OpMetrics* metrics : SortedOpMetricsDb(device_tf_metrics_db)) { - if (exclude_idle && metrics->category() == "IDLE") continue; + if (exclude_idle && IsIdleOp(*metrics)) continue; TfStatsRecord* record = tf_stats_table.add_tf_stats_record(); *record = ConvertOpMetricsToTfStatsRecord( /*on_device=*/true, *metrics, ridge_point); @@ -73,7 +73,7 @@ TfStatsTable GenerateTfStatsTable(const OpMetricsDb& host_tf_metrics_db, double total_host_time_us = PicosToMicros(total_host_time_ps); for (const OpMetrics* metrics : tensorflow::profiler::SortedOpMetricsDb(host_tf_metrics_db)) { - if (exclude_idle && metrics->category() == "IDLE") continue; + if (exclude_idle && IsIdleOp(*metrics)) continue; TfStatsRecord* record = tf_stats_table.add_tf_stats_record(); *record = ConvertOpMetricsToTfStatsRecord( /*on_device=*/false, *metrics, ridge_point); diff --git a/tensorflow/core/profiler/convert/trace_events_to_json.cc b/tensorflow/core/profiler/convert/trace_events_to_json.cc index e545bc3384f..9c8176c10ad 100644 --- a/tensorflow/core/profiler/convert/trace_events_to_json.cc +++ b/tensorflow/core/profiler/convert/trace_events_to_json.cc @@ -18,7 +18,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "include/json/json.h" -#include "tensorflow/core/protobuf/trace_events.pb.h" +#include "tensorflow/core/profiler/protobuf/trace_events.pb.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/trace_events_to_json.h b/tensorflow/core/profiler/convert/trace_events_to_json.h index a71224cf5c4..16747fec737 100644 --- a/tensorflow/core/profiler/convert/trace_events_to_json.h +++ b/tensorflow/core/profiler/convert/trace_events_to_json.h @@ -17,7 +17,7 @@ limitations under the License. #define TENSORFLOW_CORE_PROFILER_CONVERT_TRACE_EVENTS_TO_JSON_H_ #include "tensorflow/core/platform/types.h" -#include "tensorflow/core/protobuf/trace_events.pb.h" +#include "tensorflow/core/profiler/protobuf/trace_events.pb.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/trace_events_to_json_test.cc b/tensorflow/core/profiler/convert/trace_events_to_json_test.cc index da8d57f6f35..dc985f2f76f 100644 --- a/tensorflow/core/profiler/convert/trace_events_to_json_test.cc +++ b/tensorflow/core/profiler/convert/trace_events_to_json_test.cc @@ -18,7 +18,7 @@ limitations under the License. #include "include/json/json.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/test.h" -#include "tensorflow/core/protobuf/trace_events.pb.h" +#include "tensorflow/core/profiler/protobuf/trace_events.pb.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/xplane_to_op_metrics_db_test.cc b/tensorflow/core/profiler/convert/xplane_to_op_metrics_db_test.cc index 3c8d5525370..3ddc5227038 100644 --- a/tensorflow/core/profiler/convert/xplane_to_op_metrics_db_test.cc +++ b/tensorflow/core/profiler/convert/xplane_to_op_metrics_db_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/core/platform/test.h" #include "tensorflow/core/profiler/protobuf/op_metrics.pb.h" +#include "tensorflow/core/profiler/utils/op_metrics_db_utils.h" #include "tensorflow/core/profiler/utils/time_utils.h" #include "tensorflow/core/profiler/utils/xplane_builder.h" #include "tensorflow/core/profiler/utils/xplane_schema.h" @@ -85,7 +86,7 @@ TEST(ConvertXPlaneToOpMetricsDb, HostOpMetricsDb) { EXPECT_EQ(NanosToPicos(kTfOp1DurationNs) * 2, op_1.time_ps()); const OpMetrics& idle = op_metrics.metrics_db().at(1); - EXPECT_EQ("IDLE", idle.name()); + EXPECT_EQ(kIdle, idle.name()); // Idle time is the gap between Op2 start and the end of Op1, which is 2000ns. EXPECT_EQ(NanosToPicos(2000), idle.time_ps()); @@ -149,7 +150,7 @@ TEST(ConvertXPlaneToOpMetricsDb, DeviceOpMetricsDb) { EXPECT_EQ(NanosToPicos(kTfOp2DurationNs), op_2.time_ps()); const OpMetrics& idle = op_metrics.metrics_db().at(2); - EXPECT_EQ("IDLE", idle.name()); + EXPECT_EQ(kIdle, idle.name()); // GPU is always busy in this example. EXPECT_EQ(NanosToPicos(0), idle.time_ps()); } diff --git a/tensorflow/core/profiler/convert/xplane_to_trace_events.cc b/tensorflow/core/profiler/convert/xplane_to_trace_events.cc index 728a75b250d..a40af395558 100644 --- a/tensorflow/core/profiler/convert/xplane_to_trace_events.cc +++ b/tensorflow/core/profiler/convert/xplane_to_trace_events.cc @@ -77,10 +77,10 @@ void ConvertXSpaceToTraceEvents(const XSpace& xspace, Trace* trace) { event->set_device_id(device_id); event->set_resource_id(resource_id); if (xevent.HasDisplayName()) { - event->set_name(string(xevent.DisplayName())); - args["long_name"] = string(xevent.Name()); + event->set_name(std::string(xevent.DisplayName())); + args["long_name"] = std::string(xevent.Name()); } else { - event->set_name(string(xevent.Name())); + event->set_name(std::string(xevent.Name())); } event->set_timestamp_ps(xevent.TimestampPs()); event->set_duration_ps(xevent.DurationPs()); @@ -88,7 +88,7 @@ void ConvertXSpaceToTraceEvents(const XSpace& xspace, Trace* trace) { xevent.ForEachStat([&](const XStatVisitor& stat) { if (stat.ValueCase() == XStat::VALUE_NOT_SET) return; if (IsInternalStat(stat.Type())) return; - args[string(stat.Name())] = stat.ToString(); + args[std::string(stat.Name())] = stat.ToString(); }); }); }); @@ -100,5 +100,12 @@ void ConvertXSpaceToTraceEvents(const XSpace& xspace, Trace* trace) { MaybeDropEventsForTraceViewer(trace, kMaxEvents); } +void ConvertXSpaceToTraceEventsString(const XSpace& xspace, + std::string* content) { + Trace trace; + ConvertXSpaceToTraceEvents(xspace, &trace); + trace.SerializeToString(content); +} + } // namespace profiler } // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/xplane_to_trace_events.h b/tensorflow/core/profiler/convert/xplane_to_trace_events.h index b8e5f0085f4..5c6fbead805 100644 --- a/tensorflow/core/profiler/convert/xplane_to_trace_events.h +++ b/tensorflow/core/profiler/convert/xplane_to_trace_events.h @@ -18,14 +18,17 @@ limitations under the License. #include "absl/strings/str_split.h" #include "tensorflow/core/platform/types.h" +#include "tensorflow/core/profiler/protobuf/trace_events.pb.h" #include "tensorflow/core/profiler/protobuf/xplane.pb.h" -#include "tensorflow/core/protobuf/trace_events.pb.h" namespace tensorflow { namespace profiler { void ConvertXSpaceToTraceEvents(const XSpace& xspace, Trace* trace); +void ConvertXSpaceToTraceEventsString(const XSpace& xspace, + std::string* content); + // Not Public API, Testing only. void MaybeDropEventsForTraceViewer(Trace* trace, uint32 limit); diff --git a/tensorflow/core/profiler/internal/gpu/BUILD b/tensorflow/core/profiler/internal/gpu/BUILD index 6fc78e46862..bfe855ef417 100644 --- a/tensorflow/core/profiler/internal/gpu/BUILD +++ b/tensorflow/core/profiler/internal/gpu/BUILD @@ -53,7 +53,10 @@ tf_cc_test_gpu( srcs = ["device_tracer_test.cc"], args = ["--heap_check=local"], linkstatic = tf_kernel_tests_linkstatic(), - tags = tf_cuda_tests_tags() + ["nomac"], + tags = tf_cuda_tests_tags() + [ + "nomac", + "gpu_cupti", + ], deps = [ ":device_tracer", "//tensorflow/cc:cc_ops", diff --git a/tensorflow/core/profiler/internal/gpu/cupti_tracer.cc b/tensorflow/core/profiler/internal/gpu/cupti_tracer.cc index 3ce6678de01..e001d831b65 100644 --- a/tensorflow/core/profiler/internal/gpu/cupti_tracer.cc +++ b/tensorflow/core/profiler/internal/gpu/cupti_tracer.cc @@ -1348,7 +1348,7 @@ absl::string_view AnnotationMap::LookUp(uint32 device_id, } bool CuptiTracer::IsAvailable() const { - return !activity_tracing_enabled_ && !api_tracing_enabled_; + return NumGpus() && !activity_tracing_enabled_ && !api_tracing_enabled_; } int CuptiTracer::NumGpus() { diff --git a/tensorflow/core/profiler/internal/profiler_interface.h b/tensorflow/core/profiler/internal/profiler_interface.h index 9e7819de03a..74aa56514b7 100644 --- a/tensorflow/core/profiler/internal/profiler_interface.h +++ b/tensorflow/core/profiler/internal/profiler_interface.h @@ -36,8 +36,22 @@ struct ProfilerOptions { // DeviceType::kTpu: only CPU/TPU will be profiled. DeviceType device_type = DeviceType::kUnspecified; - // Inexpensive ops are not traced by default. - int host_tracer_level = 2; + // Levels of host tracing: + // - Level 0 is used to disable host traces. + // - Level 1 enables tracing of only user instrumented (or default) TraceMe. + // - Level 2 enables tracing of all level 1 TraceMe(s) and instrumented high + // level program execution details (expensive TF ops, XLA ops, etc). + // This is the default. + // - Level 3 enables tracing of all level 2 TraceMe(s) and more verbose + // (low-level) program execution details (cheap TF ops, etc). + uint32 host_tracer_level = 2; + + // Levels of device tracing: + // - Level 0 is used to disable device traces. + // - Level 1 is used to enable device traces. + // - More levels might be defined for specific device for controlling the + // verbosity of the trace. + uint32 device_tracer_level = 1; // Whether to enable python function calls tracer. bool enable_python_tracer = false; diff --git a/tensorflow/core/profiler/lib/BUILD b/tensorflow/core/profiler/lib/BUILD index 77de2e03241..7ccbb81a281 100644 --- a/tensorflow/core/profiler/lib/BUILD +++ b/tensorflow/core/profiler/lib/BUILD @@ -47,7 +47,6 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:framework", - "//tensorflow/core:protos_all_cc", "//tensorflow/core/platform", "//tensorflow/core/profiler/internal:profiler_interface", "//tensorflow/core/profiler/protobuf:xplane_proto_cc", @@ -55,7 +54,6 @@ cc_library( ] + if_not_android([ ":profiler_utils", "//tensorflow/core/profiler/internal:profiler_factory", - "//tensorflow/core/profiler/convert:xplane_to_trace_events", "//tensorflow/core/profiler/utils:derived_timeline", "//tensorflow/core/profiler/utils:group_events", "//tensorflow/core/profiler/utils:xplane_utils", diff --git a/tensorflow/core/profiler/lib/profiler_session.cc b/tensorflow/core/profiler/lib/profiler_session.cc index 982a0f93355..89d58ef0f39 100644 --- a/tensorflow/core/profiler/lib/profiler_session.cc +++ b/tensorflow/core/profiler/lib/profiler_session.cc @@ -20,14 +20,10 @@ limitations under the License. #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/platform.h" #include "tensorflow/core/platform/types.h" -#include "tensorflow/core/protobuf/config.pb.h" -#include "tensorflow/core/protobuf/error_codes.pb.h" -#include "tensorflow/core/protobuf/trace_events.pb.h" #include "tensorflow/core/util/env_var.h" #include "tensorflow/core/util/ptr_util.h" #if !defined(IS_MOBILE_PLATFORM) -#include "tensorflow/core/profiler/convert/xplane_to_trace_events.h" #include "tensorflow/core/profiler/internal/profiler_factory.h" #include "tensorflow/core/profiler/lib/profiler_utils.h" #include "tensorflow/core/profiler/utils/derived_timeline.h" @@ -126,16 +122,6 @@ Status ProfilerSession::CollectData(RunMetadata* run_metadata) { return Status::OK(); } -Status ProfilerSession::SerializeToString(string* content) { - profiler::Trace trace; -#if !defined(IS_MOBILE_PLATFORM) - profiler::XSpace xspace; - TF_RETURN_IF_ERROR(CollectData(&xspace)); - profiler::ConvertXSpaceToTraceEvents(xspace, &trace); -#endif - trace.SerializeToString(content); - return Status::OK(); -} ProfilerSession::ProfilerSession(const profiler::ProfilerOptions& options) #if !defined(IS_MOBILE_PLATFORM) diff --git a/tensorflow/core/profiler/lib/profiler_session.h b/tensorflow/core/profiler/lib/profiler_session.h index 78504312f54..ba977d72567 100644 --- a/tensorflow/core/profiler/lib/profiler_session.h +++ b/tensorflow/core/profiler/lib/profiler_session.h @@ -23,7 +23,6 @@ limitations under the License. #include "tensorflow/core/platform/thread_annotations.h" #include "tensorflow/core/profiler/internal/profiler_interface.h" #include "tensorflow/core/profiler/protobuf/xplane.pb.h" -#include "tensorflow/core/protobuf/config.pb.h" namespace tensorflow { @@ -52,9 +51,6 @@ class ProfilerSession { tensorflow::Status CollectData(RunMetadata* run_metadata) TF_LOCKS_EXCLUDED(mutex_); - tensorflow::Status SerializeToString(string* content) - TF_LOCKS_EXCLUDED(mutex_); - private: // Constructs an instance of the class and starts profiling explicit ProfilerSession(const profiler::ProfilerOptions& options); diff --git a/tensorflow/core/profiler/profiler_service.proto b/tensorflow/core/profiler/profiler_service.proto index ab7ae49df3c..007b68e9482 100644 --- a/tensorflow/core/profiler/profiler_service.proto +++ b/tensorflow/core/profiler/profiler_service.proto @@ -14,11 +14,37 @@ service ProfilerService { } message ProfileOptions { + // Some default value of option are not proto3 default value. Use this version + // to determine if we should use default option value instead of proto3 + // default value. + uint32 version = 5; + // We don't collect the dataset ops by default for better trace-viewer // scalability. The caller can mannually set this field to include the ops. bool include_dataset_ops = 1; - // next-field: 2 + // Levels of host tracing: (version >= 1) + // - Level 0 is used to disable host traces. + // - Level 1 enables tracing of only user instrumented (or default) TraceMe. + // - Level 2 enables tracing of all level 1 TraceMe(s) and instrumented high + // level program execution details (expensive TF ops, XLA ops, etc). + // This is the default. + // - Level 3 enables tracing of all level 2 TraceMe(s) and more verbose + // (low-level) program execution details (cheap TF ops, etc). + uint32 host_tracer_level = 2; + + // Levels of device tracing: (version >= 1) + // - Level 0 is used to disable device traces. + // - Level 1 is used to enable device traces. + // - More levels might be defined for specific device for controlling the + // verbosity of the trace. + uint32 device_tracer_level = 3; + + // Whether enable python function calls tracing. Runtime overhead ensues if + // enabled. Default off. (version >= 1) + uint32 python_tracer_level = 4; + + // next-field: 6 } message ToolRequestOptions { diff --git a/tensorflow/core/profiler/protobuf/BUILD b/tensorflow/core/profiler/protobuf/BUILD index 64e9ecbc0b2..ce5bc9bd120 100644 --- a/tensorflow/core/profiler/protobuf/BUILD +++ b/tensorflow/core/profiler/protobuf/BUILD @@ -62,6 +62,13 @@ tf_proto_library( ], ) +tf_proto_library( + name = "op_profile_proto", + srcs = ["op_profile.proto"], + cc_api_version = 2, + visibility = [":friends"], +) + tf_proto_library( name = "op_stats_proto", srcs = ["op_stats.proto"], @@ -92,6 +99,13 @@ tf_proto_library( visibility = [":friends"], ) +tf_proto_library( + name = "trace_events_proto", + srcs = ["trace_events.proto"], + cc_api_version = 2, + visibility = [":friends"], +) + tf_proto_library( name = "hardware_types_proto", srcs = ["hardware_types.proto"], diff --git a/tensorflow/core/profiler/op_profile.proto b/tensorflow/core/profiler/protobuf/op_profile.proto similarity index 100% rename from tensorflow/core/profiler/op_profile.proto rename to tensorflow/core/profiler/protobuf/op_profile.proto diff --git a/tensorflow/core/protobuf/trace_events.proto b/tensorflow/core/profiler/protobuf/trace_events.proto similarity index 100% rename from tensorflow/core/protobuf/trace_events.proto rename to tensorflow/core/profiler/protobuf/trace_events.proto diff --git a/tensorflow/core/profiler/rpc/client/BUILD b/tensorflow/core/profiler/rpc/client/BUILD index a36e517c4c2..370aa00a602 100644 --- a/tensorflow/core/profiler/rpc/client/BUILD +++ b/tensorflow/core/profiler/rpc/client/BUILD @@ -26,8 +26,8 @@ cc_library( deps = [ "//tensorflow/core:framework", "//tensorflow/core:lib", - "//tensorflow/core:protos_all_cc", "//tensorflow/core/profiler:profiler_service_proto_cc", + "//tensorflow/core/profiler/protobuf:trace_events_proto_cc", "@com_google_absl//absl/strings", "@com_google_absl//absl/time", ], diff --git a/tensorflow/core/profiler/rpc/client/save_profile.cc b/tensorflow/core/profiler/rpc/client/save_profile.cc index dad2918f01a..ab2e494871c 100644 --- a/tensorflow/core/profiler/rpc/client/save_profile.cc +++ b/tensorflow/core/profiler/rpc/client/save_profile.cc @@ -28,10 +28,10 @@ limitations under the License. #include "tensorflow/core/lib/io/compression.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/profiler/protobuf/trace_events.pb.h" // Windows.h #defines ERROR, but it is also used in // tensorflow/core/util/event.proto #undef ERROR -#include "tensorflow/core/protobuf/trace_events.pb.h" #include "tensorflow/core/util/events_writer.h" namespace tensorflow { diff --git a/tensorflow/core/profiler/rpc/profiler_service_impl.cc b/tensorflow/core/profiler/rpc/profiler_service_impl.cc index 01cd35ab2aa..407cd0ae0a6 100644 --- a/tensorflow/core/profiler/rpc/profiler_service_impl.cc +++ b/tensorflow/core/profiler/rpc/profiler_service_impl.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/env_time.h" #include "tensorflow/core/profiler/convert/xplane_to_profile_response.h" +#include "tensorflow/core/profiler/internal/profiler_interface.h" #include "tensorflow/core/profiler/lib/profiler_session.h" #include "tensorflow/core/profiler/protobuf/xplane.pb.h" #include "tensorflow/core/util/ptr_util.h" @@ -51,7 +52,8 @@ class ProfilerServiceImpl : public grpc::ProfilerService::Service { ::grpc::Status Profile(::grpc::ServerContext* ctx, const ProfileRequest* req, ProfileResponse* response) override { VLOG(1) << "Received a profile request: " << req->DebugString(); - std::unique_ptr profiler = ProfilerSession::Create(); + std::unique_ptr profiler = + ProfilerSession::Create(GetOptions(req->opts())); Status status = profiler->Status(); if (!status.ok()) { return ::grpc::Status(::grpc::StatusCode::INTERNAL, @@ -74,6 +76,19 @@ class ProfilerServiceImpl : public grpc::ProfilerService::Service { return ::grpc::Status::OK; } + + private: + profiler::ProfilerOptions GetOptions(const tensorflow::ProfileOptions& opts) { + profiler::ProfilerOptions options; + if (opts.version()) { + options.host_tracer_level = opts.host_tracer_level(); + options.device_tracer_level = opts.device_tracer_level(); + options.enable_python_tracer = opts.python_tracer_level() > 0; + } else { + // use default options value; + } + return options; + } }; } // namespace diff --git a/tensorflow/core/profiler/utils/group_events.cc b/tensorflow/core/profiler/utils/group_events.cc index 0f9fcf939c3..c1adadeb086 100644 --- a/tensorflow/core/profiler/utils/group_events.cc +++ b/tensorflow/core/profiler/utils/group_events.cc @@ -28,6 +28,22 @@ namespace tensorflow { namespace profiler { namespace { +static const int64 kFunctionalOpEventTypes[] = { + HostEventType::kCallOp, + HostEventType::kParallelForOp, + HostEventType::kForeverOp, + HostEventType::kNumericalGradientOpEvalRight, + HostEventType::kNumericalGradientOpEvalLeft, + HostEventType::kSymbolicGradientOp, + HostEventType::kRemoteCallOp, + HostEventType::kIfOp, + HostEventType::kCaseOp, + HostEventType::kWhileOpEvalCond, + HostEventType::kWhileOpStartBody, + HostEventType::kForOp, + HostEventType::kPartitionedCallOp, +}; + // Creates stat metadata for the stats which may be added by grouping. void CreateStatMetadata(XPlane* plane) { XPlaneBuilder builder(plane); @@ -190,19 +206,24 @@ void EventForest::ConnectInterThread( const std::vector& connect_info_list) { for (const auto& connect_info : connect_info_list) { absl::flat_hash_map, EventNode*> connect_map; - const std::vector& stat_types = connect_info.stat_types; + const std::vector& parent_stat_types = + connect_info.parent_stat_types; + const std::vector* child_stat_types = &connect_info.child_stat_types; + if (child_stat_types->empty()) { + child_stat_types = &parent_stat_types; + } if (auto parent_event_node_list = gtl::FindOrNull(event_node_map_, connect_info.parent_event_type)) { for (const auto& parent_event_node : *parent_event_node_list) { std::vector stats; - for (auto stat_type : stat_types) { + for (auto stat_type : parent_stat_types) { const XStat* stat = parent_event_node->GetContextStat(stat_type); if (!stat) break; stats.push_back(stat->value_case() == stat->kInt64Value ? stat->int64_value() : stat->uint64_value()); } - if (stats.size() == stat_types.size()) { + if (stats.size() == parent_stat_types.size()) { connect_map[stats] = parent_event_node.get(); } } @@ -211,14 +232,14 @@ void EventForest::ConnectInterThread( gtl::FindOrNull(event_node_map_, connect_info.child_event_type)) { for (const auto& child_event_node : *child_event_node_list) { std::vector stats; - for (auto stat_type : stat_types) { + for (auto stat_type : *child_stat_types) { const XStat* stat = child_event_node->GetContextStat(stat_type); if (!stat) break; stats.push_back(stat->value_case() == stat->kInt64Value ? stat->int64_value() : stat->uint64_value()); } - if (stats.size() == stat_types.size()) { + if (stats.size() == child_stat_types->size()) { if (auto parent_event_node = gtl::FindPtrOrNull(connect_map, stats)) { parent_event_node->AddChild(child_event_node.get()); } @@ -330,24 +351,46 @@ EventForest::EventForest( CreateEventGroup(root_event_types); } +std::vector CreateInterThreadConnectInfoList() { + std::vector connect_info_list = { + {HostEventType::kFunctionRun, + HostEventType::kExecutorStateProcess, + {StatType::kStepId}}, + {HostEventType::kFunctionRun, + HostEventType::kExecutorDoneCallback, + {StatType::kStepId}}, + {HostEventType::kSessionRun, + HostEventType::kExecutorStateProcess, + {StatType::kStepId}}, + {HostEventType::kSessionRun, + HostEventType::kExecutorDoneCallback, + {StatType::kStepId}}, + {HostEventType::kExecutorStateProcess, + HostEventType::kIteratorGetNextOp, + {StatType::kStepId, StatType::kIterNum}}, + {HostEventType::kKernelLaunch, + HostEventType::kKernelExecute, + {StatType::kCorrelationId}}, + {HostEventType::kLocalExecutableExecuteOnLocalDevice, + HostEventType::kLocalExecutableExecute, + {StatType::kRunId}}}; + for (int64 event_type : kFunctionalOpEventTypes) { + connect_info_list.push_back({event_type, + HostEventType::kExecutorStateProcess, + {StatType::kFunctionStepId}, + {StatType::kStepId}}); + connect_info_list.push_back({event_type, + HostEventType::kExecutorDoneCallback, + {StatType::kFunctionStepId}, + {StatType::kStepId}}); + } + return connect_info_list; +} + void GroupTfEvents(XSpace* space, EventGroupNameMap* event_group_name_map) { if (!space) return; - std::vector connect_info_list( - {{HostEventType::kFunctionRun, - HostEventType::kExecutorStateProcess, - {StatType::kStepId}}, - {HostEventType::kSessionRun, - HostEventType::kExecutorStateProcess, - {StatType::kStepId}}, - {HostEventType::kExecutorStateProcess, - HostEventType::kIteratorGetNextOp, - {StatType::kStepId, StatType::kIterNum}}, - {HostEventType::kKernelLaunch, - HostEventType::kKernelExecute, - {StatType::kCorrelationId}}, - {HostEventType::kLocalExecutableExecuteOnLocalDevice, - HostEventType::kLocalExecutableExecute, - {StatType::kRunId}}}); + std::vector connect_info_list = + CreateInterThreadConnectInfoList(); const std::vector root_event_types( {HostEventType::kTraceContext, HostEventType::kFunctionRun, HostEventType::kSessionRun, HostEventType::kHostTrainingLoopIteration}); diff --git a/tensorflow/core/profiler/utils/group_events.h b/tensorflow/core/profiler/utils/group_events.h index a66b5125d47..68daccdfdaf 100644 --- a/tensorflow/core/profiler/utils/group_events.h +++ b/tensorflow/core/profiler/utils/group_events.h @@ -32,7 +32,8 @@ namespace profiler { struct InterThreadConnectInfo { int64 parent_event_type; int64 child_event_type; - std::vector stat_types; + std::vector parent_stat_types; + std::vector child_stat_types; }; // A wrapper for XEvent with parent and children pointers. Through these @@ -136,6 +137,8 @@ class EventForest { EventGroupNameMap event_group_name_map_; }; +std::vector CreateInterThreadConnectInfoList(); + // Calls GroupEvents with connect_info_list and root_event_types specific to // TensorFlow. void GroupTfEvents(XSpace* space, EventGroupNameMap* event_group_name_map); diff --git a/tensorflow/core/profiler/utils/group_events_test.cc b/tensorflow/core/profiler/utils/group_events_test.cc index 94d576527a7..233b7d7b50c 100644 --- a/tensorflow/core/profiler/utils/group_events_test.cc +++ b/tensorflow/core/profiler/utils/group_events_test.cc @@ -97,6 +97,52 @@ TEST(GroupEventsTest, GroupHostTrainingLoopTest) { EXPECT_EQ(event_group_name_map[0], "10"); } +TEST(GroupEventsTest, GroupFunctionalOp) { + XSpace space; + XPlane* host_plane = space.add_planes(); + XPlaneBuilder host_plane_builder(host_plane); + host_plane_builder.SetName(kHostThreads); + host_plane_builder.ReserveLines(2); + + auto main_thread = host_plane_builder.GetOrCreateLine(0); + CreateXEvent(&host_plane_builder, &main_thread, HostEventType::kTraceContext, + 0, 200, {{StatType::kStepNum, 123}}); + CreateXEvent(&host_plane_builder, &main_thread, HostEventType::kFunctionRun, + 10, 190, {{StatType::kStepId, 0}}); + + auto tf_executor_thread = host_plane_builder.GetOrCreateLine(0); + CreateXEvent(&host_plane_builder, &tf_executor_thread, + HostEventType::kExecutorStateProcess, 20, 80, + {{StatType::kStepId, 0}}); + CreateXEvent(&host_plane_builder, &tf_executor_thread, + HostEventType::kRemoteCallOp, 30, 70, + {{StatType::kFunctionStepId, 1}}); + CreateXEvent(&host_plane_builder, &tf_executor_thread, + HostEventType::kExecutorStateProcess, 100, 150, + {{StatType::kStepId, 1}}); + + EventGroupNameMap event_group_name_map; + GroupTfEvents(&space, &event_group_name_map); + XPlaneVisitor host_plane_visitor = CreateTfXPlaneVisitor(host_plane); + // Check that RemoteCallOp is grouped correctly so that all events belong + // to the same group. + host_plane_visitor.ForEachLine( + [&](const tensorflow::profiler::XLineVisitor& line) { + line.ForEachEvent( + [&](const tensorflow::profiler::XEventVisitor& event) { + absl::optional group_id; + event.ForEachStat( + [&](const tensorflow::profiler::XStatVisitor& stat) { + if (stat.Type() == StatType::kGroupId) { + group_id = stat.IntValue(); + } + }); + EXPECT_TRUE(group_id.has_value()); + EXPECT_EQ(*group_id, 0); + }); + }); +} + } // namespace } // namespace profiler } // namespace tensorflow diff --git a/tensorflow/core/profiler/utils/op_metrics_db_utils.cc b/tensorflow/core/profiler/utils/op_metrics_db_utils.cc index dee33f1d1ce..07d1be230f0 100644 --- a/tensorflow/core/profiler/utils/op_metrics_db_utils.cc +++ b/tensorflow/core/profiler/utils/op_metrics_db_utils.cc @@ -23,6 +23,9 @@ limitations under the License. namespace tensorflow { namespace profiler { + +const absl::string_view kIdle = "IDLE"; + namespace { class DeviceTfOpMetricsDbBuilder : public OpMetricsDbBuilder { @@ -85,9 +88,9 @@ uint64 IdleTimePs(const OpMetricsDb& metrics_db) { void AddIdleOp(OpMetricsDb* db) { uint64 idle_time_ps = IdleTimePs(*db); OpMetrics* metrics = db->add_metrics_db(); - metrics->set_name("IDLE"); - metrics->set_category("IDLE"); - metrics->set_occurrences(1); + metrics->set_name(string(kIdle)); + metrics->set_category(string(kIdle)); + metrics->set_occurrences(0); metrics->set_time_ps(idle_time_ps); metrics->set_self_time_ps(idle_time_ps); } @@ -102,9 +105,9 @@ OpMetricsDb CreateTfMetricsDbFromDeviceOpMetricsDb( builder.UpdateTfOpMetricsWithDeviceOpMetrics(tf_op.name, tf_op.type, device_op_metrics); } else { - DCHECK_EQ(device_op_metrics.name(), "IDLE"); + DCHECK(IsIdleOp(device_op_metrics)); if (with_idle) { - builder.UpdateTfOpMetricsWithDeviceOpMetrics("IDLE", "IDLE", + builder.UpdateTfOpMetricsWithDeviceOpMetrics(kIdle, kIdle, device_op_metrics); } } diff --git a/tensorflow/core/profiler/utils/op_metrics_db_utils.h b/tensorflow/core/profiler/utils/op_metrics_db_utils.h index a1f1a045cdd..7cb776abfe7 100644 --- a/tensorflow/core/profiler/utils/op_metrics_db_utils.h +++ b/tensorflow/core/profiler/utils/op_metrics_db_utils.h @@ -25,6 +25,10 @@ limitations under the License. namespace tensorflow { namespace profiler { + +// The name of OpMetrics to represent the idle time. +ABSL_CONST_INIT extern const absl::string_view kIdle; + // Helps build an op metrics database (borrowed). // Enables fast lookup of existing ops and prevents the creation of duplicate // ops. It is the user's responsibility to ensure an op metrics database @@ -67,6 +71,11 @@ uint64 IdleTimePs(const OpMetricsDb& metrics_db); // must have been set. void AddIdleOp(OpMetricsDb* db); +// Returns true if the given metrics represents idle time. +inline bool IsIdleOp(const OpMetrics& metrics) { + return metrics.name() == kIdle; +} + // Converts from the device op metrics to Tf-op metrics. OpMetricsDb CreateTfMetricsDbFromDeviceOpMetricsDb( const OpMetricsDb& device_op_metrics_db, bool with_idle = true); diff --git a/tensorflow/core/profiler/utils/xplane_schema.cc b/tensorflow/core/profiler/utils/xplane_schema.cc index 9de8028f8eb..0c0924cfb38 100644 --- a/tensorflow/core/profiler/utils/xplane_schema.cc +++ b/tensorflow/core/profiler/utils/xplane_schema.cc @@ -122,6 +122,8 @@ const StatTypeMap& GetStatTypeMap() { {"fragmentation", kFragmentation}, {"peak_bytes_in_use", kPeakBytesInUse}, {"requested_bytes", kRequestedBytes}, + {"allocation_bytes", kAllocationBytes}, + {"addr", kAddress}, {"shape", kTensorShapes}, // Device trace arguments. {"device_id", kDeviceId}, diff --git a/tensorflow/core/profiler/utils/xplane_schema.h b/tensorflow/core/profiler/utils/xplane_schema.h index 03e7b8ee720..9e6eaab1036 100644 --- a/tensorflow/core/profiler/utils/xplane_schema.h +++ b/tensorflow/core/profiler/utils/xplane_schema.h @@ -113,6 +113,8 @@ enum StatType { kFragmentation, kPeakBytesInUse, kRequestedBytes, + kAllocationBytes, + kAddress, kTensorShapes, // Device trace arguments. kDeviceId, diff --git a/tensorflow/core/protobuf/eager_service.proto b/tensorflow/core/protobuf/eager_service.proto index 6f2913eae90..d57ca22b0d2 100644 --- a/tensorflow/core/protobuf/eager_service.proto +++ b/tensorflow/core/protobuf/eager_service.proto @@ -73,7 +73,12 @@ message QueueItem { } message QueueResponse { + // `shape` and `tensor` cannot be set in the same response. + // Shapes of output tensors for creating remote TensorHandles. repeated TensorShapeProto shape = 1; + + // Output tensors of a remote function. Set when Operation.id is invalid. + repeated TensorProto tensor = 2; } message CreateContextRequest { diff --git a/tensorflow/core/protobuf/struct.proto b/tensorflow/core/protobuf/struct.proto index e139d0b4f18..0158c4be85f 100644 --- a/tensorflow/core/protobuf/struct.proto +++ b/tensorflow/core/protobuf/struct.proto @@ -2,6 +2,7 @@ syntax = "proto3"; package tensorflow; +import "tensorflow/core/framework/tensor.proto"; import "tensorflow/core/framework/tensor_shape.proto"; import "tensorflow/core/framework/types.proto"; @@ -60,6 +61,8 @@ message StructuredValue { TensorSpecProto tensor_spec_value = 33; // Represents a value for tf.TypeSpec. TypeSpecProto type_spec_value = 34; + // Represents a value for tf.BoundedTensorSpec. + BoundedTensorSpecProto bounded_tensor_spec_value = 35; // Represents a list of `Value`. ListValue list_value = 51; @@ -103,13 +106,22 @@ message NamedTupleValue { repeated PairValue values = 2; } -// A protobuf to tf.TensorSpec. +// A protobuf to represent tf.TensorSpec. message TensorSpecProto { string name = 1; tensorflow.TensorShapeProto shape = 2; tensorflow.DataType dtype = 3; } +// A protobuf to represent tf.BoundedTensorSpec. +message BoundedTensorSpecProto { + string name = 1; + tensorflow.TensorShapeProto shape = 2; + tensorflow.DataType dtype = 3; + tensorflow.TensorProto minimum = 4; + tensorflow.TensorProto maximum = 5; +} + // Represents a tf.TypeSpec message TypeSpecProto { enum TypeSpecClass { diff --git a/tensorflow/core/public/session_options.h b/tensorflow/core/public/session_options.h index 148e1b6317e..c10cc889ed2 100644 --- a/tensorflow/core/public/session_options.h +++ b/tensorflow/core/public/session_options.h @@ -52,7 +52,7 @@ struct SessionOptions { /// /// If the session disconnects from the remote process during its /// lifetime, session calls may fail immediately. - string target; + std::string target; /// Configuration options. ConfigProto config; diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go index 1eaa4afaf0e..30cc8331324 100644 --- a/tensorflow/go/op/wrappers.go +++ b/tensorflow/go/op/wrappers.go @@ -12022,7 +12022,7 @@ func SampleDistortedBoundingBoxV2Seed2(value int64) SampleDistortedBoundingBoxV2 // // value: The cropped area of the image must have an aspect ratio = // width / height within this range. -// If not specified, defaults to {f:0.75 f:1.33} +// If not specified, defaults to {f:0.75 f:1.33} func SampleDistortedBoundingBoxV2AspectRatioRange(value []float32) SampleDistortedBoundingBoxV2Attr { return func(m optionalAttr) { m["aspect_ratio_range"] = value @@ -12033,7 +12033,7 @@ func SampleDistortedBoundingBoxV2AspectRatioRange(value []float32) SampleDistort // // value: The cropped area of the image must contain a fraction of the // supplied image within this range. -// If not specified, defaults to {f:0.05 f:1} +// If not specified, defaults to {f:0.05 f:1} func SampleDistortedBoundingBoxV2AreaRange(value []float32) SampleDistortedBoundingBoxV2Attr { return func(m optionalAttr) { m["area_range"] = value @@ -12251,7 +12251,7 @@ func SampleDistortedBoundingBoxMinObjectCovered(value float32) SampleDistortedBo // // value: The cropped area of the image must have an aspect ratio = // width / height within this range. -// If not specified, defaults to {f:0.75 f:1.33} +// If not specified, defaults to {f:0.75 f:1.33} func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistortedBoundingBoxAttr { return func(m optionalAttr) { m["aspect_ratio_range"] = value @@ -12262,7 +12262,7 @@ func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistorted // // value: The cropped area of the image must contain a fraction of the // supplied image within this range. -// If not specified, defaults to {f:0.05 f:1} +// If not specified, defaults to {f:0.05 f:1} func SampleDistortedBoundingBoxAreaRange(value []float32) SampleDistortedBoundingBoxAttr { return func(m optionalAttr) { m["area_range"] = value @@ -19038,7 +19038,7 @@ func ImageSummaryMaxImages(value int64) ImageSummaryAttr { // ImageSummaryBadColor sets the optional bad_color attribute to value. // // value: Color to use for pixels with non-finite values. -// If not specified, defaults to {dtype:DT_UINT8 tensor_shape:{dim:{size:4}} int_val:255 int_val:0 int_val:0 int_val:255} +// If not specified, defaults to {dtype:DT_UINT8 tensor_shape:{dim:{size:4}} int_val:255 int_val:0 int_val:0 int_val:255} func ImageSummaryBadColor(value tf.Tensor) ImageSummaryAttr { return func(m optionalAttr) { m["bad_color"] = value @@ -20109,7 +20109,7 @@ func Conv3DBackpropFilterV2DataFormat(value string) Conv3DBackpropFilterV2Attr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropFilterV2Dilations(value []int64) Conv3DBackpropFilterV2Attr { return func(m optionalAttr) { m["dilations"] = value @@ -21281,7 +21281,7 @@ func Conv2DBackpropInputDataFormat(value string) Conv2DBackpropInputAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DBackpropInputDilations(value []int64) Conv2DBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -21989,7 +21989,7 @@ func Conv2DDataFormat(value string) Conv2DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DDilations(value []int64) Conv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22185,7 +22185,7 @@ func QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeOutType(value tf.DataTy // QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22254,7 +22254,7 @@ func QuantizedDepthwiseConv2DWithBiasAndReluOutType(value tf.DataType) Quantized // QuantizedDepthwiseConv2DWithBiasAndReluDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasAndReluDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAndReluAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22369,7 +22369,7 @@ func QuantizedDepthwiseConv2DWithBiasOutType(value tf.DataType) QuantizedDepthwi // QuantizedDepthwiseConv2DWithBiasDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22428,7 +22428,7 @@ func QuantizedDepthwiseConv2DOutType(value tf.DataType) QuantizedDepthwiseConv2D // QuantizedDepthwiseConv2DDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DDilations(value []int64) QuantizedDepthwiseConv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22602,7 +22602,7 @@ func QuantizedConv2DPerChannelOutType(value tf.DataType) QuantizedConv2DPerChann // QuantizedConv2DPerChannelDilations sets the optional dilations attribute to value. // // value: list of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedConv2DPerChannelDilations(value []int64) QuantizedConv2DPerChannelAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22979,7 +22979,7 @@ func Conv3DBackpropInputV2DataFormat(value string) Conv3DBackpropInputV2Attr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropInputV2Dilations(value []int64) Conv3DBackpropInputV2Attr { return func(m optionalAttr) { m["dilations"] = value @@ -25322,7 +25322,7 @@ func AvgPool3DGrad(scope *Scope, orig_input_shape tf.Output, grad tf.Output, ksi type Conv3DBackpropFilterAttr func(optionalAttr) // Conv3DBackpropFilterDilations sets the optional dilations attribute to value. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropFilterDilations(value []int64) Conv3DBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -25385,7 +25385,7 @@ func Conv3DDataFormat(value string) Conv3DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DDilations(value []int64) Conv3DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -25636,7 +25636,7 @@ func DepthwiseConv2dNativeBackpropInputDataFormat(value string) DepthwiseConv2dN // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeBackpropInputDilations(value []int64) DepthwiseConv2dNativeBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -26120,7 +26120,7 @@ func QuantizedConv2DOutType(value tf.DataType) QuantizedConv2DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedConv2DDilations(value []int64) QuantizedConv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -40326,7 +40326,7 @@ func DepthwiseConv2dNativeBackpropFilterDataFormat(value string) DepthwiseConv2d // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeBackpropFilterDilations(value []int64) DepthwiseConv2dNativeBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -45852,7 +45852,7 @@ func Conv2DBackpropFilterDataFormat(value string) Conv2DBackpropFilterAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DBackpropFilterDilations(value []int64) Conv2DBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -46704,7 +46704,7 @@ func LoadTPUEmbeddingFTRLParameters(scope *Scope, parameters tf.Output, accumula type Conv3DBackpropInputAttr func(optionalAttr) // Conv3DBackpropInputDilations sets the optional dilations attribute to value. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropInputDilations(value []int64) Conv3DBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -46775,7 +46775,7 @@ func DepthwiseConv2dNativeDataFormat(value string) DepthwiseConv2dNativeAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeDilations(value []int64) DepthwiseConv2dNativeAttr { return func(m optionalAttr) { m["dilations"] = value diff --git a/tensorflow/java/src/test/java/org/tensorflow/op/core/ZerosTest.java b/tensorflow/java/src/test/java/org/tensorflow/op/core/ZerosTest.java index cf3910b594f..7fd68a0f720 100644 --- a/tensorflow/java/src/test/java/org/tensorflow/op/core/ZerosTest.java +++ b/tensorflow/java/src/test/java/org/tensorflow/op/core/ZerosTest.java @@ -17,6 +17,7 @@ package org.tensorflow.op.core; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; +import static org.junit.Assert.fail; import java.util.List; @@ -142,14 +143,15 @@ public class ZerosTest { } } - @Test(expected = IllegalArgumentException.class) + @Test public void cannotCreateStringZeros() { try (Graph g = new Graph(); Session sess = new Session(g)) { Scope scope = new Scope(g); long[] shape = {2, 2}; Zeros.create(scope, Constant.create(scope, shape), String.class); - } + fail(); + } catch (IllegalArgumentException expected) {} } @Test diff --git a/tensorflow/lite/BUILD b/tensorflow/lite/BUILD index a2d8b40bbce..fa9e62186fa 100644 --- a/tensorflow/lite/BUILD +++ b/tensorflow/lite/BUILD @@ -308,6 +308,18 @@ cc_library( ], ) +cc_library( + name = "tflite_with_xnnpack", + srcs = ["tflite_with_xnnpack.cc"], + copts = tflite_copts() + TFLITE_DEFAULT_COPTS, + linkstatic = True, + deps = [ + "//tensorflow/lite/c:common", + "//tensorflow/lite/delegates/xnnpack:xnnpack_delegate", + ], + alwayslink = 1, +) + cc_test( name = "string_util_test", size = "small", @@ -435,6 +447,32 @@ tf_cc_test( ], ) +# Test model framework with the XNNPACK delegate. +cc_test( + name = "model_xnnpack_test", + size = "small", + srcs = [ + "model_xnnpack_test.cc", + ], + data = [ + "testdata/multi_add.bin", + ], + tags = [ + "no_windows", # No weak symbols with MSVC. + "tflite_not_portable_android", + "tflite_not_portable_ios", + ], + deps = [ + ":framework", + ":tflite_with_xnnpack", + ":util", + "//tensorflow/lite/c:common", + "//tensorflow/lite/kernels:builtin_ops", + "//tensorflow/lite/testing:util", + "@com_google_googletest//:gtest", + ], +) + # Test OpResolver. cc_test( name = "mutable_op_resolver_test", diff --git a/tensorflow/lite/arena_planner.cc b/tensorflow/lite/arena_planner.cc index 6462427075b..d97eca46929 100644 --- a/tensorflow/lite/arena_planner.cc +++ b/tensorflow/lite/arena_planner.cc @@ -287,14 +287,13 @@ std::vector ArenaPlanner::CreateTensorAllocationVector(int first_node, return this->alloc_node_[idx1] < this->alloc_node_[idx2]; }; - std::set tensors_set; + std::vector tensor_order; for (int i = 0; i < static_cast(graph_info_->num_tensors()); ++i) { if (alloc_node_[i] >= first_node && alloc_node_[i] <= last_node) { - tensors_set.insert(i); + tensor_order.push_back(i); } } // Indices of tensors in order their allocation offsets will be calculated. - std::vector tensor_order(tensors_set.begin(), tensors_set.end()); std::sort(tensor_order.begin(), tensor_order.end(), tensor_compare); return tensor_order; diff --git a/tensorflow/lite/core/macros.h b/tensorflow/lite/core/macros.h index 5ff00e4814a..034ad8daac5 100644 --- a/tensorflow/lite/core/macros.h +++ b/tensorflow/lite/core/macros.h @@ -32,4 +32,23 @@ limitations under the License. #define TFLITE_EXPECT_TRUE(cond) (cond) #endif +// Normally we'd use ABSL_HAVE_ATTRIBUTE_WEAK and ABSL_ATTRIBUTE_WEAK, but +// we avoid the absl dependency for binary size reasons. +#ifdef __has_attribute +#define TFLITE_HAS_ATTRIBUTE(x) __has_attribute(x) +#else +#define TFLITE_HAS_ATTRIBUTE(x) 0 +#endif + +#if (TFLITE_HAS_ATTRIBUTE(weak) || \ + (defined(__GNUC__) && !defined(__clang__))) && \ + !(defined(__llvm__) && defined(_WIN32)) && !defined(__MINGW32__) +#undef TFLITE_ATTRIBUTE_WEAK +#define TFLITE_ATTRIBUTE_WEAK __attribute__((weak)) +#define TFLITE_HAS_ATTRIBUTE_WEAK 1 +#else +#define TFLITE_ATTRIBUTE_WEAK +#define TFLITE_HAS_ATTRIBUTE_WEAK 0 +#endif + #endif // TENSORFLOW_LITE_CORE_MACROS_H_ diff --git a/tensorflow/lite/delegates/flex/kernel.cc b/tensorflow/lite/delegates/flex/kernel.cc index 9a664b28246..04d9eca597e 100644 --- a/tensorflow/lite/delegates/flex/kernel.cc +++ b/tensorflow/lite/delegates/flex/kernel.cc @@ -274,9 +274,9 @@ class OpNode { return tensorflow::errors::Internal( "Cannot read from invalid tensor index ", input_index); } - tensorflow::TensorHandle* handle; - TF_RETURN_IF_ERROR(tensorflow::TensorHandle::CreateLocalHandle( - buffer_map->GetTensor(input_index), &handle)); + tensorflow::TensorHandle* handle = + tensorflow::TensorHandle::CreateLocalHandle( + buffer_map->GetTensor(input_index)); op_->MutableInputs()->push_back(handle); } else { // If this is a forwardable tensor, we will remove it from the previous diff --git a/tensorflow/lite/delegates/gpu/BUILD b/tensorflow/lite/delegates/gpu/BUILD index b5fff1d84d5..d6875476dec 100644 --- a/tensorflow/lite/delegates/gpu/BUILD +++ b/tensorflow/lite/delegates/gpu/BUILD @@ -237,6 +237,7 @@ cc_library( "//tensorflow/lite/delegates/gpu/common:model_transformer", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/gl:api2", + "@com_google_absl//absl/memory", "@com_google_absl//absl/types:span", ], ) diff --git a/tensorflow/lite/delegates/gpu/cl/BUILD b/tensorflow/lite/delegates/gpu/cl/BUILD index 2aeff13b3be..2e686810767 100644 --- a/tensorflow/lite/delegates/gpu/cl/BUILD +++ b/tensorflow/lite/delegates/gpu/cl/BUILD @@ -244,6 +244,7 @@ cc_library( srcs = ["gl_interop.cc"], hdrs = ["gl_interop.h"], deps = [ + ":cl_command_queue", ":cl_context", ":cl_device", ":cl_errors", @@ -252,6 +253,7 @@ cc_library( ":egl_sync", ":environment", ":opencl_wrapper", + "//tensorflow/lite/delegates/gpu:spi", "//tensorflow/lite/delegates/gpu/common:access_type", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/gl:gl_call", diff --git a/tensorflow/lite/delegates/gpu/cl/api.cc b/tensorflow/lite/delegates/gpu/cl/api.cc index a6488c51ce4..09c82307a53 100644 --- a/tensorflow/lite/delegates/gpu/cl/api.cc +++ b/tensorflow/lite/delegates/gpu/cl/api.cc @@ -86,6 +86,11 @@ class DefaultTensorTie : public TensorTie { const TensorTieDef& def, const TensorObjectConverterBuilder& converter_builder) { auto object_type = def.external_def.object_def.object_type; + if (def.external_def.object_def.user_provided && + GlClBufferCopier::IsSupported(def.external_def.object_def, + def.internal_def.object_def)) { + return true; + } return (object_type == ObjectType::OPENCL_BUFFER || object_type == ObjectType::OPENCL_TEXTURE || object_type == ObjectType::CPU_MEMORY) && @@ -132,10 +137,24 @@ class DefaultTensorTie : public TensorTie { private: absl::Status Init(TensorObjectConverterBuilder* converter_builder, Environment* env) { - RETURN_IF_ERROR(converter_builder->MakeConverter( - def().internal_def, def().external_def, &converter_to_)); - RETURN_IF_ERROR(converter_builder->MakeConverter( - def().external_def, def().internal_def, &converter_from_)); + if (def().external_def.object_def.user_provided && + GlClBufferCopier::IsSupported(def().external_def.object_def, + def().internal_def.object_def)) { + converter_from_ = absl::make_unique( + def().internal_def, def().external_def, env); + } else { + RETURN_IF_ERROR(converter_builder->MakeConverter( + def().external_def, def().internal_def, &converter_from_)); + } + if (def().external_def.object_def.user_provided && + GlClBufferCopier::IsSupported(def().internal_def.object_def, + def().external_def.object_def)) { + converter_to_ = absl::make_unique( + def().internal_def, def().external_def, env); + } else { + RETURN_IF_ERROR(converter_builder->MakeConverter( + def().internal_def, def().external_def, &converter_to_)); + } return MaybeAllocateExternalObject(env); } @@ -356,7 +375,8 @@ class TensorTieFactory { return IsValid(def.external_def.object_def) && (NoopTensorTie::IsSupported(def) || DefaultTensorTie::IsSupported(def, *converter_builder_) || - GlBufferHolder::IsSupported(def, *converter_builder_) || + (gl_interop_fabric_ && + GlBufferHolder::IsSupported(def, *converter_builder_)) || TwoStepTensorTie::IsSupported(def, *converter_builder_)); } @@ -371,12 +391,7 @@ class TensorTieFactory { if (DefaultTensorTie::IsSupported(def, *converter)) { return DefaultTensorTie::New(def, internal_object, converter, &env_, tie); } - if (GlBufferHolder::IsSupported(def, *converter)) { - if (!gl_interop_fabric_) { - return absl::InvalidArgumentError( - "GL object is used but InferenceEnvironmentOptions does not have " - "EGL display and context set."); - } + if (gl_interop_fabric_ && GlBufferHolder::IsSupported(def, *converter)) { return GlBufferHolder::New(def, internal_object, converter, gl_interop_fabric_, &env_, tie); } @@ -526,7 +541,8 @@ class InferenceBuilderImpl : public InferenceBuilder { } RETURN_IF_ERROR(context_->InitFromGraph(create_info, graph, environment_)); - if (env_options.IsGlAware()) { + if (env_options.IsGlAware() && + IsGlSharingSupported(environment_->device())) { gl_interop_fabric_ = absl::make_unique( env_options.egl_display, environment_); } @@ -719,9 +735,6 @@ class InferenceEnvironmentImpl : public InferenceEnvironment { IsClEventFromEglSyncSupported(device); properties_.is_cl_to_gl_fast_sync_supported = IsEglSyncFromClEventSupported(); - if (options_.IsGlAware() && !properties_.is_gl_sharing_supported) { - return absl::UnavailableError("GL sharing is not supported"); - } CLContext context; if (options_.context) { @@ -731,7 +744,7 @@ class InferenceEnvironmentImpl : public InferenceEnvironment { } context = CLContext(options_.context, /* has_ownership = */ false); } else { - if (options_.IsGlAware()) { + if (options_.IsGlAware() && properties_.is_gl_sharing_supported) { RETURN_IF_ERROR(CreateCLGLContext( device, reinterpret_cast(options_.egl_context), diff --git a/tensorflow/lite/delegates/gpu/cl/gl_interop.cc b/tensorflow/lite/delegates/gpu/cl/gl_interop.cc index 648b772d827..eaeff2cda07 100644 --- a/tensorflow/lite/delegates/gpu/cl/gl_interop.cc +++ b/tensorflow/lite/delegates/gpu/cl/gl_interop.cc @@ -263,6 +263,46 @@ absl::Status GlInteropFabric::Finish() { return absl::OkStatus(); } +GlClBufferCopier::GlClBufferCopier(const TensorObjectDef& input_def, + const TensorObjectDef& output_def, + Environment* environment) { + queue_ = environment->queue(); + size_in_bytes_ = + NumElements(input_def) * SizeOf(input_def.object_def.data_type); +} + +absl::Status GlClBufferCopier::Convert(const TensorObject& input_obj, + const TensorObject& output_obj) { + if (absl::get_if(&input_obj)) { + auto ssbo = absl::get_if(&input_obj); + auto cl_mem = absl::get_if(&output_obj); + RETURN_IF_ERROR( + TFLITE_GPU_CALL_GL(glBindBuffer, GL_SHADER_STORAGE_BUFFER, ssbo->id)); + void* ptr; + RETURN_IF_ERROR(TFLITE_GPU_CALL_GL(glMapBufferRange, &ptr, + GL_SHADER_STORAGE_BUFFER, 0, + size_in_bytes_, GL_MAP_READ_BIT)); + RETURN_IF_ERROR( + queue_->EnqueueWriteBuffer(cl_mem->memobj, size_in_bytes_, ptr)); + RETURN_IF_ERROR( + TFLITE_GPU_CALL_GL(glUnmapBuffer, GL_SHADER_STORAGE_BUFFER)); + } else { + auto cl_mem = absl::get_if(&input_obj); + auto ssbo = absl::get_if(&output_obj); + RETURN_IF_ERROR( + TFLITE_GPU_CALL_GL(glBindBuffer, GL_SHADER_STORAGE_BUFFER, ssbo->id)); + void* ptr; + RETURN_IF_ERROR(TFLITE_GPU_CALL_GL(glMapBufferRange, &ptr, + GL_SHADER_STORAGE_BUFFER, 0, + size_in_bytes_, GL_MAP_WRITE_BIT)); + RETURN_IF_ERROR( + queue_->EnqueueReadBuffer(cl_mem->memobj, size_in_bytes_, ptr)); + RETURN_IF_ERROR( + TFLITE_GPU_CALL_GL(glUnmapBuffer, GL_SHADER_STORAGE_BUFFER)); + } + return absl::OkStatus(); +} + } // namespace cl } // namespace gpu } // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/cl/gl_interop.h b/tensorflow/lite/delegates/gpu/cl/gl_interop.h index 7ebc3e4bf4f..1ca0181e8e5 100644 --- a/tensorflow/lite/delegates/gpu/cl/gl_interop.h +++ b/tensorflow/lite/delegates/gpu/cl/gl_interop.h @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "tensorflow/lite/delegates/gpu/cl/cl_command_queue.h" #include "tensorflow/lite/delegates/gpu/cl/cl_context.h" #include "tensorflow/lite/delegates/gpu/cl/cl_device.h" #include "tensorflow/lite/delegates/gpu/cl/cl_event.h" @@ -30,6 +31,7 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/access_type.h" #include "tensorflow/lite/delegates/gpu/common/status.h" #include "tensorflow/lite/delegates/gpu/gl/portable_gl31.h" +#include "tensorflow/lite/delegates/gpu/spi.h" namespace tflite { namespace gpu { @@ -139,6 +141,29 @@ class GlInteropFabric { AcquiredGlObjects gl_objects_; // transient during Start/Finish calls. }; +// Copies data from(to) GL buffer to(from) CL buffer using CPU. +class GlClBufferCopier : public TensorObjectConverter { + public: + static bool IsSupported(const ObjectDef& input, const ObjectDef& output) { + return input.data_type == output.data_type && + input.data_layout == output.data_layout && + ((input.object_type == ObjectType::OPENGL_SSBO && + output.object_type == ObjectType::OPENCL_BUFFER) || + (input.object_type == ObjectType::OPENCL_BUFFER && + output.object_type == ObjectType::OPENGL_SSBO)); + } + + GlClBufferCopier(const TensorObjectDef& input_def, + const TensorObjectDef& output_def, Environment* environment); + + absl::Status Convert(const TensorObject& input_obj, + const TensorObject& output_obj) override; + + private: + size_t size_in_bytes_; + CLCommandQueue* queue_ = nullptr; +}; + } // namespace cl } // namespace gpu } // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/BUILD b/tensorflow/lite/delegates/gpu/cl/kernels/BUILD index ce33c052ef1..cf9b8d2c6eb 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/BUILD +++ b/tensorflow/lite/delegates/gpu/cl/kernels/BUILD @@ -1,13 +1,13 @@ -package( - default_visibility = ["//visibility:public"], - licenses = ["notice"], # Apache 2.0 -) - load( "//tensorflow/core/platform:build_config_root.bzl", "tf_gpu_tests_tags", ) +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 +) + cc_library( name = "add", srcs = ["add.cc"], @@ -549,9 +549,9 @@ cc_test( ) cc_library( - name = "depth_wise_conv", - srcs = ["depth_wise_conv.cc"], - hdrs = ["depth_wise_conv.h"], + name = "depthwise_conv", + srcs = ["depthwise_conv.cc"], + hdrs = ["depthwise_conv.h"], deps = [ ":gpu_operation", ":util", @@ -572,9 +572,9 @@ cc_library( ) cc_library( - name = "depth_wise_conv_3d", - srcs = ["depth_wise_conv_3d.cc"], - hdrs = ["depth_wise_conv_3d.h"], + name = "depthwise_conv_3d", + srcs = ["depthwise_conv_3d.cc"], + hdrs = ["depthwise_conv_3d.h"], deps = [ ":gpu_operation", ":util", @@ -595,8 +595,8 @@ cc_library( ) cc_test( - name = "depth_wise_conv_test", - srcs = ["depth_wise_conv_test.cc"], + name = "depthwise_conv_test", + srcs = ["depthwise_conv_test.cc"], linkstatic = True, tags = tf_gpu_tests_tags() + [ "linux", @@ -604,7 +604,7 @@ cc_test( ], deps = [ ":cl_test", - ":depth_wise_conv", + ":depthwise_conv", "//tensorflow/lite/delegates/gpu/cl:tensor", "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:status", @@ -613,9 +613,9 @@ cc_test( ) cc_library( - name = "depth_wise_conv_3x3", - srcs = ["depth_wise_conv_3x3.cc"], - hdrs = ["depth_wise_conv_3x3.h"], + name = "depthwise_conv_3x3", + srcs = ["depthwise_conv_3x3.cc"], + hdrs = ["depthwise_conv_3x3.h"], deps = [ ":gpu_operation", ":util", @@ -635,8 +635,8 @@ cc_library( ) cc_test( - name = "depth_wise_conv_3x3_test", - srcs = ["depth_wise_conv_3x3_test.cc"], + name = "depthwise_conv_3x3_test", + srcs = ["depthwise_conv_3x3_test.cc"], linkstatic = True, tags = tf_gpu_tests_tags() + [ "linux", @@ -644,7 +644,7 @@ cc_test( ], deps = [ ":cl_test", - ":depth_wise_conv_3x3", + ":depthwise_conv_3x3", "//tensorflow/lite/delegates/gpu/cl:tensor", "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:status", @@ -1423,8 +1423,8 @@ test_suite( "convolution_transposed_4x4_test", "convolution_transposed_test", "convolution_transposed_thin_test", - "depth_wise_conv_3x3_test", - "depth_wise_conv_test", + "depthwise_conv_3x3_test", + "depthwise_conv_test", "elementwise_test", "fully_connected_test", "lstm_test", diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/depth_wise_conv.cc b/tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv.cc similarity index 99% rename from tensorflow/lite/delegates/gpu/cl/kernels/depth_wise_conv.cc rename to tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv.cc index 99bec18c7f8..2573f2d7422 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/depth_wise_conv.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/lite/delegates/gpu/cl/kernels/depth_wise_conv.h" +#include "tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv.h" #include #include diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/depth_wise_conv.h b/tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv.h similarity index 96% rename from tensorflow/lite/delegates/gpu/cl/kernels/depth_wise_conv.h rename to tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv.h index 67e5f01a256..1c1c55c1989 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/depth_wise_conv.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_DEPTH_WISE_CONV_H_ -#define TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_DEPTH_WISE_CONV_H_ +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_DEPTHWISE_CONV_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_DEPTHWISE_CONV_H_ #include @@ -170,4 +170,4 @@ absl::Status CreateDepthWiseConvolution( } // namespace gpu } // namespace tflite -#endif // TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_DEPTH_WISE_CONV_H_ +#endif // TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_DEPTHWISE_CONV_H_ diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/depth_wise_conv_3d.cc b/tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv_3d.cc similarity index 99% rename from tensorflow/lite/delegates/gpu/cl/kernels/depth_wise_conv_3d.cc rename to tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv_3d.cc index 57d30dd2734..5f1d529fba2 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/depth_wise_conv_3d.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv_3d.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/lite/delegates/gpu/cl/kernels/depth_wise_conv_3d.h" +#include "tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv_3d.h" #include #include diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/depth_wise_conv_3d.h b/tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv_3d.h similarity index 96% rename from tensorflow/lite/delegates/gpu/cl/kernels/depth_wise_conv_3d.h rename to tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv_3d.h index 6c07207bdb7..1d80d5ddca0 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/depth_wise_conv_3d.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv_3d.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_DEPTH_WISE_CONV_3D_H_ -#define TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_DEPTH_WISE_CONV_3D_H_ +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_DEPTHWISE_CONV_3D_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_DEPTHWISE_CONV_3D_H_ #include @@ -167,4 +167,4 @@ absl::Status CreateDepthWiseConvolution3D( } // namespace gpu } // namespace tflite -#endif // TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_DEPTH_WISE_CONV_3D_H_ +#endif // TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_DEPTHWISE_CONV_3D_H_ diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/depth_wise_conv_3x3.cc b/tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv_3x3.cc similarity index 99% rename from tensorflow/lite/delegates/gpu/cl/kernels/depth_wise_conv_3x3.cc rename to tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv_3x3.cc index 3324adada3b..e4868be7ffc 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/depth_wise_conv_3x3.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv_3x3.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/lite/delegates/gpu/cl/kernels/depth_wise_conv_3x3.h" +#include "tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv_3x3.h" #include #include diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/depth_wise_conv_3x3.h b/tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv_3x3.h similarity index 96% rename from tensorflow/lite/delegates/gpu/cl/kernels/depth_wise_conv_3x3.h rename to tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv_3x3.h index 5960a691652..769903adcb2 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/depth_wise_conv_3x3.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv_3x3.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_DEPTH_WISE_CONV_3X3_H_ -#define TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_DEPTH_WISE_CONV_3X3_H_ +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_DEPTHWISE_CONV_3X3_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_DEPTHWISE_CONV_3X3_H_ #include #include @@ -168,4 +168,4 @@ absl::Status CreateDepthWiseConv3x3( } // namespace gpu } // namespace tflite -#endif // TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_DEPTH_WISE_CONV_3X3_H_ +#endif // TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_DEPTHWISE_CONV_3X3_H_ diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/depth_wise_conv_3x3_test.cc b/tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv_3x3_test.cc similarity index 98% rename from tensorflow/lite/delegates/gpu/cl/kernels/depth_wise_conv_3x3_test.cc rename to tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv_3x3_test.cc index eafa94f15d0..6b33cdf90f2 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/depth_wise_conv_3x3_test.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv_3x3_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/lite/delegates/gpu/cl/kernels/depth_wise_conv_3x3.h" +#include "tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv_3x3.h" #include diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/depth_wise_conv_test.cc b/tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv_test.cc similarity index 98% rename from tensorflow/lite/delegates/gpu/cl/kernels/depth_wise_conv_test.cc rename to tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv_test.cc index 71b546bf384..e69b3d99309 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/depth_wise_conv_test.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/lite/delegates/gpu/cl/kernels/depth_wise_conv.h" +#include "tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv.h" #include diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/BUILD b/tensorflow/lite/delegates/gpu/cl/selectors/BUILD index 6f9b52bd1c9..e9265257c05 100644 --- a/tensorflow/lite/delegates/gpu/cl/selectors/BUILD +++ b/tensorflow/lite/delegates/gpu/cl/selectors/BUILD @@ -63,8 +63,8 @@ cc_library( deps = [ "//tensorflow/lite/delegates/gpu/cl:cl_device", "//tensorflow/lite/delegates/gpu/cl:precision", - "//tensorflow/lite/delegates/gpu/cl/kernels:depth_wise_conv", - "//tensorflow/lite/delegates/gpu/cl/kernels:depth_wise_conv_3x3", + "//tensorflow/lite/delegates/gpu/cl/kernels:depthwise_conv", + "//tensorflow/lite/delegates/gpu/cl/kernels:depthwise_conv_3x3", "//tensorflow/lite/delegates/gpu/cl/kernels:gpu_operation", "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:status", diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/dw_convolution_selector.cc b/tensorflow/lite/delegates/gpu/cl/selectors/dw_convolution_selector.cc index 0098117dea1..72f31154b4b 100644 --- a/tensorflow/lite/delegates/gpu/cl/selectors/dw_convolution_selector.cc +++ b/tensorflow/lite/delegates/gpu/cl/selectors/dw_convolution_selector.cc @@ -17,8 +17,8 @@ limitations under the License. #include "absl/memory/memory.h" #include "tensorflow/lite/delegates/gpu/cl/cl_device.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/depth_wise_conv.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/depth_wise_conv_3x3.h" +#include "tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv.h" +#include "tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv_3x3.h" #include "tensorflow/lite/delegates/gpu/cl/precision.h" namespace tflite { diff --git a/tensorflow/lite/delegates/gpu/common/BUILD b/tensorflow/lite/delegates/gpu/common/BUILD index 7d1a36d5a61..3d3f685f66a 100644 --- a/tensorflow/lite/delegates/gpu/common/BUILD +++ b/tensorflow/lite/delegates/gpu/common/BUILD @@ -101,10 +101,25 @@ cc_test( srcs = ["model_test.cc"], deps = [ ":model", + ":status", "@com_google_googletest//:gtest_main", ], ) +cc_library( + name = "model_builder_helper", + hdrs = ["model_builder_helper.h"], + deps = [ + ":status", + "//tensorflow/lite:context", + "//tensorflow/lite:kernel_api", + "//tensorflow/lite/c:common", + "//tensorflow/lite/delegates:utils", + "//tensorflow/lite/schema:schema_fbs", + "@com_google_absl//absl/strings", + ], +) + cc_library( name = "model_builder", srcs = ["model_builder.cc"], @@ -113,6 +128,7 @@ cc_library( ":custom_parsers", ":data_type", ":model", + ":model_builder_helper", ":operations", ":shape", ":status", @@ -121,7 +137,6 @@ cc_library( "//tensorflow/lite:kernel_api", "//tensorflow/lite:util", "//tensorflow/lite/c:common", - "//tensorflow/lite/delegates:utils", "//tensorflow/lite/delegates/gpu/common/transformations:general_transformations", "//tensorflow/lite/kernels:kernel_util", "//tensorflow/lite/schema:schema_fbs", diff --git a/tensorflow/lite/delegates/gpu/common/gpu_info.h b/tensorflow/lite/delegates/gpu/common/gpu_info.h index 1e93a8fe064..fa50dc99d4f 100644 --- a/tensorflow/lite/delegates/gpu/common/gpu_info.h +++ b/tensorflow/lite/delegates/gpu/common/gpu_info.h @@ -22,7 +22,17 @@ limitations under the License. namespace tflite { namespace gpu { -enum class GpuType { UNKNOWN, MALI, ADRENO, POWERVR, INTEL, NVIDIA }; +// The VendorID returned by the GPU driver. +enum class GpuType { + UNKNOWN, + APPLE, + MALI, + ADRENO, + POWERVR, + INTEL, + AMD, + NVIDIA, +}; enum class GpuModel { UNKNOWN, // Adreno 6xx series diff --git a/tensorflow/lite/delegates/gpu/common/model.h b/tensorflow/lite/delegates/gpu/common/model.h index 2e38bcc5f3f..1a68b6975dd 100644 --- a/tensorflow/lite/delegates/gpu/common/model.h +++ b/tensorflow/lite/delegates/gpu/common/model.h @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include #include #include #include @@ -181,6 +182,7 @@ class Model : public Graph { return FilterValues([](const ValueDef&) { return true; }); } + // Returns nodes in the execution order. std::vector nodes() const final { return FilterNodes([](const NodeDef&) { return true; }); } @@ -212,7 +214,7 @@ class Model : public Graph { if (id >= nodes_.size()) { return {}; } - return nodes_[id].node.get(); + return nodes_.at(id).node.get(); } Value* GetValue(ValueId id) const final { @@ -222,15 +224,40 @@ class Model : public Graph { return values_[id].value.get(); } + // Append Node to the end of the execution plan. Node* NewNode() final { + const NodeId new_id = nodes_.size(); NodeDef def; - def.node = - absl::make_unique(Node{static_cast(nodes_.size()), {}}); + def.node = absl::make_unique(Node{static_cast(new_id), {}}); Node* node = def.node.get(); - nodes_.push_back(std::move(def)); + nodes_[new_id] = std::move(def); + execution_plan_.push_back(new_id); return node; } + // Insert Node after another in the execution plan. + absl::Status InsertNodeAfter(NodeId id, Node** new_node) { + if (id >= nodes_.size()) { + return absl::OutOfRangeError("NodeId is out of range"); + } + int idx = 0; + while (idx < execution_plan_.size()) { + if (execution_plan_[idx] == id) break; + ++idx; + } + if (idx == execution_plan_.size()) { + return absl::OutOfRangeError("NodeId not in execution plan"); + } + + const NodeId new_id = nodes_.size(); + NodeDef def; + def.node = absl::make_unique(Node{static_cast(new_id), {}}); + *new_node = def.node.get(); + nodes_[new_id] = std::move(def); + execution_plan_.insert(execution_plan_.begin() + idx + 1, new_id); + return absl::OkStatus(); + } + Value* NewValue() final { ValueDef def; def.value = absl::make_unique>( @@ -244,14 +271,14 @@ class Model : public Graph { if (id >= nodes_.size()) { return {}; } - return nodes_[id].inputs; + return nodes_.at(id).inputs; } std::vector*> FindOutputs(NodeId id) const final { if (id >= nodes_.size()) { return {}; } - return nodes_[id].outputs; + return nodes_.at(id).outputs; } Node* FindProducer(ValueId id) const final { @@ -422,6 +449,7 @@ class Model : public Graph { absl::Status MakeExactCopy(Model* model) const { model->nodes_.clear(); + model->execution_plan_.clear(); model->values_.clear(); model->name_ = name_; for (auto& value_def : values_) { @@ -431,10 +459,19 @@ class Model : public Graph { absl::make_unique>(*value_def.value); } } - for (auto& node_def : nodes_) { - model->nodes_.push_back({}); + // Add all nodes first. + for (auto node_id : execution_plan_) { + model->execution_plan_.push_back(node_id); + model->nodes_[node_id] = {}; + auto& node_def = nodes_.at(node_id); + if (node_def.node) { + model->nodes_[node_id].node = absl::make_unique(*node_def.node); + } + } + // Wire up dependencies between nodes. + for (auto node_id : execution_plan_) { + auto& node_def = nodes_.at(node_id); if (node_def.node) { - model->nodes_.back().node = absl::make_unique(*node_def.node); for (auto output : node_def.outputs) { RETURN_IF_ERROR(model->SetProducer(node_def.node->id, output->id)); } @@ -519,7 +556,8 @@ class Model : public Graph { std::vector FilterNodes(const Pred& predicate) const { std::vector nodes; nodes.reserve(nodes_.size()); - for (auto& n : nodes_) { + for (const auto id : execution_plan_) { + auto& n = nodes_.at(id); if (n.node != nullptr && predicate(n)) { nodes.push_back(n.node.get()); } @@ -533,7 +571,10 @@ class Model : public Graph { // unique_ptr and store it in values_ and nodes_ or store it by value. // We store it by value here to make introspection calls cheaper. std::vector values_; - std::vector nodes_; + + std::map nodes_; + // Node Ids in order of execution. + std::vector execution_plan_; }; // Removes to_remove node that precedes to_keep node only if to_remove has diff --git a/tensorflow/lite/delegates/gpu/common/model_builder.cc b/tensorflow/lite/delegates/gpu/common/model_builder.cc index 94899efe91e..e2cc431e79b 100644 --- a/tensorflow/lite/delegates/gpu/common/model_builder.cc +++ b/tensorflow/lite/delegates/gpu/common/model_builder.cc @@ -20,7 +20,6 @@ limitations under the License. #include #include #include -#include #include #include #include @@ -43,12 +42,12 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/custom_parsers.h" #include "tensorflow/lite/delegates/gpu/common/data_type.h" #include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/model_builder_helper.h" #include "tensorflow/lite/delegates/gpu/common/operations.h" #include "tensorflow/lite/delegates/gpu/common/shape.h" #include "tensorflow/lite/delegates/gpu/common/status.h" #include "tensorflow/lite/delegates/gpu/common/tensor.h" #include "tensorflow/lite/delegates/gpu/common/transformations/general_transformations.h" -#include "tensorflow/lite/delegates/utils.h" #include "tensorflow/lite/kernels/kernel_util.h" #include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/util.h" @@ -1092,11 +1091,15 @@ class ElementwiseOperationParser : public TFLiteOperationParser { /*runtime_inputs=*/1, /*const_inputs=*/0, /*outputs=*/1)); - } else if (IsTwoArgumentOperation()) { - RETURN_IF_ERROR(CheckInputsConstsOutputs(context, tflite_node, - /*runtime_inputs=*/2, - /*const_inputs=*/0, - /*outputs=*/1)); + // For some elementwise operations (currently only for SUB operation) + // second condition may be false. But it's worth checking the next case + // with const input, which may be supported. + } else if (IsTwoArgumentOperation() && + CheckInputsConstsOutputs(context, tflite_node, + /*runtime_inputs=*/2, + /*const_inputs=*/0, + /*outputs=*/1) + .ok()) { } else if (IsTwoArgumentOperationWithConst()) { RETURN_IF_ERROR(CheckInputsConstsOutputs(context, tflite_node, /*runtime_inputs=*/1, @@ -1124,11 +1127,13 @@ class ElementwiseOperationParser : public TFLiteOperationParser { /*outputs=*/1)); RETURN_IF_ERROR(reader->AddInput(node, 0)); - } else if (IsTwoArgumentOperation()) { - RETURN_IF_ERROR(reader->VerifyInputsConstsOutputs(tflite_node, - /*runtime_inputs=*/2, - /*const_inputs=*/0, - /*outputs=*/1)); + } else if (IsTwoArgumentOperation() && + reader + ->VerifyInputsConstsOutputs(tflite_node, + /*runtime_inputs=*/2, + /*const_inputs=*/0, + /*outputs=*/1) + .ok()) { if (tflite_node->inputs->size != 2) { return absl::InvalidArgumentError("Applies only two input tensors"); } @@ -2595,6 +2600,37 @@ class Landmarks2TransformMatrixOperationParser : public TFLiteOperationParser { output_value->tensor.shape = output_shape; return absl::OkStatus(); } +}; + +class AlignmentPointsToTransformMatrixOperationParser + : public TFLiteOperationParser { + public: + absl::Status IsSupported(const TfLiteContext* context, + const TfLiteNode* tflite_node, + const TfLiteRegistration* registration) final { + return CheckInputsOutputs(context, tflite_node, /*runtime_inputs=*/1, + /*outputs=*/1); + } + + absl::Status Parse(const TfLiteNode* tflite_node, + const TfLiteRegistration* registration, + GraphFloat32* graph, ObjectReader* reader) final { + Node* node = graph->NewNode(); + RETURN_IF_ERROR(reader->AddInput(node, 0)); // alignment points + RETURN_IF_ERROR(reader->AddOutputs(node)); // transform matrix + + const std::string op_name = "alignment_points_to_transform_matrix"; + node->operation.type = op_name; + BHWC output_shape; + RETURN_IF_ERROR( + ParseCustomAttributes(op_name, tflite_node->custom_initial_data, + tflite_node->custom_initial_data_size, + &(node->operation.attributes), &output_shape)); + + auto output_value = graph->FindOutputs(node->id)[0]; + output_value->tensor.shape = output_shape; + return absl::OkStatus(); + } private: }; @@ -2779,202 +2815,16 @@ std::unique_ptr NewOperationParser( return absl::make_unique(); } + if (custom_name == "AlignmentPointsToTransformMatrix") { + return absl::make_unique< + AlignmentPointsToTransformMatrixOperationParser>(); + } + break; } return absl::make_unique(); } -absl::Status GetNodeAndRegistration(TfLiteContext* context, int node_id, - TfLiteNode** tflite_node, - TfLiteRegistration** registration) { - if (context->GetNodeAndRegistration(context, node_id, tflite_node, - registration) != kTfLiteOk) { - return absl::InvalidArgumentError(absl::StrCat( - "Couldn't get node and registration info for op: ", node_id)); - } - return absl::OkStatus(); -} - -using IsNodeSupportedFn = tflite::delegates::IsNodeSupportedFn; - -class GraphWithDequantPartitionHelper - : public tflite::delegates::GraphPartitionHelper { - public: - GraphWithDequantPartitionHelper(TfLiteContext* context, - IsNodeSupportedFn is_node_supported_fn) - : GraphPartitionHelper(context, std::move(is_node_supported_fn)) {} - - TfLiteStatus Partition( - std::set* unsupported_nodes_info) override { - const auto status = GraphPartitionHelper::Partition(unsupported_nodes_info); - // Clean up those partitions that have a single dequant op. NoteThose - // removed dequant ops have to be reserved in the graph and should not be - // delegated. - RemoveSingleDequantNodePartitions(); - return status; - } - - // Returns a list of node indices of all nodes from the first n largest - // partitions. If there are fewer paritions than n, all nodes will be - // returned. The partition is ranked according to the number of nodes. - std::vector GetNodesOfFirstNLargestPartitions(int n) { - // We first get partitions to reduce the number of nodes to be checked in - // deciding which dequant ops could actually be replaced. And then we - // remap input-tensor to dequant nodes' inputs and remove those - // to-be-reserved dequant nodes. - auto first_nps = GetFirstNLargestPartitions(n); - std::vector ops_to_replace; - for (const auto p : first_nps) { - auto nodes = p->nodes_to_replace; - ops_to_replace.insert(ops_to_replace.end(), nodes->data, - nodes->data + nodes->size); - } - RemapInputTensors(ops_to_replace); - RemoveReservedDequantsFromNodes(&ops_to_replace); - return ops_to_replace; - } - - protected: - bool IsNodeSupported(TfLiteContext* context, TfLiteNode* node, - TfLiteRegistration* registration, int node_id, - std::string* unsupported_details) override { - // If we need to handle dequant nodes, we have to remap input tensors of - // this node if some of them come from a dequant node before testing if - // the node is supported. - std::vector orig_inputs; - if (RecordAndRemapInputTensors(registration->builtin_code, node_id, node, - &orig_inputs)) { - // We have a dequant op here. Note that we retrun an Ok status because a - // dequant node is first added as supported. Later, this dequant node - // will be removed if it has to be preserved in the graph which happens - // when its immediate downstream nodes cannot be supported. - return true; - } - const auto status = GraphPartitionHelper::IsNodeSupported( - context, node, registration, node_id, unsupported_details); - RestoreToOrigInputTensors(node, orig_inputs); - return status; - } - - private: - // Record 'node' if it is a dequant op (i.e. a fp16 one here) and return true. - // When it's not a dequant op, remap its inputs to the inputs of the preceding - // dequant if there's a one and returns false. 'orig_inputs' records original - // input tensor ids of this node if any input is remapped. - bool RecordAndRemapInputTensors(int32_t op_code, int node_id, - TfLiteNode* node, - std::vector* orig_inputs) { - orig_inputs->clear(); - // Record the dequant node. - if (op_code == kTfLiteBuiltinDequantize && - context_->tensors[node->inputs->data[0]].type == - TfLiteType::kTfLiteFloat16) { - dequant_nodes_[node->outputs->data[0]] = node->inputs->data[0]; - return true; - } - // For a dequantize op, there's no need to remap its input tensors. - if (dequant_nodes_.empty()) return false; - RemapInputTensors(node, orig_inputs); - return false; - } - - // Restore inputs of 'node' to 'orig_inputs' only if two sizes match. - void RestoreToOrigInputTensors(TfLiteNode* node, - const std::vector& orig_inputs) { - if (node->inputs->size != orig_inputs.size()) return; - for (int j = 0; j < node->inputs->size; ++j) { - node->inputs->data[j] = orig_inputs[j]; - } - } - - // Remap input tensors of every node in 'nodes' (i.e. node indices) if some of - // them are from dequant ops. - void RemapInputTensors(const std::vector& nodes) const { - for (int node_id : nodes) { - TfLiteNode* node; - TfLiteRegistration* registration; - GetNodeAndRegistration(context_, node_id, &node, ®istration) - .IgnoreError(); - RemapInputTensors(node, nullptr /* orig_inputs*/); - } - } - - void RemoveSingleDequantNodePartitions() { - auto it = partitions_.begin(); - while (it != partitions_.end()) { - auto p = *it; - if (p->nodes_to_replace->size != 1) { - ++it; - continue; - } - int node_id = p->nodes_to_replace->data[0]; - TfLiteNode* node = nullptr; - TfLiteRegistration* registration = nullptr; - GetNodeAndRegistration(context_, node_id, &node, ®istration) - .IgnoreError(); - if (registration->builtin_code != kTfLiteBuiltinDequantize) { - ++it; - continue; - } - // Note such dequant nodes have to be preserved in the graph as dequant - // ops are not actually supported in the GPU delegate. - dequant_nodes_to_save_.insert(node_id); - it = partitions_.erase(it); - } - } - - void RemoveReservedDequantsFromNodes(std::vector* nodes) { - if (dequant_nodes_to_save_.empty()) return; - auto it = nodes->begin(); - while (it != nodes->end()) { - if (dequant_nodes_to_save_.find(*it) == dequant_nodes_to_save_.end()) { - ++it; - continue; - } - it = nodes->erase(it); - } - } - - // Remap input tensors of a single 'node' if some of come from a dequant op. - // If 'orig_inputs' isn't nullptr, it records original input tensor ids of - // this node if any input is remapped. - void RemapInputTensors(TfLiteNode* node, - std::vector* orig_inputs) const { - TfLiteIntArray* inputs = node->inputs; - auto inputs_view = TfLiteIntArrayView(inputs); - // Prepopulate 'orig_inputs' first and clear it if there's no input from a - // dequant op. - if (orig_inputs) { - orig_inputs->clear(); - orig_inputs->reserve(inputs->size); - for (auto tid : inputs_view) { - orig_inputs->push_back(tid); - } - } - // Fix this node's inputs (i.e. prune out the preceding dequantize node) in - // order to test if it is supported. - bool is_remapped = false; - for (int j = 0; j < inputs->size; ++j) { - const int input_tid = inputs->data[j]; - const auto it = dequant_nodes_.find(input_tid); - if (it != dequant_nodes_.end()) { - inputs->data[j] = it->second; - is_remapped = true; - } - } - if (!is_remapped && orig_inputs) orig_inputs->clear(); - } - - // A map recording dequantize nodes's input/output tensors of this selected - // graph. The key is the output tensor id, and the value is the input tensor - // id. - std::unordered_map dequant_nodes_; - - // A set of dequant nodes as in node indices that have to be preserved in the - // graph. - std::set dequant_nodes_to_save_; -}; - absl::Status IsSupported(const TfLiteContext* context, TfLiteNode* node, const TfLiteRegistration* registration) { return NewOperationParser(registration) diff --git a/tensorflow/lite/delegates/gpu/common/model_builder_helper.h b/tensorflow/lite/delegates/gpu/common/model_builder_helper.h new file mode 100644 index 00000000000..cf1367079c7 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/common/model_builder_helper.h @@ -0,0 +1,226 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_MODEL_BUILDER_HELPER_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_MODEL_BUILDER_HELPER_H_ + +#include +#include +#include + +#include "absl/strings/str_cat.h" +#include "tensorflow/lite/builtin_ops.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/context.h" +#include "tensorflow/lite/context_util.h" +#include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/utils.h" +#include "tensorflow/lite/schema/schema_generated.h" + +namespace tflite { +namespace gpu { +inline absl::Status GetNodeAndRegistration(TfLiteContext* context, int node_id, + TfLiteNode** tflite_node, + TfLiteRegistration** registration) { + if (context->GetNodeAndRegistration(context, node_id, tflite_node, + registration) != kTfLiteOk) { + return absl::InvalidArgumentError(absl::StrCat( + "Couldn't get node and registration info for op: ", node_id)); + } + return absl::OkStatus(); +} + +using IsNodeSupportedFn = tflite::delegates::IsNodeSupportedFn; + +class GraphWithDequantPartitionHelper + : public tflite::delegates::GraphPartitionHelper { + public: + GraphWithDequantPartitionHelper(TfLiteContext* context, + IsNodeSupportedFn is_node_supported_fn) + : GraphPartitionHelper(context, std::move(is_node_supported_fn)) {} + + TfLiteStatus Partition( + std::set* unsupported_nodes_info) override { + const auto status = GraphPartitionHelper::Partition(unsupported_nodes_info); + // Clean up those partitions that have a single dequant op. NoteThose + // removed dequant ops have to be reserved in the graph and should not be + // delegated. + RemoveSingleDequantNodePartitions(); + return status; + } + + // Returns a list of node indices of all nodes from the first n largest + // partitions. If there are fewer paritions than n, all nodes will be + // returned. The partition is ranked according to the number of nodes. + std::vector GetNodesOfFirstNLargestPartitions(int n) { + // We first get partitions to reduce the number of nodes to be checked in + // deciding which dequant ops could actually be replaced. And then we + // remap input-tensor to dequant nodes' inputs and remove those + // to-be-reserved dequant nodes. + auto first_nps = GetFirstNLargestPartitions(n); + std::vector ops_to_replace; + for (const auto p : first_nps) { + auto nodes = p->nodes_to_replace; + ops_to_replace.insert(ops_to_replace.end(), nodes->data, + nodes->data + nodes->size); + } + RemapInputTensors(ops_to_replace); + RemoveReservedDequantsFromNodes(&ops_to_replace); + return ops_to_replace; + } + + protected: + bool IsNodeSupported(TfLiteContext* context, TfLiteNode* node, + TfLiteRegistration* registration, int node_id, + std::string* unsupported_details) override { + // If we need to handle dequant nodes, we have to remap input tensors of + // this node if some of them come from a dequant node before testing if + // the node is supported. + std::vector orig_inputs; + if (RecordAndRemapInputTensors(registration->builtin_code, node_id, node, + &orig_inputs)) { + // We have a dequant op here. Note that we retrun an Ok status because a + // dequant node is first added as supported. Later, this dequant node + // will be removed if it has to be preserved in the graph which happens + // when its immediate downstream nodes cannot be supported. + return true; + } + const auto status = GraphPartitionHelper::IsNodeSupported( + context, node, registration, node_id, unsupported_details); + RestoreToOrigInputTensors(node, orig_inputs); + return status; + } + + private: + // Record 'node' if it is a dequant op (i.e. a fp16 one here) and return true. + // When it's not a dequant op, remap its inputs to the inputs of the preceding + // dequant if there's a one and returns false. 'orig_inputs' records original + // input tensor ids of this node if any input is remapped. + bool RecordAndRemapInputTensors(int32_t op_code, int node_id, + TfLiteNode* node, + std::vector* orig_inputs) { + orig_inputs->clear(); + // Record the dequant node. + if (op_code == kTfLiteBuiltinDequantize && + context_->tensors[node->inputs->data[0]].type == + TfLiteType::kTfLiteFloat16) { + dequant_nodes_[node->outputs->data[0]] = node->inputs->data[0]; + return true; + } + // For a dequantize op, there's no need to remap its input tensors. + if (dequant_nodes_.empty()) return false; + RemapInputTensors(node, orig_inputs); + return false; + } + + // Restore inputs of 'node' to 'orig_inputs' only if two sizes match. + void RestoreToOrigInputTensors(TfLiteNode* node, + const std::vector& orig_inputs) { + if (node->inputs->size != orig_inputs.size()) return; + for (int j = 0; j < node->inputs->size; ++j) { + node->inputs->data[j] = orig_inputs[j]; + } + } + + // Remap input tensors of every node in 'nodes' (i.e. node indices) if some of + // them are from dequant ops. + void RemapInputTensors(const std::vector& nodes) const { + for (int node_id : nodes) { + TfLiteNode* node; + TfLiteRegistration* registration; + GetNodeAndRegistration(context_, node_id, &node, ®istration) + .IgnoreError(); + RemapInputTensors(node, nullptr /* orig_inputs*/); + } + } + + void RemoveSingleDequantNodePartitions() { + auto it = partitions_.begin(); + while (it != partitions_.end()) { + auto p = *it; + if (p->nodes_to_replace->size != 1) { + ++it; + continue; + } + int node_id = p->nodes_to_replace->data[0]; + TfLiteNode* node = nullptr; + TfLiteRegistration* registration = nullptr; + GetNodeAndRegistration(context_, node_id, &node, ®istration) + .IgnoreError(); + if (registration->builtin_code != kTfLiteBuiltinDequantize) { + ++it; + continue; + } + // Note such dequant nodes have to be preserved in the graph as dequant + // ops are not actually supported in the GPU delegate. + dequant_nodes_to_save_.insert(node_id); + it = partitions_.erase(it); + } + } + + void RemoveReservedDequantsFromNodes(std::vector* nodes) { + if (dequant_nodes_to_save_.empty()) return; + auto it = nodes->begin(); + while (it != nodes->end()) { + if (dequant_nodes_to_save_.find(*it) == dequant_nodes_to_save_.end()) { + ++it; + continue; + } + it = nodes->erase(it); + } + } + + // Remap input tensors of a single 'node' if some of come from a dequant op. + // If 'orig_inputs' isn't nullptr, it records original input tensor ids of + // this node if any input is remapped. + void RemapInputTensors(TfLiteNode* node, + std::vector* orig_inputs) const { + TfLiteIntArray* inputs = node->inputs; + auto inputs_view = TfLiteIntArrayView(inputs); + // Prepopulate 'orig_inputs' first and clear it if there's no input from a + // dequant op. + if (orig_inputs) { + orig_inputs->clear(); + orig_inputs->reserve(inputs->size); + for (auto tid : inputs_view) { + orig_inputs->push_back(tid); + } + } + // Fix this node's inputs (i.e. prune out the preceding dequantize node) in + // order to test if it is supported. + bool is_remapped = false; + for (int j = 0; j < inputs->size; ++j) { + const int input_tid = inputs->data[j]; + const auto it = dequant_nodes_.find(input_tid); + if (it != dequant_nodes_.end()) { + inputs->data[j] = it->second; + is_remapped = true; + } + } + if (!is_remapped && orig_inputs) orig_inputs->clear(); + } + + // A map recording dequantize nodes's input/output tensors of this selected + // graph. The key is the output tensor id, and the value is the input tensor + // id. + std::unordered_map dequant_nodes_; + + // A set of dequant nodes as in node indices that have to be preserved in the + // graph. + std::set dequant_nodes_to_save_; +}; +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_MODEL_BUILDER_HELPER_H_ diff --git a/tensorflow/lite/delegates/gpu/common/model_builder_test.cc b/tensorflow/lite/delegates/gpu/common/model_builder_test.cc index 5cad4d186aa..214d02599d5 100644 --- a/tensorflow/lite/delegates/gpu/common/model_builder_test.cc +++ b/tensorflow/lite/delegates/gpu/common/model_builder_test.cc @@ -356,7 +356,7 @@ TEST(ModelBuilderTest, GetOpsToReplaceKeepsFp16DequantizeNodes) { params->output_tensors->data[0] = 1; // Second partition for DequantNode (t2->t3) - params = interpreter_fp16_add_op->add_delegate_params(); + params = interpreter_fp16_gt_op->add_delegate_params(); params->nodes_to_replace = TfLiteIntArrayCreate(1); params->nodes_to_replace->data[0] = 0; params->input_tensors = TfLiteIntArrayCreate(1); diff --git a/tensorflow/lite/delegates/gpu/common/model_test.cc b/tensorflow/lite/delegates/gpu/common/model_test.cc index 9d3d91b837a..6395bbaa158 100644 --- a/tensorflow/lite/delegates/gpu/common/model_test.cc +++ b/tensorflow/lite/delegates/gpu/common/model_test.cc @@ -37,7 +37,7 @@ TEST(Model, SingleNode) { ASSERT_TRUE(graph.AddConsumer(node->id, graph_input->id).ok()); ASSERT_TRUE(graph.SetProducer(node->id, graph_output->id).ok()); - EXPECT_THAT(graph.nodes(), UnorderedElementsAre(node)); + EXPECT_THAT(graph.nodes(), ElementsAre(node)); EXPECT_THAT(graph.values(), UnorderedElementsAre(graph_input, graph_output)); EXPECT_THAT(graph.inputs(), UnorderedElementsAre(graph_input)); EXPECT_THAT(graph.outputs(), UnorderedElementsAre(graph_output)); @@ -149,12 +149,12 @@ TEST(Model, RemoveSimpleNodeDegenerateCase) { ASSERT_TRUE(graph.SetProducer(node->id, graph_output->id).ok()); EXPECT_THAT(graph.inputs(), UnorderedElementsAre(graph_input)); EXPECT_THAT(graph.outputs(), UnorderedElementsAre(graph_output)); - EXPECT_THAT(graph.nodes(), UnorderedElementsAre(node)); + EXPECT_THAT(graph.nodes(), ElementsAre(node)); ASSERT_TRUE(RemoveOneInputOneOutputNode(&graph, node).ok()); EXPECT_THAT(graph.inputs(), UnorderedElementsAre()); EXPECT_THAT(graph.outputs(), UnorderedElementsAre()); - EXPECT_THAT(graph.nodes(), UnorderedElementsAre()); + EXPECT_THAT(graph.nodes(), ElementsAre()); } TEST(Model, RemoveSimpleNodeNoPreviousNode) { @@ -171,12 +171,12 @@ TEST(Model, RemoveSimpleNodeNoPreviousNode) { ASSERT_TRUE(graph.SetProducer(consumer_node->id, graph_output->id).ok()); EXPECT_THAT(graph.inputs(), UnorderedElementsAre(graph_input)); EXPECT_THAT(graph.outputs(), UnorderedElementsAre(graph_output)); - EXPECT_THAT(graph.nodes(), UnorderedElementsAre(simple_node, consumer_node)); + EXPECT_THAT(graph.nodes(), ElementsAre(simple_node, consumer_node)); ASSERT_TRUE(RemoveOneInputOneOutputNode(&graph, simple_node).ok()); EXPECT_THAT(graph.inputs(), UnorderedElementsAre(graph_input)); EXPECT_THAT(graph.outputs(), UnorderedElementsAre(graph_output)); - EXPECT_THAT(graph.nodes(), UnorderedElementsAre(consumer_node)); + EXPECT_THAT(graph.nodes(), ElementsAre(consumer_node)); } TEST(Model, RemoveSimpleNodeNoAfterNodes) { @@ -193,12 +193,12 @@ TEST(Model, RemoveSimpleNodeNoAfterNodes) { ASSERT_TRUE(graph.SetProducer(producer_node->id, value->id).ok()); EXPECT_THAT(graph.inputs(), UnorderedElementsAre(graph_input)); EXPECT_THAT(graph.outputs(), UnorderedElementsAre(graph_output)); - EXPECT_THAT(graph.nodes(), UnorderedElementsAre(simple_node, producer_node)); + EXPECT_THAT(graph.nodes(), ElementsAre(simple_node, producer_node)); ASSERT_TRUE(RemoveOneInputOneOutputNode(&graph, simple_node).ok()); EXPECT_THAT(graph.inputs(), UnorderedElementsAre(graph_input)); EXPECT_THAT(graph.outputs(), UnorderedElementsAre(value)); - EXPECT_THAT(graph.nodes(), UnorderedElementsAre(producer_node)); + EXPECT_THAT(graph.nodes(), ElementsAre(producer_node)); } TEST(Model, RemoveSimpleNodeGeneralCase) { @@ -220,13 +220,12 @@ TEST(Model, RemoveSimpleNodeGeneralCase) { EXPECT_THAT(graph.inputs(), UnorderedElementsAre(graph_input)); EXPECT_THAT(graph.outputs(), UnorderedElementsAre(graph_output)); EXPECT_THAT(graph.nodes(), - UnorderedElementsAre(simple_node, producer_node, consumer_node)); + ElementsAre(simple_node, producer_node, consumer_node)); ASSERT_TRUE(RemoveOneInputOneOutputNode(&graph, simple_node).ok()); EXPECT_THAT(graph.inputs(), UnorderedElementsAre(graph_input)); EXPECT_THAT(graph.outputs(), UnorderedElementsAre(graph_output)); - EXPECT_THAT(graph.nodes(), - UnorderedElementsAre(producer_node, consumer_node)); + EXPECT_THAT(graph.nodes(), ElementsAre(producer_node, consumer_node)); EXPECT_THAT(graph.values(), UnorderedElementsAre(graph_input, graph_output, value0)); } @@ -275,12 +274,12 @@ TEST(Model, RemoveSimpleNodeComplexCase) { ASSERT_TRUE(graph.SetProducer(n2->id, o2->id).ok()); EXPECT_THAT(graph.inputs(), UnorderedElementsAre(v0, v1, v3)); EXPECT_THAT(graph.outputs(), UnorderedElementsAre(o1, o2)); - EXPECT_THAT(graph.nodes(), UnorderedElementsAre(n0, n1, n2)); + EXPECT_THAT(graph.nodes(), ElementsAre(n0, n1, n2)); ASSERT_TRUE(RemoveOneInputOneOutputNode(&graph, n1).ok()); EXPECT_THAT(graph.inputs(), UnorderedElementsAre(v0, v1, v3)); EXPECT_THAT(graph.outputs(), UnorderedElementsAre(o1, o2)); - EXPECT_THAT(graph.nodes(), UnorderedElementsAre(n0, n2)); + EXPECT_THAT(graph.nodes(), ElementsAre(n0, n2)); EXPECT_THAT(graph.values(), UnorderedElementsAre(v0, v1, v3, o1, o2)); EXPECT_THAT(graph.FindInputs(n0->id), ElementsAre(v0, v1)); EXPECT_THAT(graph.FindInputs(n2->id), ElementsAre(v1, v3)); @@ -321,7 +320,7 @@ TEST(Model, ReassignValue) { // \ -> node2 -> graph_output ASSERT_TRUE(graph.SetProducer(node2->id, graph_output->id).ok()); - EXPECT_THAT(graph.nodes(), UnorderedElementsAre(node1, node2)); + EXPECT_THAT(graph.nodes(), ElementsAre(node1, node2)); EXPECT_THAT(graph.FindInputs(node1->id), UnorderedElementsAre(graph_input)); EXPECT_THAT(graph.FindInputs(node2->id), UnorderedElementsAre(graph_input)); EXPECT_THAT(graph.FindOutputs(node1->id), UnorderedElementsAre()); @@ -389,7 +388,7 @@ TEST(Model, DeleteNode) { ASSERT_TRUE(graph.SetProducer(node2->id, graph_output->id).ok()); ASSERT_TRUE(graph.SetProducer(node3->id, graph_output2->id).ok()); - EXPECT_THAT(graph.nodes(), UnorderedElementsAre(node1, node2, node3)); + EXPECT_THAT(graph.nodes(), ElementsAre(node1, node2, node3)); EXPECT_THAT(graph.inputs(), UnorderedElementsAre(graph_input)); EXPECT_THAT(graph.outputs(), UnorderedElementsAre(graph_output, graph_output2)); @@ -403,7 +402,7 @@ TEST(Model, DeleteNode) { // graph_output2 ASSERT_TRUE(graph.DeleteNode(node3->id).ok()); node3 = nullptr; - EXPECT_THAT(graph.nodes(), UnorderedElementsAre(node1, node2)); + EXPECT_THAT(graph.nodes(), ElementsAre(node1, node2)); EXPECT_THAT(graph.inputs(), UnorderedElementsAre(graph_input, graph_output2)); EXPECT_THAT(graph.outputs(), UnorderedElementsAre(graph_output, graph_output2)); @@ -414,7 +413,7 @@ TEST(Model, DeleteNode) { // graph_output2 ASSERT_TRUE(graph.DeleteNode(node1->id).ok()); node1 = nullptr; - EXPECT_THAT(graph.nodes(), UnorderedElementsAre(node2)); + EXPECT_THAT(graph.nodes(), ElementsAre(node2)); EXPECT_THAT(graph.inputs(), UnorderedElementsAre(value, graph_output2, graph_input)); EXPECT_THAT(graph.outputs(), @@ -424,7 +423,7 @@ TEST(Model, DeleteNode) { ASSERT_TRUE(graph.DeleteNode(node2->id).ok()); node2 = nullptr; - EXPECT_THAT(graph.nodes(), UnorderedElementsAre()); + EXPECT_THAT(graph.nodes(), ElementsAre()); EXPECT_THAT(graph.inputs(), UnorderedElementsAre(graph_output, graph_output2, graph_input, value)); EXPECT_THAT(graph.outputs(), UnorderedElementsAre(graph_output, graph_output2, @@ -433,6 +432,40 @@ TEST(Model, DeleteNode) { EXPECT_THAT(graph.FindProducer(value->id), ::testing::Eq(nullptr)); } +TEST(Model, InsertNodeAfter) { + // graph_input -> node1 -> value -> node2 -> graph_output + GraphFloat32 graph; + Node* node1 = graph.NewNode(); + Node* node2 = graph.NewNode(); + Value>* graph_input = graph.NewValue(); + Value>* graph_output = graph.NewValue(); + Value>* value = graph.NewValue(); + ASSERT_TRUE(graph.AddConsumer(node1->id, graph_input->id).ok()); + ASSERT_TRUE(graph.SetProducer(node1->id, value->id).ok()); + ASSERT_TRUE(graph.AddConsumer(node2->id, value->id).ok()); + ASSERT_TRUE(graph.SetProducer(node2->id, graph_output->id).ok()); + + EXPECT_THAT(graph.nodes(), ElementsAre(node1, node2)); + EXPECT_THAT(graph.inputs(), UnorderedElementsAre(graph_input)); + EXPECT_THAT(graph.outputs(), UnorderedElementsAre(graph_output)); + EXPECT_THAT(graph.FindConsumers(value->id), UnorderedElementsAre(node2)); + EXPECT_THAT(graph.FindProducer(value->id), ::testing::Eq(node1)); + EXPECT_THAT(graph.FindInputs(node2->id), UnorderedElementsAre(value)); + + Node* new_node1; + absl::Status status = graph.InsertNodeAfter(node1->id, &new_node1); + ASSERT_TRUE(status.ok()); + EXPECT_THAT(graph.nodes(), ElementsAre(node1, new_node1, node2)); + + Node* new_node2; + status = graph.InsertNodeAfter(/*id=*/100, &new_node2); + EXPECT_EQ(status.code(), absl::StatusCode::kOutOfRange); + + status = graph.InsertNodeAfter(node2->id, &new_node2); + ASSERT_TRUE(status.ok()); + EXPECT_THAT(graph.nodes(), ElementsAre(node1, new_node1, node2, new_node2)); +} + } // namespace } // namespace gpu } // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/common/transformations/add_quant_adjustments.cc b/tensorflow/lite/delegates/gpu/common/transformations/add_quant_adjustments.cc index 0011cc24dfa..19153d94f83 100644 --- a/tensorflow/lite/delegates/gpu/common/transformations/add_quant_adjustments.cc +++ b/tensorflow/lite/delegates/gpu/common/transformations/add_quant_adjustments.cc @@ -48,7 +48,12 @@ class AddQuantAdjustments : public NodeTransformation { } // Add a new QuantizeAndDequantize node. - auto* quant_and_dequant_node = graph->NewNode(); + Node* quant_and_dequant_node; + absl::Status status = + graph->InsertNodeAfter(node->id, &quant_and_dequant_node); + if (!status.ok()) { + return {TransformStatus::INVALID, "Could not insert new node."}; + } quant_and_dequant_node->operation.type = ToString(OperationType::QUANTIZE_AND_DEQUANTIZE); QuantizeAndDequantizeAttributes attr; @@ -61,7 +66,7 @@ class AddQuantAdjustments : public NodeTransformation { // The tensor information should rename the same. Value>* adjusted_value = graph->NewValue(); adjusted_value->tensor = output_value->tensor; - absl::Status status = + status = graph->SetProducer(quant_and_dequant_node->id, adjusted_value->id); if (!status.ok()) { return {TransformStatus::INVALID, diff --git a/tensorflow/lite/delegates/gpu/common/transformations/add_quant_adjustments_test.cc b/tensorflow/lite/delegates/gpu/common/transformations/add_quant_adjustments_test.cc index fc0913d2494..b392ffa87bf 100644 --- a/tensorflow/lite/delegates/gpu/common/transformations/add_quant_adjustments_test.cc +++ b/tensorflow/lite/delegates/gpu/common/transformations/add_quant_adjustments_test.cc @@ -141,13 +141,15 @@ TEST(AddQuantAdjustments, GeneralCase) { EXPECT_EQ(4, graph.nodes().size()); EXPECT_EQ(5, graph.values().size()); EXPECT_EQ(ToString(OperationType::ADD), graph.nodes()[0]->operation.type); + // The new node should be inserted at index 1, just after add1. EXPECT_EQ(ToString(OperationType::QUANTIZE_AND_DEQUANTIZE), graph.nodes()[1]->operation.type); - EXPECT_EQ(ToString(OperationType::ADD), graph.nodes()[2]->operation.type); EXPECT_EQ(ToString(OperationType::QUANTIZE_AND_DEQUANTIZE), - graph.nodes()[3]->operation.type); + graph.nodes()[2]->operation.type); + EXPECT_EQ(quant_node->id, graph.nodes()[2]->id); + EXPECT_EQ(ToString(OperationType::ADD), graph.nodes()[3]->operation.type); auto new_quant_attr = absl::any_cast( - graph.nodes()[3]->operation.attributes); + graph.nodes()[1]->operation.attributes); EXPECT_EQ(0.0, new_quant_attr.min); EXPECT_EQ(2.0, new_quant_attr.max); const auto& new_quant_consumers = graph.FindConsumers(graph.values()[4]->id); diff --git a/tensorflow/lite/delegates/gpu/delegate.cc b/tensorflow/lite/delegates/gpu/delegate.cc index 3451119c71d..81974a1db68 100644 --- a/tensorflow/lite/delegates/gpu/delegate.cc +++ b/tensorflow/lite/delegates/gpu/delegate.cc @@ -20,6 +20,7 @@ limitations under the License. #include // NOLINT(build/c++11) #include +#include "absl/memory/memory.h" #include "absl/types/span.h" #include "tensorflow/lite/builtin_ops.h" #include "tensorflow/lite/delegates/gpu/api.h" @@ -70,6 +71,28 @@ class Delegate { options_ = options ? *options : TfLiteGpuDelegateOptionsV2Default(); } + TfLiteDelegate* tflite_delegate() { return &delegate_; } + const TfLiteGpuDelegateOptionsV2& options() const { return options_; } + + private: + TfLiteDelegate delegate_ = { + .data_ = reinterpret_cast(this), + .Prepare = DelegatePrepare, + .CopyFromBufferHandle = nullptr, + .CopyToBufferHandle = nullptr, + .FreeBufferHandle = nullptr, + .flags = kTfLiteDelegateFlagsNone, + }; + + TfLiteGpuDelegateOptionsV2 options_; +}; + +// Represent the execution of a subset of nodes on GPU. +class DelegateKernel { + public: + explicit DelegateKernel(const TfLiteGpuDelegateOptionsV2& options) + : options_(options) {} + absl::Status Prepare(TfLiteContext* context, const TfLiteDelegateParams* delegate_params) { thread_id_prepare_ = std::this_thread::get_id(); @@ -133,20 +156,6 @@ class Delegate { return builder->Build(&runner_); } - absl::Status SetInputsAndOutputs(TfLiteContext* context) { - int i = 0; - for (auto index : input_indices_) { - RETURN_IF_ERROR( - runner_->SetInputObject(i++, GetTensorObject(index, context))); - } - i = 0; - for (auto index : output_indices_) { - RETURN_IF_ERROR( - runner_->SetOutputObject(i++, GetTensorObject(index, context))); - } - return absl::OkStatus(); - } - absl::Status Invoke(TfLiteContext* context) { if (thread_id_prepare_ != std::this_thread::get_id()) { TFLITE_LOG(tflite::TFLITE_LOG_WARNING, @@ -162,6 +171,19 @@ class Delegate { return runner_->Run(); } + private: + absl::Status SetInputsAndOutputs(TfLiteContext* context) { + for (int i = 0; i < input_indices_.size(); ++i) { + RETURN_IF_ERROR(runner_->SetInputObject( + i, GetTensorObject(input_indices_[i], context))); + } + for (int i = 0; i < output_indices_.size(); ++i) { + RETURN_IF_ERROR(runner_->SetOutputObject( + i, GetTensorObject(output_indices_[i], context))); + } + return absl::OkStatus(); + } + ObjectDef GetObjectDef(int index) const { ObjectDef default_object_def; default_object_def.data_type = DataType::FLOAT32; @@ -176,9 +198,6 @@ class Delegate { return MakeCpuMemory(absl::MakeSpan(tensor.data.raw, tensor.bytes)); } - TfLiteDelegate* tflite_delegate() { return &delegate_; } - - private: absl::Status InitializeOpenClApi(GraphFloat32* graph, std::unique_ptr* builder, bool* graph_is_destroyed) { @@ -230,28 +249,20 @@ class Delegate { return absl::OkStatus(); } - TfLiteDelegate delegate_ = { - reinterpret_cast(this), // .data_ - DelegatePrepare, // .Prepare - nullptr, // .CopyFromBufferHandle - nullptr, // .CopyToBufferHandle - nullptr, // .FreeBufferHandle - kTfLiteDelegateFlagsNone, // .flags - }; - - TfLiteGpuDelegateOptionsV2 options_; + // Shared across all DelegateKernel instances, passed by the Delegate + // instance. + const TfLiteGpuDelegateOptionsV2& options_; std::unique_ptr cl_environment_; std::unique_ptr gl_environment_; std::unique_ptr runner_; std::vector input_indices_; std::vector output_indices_; - std::thread::id thread_id_prepare_; // thread id used for Prapare() bool enforce_same_thread_ = false; // flag to enforce same thread for Invoke }; -inline Delegate* GetDelegate(TfLiteNode* node) { - return reinterpret_cast(node->user_data); +inline DelegateKernel* GetDelegateKernel(TfLiteNode* node) { + return reinterpret_cast(node->user_data); } inline Delegate* GetDelegate(TfLiteDelegate* delegate) { @@ -267,16 +278,20 @@ TfLiteStatus DelegatePrepare(TfLiteContext* context, TfLiteDelegate* delegate) { auto* gpu_delegate = GetDelegate(params->delegate); // Everything below should happen in prepare function call, but TFLite // for whatever reason forbids that. - const auto status = gpu_delegate->Prepare(context, params); + auto gpu_delegate_kernel = + absl::make_unique(gpu_delegate->options()); + const auto status = gpu_delegate_kernel->Prepare(context, params); if (!status.ok()) { context->ReportError(context, "TfLiteGpuDelegate Init: %s", std::string(status.message()).c_str()); return nullptr; } - return gpu_delegate; + return gpu_delegate_kernel.release(); }, // .free - [](TfLiteContext*, void* buffer) -> void {}, + [](TfLiteContext*, void* buffer) -> void { + delete reinterpret_cast(buffer); + }, // .prepare [](TfLiteContext* context, TfLiteNode* node) -> TfLiteStatus { if (!node->user_data) { @@ -292,7 +307,7 @@ TfLiteStatus DelegatePrepare(TfLiteContext* context, TfLiteDelegate* delegate) { }, // .invoke [](TfLiteContext* context, TfLiteNode* node) -> TfLiteStatus { - const auto status = GetDelegate(node)->Invoke(context); + const auto status = GetDelegateKernel(node)->Invoke(context); if (!status.ok()) { context->ReportError(context, "TfLiteGpuDelegate Invoke: %s", std::string(status.message()).c_str()); diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/fully_connected.cc b/tensorflow/lite/delegates/gpu/gl/kernels/fully_connected.cc index a8246515247..a472d4e5428 100644 --- a/tensorflow/lite/delegates/gpu/gl/kernels/fully_connected.cc +++ b/tensorflow/lite/delegates/gpu/gl/kernels/fully_connected.cc @@ -92,8 +92,15 @@ class FullyConnectedBuffers : public NodeShader { source += " $output_data_0[0, 0, gid.x] = value_0$;"; std::vector shared_variables = { +#ifdef __APPLE__ + // MoltenVK has problems with shared memory sized using the workgroup + // size. Fortunately with Metal a fixed workgroup size of 32 seems to + // give optimal results. + {"sh_mem", std::vector(32)}, +#else // The actual size of sh_mem depends on the WorkgroupSize {"sh_mem", std::vector(0)}, +#endif }; *generated_code = { diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/softmax.cc b/tensorflow/lite/delegates/gpu/gl/kernels/softmax.cc index b6c8e144a09..03a414c1547 100644 --- a/tensorflow/lite/delegates/gpu/gl/kernels/softmax.cc +++ b/tensorflow/lite/delegates/gpu/gl/kernels/softmax.cc @@ -119,7 +119,7 @@ class Softmax : public NodeShader { if (z < $depth$) { highp vec4 src = $input_data_0[0, 0, z]$; highp vec4 temp = exp(src) * sum; - $output_data_0[0, 0, z]$ = temp; + $output_data_0[0, 0, z] = temp$; offset += 32; } s++; diff --git a/tensorflow/lite/delegates/gpu/metal/BUILD b/tensorflow/lite/delegates/gpu/metal/BUILD index b407083d8d2..192c787b0db 100644 --- a/tensorflow/lite/delegates/gpu/metal/BUILD +++ b/tensorflow/lite/delegates/gpu/metal/BUILD @@ -32,6 +32,7 @@ cc_library( "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:shape", "//tensorflow/lite/delegates/gpu/common:status", + "//tensorflow/lite/delegates/gpu/common:util", "//tensorflow/lite/delegates/gpu/metal/kernels", "//tensorflow/lite/delegates/gpu/metal/kernels:custom_registry", ], @@ -172,29 +173,6 @@ objc_library( ], ) -objc_library( - name = "environment_test_lib", - testonly = 1, - srcs = ["environment_test.mm"], - sdk_frameworks = ["XCTest"], - deps = [ - ":environment", - "//tensorflow/lite/delegates/gpu/metal/kernels:test_util", - ], -) - -ios_unit_test( - name = "environment_test", - testonly = 1, - minimum_os_version = "10.0", - runner = tflite_ios_lab_runner("IOS_LATEST"), - tags = tf_gpu_tests_tags() + [ - "notap", - "tflite_not_portable_android", - ], - deps = [":environment_test_lib"], -) - objc_library( name = "inference_context", srcs = ["inference_context.mm"], @@ -272,7 +250,6 @@ objc_library( srcs = [ "//tensorflow/lite/delegates/gpu/metal:common_test.mm", "//tensorflow/lite/delegates/gpu/metal:compiled_model_test.mm", - "//tensorflow/lite/delegates/gpu/metal:environment_test.mm", "//tensorflow/lite/delegates/gpu/metal:inference_context_test.mm", ], hdrs = [ diff --git a/tensorflow/lite/delegates/gpu/metal/api.cc b/tensorflow/lite/delegates/gpu/metal/api.cc index 744094c8c03..d9c8a369592 100644 --- a/tensorflow/lite/delegates/gpu/metal/api.cc +++ b/tensorflow/lite/delegates/gpu/metal/api.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/operations.h" #include "tensorflow/lite/delegates/gpu/common/shape.h" #include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/util.h" #include "tensorflow/lite/delegates/gpu/metal/compiled_model.h" #include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h" #include "tensorflow/lite/delegates/gpu/metal/environment.h" @@ -45,6 +46,7 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/metal/kernels/softmax.h" #include "tensorflow/lite/delegates/gpu/metal/kernels/space_to_depth.h" #include "tensorflow/lite/delegates/gpu/metal/kernels/transpose_conv.h" +#include "tensorflow/lite/delegates/gpu/metal/kernels/winograd.h" #include "tensorflow/lite/delegates/gpu/metal/runtime_options.h" namespace tflite { @@ -86,12 +88,14 @@ std::vector SelectDepthWiseConv( std::vector SelectConvolutionTransposed( int id, ValueId input_id, ValueId output_id, - const ConvolutionTransposedAttributes& attr, + const ConvolutionTransposedAttributes& attr, const DeviceInfo& device_info, const metal::RuntimeOptions& options) { if (CheckConvolutionTransposed4x4Support(attr)) { - return ConvolutionTransposed4x4(id, input_id, output_id, attr, options); + return ConvolutionTransposed4x4(id, input_id, output_id, attr, device_info, + options); } else { - return ConvolutionTransposed(id, input_id, output_id, attr, options); + return ConvolutionTransposed(id, input_id, output_id, attr, device_info, + options); } } @@ -142,10 +146,30 @@ std::vector SelectSpaceToDepth( return SpaceToDepth(id, input_id, output_id, attr); } +bool IsSuitableForWinograd4x4To6x6(const Convolution2DAttributes& attr, + const BHWC& dst_shape) { + const int tiles_x = IntegralDivideRoundUp(dst_shape.w, 4); + const int tiles_y = IntegralDivideRoundUp(dst_shape.h, 4); + const int src_depth = IntegralDivideRoundUp(attr.weights.shape.i, 4); + const int dst_depth = IntegralDivideRoundUp(attr.weights.shape.o, 4); + const bool suitable_attributes = + attr.weights.shape.w == 3 && attr.weights.shape.h == 3 && + attr.dilations == HW(1, 1) && attr.strides == HW(1, 1); + + const int min_depth = 16; + const int min_hw = 32; + const bool recommended_channels = + src_depth >= min_depth && dst_depth >= min_depth; + const bool recommended_hw = tiles_x * tiles_y >= min_hw; + return suitable_attributes && recommended_channels && recommended_hw; +} + absl::Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node, const std::vector& inputs, const std::vector& outputs, + const DeviceInfo& device_info, const RuntimeOptions& options, + int* last_node_id, int* last_value_id, std::vector* tasks) { if (!IsBatchMatchesForAllValues(graph)) { return absl::InvalidArgumentError( @@ -185,8 +209,36 @@ absl::Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node, const auto dst_shape = graph.FindOutputs(node_id)[0]->tensor.shape; auto attr = absl::any_cast(node->operation.attributes); - *tasks = ConvolutionGeneric(node_id, inputs[0], outputs[0], dst_shape, - attr, options); + if (IsSuitableForWinograd4x4To6x6(attr, dst_shape)) { + int tiles_x = IntegralDivideRoundUp(dst_shape.w, 4); + int tiles_y = IntegralDivideRoundUp(dst_shape.h, 4); + + Winograd4x4To36Attributes wino_up_attr; + wino_up_attr.padding = attr.padding; + (*last_node_id) += 1; + int value_id = *last_value_id + 1; + *tasks = + Winograd4x4To36(*last_node_id, inputs[0], value_id, wino_up_attr); + + BHWC conv_shape{dst_shape.b, 36, tiles_x * tiles_y, dst_shape.c}; + (*last_node_id) += 1; + auto t1 = + ConvolutionWino4x4To6x6(*last_node_id, value_id, value_id + 1, + conv_shape, attr, device_info, options); + tasks->insert(tasks->end(), t1.begin(), t1.end()); + + Winograd36To4x4Attributes wino_down_attr; + wino_down_attr.output_shape = dst_shape; + wino_down_attr.biases = attr.bias; + (*last_node_id) += 1; + auto t2 = Winograd36To4x4(*last_node_id, value_id + 1, outputs[0], + options, wino_down_attr); + tasks->insert(tasks->end(), t2.begin(), t2.end()); + (*last_value_id) += 2; + } else { + *tasks = ConvolutionGeneric(node_id, inputs[0], outputs[0], dst_shape, + attr, device_info, options); + } break; } case OperationType::CONVOLUTION_TRANSPOSED: @@ -194,7 +246,7 @@ absl::Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node, node_id, inputs[0], outputs[0], absl::any_cast( node->operation.attributes), - options); + device_info, options); break; case OperationType::DEPTHWISE_CONVOLUTION: *tasks = @@ -207,7 +259,7 @@ absl::Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node, *tasks = FullyConnected( node_id, inputs[0], outputs[0], absl::any_cast(node->operation.attributes), - options); + device_info, options); break; case OperationType::MAX_UNPOOLING_2D: *tasks = MaxUnpooling( @@ -340,8 +392,17 @@ absl::Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node, } // namespace -absl::Status Compile(const GraphFloat32& graph, const RuntimeOptions& options, +absl::Status Compile(const GraphFloat32& graph, const DeviceInfo& device_info, + const RuntimeOptions& options, CompiledModel* compiled_model) { + int last_node_id = 0; + for (const auto& node : graph.nodes()) { + last_node_id = std::max(last_node_id, static_cast(node->id)); + } + int last_value_id = 0; + for (const auto& value : graph.values()) { + last_value_id = std::max(last_value_id, static_cast(value->id)); + } for (const auto& node : graph.nodes()) { std::vector inputs; for (auto& input : graph.FindInputs(node->id)) { @@ -356,7 +417,8 @@ absl::Status Compile(const GraphFloat32& graph, const RuntimeOptions& options, RegisterCustomOps(graph, node, inputs, outputs, options, &tasks); if (!custom_status.ok()) { auto primary_status = - RegisterPrimaryOps(graph, node, inputs, outputs, options, &tasks); + RegisterPrimaryOps(graph, node, inputs, outputs, device_info, options, + &last_node_id, &last_value_id, &tasks); if (!primary_status.ok()) { return absl::UnimplementedError( absl::Substitute("Unsupported op type: $0; custom registry error: " diff --git a/tensorflow/lite/delegates/gpu/metal/api.h b/tensorflow/lite/delegates/gpu/metal/api.h index c1c7648638c..e4435287518 100644 --- a/tensorflow/lite/delegates/gpu/metal/api.h +++ b/tensorflow/lite/delegates/gpu/metal/api.h @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/model.h" #include "tensorflow/lite/delegates/gpu/common/status.h" #include "tensorflow/lite/delegates/gpu/metal/compiled_model.h" +#include "tensorflow/lite/delegates/gpu/metal/environment.h" #include "tensorflow/lite/delegates/gpu/metal/runtime_options.h" namespace tflite { @@ -26,7 +27,8 @@ namespace gpu { namespace metal { // Builds CompiledModel out of GraphFloat32 graph using provided RuntimeOptions. -absl::Status Compile(const GraphFloat32& graph, const RuntimeOptions& options, +absl::Status Compile(const GraphFloat32& graph, const DeviceInfo& device_info, + const RuntimeOptions& options, CompiledModel* compiled_model); } // namespace metal diff --git a/tensorflow/lite/delegates/gpu/metal/environment.h b/tensorflow/lite/delegates/gpu/metal/environment.h index f313bacf3ac..732dbe1d18b 100644 --- a/tensorflow/lite/delegates/gpu/metal/environment.h +++ b/tensorflow/lite/delegates/gpu/metal/environment.h @@ -16,21 +16,67 @@ limitations under the License. #ifndef TENSORFLOW_LITE_DELEGATES_GPU_METAL_ENVIRONMENT_H_ #define TENSORFLOW_LITE_DELEGATES_GPU_METAL_ENVIRONMENT_H_ +#include + namespace tflite { namespace gpu { namespace metal { -enum class GpuType { +enum class Vendor { kUnknown, - kA7, // iPhone 5s, iPad Air, iPad Mini 2, iPad Mini 3. - kA8, // A8 iPhone 6, A8X iPad Air 2, iPad Mini 4. - kA9, // A9 iPhone 6s, iPad (2017), A9X iPad Pro (1st generation). - kA10, // iPhone 7, iPad (2018), A10X iPad Pro (2nd generation). - kA11, // iPhone 8/X. - kA12, // iPhone Xs. + kApple, + kIntel, + kAMD, }; -GpuType GetGpuType(); +enum class AppleGPU { + kUnknown, + kA7, + kA8, + kA8X, + kA9, + kA9X, + kA10, + kA10X, + kA11, + kA12, + kA12X, + kA12Z, + kA13, +}; + +struct AppleGPUInfo { + AppleGPUInfo() = default; + explicit AppleGPUInfo(const std::string& device_name); + AppleGPU gpu_type; + + bool IsLocalMemoryPreferredOverGlobal() const; + + bool IsBionic() const; + + // floating point rounding mode + bool IsRoundToNearestSupported() const; + + int GetComputeUnitsCount() const; +}; + +struct DeviceInfo { + DeviceInfo() = default; + explicit DeviceInfo(const std::string& device_name); + + Vendor vendor; + + AppleGPUInfo apple_info; + + bool IsIntelGPU() const; + bool IsAppleGPU() const; + bool IsAMDGPU() const; + + // floating point rounding mode + bool IsRoundToNearestSupported() const; + + int GetComputeUnitsCount() const; +}; } // namespace metal } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/metal/environment.mm b/tensorflow/lite/delegates/gpu/metal/environment.mm index 27c51100897..78376b70c8c 100644 --- a/tensorflow/lite/delegates/gpu/metal/environment.mm +++ b/tensorflow/lite/delegates/gpu/metal/environment.mm @@ -15,82 +15,132 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/metal/environment.h" -#import - -#include -#include -#include - -#include "tensorflow/lite/delegates/gpu/metal/common.h" +#include +#include namespace tflite { namespace gpu { namespace metal { - -GpuType GetGpuType() { - int max_feature_set = 0; -#if defined(__IPHONE_9_0) && __IPHONE_OS_VERSION_MIN_REQUIRED >= __IPHONE_9_0 - std::vector> features; - if (@available(iOS 8.0, *)) { - features.emplace_back(MTLFeatureSet_iOS_GPUFamily1_v1, 7); - features.emplace_back(MTLFeatureSet_iOS_GPUFamily2_v1, 8); - } - if (@available(iOS 9.0, *)) { - features.emplace_back(MTLFeatureSet_iOS_GPUFamily1_v2, 7); - features.emplace_back(MTLFeatureSet_iOS_GPUFamily2_v2, 8); - features.emplace_back(MTLFeatureSet_iOS_GPUFamily3_v1, 9); - } - if (@available(iOS 10.0, *)) { - features.emplace_back(MTLFeatureSet_iOS_GPUFamily1_v3, 7); - features.emplace_back(MTLFeatureSet_iOS_GPUFamily2_v3, 8); - features.emplace_back(MTLFeatureSet_iOS_GPUFamily3_v2, 9); - } - if (@available(iOS 11.0, *)) { - features.emplace_back(MTLFeatureSet_iOS_GPUFamily2_v4, 8); - features.emplace_back(MTLFeatureSet_iOS_GPUFamily3_v3, 9); - features.emplace_back(MTLFeatureSet_iOS_GPUFamily4_v1, 11); - } - if (@available(iOS 12.0, *)) { - features.emplace_back(MTLFeatureSet_iOS_GPUFamily1_v5, 7); - features.emplace_back(MTLFeatureSet_iOS_GPUFamily2_v5, 8); - features.emplace_back(MTLFeatureSet_iOS_GPUFamily3_v4, 9); - features.emplace_back(MTLFeatureSet_iOS_GPUFamily4_v2, 11); - features.emplace_back(MTLFeatureSet_iOS_GPUFamily5_v1, 12); - } - id device = GetBestSupportedMetalDevice(); - for (auto &type : features) { - if ([device supportsFeatureSet:type.first]) { - max_feature_set = std::max(max_feature_set, type.second); - } - } -#elif defined(__MAC_10_5) && __MAC_OS_X_VERSION_MIN_REQUIRED >= __MAC_10_5 - std::vector> features; - if (@available(macOS 10.15, *)) { - features.emplace_back(MTLFeatureSet_macOS_GPUFamily2_v1, 12); - } - id device = GetBestSupportedMetalDevice(); - for (auto &type : features) { - if ([device supportsFeatureSet:type.first]) { - max_feature_set = std::max(max_feature_set, type.second); - } - } -#endif - switch (max_feature_set) { - case 7: - return GpuType::kA7; - case 8: - return GpuType::kA8; - case 9: - return GpuType::kA9; - case 10: - return GpuType::kA10; - case 11: - return GpuType::kA11; - case 12: - return GpuType::kA12; - default: - return GpuType::kUnknown; +namespace { +Vendor GetVendorFromString(const std::string& device_name) { + const std::map kMapping = { + {"Apple", Vendor::kApple}, + {"Intel", Vendor::kIntel}, + {"AMD", Vendor::kAMD}, }; + for (auto v : kMapping) { + if (device_name.find(v.first) != std::string::npos) { + return v.second; + } + } + return Vendor::kUnknown; +} +} // namespace + +AppleGPUInfo::AppleGPUInfo(const std::string& device_name) { + const std::map kMapping = { + {"Apple A7 GPU", AppleGPU::kA7}, + {"Apple A8 GPU", AppleGPU::kA8}, + {"Apple A8X GPU", AppleGPU::kA8X}, + {"Apple A9 GPU", AppleGPU::kA9}, + {"Apple A9X GPU", AppleGPU::kA9X}, + {"Apple A10 GPU", AppleGPU::kA10}, + {"Apple A10X GPU", AppleGPU::kA10X}, + {"Apple A11 GPU", AppleGPU::kA11}, + {"Apple A12 GPU", AppleGPU::kA12}, + {"Apple A12X GPU", AppleGPU::kA12X}, + {"Apple A12Z GPU", AppleGPU::kA12Z}, + {"Apple A13 GPU", AppleGPU::kA13}, + }; + auto it = kMapping.find(device_name); + if (it != kMapping.end()) { + gpu_type = it->second; + } else { + gpu_type = AppleGPU::kUnknown; + } +} + +bool AppleGPUInfo::IsLocalMemoryPreferredOverGlobal() const { + return gpu_type == AppleGPU::kA7 || + gpu_type == AppleGPU::kA8 || + gpu_type == AppleGPU::kA8X; +} + +bool AppleGPUInfo::IsBionic() const { + return gpu_type == AppleGPU::kA11 || + gpu_type == AppleGPU::kA12 || + gpu_type == AppleGPU::kA12X || + gpu_type == AppleGPU::kA12Z || + gpu_type == AppleGPU::kA13; +} + +bool AppleGPUInfo::IsRoundToNearestSupported() const { + return IsBionic(); +} + +int AppleGPUInfo::GetComputeUnitsCount() const { + switch (gpu_type) { + case AppleGPU::kA7: + return 4; + case AppleGPU::kA8: + return 4; + case AppleGPU::kA8X: + return 8; + case AppleGPU::kA9: + return 6; + case AppleGPU::kA9X: + return 12; + case AppleGPU::kA10: + return 6; + case AppleGPU::kA10X: + return 12; + case AppleGPU::kA11: + return 3; + case AppleGPU::kA12: + return 4; + case AppleGPU::kA12X: + return 7; + case AppleGPU::kA12Z: + return 8; + case AppleGPU::kA13: + return 4; + case AppleGPU::kUnknown: + return 1; + } +} + +DeviceInfo::DeviceInfo(const std::string& device_name) : vendor(GetVendorFromString(device_name)) { + if (vendor == Vendor::kApple) { + apple_info = AppleGPUInfo(device_name); + } +} + +bool DeviceInfo::IsIntelGPU() const { + return vendor == Vendor::kIntel; +} + +bool DeviceInfo::IsAppleGPU() const { + return vendor == Vendor::kApple; +} + +bool DeviceInfo::IsAMDGPU() const { + return vendor == Vendor::kAMD; +} + +bool DeviceInfo::IsRoundToNearestSupported() const { + if (vendor == Vendor::kApple) { + return apple_info.IsRoundToNearestSupported(); + } else { + return true; + } +} + +int DeviceInfo::GetComputeUnitsCount() const { + if (vendor == Vendor::kApple) { + return apple_info.GetComputeUnitsCount(); + } else { + return 1; + } } } // namespace metal diff --git a/tensorflow/lite/delegates/gpu/metal/environment_test.mm b/tensorflow/lite/delegates/gpu/metal/environment_test.mm deleted file mode 100644 index 98a9ffd53cd..00000000000 --- a/tensorflow/lite/delegates/gpu/metal/environment_test.mm +++ /dev/null @@ -1,49 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/lite/delegates/gpu/metal/environment.h" - -#import - -#include "tensorflow/lite/delegates/gpu/metal/common.h" - -using ::tflite::gpu::metal::GetGpuType; - -@interface EnvironmentTest : XCTestCase - -@end - -@implementation EnvironmentTest - -- (void)testCompileTimeOSDetection { -#if IOS_VERSION > 0 - XCTAssertTrue(MACOS_VERSION == 0 && TVOS_VERSION == 0, @"IOS_VERSION: %d", int{IOS_VERSION}); -#endif -#if MACOS_VERSION > 0 - XCTAssertTrue(IOS_VERSION == 0 && TVOS_VERSION == 0, @"MACOS_VERSION: %d", int{MACOS_VERSION}); -#endif -#if TVOS_VERSION > 0 - XCTAssertTrue(IOS_VERSION == 0 && MACOS_VERSION == 0, @"TVOS_VERSION: %d", int{TVOS_VERSION}); -#endif -} - -- (void)testGetGpuType { -#if (IOS_VERSION > 0) || (TVOS_VERSION > 0) - auto gpuType = GetGpuType(); - XCTAssertTrue(gpuType != GpuType::kUnknown); -#endif -} - -@end diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/BUILD b/tensorflow/lite/delegates/gpu/metal/kernels/BUILD index 7045614e151..2773f2933cf 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/BUILD +++ b/tensorflow/lite/delegates/gpu/metal/kernels/BUILD @@ -37,6 +37,7 @@ cc_library( ":softmax", ":space_to_depth", ":transpose_conv", + ":winograd", ], ) @@ -126,6 +127,7 @@ cc_library( "//tensorflow/lite/delegates/gpu/common:shape", "//tensorflow/lite/delegates/gpu/common:types", "//tensorflow/lite/delegates/gpu/common:util", + "//tensorflow/lite/delegates/gpu/common:winograd_util", "//tensorflow/lite/delegates/gpu/metal:compute_task_descriptor", "//tensorflow/lite/delegates/gpu/metal:environment", "//tensorflow/lite/delegates/gpu/metal:runtime_options", @@ -141,6 +143,7 @@ objc_library( deps = [ ":conv", ":test_util", + ":winograd", ], ) @@ -830,6 +833,7 @@ objc_library( "//tensorflow/lite/delegates/gpu/metal:api", "//tensorflow/lite/delegates/gpu/metal:common", "//tensorflow/lite/delegates/gpu/metal:compiled_model", + "//tensorflow/lite/delegates/gpu/metal:environment", "//tensorflow/lite/delegates/gpu/metal:inference_context", "//tensorflow/lite/delegates/gpu/metal:runtime_options", "@FP16", @@ -849,6 +853,7 @@ cc_library( "//tensorflow/lite/delegates/gpu/common:util", "//tensorflow/lite/delegates/gpu/common:winograd_util", "//tensorflow/lite/delegates/gpu/metal:compute_task_descriptor", + "//tensorflow/lite/delegates/gpu/metal:runtime_options", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", ], diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/conv.cc b/tensorflow/lite/delegates/gpu/metal/kernels/conv.cc index 73f152412a9..b2a42e83242 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/conv.cc +++ b/tensorflow/lite/delegates/gpu/metal/kernels/conv.cc @@ -24,11 +24,13 @@ limitations under the License. #include #include "absl/strings/substitute.h" +#include "tensorflow/lite/delegates/gpu/common/data_type.h" #include "tensorflow/lite/delegates/gpu/common/model.h" #include "tensorflow/lite/delegates/gpu/common/operations.h" #include "tensorflow/lite/delegates/gpu/common/shape.h" #include "tensorflow/lite/delegates/gpu/common/types.h" #include "tensorflow/lite/delegates/gpu/common/util.h" +#include "tensorflow/lite/delegates/gpu/common/winograd_util.h" #include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h" #include "tensorflow/lite/delegates/gpu/metal/environment.h" #include "tensorflow/lite/delegates/gpu/metal/runtime_options.h" @@ -53,6 +55,7 @@ struct ConvParams { bool linear_wh; bool linear_whs; WeightsUploadType weights_upload_type; + bool different_weights_for_height = false; bool x_kernel_is_1; bool y_kernel_is_1; }; @@ -264,9 +267,16 @@ kernel void ComputeFunction( if (!params.need_dst_loop) { c += " " + addr_space + " FLT4* tmp = filters;\n"; } else { - c += " " + addr_space + - " FLT4* tmp = filters + Z * 4 * params.src_size.w" + kern_x + - kern_y + ";\n"; + if (params.different_weights_for_height) { + c += " " + addr_space + + " FLT4* tmp = filters + (Z * params.src_size.y + Y * " + + std::to_string(params.block_size.z) + + ") * 4 * params.src_size.w;\n"; + } else { + c += " " + addr_space + + " FLT4* tmp = filters + Z * 4 * params.src_size.w" + kern_x + + kern_y + ";\n"; + } } } if (!params.x_kernel_is_1) { @@ -498,30 +508,29 @@ kernel void ComputeFunction( return c; } -std::vector ReorderWeightsForConv(const Convolution2DAttributes& params, - int z_out) { - const int dst_depth = IntegralDivideRoundUp(params.weights.shape.o, 4); - const int src_depth = IntegralDivideRoundUp(params.weights.shape.i, 4); - std::vector weights_reordered( - params.weights.shape.w * params.weights.shape.h * - AlignByN(dst_depth, z_out) * 4 * src_depth * 4); +std::vector ReorderWeightsForConv( + const tflite::gpu::Tensor& weights, int z_out) { + const int dst_depth = IntegralDivideRoundUp(weights.shape.o, 4); + const int src_depth = IntegralDivideRoundUp(weights.shape.i, 4); + std::vector weights_reordered(weights.shape.w * weights.shape.h * + AlignByN(dst_depth, z_out) * 4 * + src_depth * 4); int counter = 0; for (int d = 0; d < IntegralDivideRoundUp(dst_depth, z_out); ++d) { - for (int y = 0; y < params.weights.shape.h; ++y) { - for (int x = 0; x < params.weights.shape.w; ++x) { + for (int y = 0; y < weights.shape.h; ++y) { + for (int x = 0; x < weights.shape.w; ++x) { for (int s = 0; s < src_depth; ++s) { for (int k = 0; k < z_out; ++k) { for (int j = 0; j < 4; ++j) { for (int i = 0; i < 4; ++i) { int src_ch = s * 4 + i; int dst_ch = (d * z_out + k) * 4 + j; - if (src_ch >= params.weights.shape.i || - dst_ch >= params.weights.shape.o) { + if (src_ch >= weights.shape.i || dst_ch >= weights.shape.o) { weights_reordered[counter++] = 0.0f; } else { const size_t f_index = - params.weights.shape.LinearIndex({dst_ch, y, x, src_ch}); - weights_reordered[counter++] = params.weights.data[f_index]; + weights.shape.LinearIndex({dst_ch, y, x, src_ch}); + weights_reordered[counter++] = weights.data[f_index]; } } } @@ -568,6 +577,40 @@ std::vector GetUniformBuffer(const BHWC& src_size, return GetByteBuffer(uniform_params); } +std::vector GetUniformBufferForWinograd(const BHWC& src_size, + const BHWC& dst_size, + const ConvParams& params) { + const int grid_x = IntegralDivideRoundUp(dst_size.w, params.block_size.x); + const int grid_y = IntegralDivideRoundUp(dst_size.h, params.block_size.y); + std::vector uniform_params = { + src_size.w, + src_size.h, + src_size.w * src_size.h, + IntegralDivideRoundUp(src_size.c, 4), + dst_size.w, + dst_size.h, + dst_size.w * dst_size.h, + IntegralDivideRoundUp(dst_size.c, 4), + 1, + 1, + 0, + 0, + 1, + 1, + 1, + 1, + grid_x, + grid_x * grid_y, + 0, // dummy, for alignment + 0, // dummy, for alignment + params.work_group_size.x, + params.work_group_size.y, + params.work_group_size.z, + 0, // dummy, for alignment + }; + return GetByteBuffer(uniform_params); +} + int GetGroupsCount(const BHWC& dst_shape, const int3& wg_size, const int3& block_size) { const int dst_slices = IntegralDivideRoundUp(dst_shape.c, 4); @@ -616,37 +659,33 @@ bool IsKernelYIs1(const Convolution2DAttributes& attr) { attr.padding.appended.h == 0; } -int GetMaximumPossibleWavesCount(const BHWC& dst_shape, GpuType gpu) { - if (gpu == GpuType::kA7 || gpu == GpuType::kA8) { +int GetMaximumPossibleWavesCount(const AppleGPUInfo& apple_info, + const BHWC& dst_shape) { + if (apple_info.IsLocalMemoryPreferredOverGlobal()) { return GetGroupsCountForLinearWH(dst_shape, {32, 1, 1}, {1, 1, 1}); } else { return GetGroupsCountForLinearWHS(dst_shape, {32, 1, 1}, {1, 1, 1}); } } -int GetRecommendedBlockSize(const BHWC& dst_shape, GpuType gpu) { - const int max_waves = GetMaximumPossibleWavesCount(dst_shape, gpu); - int base_threshold; - if (gpu == GpuType::kA7 || gpu == GpuType::kA8) { - base_threshold = 32; - } else if (gpu == GpuType::kA11) { - base_threshold = 48; - } else { - base_threshold = 64; - } - if (max_waves >= base_threshold * 4) { +int GetRecommendedBlockSize(const AppleGPUInfo& apple_info, + const BHWC& dst_shape) { + const int max_waves = GetMaximumPossibleWavesCount(apple_info, dst_shape); + const int cu_count = apple_info.GetComputeUnitsCount(); + if (max_waves >= cu_count * 64) { return 8; - } else if (max_waves >= base_threshold * 2) { + } else if (max_waves >= cu_count * 32) { return 4; - } else if (max_waves >= base_threshold) { + } else if (max_waves >= cu_count * 16) { return 2; } else { return 1; } } -ConvParams GetConvParamsForA7A8(const Convolution2DAttributes& attr, - const BHWC& dst_shape, GpuType gpu) { +ConvParams GetConvParamsForA7A8(const AppleGPUInfo& apple_info, + const Convolution2DAttributes& attr, + const BHWC& dst_shape) { const int dst_slices = IntegralDivideRoundUp(dst_shape.c, 4); const int src_slices = IntegralDivideRoundUp(attr.weights.shape.i, 4); @@ -660,7 +699,7 @@ ConvParams GetConvParamsForA7A8(const Convolution2DAttributes& attr, params.linear_whs = false; params.work_group_launch_order = int3(0, 1, 2); - int blk_total_size = GetRecommendedBlockSize(dst_shape, gpu); + int blk_total_size = GetRecommendedBlockSize(apple_info, dst_shape); if (blk_total_size >= 4 && (dst_slices % 4 == 0 || dst_slices >= 16)) { params.block_size.z = 4; @@ -720,14 +759,14 @@ ConvParams GetConvParamsForA7A8(const Convolution2DAttributes& attr, return params; } -ConvParams GetConvParamsForA9AndHigher(const Convolution2DAttributes& attr, - const BHWC& dst_shape, GpuType gpu) { +ConvParams GetConvParamsForA9AndHigher(const AppleGPUInfo& apple_info, + const Convolution2DAttributes& attr, + const BHWC& dst_shape) { const int dst_slices = IntegralDivideRoundUp(dst_shape.c, 4); const int src_slices = IntegralDivideRoundUp(attr.weights.shape.i, 4); - int blk_total_size = GetRecommendedBlockSize(dst_shape, gpu); - bool apple_gpu = gpu == GpuType::kA11 || gpu == GpuType::kA12; + int blk_total_size = GetRecommendedBlockSize(apple_info, dst_shape); int3 block_size = int3(1, 1, 1); - if (blk_total_size >= 2 && apple_gpu) { + if (blk_total_size >= 2 && apple_info.IsBionic()) { if (dst_shape.h % 2 != 0 && dst_shape.w % 2 == 0) { block_size.x = 2; } else { @@ -765,7 +804,7 @@ ConvParams GetConvParamsForA9AndHigher(const Convolution2DAttributes& attr, params.work_group_size = int3(32, 1, 1); params.work_group_launch_order = int3(0, 1, 2); } - float precise_threshold = gpu == GpuType::kA12 ? 1.0f : 1.04f; + float precise_threshold = apple_info.IsBionic() ? 1.0f : 1.04f; float precise_ratio = static_cast(g2) / static_cast(g3); if (precise_ratio > precise_threshold) { params.linear_wh = false; @@ -801,22 +840,130 @@ ConvParams GetConvParamsForA9AndHigher(const Convolution2DAttributes& attr, return params; } -ConvParams GetConvParams(const Convolution2DAttributes& attr, - const BHWC& dst_shape) { - auto gpu_type = GetGpuType(); - if (gpu_type == GpuType::kA7 || gpu_type == GpuType::kA8) { - return GetConvParamsForA7A8(attr, dst_shape, gpu_type); - } else { - return GetConvParamsForA9AndHigher(attr, dst_shape, gpu_type); +ConvParams GetConvParamsForIntel(const Convolution2DAttributes& attr, + const BHWC& dst_shape) { + ConvParams params; + params.weights_upload_type = WeightsUploadType::LOCAL_MEM_BY_THREADS; + params.x_kernel_is_1 = IsKernelXIs1(attr); + params.y_kernel_is_1 = IsKernelYIs1(attr); + params.src_depth_loop_size = 1; + params.linear_wh = false; + params.linear_whs = false; + params.work_group_launch_order = int3(2, 0, 1); + params.block_size = int3(1, 2, 4); + params.work_group_size = int3(8, 2, 1); + + int g1 = GetGroupsCount(dst_shape, params.work_group_size, params.block_size); + int g2 = GetGroupsCountForLinearWH(dst_shape, {16, 1, 1}, params.block_size); + int g3 = GetGroupsCountForLinearWHS(dst_shape, {16, 1, 1}, params.block_size); + + if (g2 < g1) { + params.linear_wh = true; + params.work_group_size = int3(16, 1, 1); + params.work_group_launch_order = int3(1, 0, 2); } + + float precise_threshold = 2.0f; + float precise_ratio = static_cast(g2) / static_cast(g3); + if (precise_ratio > precise_threshold) { + params.linear_wh = false; + params.linear_whs = true; + params.work_group_size = int3(16, 1, 1); + params.weights_upload_type = WeightsUploadType::GLOBAL_MEM; + } + + return params; +} + +ConvParams GetConvParamsForAMD(const Convolution2DAttributes& attr, + const BHWC& dst_shape) { + ConvParams params; + params.block_size = int3(1, 1, 4); + params.work_group_size = int3(8, 4, 1); + params.work_group_launch_order = int3(2, 0, 1); + params.src_depth_loop_size = 1; + params.need_src_loop = true; + params.need_dst_loop = true; + params.linear_wh = false; + params.linear_whs = false; + params.weights_upload_type = WeightsUploadType::GLOBAL_MEM; + params.different_weights_for_height = false; + params.x_kernel_is_1 = IsKernelXIs1(attr); + params.y_kernel_is_1 = IsKernelYIs1(attr); + return params; +} + +ConvParams GetConvParams(const DeviceInfo& device_info, + const Convolution2DAttributes& attr, + const BHWC& dst_shape) { + if (device_info.IsAppleGPU()) { + if (device_info.apple_info.IsLocalMemoryPreferredOverGlobal()) { + return GetConvParamsForA7A8(device_info.apple_info, attr, dst_shape); + } else { + return GetConvParamsForA9AndHigher(device_info.apple_info, attr, + dst_shape); + } + } else if (device_info.IsIntelGPU()) { + return GetConvParamsForIntel(attr, dst_shape); + } else if (device_info.IsAMDGPU()) { + return GetConvParamsForAMD(attr, dst_shape); + } else { + ConvParams params; + params.block_size = int3(1, 1, 4); + params.work_group_size = int3(8, 4, 1); + params.work_group_launch_order = int3(2, 0, 1); + params.src_depth_loop_size = 1; + params.need_src_loop = true; + params.need_dst_loop = true; + params.linear_wh = false; + params.linear_whs = false; + params.weights_upload_type = WeightsUploadType::GLOBAL_MEM; + params.different_weights_for_height = false; + params.x_kernel_is_1 = IsKernelXIs1(attr); + params.y_kernel_is_1 = IsKernelYIs1(attr); + return params; + } +} + +std::pair GetDispatchSizes(const ConvParams& params, + const BHWC& shape) { + const int dst_slices = IntegralDivideRoundUp(shape.c, 4); + + int grid_x = IntegralDivideRoundUp(shape.w, params.block_size.x); + int grid_y = IntegralDivideRoundUp(shape.h, params.block_size.y); + int grid_z = IntegralDivideRoundUp(dst_slices, params.block_size.z); + + const uint3 group_size(params.work_group_size.x, params.work_group_size.y, + params.work_group_size.z); + int3 wg; + uint3 groups_count; + if (params.linear_whs) { + wg.x = IntegralDivideRoundUp(grid_x * grid_y * grid_z, + params.work_group_size.x); + groups_count = uint3(wg.x, 1, 1); + } else if (params.linear_wh) { + wg.x = IntegralDivideRoundUp(grid_x * grid_y, params.work_group_size.x); + wg.y = IntegralDivideRoundUp(grid_z, params.work_group_size.y); + groups_count = uint3(wg[params.work_group_launch_order.x], + wg[params.work_group_launch_order.y], 1); + } else { + wg.x = IntegralDivideRoundUp(grid_x, params.work_group_size.x); + wg.y = IntegralDivideRoundUp(grid_y, params.work_group_size.y); + wg.z = IntegralDivideRoundUp(grid_z, params.work_group_size.z); + groups_count = uint3(wg[params.work_group_launch_order.x], + wg[params.work_group_launch_order.y], + wg[params.work_group_launch_order.z]); + } + return std::make_pair(group_size, groups_count); } } // namespace std::vector ConvolutionGeneric( int id, ValueId input_id, ValueId output_id, const BHWC& dst_shape, - const Convolution2DAttributes& attr, const metal::RuntimeOptions& options) { - ConvParams params = GetConvParams(attr, dst_shape); + const Convolution2DAttributes& attr, const DeviceInfo& device_info, + const metal::RuntimeOptions& options) { + ConvParams params = GetConvParams(device_info, attr, dst_shape); auto desc = std::make_shared(); desc->id = id; @@ -835,7 +982,8 @@ std::vector ConvolutionGeneric( return out_shape; }}; - auto weights_reordered = ReorderWeightsForConv(attr, params.block_size.z); + auto weights_reordered = + ReorderWeightsForConv(attr.weights, params.block_size.z); std::string addr_space = params.weights_upload_type == WeightsUploadType::CONSTANT_MEM ? "constant" : "device"; @@ -861,35 +1009,79 @@ std::vector ConvolutionGeneric( desc->resize_function = [output_id, params](const std::map& buffers) { - const auto& output_dims = buffers.find(output_id)->second; - const int dst_slices = IntegralDivideRoundUp(output_dims.c, 4); + return GetDispatchSizes(params, buffers.find(output_id)->second); + }; - int grid_x = IntegralDivideRoundUp(output_dims.w, params.block_size.x); - int grid_y = IntegralDivideRoundUp(output_dims.h, params.block_size.y); - int grid_z = IntegralDivideRoundUp(dst_slices, params.block_size.z); + return {desc}; +} - const uint3 group_size(params.work_group_size.x, params.work_group_size.y, - params.work_group_size.z); - int3 wg; - uint3 groups_count; - if (params.linear_whs) { - wg.x = IntegralDivideRoundUp(grid_x * grid_y * grid_z, - params.work_group_size.x); - groups_count = uint3(wg.x, 1, 1); - } else if (params.linear_wh) { - wg.x = IntegralDivideRoundUp(grid_x * grid_y, params.work_group_size.x); - wg.y = IntegralDivideRoundUp(grid_z, params.work_group_size.y); - groups_count = uint3(wg[params.work_group_launch_order.x], - wg[params.work_group_launch_order.y], 1); - } else { - wg.x = IntegralDivideRoundUp(grid_x, params.work_group_size.x); - wg.y = IntegralDivideRoundUp(grid_y, params.work_group_size.y); - wg.z = IntegralDivideRoundUp(grid_z, params.work_group_size.z); - groups_count = uint3(wg[params.work_group_launch_order.x], - wg[params.work_group_launch_order.y], - wg[params.work_group_launch_order.z]); - } - return std::make_pair(group_size, groups_count); +std::vector ConvolutionWino4x4To6x6( + int id, ValueId input_id, ValueId output_id, const BHWC& dst_shape, + const Convolution2DAttributes& attr, const DeviceInfo& device_info, + const RuntimeOptions& options) { + const int dst_slices = IntegralDivideRoundUp(attr.weights.shape.o, 4); + ConvParams params; + params.work_group_launch_order = int3(2, 0, 1); + params.src_depth_loop_size = 1; + params.need_src_loop = true; + params.need_dst_loop = true; + params.linear_wh = false; + params.linear_whs = false; + params.different_weights_for_height = true; + params.x_kernel_is_1 = true; + params.y_kernel_is_1 = true; + if (device_info.apple_info.IsLocalMemoryPreferredOverGlobal()) { + params.weights_upload_type = WeightsUploadType::LOCAL_MEM_BY_THREADS; + params.work_group_size = int3(32, 1, 1); + params.block_size = int3(4, 1, 4); + } else { + params.weights_upload_type = WeightsUploadType::GLOBAL_MEM; + params.work_group_size = int3(8, 4, 1); + params.block_size = int3(4, 1, 4); + } + + auto desc = std::make_shared(); + desc->id = id; + desc->is_linkable = false; + desc->shader_source = GenerateConvolution(params); + + desc->input_buffers = { + {input_id, "device FLT4* const src_buffer"}, + }; + + desc->output_buffer = { + output_id, "device FLT4* dst_buffer", + [input_id, attr](const std::map& buffers) { + const auto src_shape = buffers.find(input_id)->second; + return BHWC(src_shape.b, src_shape.h, src_shape.w, + attr.weights.shape.o); + }}; + + ::tflite::gpu::Tensor wino_weights; + RearrangeWeightsToWinograd4x4To6x6Weights(attr.weights, &wino_weights); + auto weights_reordered = + ReorderWeightsForConv(wino_weights, params.block_size.z); + std::vector dummy_biases(AlignByN(dst_slices, params.block_size.z) * 4, + 0.0f); + desc->immutable_buffers = { + {"device FLT4* const filters", + GetByteBufferConverted(weights_reordered, options.storage_precision)}, + {"device FLT4* const biases", + GetByteBufferConverted(dummy_biases, options.storage_precision)}, + }; + + desc->uniform_buffers = { + {"constant uniforms& params", + [input_id, output_id, params](const std::map& buffers) { + const auto& src_shape = buffers.find(input_id)->second; + const auto& dst_shape = buffers.find(output_id)->second; + return GetUniformBufferForWinograd(src_shape, dst_shape, params); + }}, + }; + + desc->resize_function = [output_id, + params](const std::map& buffers) { + return GetDispatchSizes(params, buffers.find(output_id)->second); }; return {desc}; diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/conv.h b/tensorflow/lite/delegates/gpu/metal/kernels/conv.h index 2853631abe8..def4ba5e08a 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/conv.h +++ b/tensorflow/lite/delegates/gpu/metal/kernels/conv.h @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/model.h" #include "tensorflow/lite/delegates/gpu/common/operations.h" #include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h" +#include "tensorflow/lite/delegates/gpu/metal/environment.h" #include "tensorflow/lite/delegates/gpu/metal/runtime_options.h" namespace tflite { @@ -29,7 +30,13 @@ namespace metal { std::vector ConvolutionGeneric( int id, ValueId input_id, ValueId output_id, const BHWC& dst_shape, - const Convolution2DAttributes& params, const RuntimeOptions& options); + const Convolution2DAttributes& attr, const DeviceInfo& device_info, + const RuntimeOptions& options); + +std::vector ConvolutionWino4x4To6x6( + int id, ValueId input_id, ValueId output_id, const BHWC& dst_shape, + const Convolution2DAttributes& attr, const DeviceInfo& device_info, + const RuntimeOptions& options); } // namespace metal } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/conv_test.mm b/tensorflow/lite/delegates/gpu/metal/kernels/conv_test.mm index a74b22cf13e..36f5938c10f 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/conv_test.mm +++ b/tensorflow/lite/delegates/gpu/metal/kernels/conv_test.mm @@ -13,7 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/lite/delegates/gpu/metal/kernels/add.h" +#include "tensorflow/lite/delegates/gpu/metal/kernels/conv.h" +#include "tensorflow/lite/delegates/gpu/metal/kernels/winograd.h" #import @@ -38,7 +39,10 @@ using ::tflite::gpu::Linear; using ::tflite::gpu::OHWI; using ::tflite::gpu::OperationType; using ::tflite::gpu::Tensor; +using ::tflite::gpu::TensorFloat32; using ::tflite::gpu::TensorRef; +using ::tflite::gpu::ValueId; +using ::tflite::gpu::IntegralDivideRoundUp; using ::tflite::gpu::metal::CompareVectors; using ::tflite::gpu::metal::SingleOpModel; @@ -241,4 +245,89 @@ using ::tflite::gpu::metal::SingleOpModel; XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); } +- (void)testWinograd4x4To6x6 { + const int src_channels = 7; + const int dst_channels = 13; + Convolution2DAttributes attr; + attr.padding.prepended = HW(0, 0); + attr.padding.appended = HW(10, 10); + attr.strides = HW(1, 1); + attr.dilations = HW(1, 1); + attr.weights.shape = OHWI(dst_channels, 3, 3, src_channels); + attr.weights.data.resize(attr.weights.shape.DimensionsProduct()); + for (int i = 0; i < attr.weights.data.size(); ++i) { + attr.weights.data[i] = sin(i); + } + attr.bias.shape = Linear(dst_channels); + attr.bias.data.resize(attr.bias.shape.DimensionsProduct()); + for (int i = 0; i < attr.bias.data.size(); ++i) { + attr.bias.data[i] = sin(i); + } + + auto src_shape = BHWC(1, 17, 13, src_channels); + auto dst_shape = CalculateOutputShape(src_shape, attr); + int new_width = src_shape.w + attr.padding.prepended.w + + attr.padding.appended.w - 2; + int new_height = src_shape.h + attr.padding.prepended.h + + attr.padding.appended.h - 2; + std::cout << dst_shape.w << " vs " << new_width << std::endl; + std::cout << dst_shape.h << " vs " << new_height << std::endl; + BHWC conv_shape; + conv_shape.b = dst_shape.b; + conv_shape.h = 36; + conv_shape.w = IntegralDivideRoundUp(new_width, 4) * IntegralDivideRoundUp(new_height, 4); + conv_shape.c = dst_shape.c; + + TensorFloat32 src_tensor; + src_tensor.shape = src_shape; + src_tensor.data.resize(src_tensor.shape.DimensionsProduct()); + for (int i = 0; i < src_tensor.data.size(); ++i) { + src_tensor.data[i] = sin(i); + } + + id device = MTLCreateSystemDefaultDevice(); + tflite::gpu::metal::RuntimeOptions options; + options.storage_precision = tflite::gpu::metal::RuntimeOptions::Precision::FP32; + options.accumulator_precision = tflite::gpu::metal::RuntimeOptions::Precision::FP32; + + std::map inputs_v0; + inputs_v0[0] = src_tensor; + std::map outputs_v0; + outputs_v0[1].shape = dst_shape; + outputs_v0[1].data.resize(dst_shape.DimensionsProduct()); + + auto tasks_v0 = tflite::gpu::metal::ConvolutionGeneric(0, 0, 1, dst_shape, attr, options); + + auto status = RunGraph(tasks_v0, device, inputs_v0, &outputs_v0); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + + std::map inputs_v1; + inputs_v1[0] = src_tensor; + std::map outputs_v1; + outputs_v1[1].shape = dst_shape; + outputs_v1[1].data.resize(outputs_v1[1].shape.DimensionsProduct()); + + tflite::gpu::metal::Winograd4x4To36Attributes wino_up_attr; + wino_up_attr.padding = attr.padding; + auto tasks_v1 = tflite::gpu::metal::Winograd4x4To36(0, 0, 2, wino_up_attr); + + auto tasks_v2 = tflite::gpu::metal::ConvolutionWino4x4To6x6(1, 2, 3, conv_shape, attr, options); + + tflite::gpu::metal::Winograd36To4x4Attributes wino_down_attr; + wino_down_attr.output_shape = dst_shape; + wino_down_attr.biases = attr.bias; + auto tasks_v3 = tflite::gpu::metal::Winograd36To4x4(2, 3, 1, options, wino_down_attr); + + std::vector tasks; + tasks.insert(tasks.end(), tasks_v1.begin(), tasks_v1.end()); + tasks.insert(tasks.end(), tasks_v2.begin(), tasks_v2.end()); + tasks.insert(tasks.end(), tasks_v3.begin(), tasks_v3.end()); + + status = RunGraph(tasks, device, inputs_v1, &outputs_v1); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + + status = CompareVectors(outputs_v0[1].data, outputs_v1[1].data, 1e-4f); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); +} + @end diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/depthwise_conv.cc b/tensorflow/lite/delegates/gpu/metal/kernels/depthwise_conv.cc index 9fa627bcac2..6c26a87c267 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/depthwise_conv.cc +++ b/tensorflow/lite/delegates/gpu/metal/kernels/depthwise_conv.cc @@ -475,91 +475,93 @@ std::vector DepthWiseConvolution( std::string shader_source = R"( #include using namespace metal; - constant int kernel_x = $0; - constant int kernel_y = $1; struct uniforms { - int4 stride; - int4 padding; - int4 dilation; - int4 size; + int4 src_size; + int4 dst_size; + int2 stride; + int2 padding; + int2 dilation; + int2 kernel_size; int4 channel_multiplier; }; - $$0 + $0 kernel void ComputeFunction( - $$1 + $1 uint tid[[thread_index_in_threadgroup]], uint3 gid[[thread_position_in_grid]]) { - const bool outside = static_cast(gid.x) >= params.size.z || - static_cast(gid.y) >= params.size.w; - if (outside) { - return; - } - device FLT4* temp = filters + gid.z * kernel_y * kernel_x; - float4 sum0 = float4(0.0f, 0.0f, 0.0f, 0.0f); + int dst_x = static_cast(gid.x); + int dst_y = static_cast(gid.y); + int dst_z = static_cast(gid.z); - for(int ky = 0; ky < kernel_y; ++ky) { - for(int kx = 0; kx < kernel_x; ++kx) { - int2 coords = int2(gid.xy) * params.stride.xy + int2(kx, ky) * params.dilation.xy - - params.padding.xy; - const bool outside = coords.x < 0 || coords.y < 0 || - coords.x >= params.size.x || coords.y >= params.size.y; - if (outside) continue; + if (dst_x >= U.dst_size.x || dst_y >= U.dst_size.y) return; + + device FLT4* temp = filters + dst_z * U.kernel_size.x * U.kernel_size.y; + ACCUM_FLT4 sum0 = ACCUM_FLT4(0.0f, 0.0f, 0.0f, 0.0f); + + int src_x = dst_x * U.stride.x + U.padding.x; + int src_y = dst_y * U.stride.y + U.padding.y; + + for(int ky = 0; ky < U.kernel_size.y; ++ky) { + int yc = ky * U.dilation.y + src_y; + if (yc < 0 || yc >= U.src_size.y) continue; + for(int kx = 0; kx < U.kernel_size.x; ++kx) { + int xc = kx * U.dilation.x + src_x; + if (xc < 0 || xc >= U.src_size.x) continue; )"; if (channels_multiplier == 1) { shader_source += R"( - const int src_layer = gid.z; - const int src_index = (src_layer * params.size.y + coords.y) * params.size.x + coords.x; - const FLT4 src_modified = src_buffer[src_index]; + int src_layer = dst_z; + int src_index = (src_layer * U.src_size.y + yc) * U.src_size.x + xc; + FLT4 src_modified = src_buffer[src_index]; )"; } else if (channels_multiplier == 2) { shader_source += R"( - const int src_layer = gid.z / 2; - const int src_index = (src_layer * params.size.y + coords.y) * params.size.x + coords.x; - const FLT4 src = src_buffer[src_index]; - const FLT2 t0 = gid.z % 2 == 0 ? src.xy : src.zw; - const FLT4 src_modified = FLT4(t0.x, t0.x, t0.y, t0.y); + int src_layer = dst_z / 2; + int src_index = (src_layer * U.src_size.y + yc) * U.src_size.x + xc; + FLT4 src = src_buffer[src_index]; + FLT2 t0 = dst_z % 2 == 0 ? src.xy : src.zw; + FLT4 src_modified = FLT4(t0.x, t0.x, t0.y, t0.y); )"; } else if (channels_multiplier == 4) { shader_source += R"( - const int src_layer = gid.z / 4; - const int src_index = (src_layer * params.size.y + coords.y) * params.size.x + coords.x; - const FLT4 src = src_buffer[src_index]; - const FLT t0 = src[gid.z % 4]; - const FLT4 src_modified = FLT4(t0, t0, t0, t0); + int src_layer = dst_z / 4; + int src_index = (src_layer * U.src_size.y + yc) * U.src_size.x + xc; + FLT4 src = src_buffer[src_index]; + FLT t0 = src[dst_z % 4]; + FLT4 src_modified = FLT4(t0, t0, t0, t0); )"; } else { shader_source += R"( - const int src_layer = gid.z / params.channel_multiplier.x; - const int src_index = (src_layer * params.size.y + coords.y) * params.size.x + coords.x; - const FLT4 src = src_buffer[src_index]; + int src_layer = dst_z / U.channel_multiplier.x; + int src_index = (src_layer * U.src_size.y + yc) * U.src_size.x + xc; + FLT4 src = src_buffer[src_index]; FLT4 src_modified; - const int src_layer_offset = (gid.z % params.channel_multiplier.x) * 4; - src_modified.x = src[(src_layer_offset + 0) / params.channel_multiplier.x]; - src_modified.y = src[(src_layer_offset + 1) / params.channel_multiplier.x]; - src_modified.z = src[(src_layer_offset + 2) / params.channel_multiplier.x]; - src_modified.w = src[(src_layer_offset + 3) / params.channel_multiplier.x]; + const int src_layer_offset = (dst_z % U.channel_multiplier.x) * 4; + src_modified.x = src[(src_layer_offset + 0) / U.channel_multiplier.x]; + src_modified.y = src[(src_layer_offset + 1) / U.channel_multiplier.x]; + src_modified.z = src[(src_layer_offset + 2) / U.channel_multiplier.x]; + src_modified.w = src[(src_layer_offset + 3) / U.channel_multiplier.x]; )"; } shader_source += R"( - sum0 += float4(src_modified * temp[ky * kernel_x + kx]); + sum0 += TO_ACCUM4_TYPE(src_modified * temp[ky * U.kernel_size.x + kx]); } } - FLT4 res = FLT4(sum0 + float4(biases[gid.z])); - const int linear_index = (gid.z * params.size.w + int(gid.y)) * params.size.z + int(gid.x); + FLT4 res = FLT4(sum0) + biases[dst_z]; + const int linear_index = (dst_z * U.dst_size.y + dst_y) * U.dst_size.x + dst_x; FLT4 value = res; - $$2 - output_buffer[linear_index] = value; + $2 + dst_buffer[linear_index] = value; } )"; - desc->shader_source = absl::Substitute(shader_source, attr.weights.shape.w, - attr.weights.shape.h); + desc->shader_source = shader_source; desc->input_buffers = { {input_id, "device FLT4* const src_buffer"}, }; desc->output_buffer = { - output_id, "device FLT4* output_buffer", + output_id, "device FLT4* dst_buffer", [input_id, attr](const std::map& buffers) { auto out_shape = CalculateOutputShape(buffers.find(input_id)->second, attr); @@ -577,27 +579,27 @@ std::vector DepthWiseConvolution( }; desc->uniform_buffers = { - {"constant uniforms& params", + {"constant uniforms& U", [input_id, output_id, attr](const std::map& buffers) { const auto& dimension = buffers.find(input_id)->second; const auto& output_dimension = buffers.find(output_id)->second; std::vector uniform_params{ - attr.strides.w, - attr.strides.h, - 1, - 1, - attr.padding.prepended.w, - attr.padding.prepended.h, - 1, - 1, - attr.dilations.w, - attr.dilations.h, - 1, - 1, dimension.w, dimension.h, + IntegralDivideRoundUp(dimension.c, 4), + 0, output_dimension.w, output_dimension.h, + IntegralDivideRoundUp(output_dimension.c, 4), + 0, + attr.strides.w, + attr.strides.h, + -attr.padding.prepended.w, + -attr.padding.prepended.h, + attr.dilations.w, + attr.dilations.h, + attr.weights.shape.w, + attr.weights.shape.h, attr.weights.shape.o, 0, 0, diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/depthwise_conv_test.mm b/tensorflow/lite/delegates/gpu/metal/kernels/depthwise_conv_test.mm index d76507253a9..dcf550f7868 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/depthwise_conv_test.mm +++ b/tensorflow/lite/delegates/gpu/metal/kernels/depthwise_conv_test.mm @@ -167,4 +167,43 @@ using ::tflite::gpu::metal::SingleOpModel; XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); } +- (void)testShape2x2Kernel2x2 { + TensorRef input; + input.type = DataType::FLOAT32; + input.ref = 0; + input.shape = BHWC(1, 2, 2, 1); + + DepthwiseConvolution2DAttributes attr; + Tensor bias; + bias.shape.v = 1; + bias.id = 1; + bias.data = {0}; + attr.bias = std::move(bias); + + Tensor weights; + weights.shape = OHWI(1, 2, 2, 1); + weights.id = 1; + weights.data = {1, 2, 3, 4}; + + attr.weights = std::move(weights); + + attr.dilations = HW(1, 1); + attr.padding.prepended = HW(0, 0); + attr.padding.appended = HW(1, 1); + attr.strides = HW(1, 1); + + TensorRef output; + output.type = DataType::FLOAT32; + output.ref = 3; + output.shape = BHWC(1, 2, 2, 1); + + SingleOpModel model({ToString(OperationType::DEPTHWISE_CONVOLUTION), std::move(attr)}, {input}, + {output}); + XCTAssertTrue(model.PopulateTensor(0, {1, 4, 9, 16})); + auto status = model.Invoke(); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + status = CompareVectors({100, 52, 41, 16}, model.GetOutput(0), 1e-6f); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); +} + @end diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/fully_connected.cc b/tensorflow/lite/delegates/gpu/metal/kernels/fully_connected.cc index ed24ce25d29..283b03ce707 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/fully_connected.cc +++ b/tensorflow/lite/delegates/gpu/metal/kernels/fully_connected.cc @@ -119,12 +119,12 @@ std::string GetFullyConnectedCode(bool shared_memory, int src_channels, std::vector FullyConnected( int id, ValueId input_id, ValueId output_id, - const FullyConnectedAttributes& attr, const RuntimeOptions& options) { + const FullyConnectedAttributes& attr, const DeviceInfo& device_info, + const RuntimeOptions& options) { auto desc = std::make_shared(); desc->id = id; desc->is_linkable = false; - auto gpu_type = GetGpuType(); - bool shared = gpu_type == GpuType::kA7 || gpu_type == GpuType::kA8; + bool shared = device_info.apple_info.IsLocalMemoryPreferredOverGlobal(); desc->shader_source = GetFullyConnectedCode(shared, attr.weights.shape.i, attr.weights.shape.o); diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/fully_connected.h b/tensorflow/lite/delegates/gpu/metal/kernels/fully_connected.h index 00d73fdf944..3e1f26fc7a8 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/fully_connected.h +++ b/tensorflow/lite/delegates/gpu/metal/kernels/fully_connected.h @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/model.h" #include "tensorflow/lite/delegates/gpu/common/operations.h" #include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h" +#include "tensorflow/lite/delegates/gpu/metal/environment.h" #include "tensorflow/lite/delegates/gpu/metal/runtime_options.h" namespace tflite { @@ -33,7 +34,8 @@ namespace metal { // will be inefficient std::vector FullyConnected( int id, ValueId input_id, ValueId output_id, - const FullyConnectedAttributes& attr, const RuntimeOptions& options); + const FullyConnectedAttributes& attr, const DeviceInfo& device_info, + const RuntimeOptions& options); } // namespace metal } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/test_util.mm b/tensorflow/lite/delegates/gpu/metal/kernels/test_util.mm index 80c0e2457af..a1b414f0060 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/test_util.mm +++ b/tensorflow/lite/delegates/gpu/metal/kernels/test_util.mm @@ -34,6 +34,7 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h" #include "tensorflow/lite/delegates/gpu/metal/inference_context.h" #include "tensorflow/lite/delegates/gpu/metal/runtime_options.h" +#include "tensorflow/lite/delegates/gpu/metal/environment.h" namespace tflite { namespace gpu { @@ -77,15 +78,17 @@ absl::Status SingleOpModel::Invoke() { output_ids.push_back(output.id); } + id device = MTLCreateSystemDefaultDevice(); + std::string device_name = std::string([[device name] UTF8String]); + DeviceInfo device_info(device_name); RuntimeOptions options; options.storage_precision = RuntimeOptions::Precision::FP32; options.accumulator_precision = RuntimeOptions::Precision::FP32; CompiledModel compiled_model; - RETURN_IF_ERROR(Compile(graph_, options, &compiled_model)); + RETURN_IF_ERROR(Compile(graph_, device_info, options, &compiled_model)); CompiledModel optimized_model; RETURN_IF_ERROR(ValidateOptimizeModel(input_ids, output_ids, compiled_model, &optimized_model)); - id device = MTLCreateSystemDefaultDevice(); TFLInferenceContext* graph = [[TFLInferenceContext alloc] init]; RETURN_IF_ERROR([graph compileModelWithDevice:device taskDescriptors:optimized_model diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/transpose_conv.cc b/tensorflow/lite/delegates/gpu/metal/kernels/transpose_conv.cc index 9c3f91df7f2..1b6e6963fb5 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/transpose_conv.cc +++ b/tensorflow/lite/delegates/gpu/metal/kernels/transpose_conv.cc @@ -275,618 +275,6 @@ std::string GetDeconvolutionShared(const ConvolutionTransposedAttributes& attr, src_local_size_x, src_local_size_y, workgroup_x, workgroup_y); } -struct GridParams { - uint rect_offsets[4]; - uint widths[4]; - short2 origins[4]; - uint elements_count; -}; - -struct Params3x3 { - short2 inner_size; - short2 src_offset; - short2 dst_offset; -}; - -void Init3x3(const ConvolutionTransposedAttributes& attr, const int2& src_size, - const int2& dst_size, GridParams* grid_params, - Params3x3* params3x3) { - short2 src_size_scaled; - src_size_scaled.x = (src_size.x - 1) * 2; - src_size_scaled.y = (src_size.y - 1) * 2; - short2 top_left_src, bottom_right_src; - top_left_src.x = 1 - attr.padding.prepended.w; - top_left_src.y = 1 - attr.padding.prepended.h; - bottom_right_src.x = top_left_src.x + src_size_scaled.x; - bottom_right_src.y = top_left_src.y + src_size_scaled.y; - short2 top_left_inner, bottom_right_inner; - if (top_left_src.x >= 0) { - top_left_inner.x = top_left_src.x; - } else { - top_left_inner.x = std::abs(top_left_src.x % 2); - } - if (top_left_src.y >= 0) { - top_left_inner.y = top_left_src.y; - } else { - top_left_inner.y = std::abs(top_left_src.y % 2); - } - - if (bottom_right_src.x <= dst_size.x) { - bottom_right_inner.x = bottom_right_src.x; - } else { - bottom_right_inner.x = dst_size.x; - } - if (top_left_src.x % 2 == 0) { - bottom_right_inner.x -= bottom_right_inner.x % 2; - } else { - if (bottom_right_inner.x % 2 == 0) { - bottom_right_inner.x -= 1; - } - } - bottom_right_inner.x -= 1; - - if (bottom_right_src.y <= dst_size.y) { - bottom_right_inner.y = bottom_right_src.y; - } else { - bottom_right_inner.y = dst_size.y; - } - if (top_left_src.y % 2 == 0) { - bottom_right_inner.y -= bottom_right_inner.y % 2; - } else { - if (bottom_right_inner.y % 2 == 0) { - bottom_right_inner.y -= 1; - } - } - bottom_right_inner.y -= 1; - - params3x3->dst_offset = top_left_inner; - params3x3->src_offset.x = (top_left_inner.x - top_left_src.x) / 2; - params3x3->src_offset.y = (top_left_inner.y - top_left_src.y) / 2; - params3x3->inner_size.x = - std::max(0, bottom_right_inner.x - top_left_inner.x + 1) / 2; - params3x3->inner_size.y = - std::max(0, bottom_right_inner.y - top_left_inner.y + 1) / 2; - - short2 top_rect, bottom_rect, left_rect, right_rect; - - top_rect.x = dst_size.x; - top_rect.y = top_left_inner.y; - - bottom_rect.x = dst_size.x; - bottom_rect.y = dst_size.y - bottom_right_inner.y - 1; - - left_rect.x = top_left_inner.x; - left_rect.y = dst_size.y - top_rect.y - bottom_rect.y; - - right_rect.x = dst_size.x - bottom_right_inner.x - 1; - right_rect.y = left_rect.y; - - grid_params->widths[0] = top_rect.x; - grid_params->widths[1] = left_rect.x; - grid_params->widths[2] = right_rect.x; - grid_params->widths[3] = bottom_rect.x; - - grid_params->rect_offsets[0] = 0; - grid_params->rect_offsets[1] = - grid_params->rect_offsets[0] + top_rect.x * top_rect.y; - grid_params->rect_offsets[2] = - grid_params->rect_offsets[1] + left_rect.x * left_rect.y; - grid_params->rect_offsets[3] = - grid_params->rect_offsets[2] + right_rect.x * right_rect.y; - grid_params->elements_count = - grid_params->rect_offsets[3] + bottom_rect.x * bottom_rect.y; - - grid_params->origins[0] = short2(0, 0); - grid_params->origins[1] = short2(int16_t(0), int16_t(top_rect.y)); - grid_params->origins[2] = - short2(int16_t(dst_size.x - right_rect.x), int16_t(top_rect.y)); - grid_params->origins[3] = short2(0, dst_size.y - bottom_rect.y); -} - -std::string GetDeconvolutionBorder( - const ConvolutionTransposedAttributes& attr) { - std::string constant_args = R"( - constant short2 padding = {$0, $1}; - constant short2 stride = {$2, $3}; - constant short2 kernel_size = {$4, $5}; - constant short2 inner_size = {$6, $7}; - constant short2 kernel_offset = {$8, $9}; - )"; - std::string shader_source = R"( - #include - using namespace metal; - - struct FilterStripe { - FLT4 vals[$0]; - }; - - constant int src_depth = $1; - constant int dst_depth = $2; - constant int dst_channels = $3; - constant int dst_channels_aligned = $4; - - $5 - - struct uniforms { - int2 src_size; - int2 dst_size; - uint rect_offsets[4]; - uint widths[4]; - short2 origins[4]; - uint elements_count; - }; - - short2 GetGridIdByLinearId(uint linear_id, constant uniforms& params); - - short2 GetGridIdByLinearId(uint linear_id, constant uniforms& params) { - int index = 0; - index = linear_id >= params.rect_offsets[0] ? 0 : index; - index = linear_id >= params.rect_offsets[1] ? 1 : index; - index = linear_id >= params.rect_offsets[2] ? 2 : index; - index = linear_id >= params.rect_offsets[3] ? 3 : index; - - const uint rect_index = linear_id - params.rect_offsets[index]; - - const uint rect_width = params.widths[index]; - const short2 offset = short2(rect_index % rect_width, rect_index / rect_width); - return params.origins[index] + offset; - } - - $$0 - kernel void ComputeFunction( - $$1 - uint linear_id[[thread_position_in_grid]]) { - if (linear_id >= params.elements_count) { - return; - } - short2 gid_sh = GetGridIdByLinearId(linear_id, params); - - float out[$4]; - for (short l = 0; l < dst_depth * 4; ++l) { - out[l] = float(0.0f); - } - - short2 offset = gid_sh + padding - kernel_offset; - offset.x = offset.x % stride.x; - offset.y = offset.y % stride.y; - offset += stride; - offset.x = offset.x % stride.x; - offset.y = offset.y % stride.y; - short2 f_offset; - f_offset.x = offset.x == 0 ? 0 : stride.x - offset.x; - f_offset.y = offset.y == 0 ? 0 : stride.y - offset.y; - for (int ky = 0; ky < inner_size.y; ++ky) { - for (int kx = 0; kx < inner_size.x; ++kx) { - short2 index = short2(kx, ky) * stride + f_offset; - bool inside_kernel = index.x < kernel_size.x && index.y < kernel_size.y; - const short2 src_coord = (gid_sh + index + padding - kernel_offset) / stride; - index = kernel_size - short2(1, 1) - index; - bool outside = src_coord.x < 0 || src_coord.y < 0 || - src_coord.x >= params.src_size.x || src_coord.y >= params.src_size.y; - const int kernel_index = index.y * kernel_size.x + index.x; - bool belong = inside_kernel && !outside; - if (belong) { - for (int l = 0; l < src_depth; ++l) { - const int src_index = (l * params.src_size.y + src_coord.y) * - params.src_size.x + src_coord.x; - FLT4 srcColor = src_buffer[src_index]; - for (int k = 0; k < dst_channels; ++k) { - out[k] += dot(srcColor, filters[kernel_index].vals[l * dst_channels_aligned + k]); - } - } - } - } - } - - for (short l = 0; l < dst_depth; ++l) { - FLT4 value = FLT4(out[l * 4], out[l * 4 + 1], out[l * 4 + 2], out[l * 4 + 3]) + biases[l]; - const int linear_index = (l * params.dst_size.y + int(gid_sh.y)) * - params.dst_size.x + int(gid_sh.x); - uint3 gid = uint3(uint(gid_sh.x), uint(gid_sh.y), uint(l)); - $$2 - dst_buffer[linear_index] = value; - } - } - )"; - const int kernel_x = attr.weights.shape.w; - const int kernel_y = attr.weights.shape.h; - const int inner_size_x = (kernel_x - 1) / attr.stride.w + 1; - const int inner_size_y = (kernel_y - 1) / attr.stride.h + 1; - std::string constant_args_inplaced = absl::Substitute( - constant_args, attr.padding.prepended.w, attr.padding.prepended.h, - attr.stride.w, attr.stride.h, kernel_x, kernel_y, inner_size_x, - inner_size_y, kernel_x - 1, kernel_y - 1); - const int src_depth = IntegralDivideRoundUp(attr.weights.shape.i, 4); - const int dst_depth = IntegralDivideRoundUp(attr.weights.shape.o, 4); - const int dst_channels_aligned = AlignByN(attr.weights.shape.o, 4); - return absl::Substitute(shader_source, src_depth * dst_channels_aligned, - src_depth, dst_depth, attr.weights.shape.o, - dst_channels_aligned, constant_args_inplaced); -} - -std::string GetDeconvolution3x3(const ConvolutionTransposedAttributes& attr) { - std::string shader_source = R"( - #include - using namespace metal; - - struct FilterStripe { - FLT4 vals[$0]; - }; - - constant int src_depth = $1; - constant int dst_depth = $2; - constant int dst_channels = $3; - constant int dst_channels_aligned = $4; - - struct uniforms { - int2 src_size; - int2 dst_size; - short2 inner_size; - short2 src_offset; - short2 dst_offset; - }; - - $$0 - kernel void ComputeFunction( - $$1 - uint tid[[thread_index_in_threadgroup]], - uint2 ugid[[thread_position_in_grid]]) { - if (static_cast(ugid.x) >= params.inner_size.x || - static_cast(ugid.y) >= params.inner_size.y) { - return; - } - - float out[$4]; - short2 src_coord_0 = short2(ugid) + params.src_offset; - short2 dst_coord = short2(ugid) * 2 + params.dst_offset; - - for (short l = 0; l < dst_depth * 4; ++l) { - out[l] = float(0.0f); - } - - for (int l = 0; l < src_depth; ++l) { - const int src_index_0 = (l * params.src_size.y + src_coord_0.y) * - params.src_size.x + src_coord_0.x; - FLT4 srcColor_0 = src_buffer[src_index_0]; - for (int k = 0; k < dst_channels; ++k) { - out[k] += dot(srcColor_0, filters[4].vals[l * dst_channels_aligned + k]); - } - } - - for (short l = 0; l < dst_depth; ++l) { - FLT4 value = FLT4(out[l * 4], out[l * 4 + 1], out[l * 4 + 2], out[l * 4 + 3]) + biases[l]; - const int linear_index = (l * params.dst_size.y + int(dst_coord.y)) * - params.dst_size.x + int(dst_coord.x); - uint3 gid = uint3(uint(dst_coord.x), uint(dst_coord.y), uint(l)); - $$2 - dst_buffer[linear_index] = value; - } - - short2 src_coord_1 = src_coord_0 + short2(1, 0); - dst_coord += short2(1, 0); - - for (short l = 0; l < dst_depth * 4; ++l) { - out[l] = float(0.0f); - } - - for (int l = 0; l < src_depth; ++l) { - const int src_index_0 = (l * params.src_size.y + src_coord_0.y) * - params.src_size.x + src_coord_0.x; - const int src_index_1 = (l * params.src_size.y + src_coord_1.y) * - params.src_size.x + src_coord_1.x; - FLT4 srcColor_0 = src_buffer[src_index_0]; - FLT4 srcColor_1 = src_buffer[src_index_1]; - for (int k = 0; k < dst_channels; ++k) { - out[k] += dot(srcColor_0, filters[5].vals[l * dst_channels_aligned + k]); - out[k] += dot(srcColor_1, filters[3].vals[l * dst_channels_aligned + k]); - } - } - - for (short l = 0; l < dst_depth; ++l) { - FLT4 value = FLT4(out[l * 4], out[l * 4 + 1], out[l * 4 + 2], out[l * 4 + 3]) + biases[l]; - const int linear_index = (l * params.dst_size.y + int(dst_coord.y)) * - params.dst_size.x + int(dst_coord.x); - uint3 gid = uint3(uint(dst_coord.x), uint(dst_coord.y), uint(l)); - $$2 - dst_buffer[linear_index] = value; - } - - short2 src_coord_2 = src_coord_0 + short2(0, 1); - dst_coord += short2(-1, 1); - - for (short l = 0; l < dst_depth * 4; ++l) { - out[l] = float(0.0f); - } - - for (int l = 0; l < src_depth; ++l) { - const int src_index_0 = (l * params.src_size.y + src_coord_0.y) * - params.src_size.x + src_coord_0.x; - const int src_index_2 = (l * params.src_size.y + src_coord_2.y) * - params.src_size.x + src_coord_2.x; - FLT4 srcColor_0 = src_buffer[src_index_0]; - FLT4 srcColor_2 = src_buffer[src_index_2]; - for (int k = 0; k < dst_channels; ++k) { - out[k] += dot(srcColor_0, filters[7].vals[l * dst_channels_aligned + k]); - out[k] += dot(srcColor_2, filters[1].vals[l * dst_channels_aligned + k]); - } - } - - for (short l = 0; l < dst_depth; ++l) { - FLT4 value = FLT4(out[l * 4], out[l * 4 + 1], out[l * 4 + 2], out[l * 4 + 3]) + biases[l]; - const int linear_index = (l * params.dst_size.y + int(dst_coord.y)) * - params.dst_size.x + int(dst_coord.x); - uint3 gid = uint3(uint(dst_coord.x), uint(dst_coord.y), uint(l)); - $$2 - dst_buffer[linear_index] = value; - } - - short2 src_coord_3 = src_coord_0 + short2(1, 1); - dst_coord += short2(1, 0); - - for (short l = 0; l < dst_depth * 4; ++l) { - out[l] = float(0.0f); - } - - for (int l = 0; l < src_depth; ++l) { - const int src_index_0 = (l * params.src_size.y + src_coord_0.y) * - params.src_size.x + src_coord_0.x; - const int src_index_1 = (l * params.src_size.y + src_coord_1.y) * - params.src_size.x + src_coord_1.x; - const int src_index_2 = (l * params.src_size.y + src_coord_2.y) * - params.src_size.x + src_coord_2.x; - const int src_index_3 = (l * params.src_size.y + src_coord_3.y) * - params.src_size.x + src_coord_3.x; - FLT4 srcColor_0 = src_buffer[src_index_0]; - FLT4 srcColor_1 = src_buffer[src_index_1]; - FLT4 srcColor_2 = src_buffer[src_index_2]; - FLT4 srcColor_3 = src_buffer[src_index_3]; - for (int k = 0; k < dst_channels; ++k) { - out[k] += dot(srcColor_0, filters[8].vals[l * dst_channels_aligned + k]); - out[k] += dot(srcColor_1, filters[6].vals[l * dst_channels_aligned + k]); - out[k] += dot(srcColor_2, filters[2].vals[l * dst_channels_aligned + k]); - out[k] += dot(srcColor_3, filters[0].vals[l * dst_channels_aligned + k]); - } - } - - for (short l = 0; l < dst_depth; ++l) { - FLT4 value = FLT4(out[l * 4], out[l * 4 + 1], out[l * 4 + 2], out[l * 4 + 3]) + biases[l]; - const int linear_index = (l * params.dst_size.y + int(dst_coord.y)) * - params.dst_size.x + int(dst_coord.x); - uint3 gid = uint3(uint(dst_coord.x), uint(dst_coord.y), uint(l)); - $$2 - dst_buffer[linear_index] = value; - } - } - )"; - - const int src_depth = IntegralDivideRoundUp(attr.weights.shape.i, 4); - const int dst_depth = IntegralDivideRoundUp(attr.weights.shape.o, 4); - const int dst_channels_aligned = AlignByN(attr.weights.shape.o, 4); - return absl::Substitute(shader_source, src_depth * dst_channels_aligned, - src_depth, dst_depth, attr.weights.shape.o, - dst_channels_aligned); -} - -std::string GetDeconvolutionShared3x3( - const ConvolutionTransposedAttributes& attr) { - std::string shader_source = R"( - #include - using namespace metal; - - struct FilterStripe { - FLT4 vals[$0]; - }; - - constant int src_depth = $1; - constant int dst_depth = $2; - constant int dst_channels = $3; - constant int dst_channels_aligned = $4; - - struct uniforms { - int2 src_size; - int2 dst_size; - short2 inner_size; - short2 src_offset; - short2 dst_offset; - }; - - $$0 - kernel void ComputeFunction( - $$1 - uint tid[[thread_index_in_threadgroup]], - uint2 ugid[[thread_position_in_grid]]) { - - float out[$4]; - for (short l = 0; l < dst_depth * 4; ++l) { - out[l] = float(0.0f); - } - - threadgroup FilterStripe stripes[4]; - threadgroup_barrier(mem_flags::mem_none); - if (tid < dst_channels) { - for (int l = 0; l < src_depth; ++l) { - stripes[0].vals[l * dst_channels_aligned + tid] - = filters[4].vals[l * dst_channels_aligned + tid]; - } - } - threadgroup_barrier(mem_flags::mem_threadgroup); - bool inside_grid = (static_cast(ugid.x) < params.inner_size.x) - && (static_cast(ugid.y) < params.inner_size.y); - - short2 src_coord_0 = short2(ugid) + params.src_offset; - short2 dst_coord = short2(ugid) * 2 + params.dst_offset; - - if (inside_grid) { - for (short l = 0; l < dst_depth * 4; ++l) { - out[l] = float(0.0f); - } - - for (int l = 0; l < src_depth; ++l) { - const int src_index_0 = (l * params.src_size.y + src_coord_0.y) * - params.src_size.x + src_coord_0.x; - FLT4 srcColor_0 = src_buffer[src_index_0]; - for (int k = 0; k < dst_channels; ++k) { - out[k] += dot(srcColor_0, stripes[0].vals[l * dst_channels_aligned + k]); - } - } - - for (short l = 0; l < dst_depth; ++l) { - FLT4 value = FLT4(out[l * 4], out[l * 4 + 1], out[l * 4 + 2], out[l * 4 + 3]) + biases[l]; - const int linear_index = (l * params.dst_size.y + int(dst_coord.y)) * - params.dst_size.x + int(dst_coord.x); - uint3 gid = uint3(ugid.x, ugid.y, uint(l)); - $$2 - dst_buffer[linear_index] = value; - } - } - - short2 src_coord_1 = src_coord_0 + short2(1, 0); - dst_coord += short2(1, 0); - - threadgroup_barrier(mem_flags::mem_none); - if (tid < dst_channels) { - for (int l = 0; l < src_depth; ++l) { - stripes[0].vals[l * dst_channels_aligned + tid] - = filters[5].vals[l * dst_channels_aligned + tid]; - stripes[1].vals[l * dst_channels_aligned + tid] - = filters[3].vals[l * dst_channels_aligned + tid]; - } - } - threadgroup_barrier(mem_flags::mem_threadgroup); - - if (inside_grid) { - for (short l = 0; l < dst_depth * 4; ++l) { - out[l] = float(0.0f); - } - - for (int l = 0; l < src_depth; ++l) { - const int src_index_0 = (l * params.src_size.y + src_coord_0.y) * - params.src_size.x + src_coord_0.x; - const int src_index_1 = (l * params.src_size.y + src_coord_1.y) * - params.src_size.x + src_coord_1.x; - FLT4 srcColor_0 = src_buffer[src_index_0]; - FLT4 srcColor_1 = src_buffer[src_index_1]; - for (int k = 0; k < dst_channels; ++k) { - out[k] += dot(srcColor_0, stripes[0].vals[l * dst_channels_aligned + k]); - out[k] += dot(srcColor_1, stripes[1].vals[l * dst_channels_aligned + k]); - } - } - - for (short l = 0; l < dst_depth; ++l) { - FLT4 value = FLT4(out[l * 4], out[l * 4 + 1], out[l * 4 + 2], out[l * 4 + 3]) + biases[l]; - const int linear_index = (l * params.dst_size.y + int(dst_coord.y)) * - params.dst_size.x + int(dst_coord.x); - uint3 gid = uint3(ugid.x, ugid.y, uint(l)); - $$2 - dst_buffer[linear_index] = value; - } - } - - short2 src_coord_2 = src_coord_0 + short2(0, 1); - dst_coord += short2(-1, 1); - - threadgroup_barrier(mem_flags::mem_none); - if (tid < dst_channels) { - for (int l = 0; l < src_depth; ++l) { - stripes[0].vals[l * dst_channels_aligned + tid] - = filters[7].vals[l * dst_channels_aligned + tid]; - stripes[1].vals[l * dst_channels_aligned + tid] - = filters[1].vals[l * dst_channels_aligned + tid]; - } - } - threadgroup_barrier(mem_flags::mem_threadgroup); - - if (inside_grid) { - for (short l = 0; l < dst_depth * 4; ++l) { - out[l] = float(0.0f); - } - - for (int l = 0; l < src_depth; ++l) { - const int src_index_0 = (l * params.src_size.y + src_coord_0.y) * - params.src_size.x + src_coord_0.x; - const int src_index_2 = (l * params.src_size.y + src_coord_2.y) * - params.src_size.x + src_coord_2.x; - FLT4 srcColor_0 = src_buffer[src_index_0]; - FLT4 srcColor_2 = src_buffer[src_index_2]; - for (int k = 0; k < dst_channels; ++k) { - out[k] += dot(srcColor_0, stripes[0].vals[l * dst_channels_aligned + k]); - out[k] += dot(srcColor_2, stripes[1].vals[l * dst_channels_aligned + k]); - } - } - - for (short l = 0; l < dst_depth; ++l) { - FLT4 value = FLT4(out[l * 4], out[l * 4 + 1], out[l * 4 + 2], out[l * 4 + 3]) + biases[l]; - const int linear_index = (l * params.dst_size.y + int(dst_coord.y)) * - params.dst_size.x + int(dst_coord.x); - uint3 gid = uint3(ugid.x, ugid.y, uint(l)); - $$2 - dst_buffer[linear_index] = value; - } - } - - short2 src_coord_3 = src_coord_0 + short2(1, 1); - dst_coord += short2(1, 0); - - threadgroup_barrier(mem_flags::mem_none); - if (tid < dst_channels) { - for (int l = 0; l < src_depth; ++l) { - stripes[0].vals[l * dst_channels_aligned + tid] - = filters[8].vals[l * dst_channels_aligned + tid]; - stripes[1].vals[l * dst_channels_aligned + tid] - = filters[6].vals[l * dst_channels_aligned + tid]; - stripes[2].vals[l * dst_channels_aligned + tid] - = filters[2].vals[l * dst_channels_aligned + tid]; - stripes[3].vals[l * dst_channels_aligned + tid] - = filters[0].vals[l * dst_channels_aligned + tid]; - } - } - threadgroup_barrier(mem_flags::mem_threadgroup); - - if (inside_grid) { - for (short l = 0; l < dst_depth * 4; ++l) { - out[l] = float(0.0f); - } - - for (int l = 0; l < src_depth; ++l) { - const int src_index_0 = (l * params.src_size.y + src_coord_0.y) * - params.src_size.x + src_coord_0.x; - const int src_index_1 = (l * params.src_size.y + src_coord_1.y) * - params.src_size.x + src_coord_1.x; - const int src_index_2 = (l * params.src_size.y + src_coord_2.y) * - params.src_size.x + src_coord_2.x; - const int src_index_3 = (l * params.src_size.y + src_coord_3.y) * - params.src_size.x + src_coord_3.x; - FLT4 srcColor_0 = src_buffer[src_index_0]; - FLT4 srcColor_1 = src_buffer[src_index_1]; - FLT4 srcColor_2 = src_buffer[src_index_2]; - FLT4 srcColor_3 = src_buffer[src_index_3]; - for (int k = 0; k < dst_channels; ++k) { - out[k] += dot(srcColor_0, stripes[0].vals[l * dst_channels_aligned + k]); - out[k] += dot(srcColor_1, stripes[1].vals[l * dst_channels_aligned + k]); - out[k] += dot(srcColor_2, stripes[2].vals[l * dst_channels_aligned + k]); - out[k] += dot(srcColor_3, stripes[3].vals[l * dst_channels_aligned + k]); - } - } - - for (short l = 0; l < dst_depth; ++l) { - FLT4 value = FLT4(out[l * 4], out[l * 4 + 1], out[l * 4 + 2], out[l * 4 + 3]) + biases[l]; - const int linear_index = (l * params.dst_size.y + int(dst_coord.y)) * - params.dst_size.x + int(dst_coord.x); - uint3 gid = uint3(ugid.x, ugid.y, uint(l)); - $$2 - dst_buffer[linear_index] = value; - } - } - } - )"; - const int src_depth = IntegralDivideRoundUp(attr.weights.shape.i, 4); - const int dst_depth = IntegralDivideRoundUp(attr.weights.shape.o, 4); - const int dst_channels_aligned = AlignByN(attr.weights.shape.o, 4); - return absl::Substitute(shader_source, src_depth * dst_channels_aligned, - src_depth, dst_depth, attr.weights.shape.o, - dst_channels_aligned); -} - std::string GetDeconvolution4x4(const int2& block_size, bool use_local_mem) { std::string c = R"( #include @@ -1054,7 +442,7 @@ std::string GetDeconvolution4x4(const int2& block_size, bool use_local_mem) { std::vector ConvolutionTransposed( int id, ValueId input_id, ValueId output_id, const ConvolutionTransposedAttributes& params, - const RuntimeOptions& options) { + const DeviceInfo& device_info, const RuntimeOptions& options) { auto desc = std::make_shared(); desc->id = id; desc->is_linkable = false; @@ -1066,9 +454,8 @@ std::vector ConvolutionTransposed( const int src_depth = IntegralDivideRoundUp(params.weights.shape.i, 4); const int shared_size = sizeof(float) * 4 * src_depth * src_local_size_x * src_local_size_y; - auto gpu_type = GetGpuType(); if (shared_size < 1000 * 16 && - (gpu_type == GpuType::kA7 || gpu_type == GpuType::kA8)) { + device_info.apple_info.IsLocalMemoryPreferredOverGlobal()) { desc->shader_source = GetDeconvolutionShared(params, kThreadGroupWidth, kThreadGroupHeight); } else { @@ -1152,204 +539,10 @@ std::vector ConvolutionTransposed( return {desc}; } -std::vector ConvolutionTransposed3x3( - int id, ValueId input_id, ValueId output_id, - const ConvolutionTransposedAttributes& params, - const RuntimeOptions& options) { - const int kThreadGroupWidth = 16; - const int kThreadGroupHeight = 4; - - auto border_desc = std::make_shared(); - border_desc->id = id; - border_desc->is_linkable = false; - - border_desc->shader_source = GetDeconvolutionBorder(params); - - border_desc->input_buffers = { - {input_id, "device FLT4* const src_buffer"}, - }; - - border_desc->output_buffer = { - output_id, "device FLT4* dst_buffer", - [input_id, params](const std::map& buffers) { - const auto& src_shape = buffers.find(input_id)->second; - BHWC dst_shape = CalculateOutputShape(src_shape, params); - return BHWC{src_shape.b, dst_shape.h, dst_shape.w, dst_shape.c}; - }}; - - const int src_depth = IntegralDivideRoundUp(params.weights.shape.i, 4); - const int src_ch_aligned = AlignByN(params.weights.shape.i, 4); - const int dst_ch_aligned = AlignByN(params.weights.shape.o, 4); - const int kernel_x = params.weights.shape.w; - const int kernel_y = params.weights.shape.h; - const int filters_aligned_size = - src_ch_aligned * dst_ch_aligned * kernel_x * kernel_y; - std::vector filters_reordered(filters_aligned_size); - - int counter = 0; - for (int y = 0; y < kernel_y; ++y) { - for (int x = 0; x < kernel_x; ++x) { - for (int ch = 0; ch < src_depth; ++ch) { - for (int f = 0; f < dst_ch_aligned; ++f) { - for (int i = 0; i < 4; ++i) { - if (ch * 4 + i >= params.weights.shape.i || - f >= params.weights.shape.o) { - filters_reordered[counter++] = 0.0f; - } else { - const int f_index = - params.weights.shape.LinearIndex({f, y, x, ch * 4 + i}); - filters_reordered[counter++] = params.weights.data[f_index]; - } - } - } - } - } - } - - auto filters = - GetByteBufferConverted(filters_reordered, options.storage_precision); - auto biases = GetByteBufferConvertedResized( - params.bias.data, options.storage_precision, params.weights.shape.o); - border_desc->immutable_buffers = { - {"device FilterStripe* const filters", filters}, - {"constant FLT4* const biases", biases}, - }; - - border_desc->uniform_buffers = { - {"constant uniforms& params", - [input_id, output_id, params](const std::map& buffers) { - const auto& src_dim = buffers.find(input_id)->second; - const auto& dst_dim = buffers.find(output_id)->second; - GridParams grid_params; - Params3x3 params3x3; - Init3x3(params, int2(src_dim.w, src_dim.h), int2(dst_dim.w, dst_dim.h), - &grid_params, ¶ms3x3); - int* ptr = reinterpret_cast(&grid_params); - std::vector uniform_params{ - src_dim.w, - src_dim.h, - dst_dim.w, - dst_dim.h, - /*uint GridParams.rect_offsets[4]*/ - ptr[0], - ptr[1], - ptr[2], - ptr[3], - /*uint GridParams.widths[4]*/ - ptr[4], - ptr[5], - ptr[6], - ptr[7], - /*short2 GridParams.origins[4]*/ - ptr[8], - ptr[9], - ptr[10], - ptr[11], - /*uint GridParams.elements_count*/ - ptr[12], - }; - return GetByteBuffer(uniform_params); - }}, - }; - - border_desc->resize_function = - [input_id, params](const std::map& buffers) { - const uint3 groups_size{kThreadGroupWidth * kThreadGroupHeight, 1, 1}; - const auto& src_shape = buffers.find(input_id)->second; - BHWC dst_shape = CalculateOutputShape(src_shape, params); - GridParams grid_params; - Params3x3 params3x3; - Init3x3(params, int2(src_shape.w, src_shape.h), - int2(dst_shape.w, dst_shape.h), &grid_params, ¶ms3x3); - if (grid_params.elements_count == 0) { - return std::make_pair(groups_size, uint3{0, 0, 0}); - } - int groups_x = - IntegralDivideRoundUp(grid_params.elements_count, groups_size.x); - int groups_y = 1; - int groups_z = 1; - return std::make_pair(groups_size, uint3{groups_x, groups_y, groups_z}); - }; - - auto desc = std::make_shared(); - desc->id = id; - desc->is_linkable = false; - - const int shared_size = sizeof(float) * 4 * src_depth * dst_ch_aligned * 4; - auto gpu_type = GetGpuType(); - if (shared_size < (1024 * 16 - 32) && - (gpu_type == GpuType::kA7 || gpu_type == GpuType::kA8) && - dst_ch_aligned <= kThreadGroupWidth * kThreadGroupHeight) { - desc->shader_source = GetDeconvolutionShared3x3(params); - } else { - desc->shader_source = GetDeconvolution3x3(params); - } - - desc->input_buffers = { - {input_id, "device FLT4* const src_buffer"}, - }; - - desc->output_buffer = { - output_id, "device FLT4* dst_buffer", - [input_id, params](const std::map& buffers) { - const auto& src_shape = buffers.find(input_id)->second; - BHWC dst_shape = CalculateOutputShape(src_shape, params); - return BHWC{src_shape.b, dst_shape.h, dst_shape.w, dst_shape.c}; - }}; - - desc->immutable_buffers = { - {"device FilterStripe* const filters", filters}, - {"constant FLT4* const biases", biases}, - }; - - desc->uniform_buffers = { - {"constant uniforms& params", - [input_id, output_id, params](const std::map& buffers) { - const auto& src_shape = buffers.find(input_id)->second; - const auto& dst_shape = buffers.find(output_id)->second; - GridParams grid_params; - Params3x3 params3x3; - Init3x3(params, int2(src_shape.w, src_shape.h), - int2(dst_shape.w, dst_shape.h), &grid_params, ¶ms3x3); - int* ptr = reinterpret_cast(¶ms3x3); - std::vector uniform_params{ - src_shape.w, - src_shape.h, - dst_shape.w, - dst_shape.h, - /*short2 Params3x3.inner_size*/ ptr[0], - /*short2 Params3x3.src_offset*/ ptr[1], - /*short2 Params3x3.dst_offset*/ ptr[2], - }; - return GetByteBuffer(uniform_params); - }}, - }; - - desc->resize_function = [input_id, - params](const std::map& buffers) { - const uint3 groups_size{kThreadGroupWidth, kThreadGroupHeight, 1}; - const auto& src_shape = buffers.find(input_id)->second; - BHWC dst_shape = CalculateOutputShape(src_shape, params); - GridParams grid_params; - Params3x3 params3x3; - Init3x3(params, int2(src_shape.w, src_shape.h), - int2(dst_shape.w, dst_shape.h), &grid_params, ¶ms3x3); - if (params3x3.inner_size.x * params3x3.inner_size.y == 0) { - return std::make_pair(groups_size, uint3{0, 0, 0}); - } - int groups_x = IntegralDivideRoundUp(params3x3.inner_size.x, groups_size.x); - int groups_y = IntegralDivideRoundUp(params3x3.inner_size.y, groups_size.y); - int groups_z = 1; - return std::make_pair(groups_size, uint3{groups_x, groups_y, groups_z}); - }; - - return {border_desc, desc}; -} - std::vector ConvolutionTransposed4x4( int id, ValueId input_id, ValueId output_id, const ConvolutionTransposedAttributes& params, - const RuntimeOptions& options) { + const DeviceInfo& device_info, const RuntimeOptions& options) { const int src_depth = IntegralDivideRoundUp(params.weights.shape.i, 4); const int dst_depth = IntegralDivideRoundUp(params.weights.shape.o, 4); const int kernel_x = 4; @@ -1402,12 +595,10 @@ std::vector ConvolutionTransposed4x4( desc->id = id; desc->is_linkable = false; - const auto gpu_type = GetGpuType(); - const bool powervr = gpu_type == GpuType::kA7 || gpu_type == GpuType::kA8 || - gpu_type == GpuType::kA9 || gpu_type == GpuType::kA10; const bool recommended_2x = - !powervr && options.storage_precision == RuntimeOptions::Precision::FP16; - const bool use_local_mem = powervr; + device_info.apple_info.IsBionic() && + options.storage_precision == RuntimeOptions::Precision::FP16; + const bool use_local_mem = !device_info.apple_info.IsBionic(); const int2 block_size(recommended_2x ? 2 : 1, 1); desc->shader_source = GetDeconvolution4x4(block_size, use_local_mem); diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/transpose_conv.h b/tensorflow/lite/delegates/gpu/metal/kernels/transpose_conv.h index cffab3cf90e..54dd2f93dcc 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/transpose_conv.h +++ b/tensorflow/lite/delegates/gpu/metal/kernels/transpose_conv.h @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/model.h" #include "tensorflow/lite/delegates/gpu/common/operations.h" #include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h" +#include "tensorflow/lite/delegates/gpu/metal/environment.h" #include "tensorflow/lite/delegates/gpu/metal/runtime_options.h" namespace tflite { @@ -30,17 +31,12 @@ namespace metal { std::vector ConvolutionTransposed( int id, ValueId input_id, ValueId output_id, const ConvolutionTransposedAttributes& params, - const RuntimeOptions& options); - -std::vector ConvolutionTransposed3x3( - int id, ValueId input_id, ValueId output_id, - const ConvolutionTransposedAttributes& params, - const RuntimeOptions& options); + const DeviceInfo& device_info, const RuntimeOptions& options); std::vector ConvolutionTransposed4x4( int id, ValueId input_id, ValueId output_id, const ConvolutionTransposedAttributes& params, - const RuntimeOptions& options); + const DeviceInfo& device_info, const RuntimeOptions& options); bool CheckConvolutionTransposed4x4Support( const ConvolutionTransposedAttributes& attr); diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/winograd.cc b/tensorflow/lite/delegates/gpu/metal/kernels/winograd.cc index 62dff5aa487..f1c9d75e62a 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/winograd.cc +++ b/tensorflow/lite/delegates/gpu/metal/kernels/winograd.cc @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/util.h" #include "tensorflow/lite/delegates/gpu/common/winograd_util.h" #include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h" +#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h" namespace tflite { namespace gpu { @@ -43,7 +44,7 @@ struct uniforms { int4 src_size; int4 dst_size; int2 padding; - int2 dummy0; + int2 tiles_count; }; )"; auto bt_mat = BtMatrixForWinograd4x4To6x6(); @@ -65,7 +66,7 @@ kernel void ComputeFunction($1 { int3 gid = int3(ugid.x * 4, ugid.y * 4, ugid.z); - if (gid.x >= U.src_size.x || gid.y >= U.src_size.y) return; + if (ugid.x >= U.tiles_count.x || ugid.y >= U.tiles_count.y) return; FLT4 I[6][6]; for (int y = 0; y < 6; ++y) { @@ -107,7 +108,7 @@ kernel void ComputeFunction($1 } c += R"( - int dst_x = ugid.y * (U.src_size.x + 3) / 4 + ugid.x; + int dst_x = ugid.y * U.tiles_count.x + ugid.x; int dst_adress = gid.z * U.dst_size.y * U.dst_size.x + dst_x; for (int y = 0; y < 6; ++y) { dst_buffer[dst_adress] = I[y][0] + Bt[2] * I[y][2] + Bt[4] * I[y][4]; @@ -156,15 +157,14 @@ $0 kernel void ComputeFunction($1 uint3 global_ids[[thread_position_in_grid]]) { - int3 gid = int3(global_ids.x, global_ids.y, global_ids.z); - int tile_id = global_ids.x; + int Z = static_cast(global_ids.z); int tiles_count_x = (U.dst_size.x + 3) / 4; int tile_x = (tile_id % tiles_count_x) * 4; int tile_y = (tile_id / tiles_count_x) * 4; if (tile_x >= U.dst_size.x || tile_y >= U.dst_size.y) return; - int src_adress = gid.z * U.src_size.y * U.src_size.x + gid.x; + int src_adress = Z * U.src_size.y * U.src_size.x + tile_id; FLT4 I[4][6]; for (int y = 0; y < 4; ++y) { for (int x = 0; x < 6; ++x) { @@ -181,15 +181,15 @@ kernel void ComputeFunction($1 } } - FLT4 bias_val = biases[gid.z]; - int dst_adress = (gid.z * U.dst_size.y + tile_y) * U.dst_size.x + tile_x; + FLT4 bias_val = biases[Z]; + int dst_adress = (Z * U.dst_size.y + tile_y) * U.dst_size.x + tile_x; for (int y = 0; y < 4 && tile_y + y < U.dst_size.y; ++y) { FLT4 t0 = I[y][1] + I[y][2]; FLT4 t1 = I[y][3] + I[y][4]; if (tile_x < U.dst_size.x) { FLT4 value = I[y][0] + t0 + t1 + bias_val; int linear_index = dst_adress; - uint3 ugid = uint3(tile_x, tile_y + y, global_ids.z); + uint3 gid = uint3(tile_x, tile_y + y, global_ids.z); $2 dst_buffer[linear_index] = value; } @@ -198,20 +198,20 @@ kernel void ComputeFunction($1 if (tile_x + 1 < U.dst_size.x) { FLT4 value = t2 * At[7] + t3 * At[9] + bias_val; int linear_index = dst_adress + 1; - uint3 ugid = uint3(tile_x + 1, tile_y + y, global_ids.z); + uint3 gid = uint3(tile_x + 1, tile_y + y, global_ids.z); $2 dst_buffer[linear_index] = value; } if (tile_x + 2 < U.dst_size.x) { FLT4 value = t0 * At[13] + t1 * At[15] + bias_val; int linear_index = dst_adress + 2; - uint3 ugid = uint3(tile_x + 2, tile_y + y, global_ids.z); + uint3 gid = uint3(tile_x + 2, tile_y + y, global_ids.z); $2 dst_buffer[linear_index] = value; } if (tile_x + 3 < U.dst_size.x) { FLT4 value = t2 * At[19] + t3 * At[21] + I[y][5] + bias_val; - uint3 ugid = uint3(tile_x + 3, tile_y + y, global_ids.z); + uint3 gid = uint3(tile_x + 3, tile_y + y, global_ids.z); int linear_index = dst_adress + 3; $2 dst_buffer[linear_index] = value; @@ -236,24 +236,34 @@ std::vector Winograd4x4To36( {input_id, "device FLT4* const src_buffer"}, }; - desc->output_buffer = {output_id, "device FLT4* dst_buffer", - [input_id](const std::map& buffers) { - const auto src_shape = - buffers.find(input_id)->second; - BHWC dst_shape; - dst_shape.b = src_shape.b; - dst_shape.h = 36; - dst_shape.w = IntegralDivideRoundUp(src_shape.w, 4) * - IntegralDivideRoundUp(src_shape.h, 4); - dst_shape.c = src_shape.c; - return dst_shape; - }}; + desc->output_buffer = { + output_id, "device FLT4* dst_buffer", + [input_id, attr](const std::map& buffers) { + const auto src_shape = buffers.find(input_id)->second; + int new_width = src_shape.w + attr.padding.prepended.w + + attr.padding.appended.w - 2; + int new_height = src_shape.h + attr.padding.prepended.h + + attr.padding.appended.h - 2; + BHWC dst_shape; + dst_shape.b = src_shape.b; + dst_shape.h = 36; + dst_shape.w = IntegralDivideRoundUp(new_width, 4) * + IntegralDivideRoundUp(new_height, 4); + dst_shape.c = src_shape.c; + return dst_shape; + }}; desc->uniform_buffers = { {"constant uniforms& U", [input_id, output_id, attr](const std::map& buffers) { const auto& src_shape = buffers.find(input_id)->second; const auto& dst_shape = buffers.find(output_id)->second; + int new_width = src_shape.w + attr.padding.prepended.w + + attr.padding.appended.w - 2; + int new_height = src_shape.h + attr.padding.prepended.h + + attr.padding.appended.h - 2; + int tiles_x = IntegralDivideRoundUp(new_width, 4); + int tiles_y = IntegralDivideRoundUp(new_height, 4); std::vector sizes = { src_shape.w, src_shape.h, @@ -265,18 +275,23 @@ std::vector Winograd4x4To36( 0, -attr.padding.prepended.w, -attr.padding.prepended.h, - 0, - 0, + tiles_x, + tiles_y, }; return GetByteBuffer(sizes); }}, }; - desc->resize_function = [input_id](const std::map& buffers) { + desc->resize_function = [input_id, + attr](const std::map& buffers) { const uint3 groups_size{8, 4, 1}; const auto& src_shape = buffers.find(input_id)->second; - int grid_x = IntegralDivideRoundUp(src_shape.w, 4); - int grid_y = IntegralDivideRoundUp(src_shape.h, 4); + int new_width = + src_shape.w + attr.padding.prepended.w + attr.padding.appended.w - 2; + int new_height = + src_shape.h + attr.padding.prepended.h + attr.padding.appended.h - 2; + int grid_x = IntegralDivideRoundUp(new_width, 4); + int grid_y = IntegralDivideRoundUp(new_height, 4); int grid_z = IntegralDivideRoundUp(src_shape.c, 4); int groups_x = IntegralDivideRoundUp(grid_x, groups_size.x); int groups_y = IntegralDivideRoundUp(grid_y, groups_size.y); diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/winograd.h b/tensorflow/lite/delegates/gpu/metal/kernels/winograd.h index 88267694f07..26c18538fd9 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/winograd.h +++ b/tensorflow/lite/delegates/gpu/metal/kernels/winograd.h @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/model.h" #include "tensorflow/lite/delegates/gpu/common/operations.h" #include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h" +#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h" namespace tflite { namespace gpu { diff --git a/tensorflow/lite/delegates/gpu/metal_delegate.mm b/tensorflow/lite/delegates/gpu/metal_delegate.mm index 4c6bb140a96..797a2c4e4c9 100644 --- a/tensorflow/lite/delegates/gpu/metal_delegate.mm +++ b/tensorflow/lite/delegates/gpu/metal_delegate.mm @@ -291,21 +291,17 @@ class Delegate { tensor->delegate = &delegate_; } + std::string device_name = std::string([[metal_device_ name] UTF8String]); + DeviceInfo device_info(device_name); size_t storage_type_size; RuntimeOptions runtime_options; if (options_.allow_precision_loss) { storage_type_size = sizeof(HalfBits); runtime_options.storage_precision = RuntimeOptions::Precision::FP16; - const auto gpu_type = GetGpuType(); - const bool powervr = gpu_type == GpuType::kA7 || gpu_type == GpuType::kA8 || - gpu_type == GpuType::kA9 || gpu_type == GpuType::kA10; - if (powervr) { - // PowerVR gpus support only round to zero for floating-point operations, - // to increase precision we will use F32 accumulator in this case - runtime_options.accumulator_precision = RuntimeOptions::Precision::FP32; - } else { - // Apple own gpus support round to nearest and have better precision + if (device_info.IsRoundToNearestSupported()) { runtime_options.accumulator_precision = RuntimeOptions::Precision::FP16; + } else { + runtime_options.accumulator_precision = RuntimeOptions::Precision::FP32; } } else { storage_type_size = sizeof(float); @@ -395,7 +391,7 @@ class Delegate { // TODO(impjdi): Merge these. CompiledModel compiled_model; - RETURN_IF_ERROR(Compile(graph, runtime_options, &compiled_model)); + RETURN_IF_ERROR(Compile(graph, device_info, runtime_options, &compiled_model)); CompiledModel optimized_model; RETURN_IF_ERROR(ValidateOptimizeModel(input_ids, output_ids, compiled_model, &optimized_model)); diff --git a/tensorflow/lite/delegates/xnnpack/README.md b/tensorflow/lite/delegates/xnnpack/README.md index 5c2a8569fee..76cad421a09 100644 --- a/tensorflow/lite/delegates/xnnpack/README.md +++ b/tensorflow/lite/delegates/xnnpack/README.md @@ -14,7 +14,6 @@ the model to the XNNPACK delegate. The users must destroy the delegate with `TfLiteXNNPackDelegateDelete` **after** releasing the TensorFlow Lite interpreter. The snippet below illustrates the typical usage: - ```c++ // Build the interpreter std::unique_ptr interpreter; @@ -40,7 +39,7 @@ interpreter->Invoke() ... -// IMPORTANT: release the interpreter before destroing the delegate +// IMPORTANT: release the interpreter before destroying the delegate interpreter.reset(); TfLiteXNNPackDelegateDelete(xnnpack_delegate); ``` @@ -63,6 +62,15 @@ Below is the list of current operators and limitations: * Dynamically allocated (with `kTfLiteDynamic` allocation type) inputs and output are not supported. +### `AVERAGE_POOL_2D` + +* Inputs and outputs must be in 32-bit floating-point format. +* 1x1 pooling is not supported. +* Fused `NONE`, `RELU`, `RELU_N1_TO_1`, and `RELU6` activations are supported, + but fused `TANH` and `SIGN_BIT` activations are not. +* Dynamically allocated (with `kTfLiteDynamic` allocation type) inputs and + output are not supported. + ### `CONV_2D` * Inputs and outputs must be in 32-bit floating-point format. @@ -95,6 +103,15 @@ Below is the list of current operators and limitations: * Dynamically allocated (with `kTfLiteDynamic` allocation type) inputs and output are not supported. +### `MAX_POOL_2D` + +* Inputs and outputs must be in 32-bit floating-point format. +* 1x1 pooling is not supported. +* Fused `NONE`, `RELU`, `RELU_N1_TO_1`, and `RELU6` activations are supported, + but fused `TANH` and `SIGN_BIT` activations are not. +* Dynamically allocated (with `kTfLiteDynamic` allocation type) inputs and + output are not supported. + ### `MUL` * Inputs and outputs must be in 32-bit floating-point format. diff --git a/tensorflow/lite/experimental/delegates/hexagon/builders/tests/mul_test.cc b/tensorflow/lite/experimental/delegates/hexagon/builders/tests/mul_test.cc index 3fd68c30586..db7eef46150 100644 --- a/tensorflow/lite/experimental/delegates/hexagon/builders/tests/mul_test.cc +++ b/tensorflow/lite/experimental/delegates/hexagon/builders/tests/mul_test.cc @@ -26,8 +26,9 @@ class MulOpModel : public SingleOpModelWithHexagon { input1_ = AddInput(input1); input2_ = AddInput(input2); output_ = AddOutput(output); - SetBuiltinOp(BuiltinOperator_MUL, BuiltinOptions_MulOptions, - CreateMulOptions(builder_, activation_type).Union()); + SetBuiltinOp( + BuiltinOperator_MUL, BuiltinOptions_MulOptions, + CreateMulOptions(builder_, ActivationFunctionType_NONE).Union()); BuildInterpreter({GetShape(input1_), GetShape(input2_)}); } diff --git a/tensorflow/lite/experimental/delegates/hexagon/utils.cc b/tensorflow/lite/experimental/delegates/hexagon/utils.cc index b1ecde764ad..508f6657e61 100644 --- a/tensorflow/lite/experimental/delegates/hexagon/utils.cc +++ b/tensorflow/lite/experimental/delegates/hexagon/utils.cc @@ -67,6 +67,8 @@ bool CheckOpVersion(const TfLiteRegistration* registration) { case kTfLiteBuiltinDepthwiseConv2d: case kTfLiteBuiltinSoftmax: return registration->version <= 2; + case kTfLiteBuiltinRelu: + return registration->version >= 2; default: return registration->version == 1; } diff --git a/tensorflow/lite/experimental/examples/lstm/rnn_cell.py b/tensorflow/lite/experimental/examples/lstm/rnn_cell.py index 3d5ebf4946f..9736719c997 100644 --- a/tensorflow/lite/experimental/examples/lstm/rnn_cell.py +++ b/tensorflow/lite/experimental/examples/lstm/rnn_cell.py @@ -23,8 +23,6 @@ from __future__ import print_function import itertools from tensorflow.lite.python.op_hint import OpHint -from tensorflow.python.keras import activations -from tensorflow.python.keras import initializers from tensorflow.python.layers import base as base_layer from tensorflow.python.ops import array_ops from tensorflow.python.ops import clip_ops @@ -80,7 +78,9 @@ class TfLiteRNNCell(rnn_cell_impl.LayerRNNCell): self._tflite_wrapper = OpHint("UnidirectionalSequenceRnn") self._num_units = num_units if activation: - self._activation = activations.get(activation) + if activation != "tanh": + raise ValueError("activation other than tanh is not supported") + self._activation = math_ops.tanh else: self._activation = math_ops.tanh @@ -150,7 +150,7 @@ class TfLiteRNNCell(rnn_cell_impl.LayerRNNCell): def get_config(self): config = { "num_units": self._num_units, - "activation": activations.serialize(self._activation), + "activation": "tanh", "reuse": self._reuse, } base_config = super(TfLiteRNNCell, self).get_config() @@ -268,7 +268,12 @@ class TFLiteLSTMCell(rnn_cell_impl.LayerRNNCell): self._num_proj_shards = num_proj_shards self._forget_bias = forget_bias self._state_is_tuple = state_is_tuple - self._activation = activation or math_ops.tanh + if activation: + if activation != "tanh": + raise ValueError("activation other than tanh is not supported") + self._activation = math_ops.tanh + else: + self._activation = math_ops.tanh self._output_size = num_proj if num_proj else num_units self._state_size = ( @@ -516,14 +521,13 @@ class TFLiteLSTMCell(rnn_cell_impl.LayerRNNCell): "num_units": self._num_units, "use_peepholes": self._use_peepholes, "cell_clip": self._cell_clip, - "initializer": initializers.serialize(self._initializer), "num_proj": self._num_proj, "proj_clip": self._proj_clip, "num_unit_shards": self._num_unit_shards, "num_proj_shards": self._num_proj_shards, "forget_bias": self._forget_bias, "state_is_tuple": self._state_is_tuple, - "activation": activations.serialize(self._activation), + "activation": "tanh", "reuse": self._reuse, } base_config = super(TFLiteLSTMCell, self).get_config() diff --git a/tensorflow/lite/experimental/ruy/WORKSPACE b/tensorflow/lite/experimental/ruy/WORKSPACE new file mode 100644 index 00000000000..8364d8047b1 --- /dev/null +++ b/tensorflow/lite/experimental/ruy/WORKSPACE @@ -0,0 +1,17 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Workspace file for the Ruy project. + +workspace(name = "com_google_ruy") diff --git a/tensorflow/lite/experimental/ruy/BUILD b/tensorflow/lite/experimental/ruy/ruy/BUILD similarity index 90% rename from tensorflow/lite/experimental/ruy/BUILD rename to tensorflow/lite/experimental/ruy/ruy/BUILD index 410b197f11f..c808c3ec063 100644 --- a/tensorflow/lite/experimental/ruy/BUILD +++ b/tensorflow/lite/experimental/ruy/ruy/BUILD @@ -131,7 +131,7 @@ cc_library( ":opt_set", ":platform", ":time", - "//tensorflow/lite/experimental/ruy/profiler:instrumentation", + "//tensorflow/lite/experimental/ruy/ruy/profiler:instrumentation", ], ) @@ -209,7 +209,7 @@ cc_library( ":path", ":side_pair", ":size_util", - "//tensorflow/lite/experimental/ruy/profiler:instrumentation", + "//tensorflow/lite/experimental/ruy/ruy/profiler:instrumentation", ], ) @@ -383,7 +383,7 @@ cc_library( ":size_util", ":spec", ":tune", - "//tensorflow/lite/experimental/ruy/profiler:instrumentation", + "//tensorflow/lite/experimental/ruy/ruy/profiler:instrumentation", ], ) @@ -405,7 +405,7 @@ cc_library( ":path", ":platform", ":tune", - "//tensorflow/lite/experimental/ruy/profiler:instrumentation", + "//tensorflow/lite/experimental/ruy/ruy/profiler:instrumentation", ], ) @@ -421,7 +421,7 @@ cc_library( ":kernel_common", ":opt_set", ":platform", - "//tensorflow/lite/experimental/ruy/profiler:instrumentation", + "//tensorflow/lite/experimental/ruy/ruy/profiler:instrumentation", ], ) @@ -436,7 +436,7 @@ cc_library( ":opt_set", ":pack_common", ":platform", - "//tensorflow/lite/experimental/ruy/profiler:instrumentation", + "//tensorflow/lite/experimental/ruy/ruy/profiler:instrumentation", ], ) @@ -456,7 +456,7 @@ cc_library( ":kernel_common", ":opt_set", ":platform", - "//tensorflow/lite/experimental/ruy/profiler:instrumentation", + "//tensorflow/lite/experimental/ruy/ruy/profiler:instrumentation", ], ) @@ -473,7 +473,7 @@ cc_library( ":pack_common", ":path", ":platform", - "//tensorflow/lite/experimental/ruy/profiler:instrumentation", + "//tensorflow/lite/experimental/ruy/ruy/profiler:instrumentation", ], ) @@ -509,7 +509,7 @@ cc_library( ":kernel_common", ":opt_set", ":platform", - "//tensorflow/lite/experimental/ruy/profiler:instrumentation", + "//tensorflow/lite/experimental/ruy/ruy/profiler:instrumentation", ], ) @@ -526,7 +526,7 @@ cc_library( ":pack_common", ":path", ":platform", - "//tensorflow/lite/experimental/ruy/profiler:instrumentation", + "//tensorflow/lite/experimental/ruy/ruy/profiler:instrumentation", ], ) @@ -566,7 +566,7 @@ cc_library( ":kernel_common", ":opt_set", ":platform", - "//tensorflow/lite/experimental/ruy/profiler:instrumentation", + "//tensorflow/lite/experimental/ruy/ruy/profiler:instrumentation", ], ) @@ -583,7 +583,7 @@ cc_library( ":pack_common", ":path", ":platform", - "//tensorflow/lite/experimental/ruy/profiler:instrumentation", + "//tensorflow/lite/experimental/ruy/ruy/profiler:instrumentation", ], ) @@ -623,7 +623,7 @@ cc_library( ":kernel_common", ":opt_set", ":platform", - "//tensorflow/lite/experimental/ruy/profiler:instrumentation", + "//tensorflow/lite/experimental/ruy/ruy/profiler:instrumentation", ], ) @@ -640,7 +640,7 @@ cc_library( ":pack_common", ":path", ":platform", - "//tensorflow/lite/experimental/ruy/profiler:instrumentation", + "//tensorflow/lite/experimental/ruy/ruy/profiler:instrumentation", ], ) @@ -685,7 +685,7 @@ cc_library( ":size_util", ":spec", ":tune", - "//tensorflow/lite/experimental/ruy/profiler:instrumentation", + "//tensorflow/lite/experimental/ruy/ruy/profiler:instrumentation", ], ) @@ -711,7 +711,7 @@ cc_library( ":path", ":platform", ":tune", - "//tensorflow/lite/experimental/ruy/profiler:instrumentation", + "//tensorflow/lite/experimental/ruy/ruy/profiler:instrumentation", ], ) @@ -796,7 +796,7 @@ cc_library( ":trace", ":trmul_params", ":tune", - "//tensorflow/lite/experimental/ruy/profiler:instrumentation", + "//tensorflow/lite/experimental/ruy/ruy/profiler:instrumentation", ], ) @@ -829,7 +829,7 @@ cc_library( ":trmul", ":trmul_params", ":tune", - "//tensorflow/lite/experimental/ruy/profiler:instrumentation", + "//tensorflow/lite/experimental/ruy/ruy/profiler:instrumentation", ], ) @@ -877,7 +877,7 @@ cc_library( ":time", "@com_google_googletest//:gtest", ":platform", - "//tensorflow/lite/experimental/ruy/profiler:profiler", + "//tensorflow/lite/experimental/ruy/ruy/profiler:profiler", ] + ruy_test_ext_deps(), ) @@ -894,8 +894,8 @@ ruy_benchmark( ("i8", "i8", "i32", "i32"), ], deps = [ - "//tensorflow/lite/experimental/ruy:test_lib", - "//tensorflow/lite/experimental/ruy/profiler:instrumentation", + "//tensorflow/lite/experimental/ruy/ruy:test_lib", + "//tensorflow/lite/experimental/ruy/ruy/profiler:instrumentation", ], ) @@ -915,7 +915,7 @@ ruy_test( ("i8", "u8", "i32", "i32"), ], deps = [ - "//tensorflow/lite/experimental/ruy:test_lib", + "//tensorflow/lite/experimental/ruy/ruy:test_lib", "@com_google_googletest//:gtest_main", ], ) @@ -933,7 +933,7 @@ ruy_test( ], tags = ["slow"], deps = [ - "//tensorflow/lite/experimental/ruy:test_lib", + "//tensorflow/lite/experimental/ruy/ruy:test_lib", "@com_google_googletest//:gtest_main", ], ) @@ -948,7 +948,7 @@ ruy_test( ("u8", "u8", "i32", "i16"), ], deps = [ - "//tensorflow/lite/experimental/ruy:test_lib", + "//tensorflow/lite/experimental/ruy/ruy:test_lib", "@com_google_googletest//:gtest_main", ], ) diff --git a/tensorflow/lite/experimental/ruy/allocator.cc b/tensorflow/lite/experimental/ruy/ruy/allocator.cc similarity index 95% rename from tensorflow/lite/experimental/ruy/allocator.cc rename to tensorflow/lite/experimental/ruy/ruy/allocator.cc index d702f70e9fb..2c507561f2f 100644 --- a/tensorflow/lite/experimental/ruy/allocator.cc +++ b/tensorflow/lite/experimental/ruy/ruy/allocator.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/lite/experimental/ruy/allocator.h" +#include "tensorflow/lite/experimental/ruy/ruy/allocator.h" #include #include diff --git a/tensorflow/lite/experimental/ruy/allocator.h b/tensorflow/lite/experimental/ruy/ruy/allocator.h similarity index 95% rename from tensorflow/lite/experimental/ruy/allocator.h rename to tensorflow/lite/experimental/ruy/ruy/allocator.h index 2f5c98d6870..56aa0eef8f9 100644 --- a/tensorflow/lite/experimental/ruy/allocator.h +++ b/tensorflow/lite/experimental/ruy/ruy/allocator.h @@ -13,16 +13,16 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_ALLOCATOR_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_ALLOCATOR_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_ALLOCATOR_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_ALLOCATOR_H_ #include #include #include #include -#include "tensorflow/lite/experimental/ruy/check_macros.h" -#include "tensorflow/lite/experimental/ruy/size_util.h" +#include "tensorflow/lite/experimental/ruy/ruy/check_macros.h" +#include "tensorflow/lite/experimental/ruy/ruy/size_util.h" namespace ruy { @@ -182,4 +182,4 @@ class Allocator { } // namespace ruy -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_ALLOCATOR_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_ALLOCATOR_H_ diff --git a/tensorflow/lite/experimental/ruy/allocator_test.cc b/tensorflow/lite/experimental/ruy/ruy/allocator_test.cc similarity index 98% rename from tensorflow/lite/experimental/ruy/allocator_test.cc rename to tensorflow/lite/experimental/ruy/ruy/allocator_test.cc index 4bc99568163..1584b86b4cc 100644 --- a/tensorflow/lite/experimental/ruy/allocator_test.cc +++ b/tensorflow/lite/experimental/ruy/ruy/allocator_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/lite/experimental/ruy/allocator.h" +#include "tensorflow/lite/experimental/ruy/ruy/allocator.h" #include diff --git a/tensorflow/lite/experimental/ruy/benchmark.cc b/tensorflow/lite/experimental/ruy/ruy/benchmark.cc similarity index 99% rename from tensorflow/lite/experimental/ruy/benchmark.cc rename to tensorflow/lite/experimental/ruy/ruy/benchmark.cc index beb52cbdab7..406345cec06 100644 --- a/tensorflow/lite/experimental/ruy/benchmark.cc +++ b/tensorflow/lite/experimental/ruy/ruy/benchmark.cc @@ -17,7 +17,7 @@ limitations under the License. #include #include -#include "tensorflow/lite/experimental/ruy/test.h" +#include "tensorflow/lite/experimental/ruy/ruy/test.h" namespace ruy { diff --git a/tensorflow/lite/experimental/ruy/block_map.cc b/tensorflow/lite/experimental/ruy/ruy/block_map.cc similarity index 98% rename from tensorflow/lite/experimental/ruy/block_map.cc rename to tensorflow/lite/experimental/ruy/ruy/block_map.cc index a08fbceb941..32781d82ad3 100644 --- a/tensorflow/lite/experimental/ruy/block_map.cc +++ b/tensorflow/lite/experimental/ruy/ruy/block_map.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/lite/experimental/ruy/block_map.h" +#include "tensorflow/lite/experimental/ruy/ruy/block_map.h" #include #include @@ -24,10 +24,10 @@ limitations under the License. #include #endif -#include "tensorflow/lite/experimental/ruy/check_macros.h" -#include "tensorflow/lite/experimental/ruy/opt_set.h" -#include "tensorflow/lite/experimental/ruy/profiler/instrumentation.h" -#include "tensorflow/lite/experimental/ruy/size_util.h" +#include "tensorflow/lite/experimental/ruy/ruy/check_macros.h" +#include "tensorflow/lite/experimental/ruy/ruy/opt_set.h" +#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" +#include "tensorflow/lite/experimental/ruy/ruy/size_util.h" namespace ruy { diff --git a/tensorflow/lite/experimental/ruy/block_map.h b/tensorflow/lite/experimental/ruy/ruy/block_map.h similarity index 96% rename from tensorflow/lite/experimental/ruy/block_map.h rename to tensorflow/lite/experimental/ruy/ruy/block_map.h index 48110c8bcfc..0fa4c9d5d60 100644 --- a/tensorflow/lite/experimental/ruy/block_map.h +++ b/tensorflow/lite/experimental/ruy/ruy/block_map.h @@ -13,11 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_BLOCK_MAP_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_BLOCK_MAP_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_BLOCK_MAP_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_BLOCK_MAP_H_ -#include "tensorflow/lite/experimental/ruy/path.h" -#include "tensorflow/lite/experimental/ruy/side_pair.h" +#include "tensorflow/lite/experimental/ruy/ruy/path.h" +#include "tensorflow/lite/experimental/ruy/ruy/side_pair.h" namespace ruy { @@ -158,4 +158,4 @@ inline int NumBlocks(const BlockMap& block_map) { } // namespace ruy -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_BLOCK_MAP_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_BLOCK_MAP_H_ diff --git a/tensorflow/lite/experimental/ruy/block_map_test.cc b/tensorflow/lite/experimental/ruy/ruy/block_map_test.cc similarity index 98% rename from tensorflow/lite/experimental/ruy/block_map_test.cc rename to tensorflow/lite/experimental/ruy/ruy/block_map_test.cc index fd322ab66ca..cdd7ee0e01f 100644 --- a/tensorflow/lite/experimental/ruy/block_map_test.cc +++ b/tensorflow/lite/experimental/ruy/ruy/block_map_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/lite/experimental/ruy/block_map.h" +#include "tensorflow/lite/experimental/ruy/ruy/block_map.h" #include #include @@ -22,9 +22,9 @@ limitations under the License. #include #include -#include "tensorflow/lite/experimental/ruy/cpu_cache_size.h" -#include "tensorflow/lite/experimental/ruy/path.h" -#include "tensorflow/lite/experimental/ruy/side_pair.h" +#include "tensorflow/lite/experimental/ruy/ruy/cpu_cache_size.h" +#include "tensorflow/lite/experimental/ruy/ruy/path.h" +#include "tensorflow/lite/experimental/ruy/ruy/side_pair.h" namespace ruy { namespace { diff --git a/tensorflow/lite/experimental/ruy/blocking_counter.cc b/tensorflow/lite/experimental/ruy/ruy/blocking_counter.cc similarity index 88% rename from tensorflow/lite/experimental/ruy/blocking_counter.cc rename to tensorflow/lite/experimental/ruy/ruy/blocking_counter.cc index eba4ae4a2f4..d313ffce51b 100644 --- a/tensorflow/lite/experimental/ruy/blocking_counter.cc +++ b/tensorflow/lite/experimental/ruy/ruy/blocking_counter.cc @@ -13,10 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/lite/experimental/ruy/blocking_counter.h" +#include "tensorflow/lite/experimental/ruy/ruy/blocking_counter.h" -#include "tensorflow/lite/experimental/ruy/check_macros.h" -#include "tensorflow/lite/experimental/ruy/wait.h" +#include "tensorflow/lite/experimental/ruy/ruy/check_macros.h" +#include "tensorflow/lite/experimental/ruy/ruy/wait.h" namespace ruy { diff --git a/tensorflow/lite/experimental/ruy/blocking_counter.h b/tensorflow/lite/experimental/ruy/ruy/blocking_counter.h similarity index 91% rename from tensorflow/lite/experimental/ruy/blocking_counter.h rename to tensorflow/lite/experimental/ruy/ruy/blocking_counter.h index e8c76d514a5..878f0e7219e 100644 --- a/tensorflow/lite/experimental/ruy/blocking_counter.h +++ b/tensorflow/lite/experimental/ruy/ruy/blocking_counter.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_BLOCKING_COUNTER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_BLOCKING_COUNTER_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_BLOCKING_COUNTER_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_BLOCKING_COUNTER_H_ #include #include // NOLINT(build/c++11) // IWYU pragma: keep @@ -59,4 +59,4 @@ class BlockingCounter { } // namespace ruy -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_BLOCKING_COUNTER_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_BLOCKING_COUNTER_H_ diff --git a/tensorflow/lite/experimental/ruy/build_defs.bzl b/tensorflow/lite/experimental/ruy/ruy/build_defs.bzl similarity index 100% rename from tensorflow/lite/experimental/ruy/build_defs.bzl rename to tensorflow/lite/experimental/ruy/ruy/build_defs.bzl diff --git a/tensorflow/lite/experimental/ruy/check_macros.h b/tensorflow/lite/experimental/ruy/ruy/check_macros.h similarity index 96% rename from tensorflow/lite/experimental/ruy/check_macros.h rename to tensorflow/lite/experimental/ruy/ruy/check_macros.h index 564440b4c8f..773f37d99f2 100644 --- a/tensorflow/lite/experimental/ruy/check_macros.h +++ b/tensorflow/lite/experimental/ruy/ruy/check_macros.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_CHECK_MACROS_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_CHECK_MACROS_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_CHECK_MACROS_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_CHECK_MACROS_H_ #include #include @@ -135,4 +135,4 @@ inline void Failure(const char* file, int line, const char* macro, } // end namespace check_macros } // end namespace ruy -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_CHECK_MACROS_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_CHECK_MACROS_H_ diff --git a/tensorflow/lite/experimental/ruy/check_macros_test.cc b/tensorflow/lite/experimental/ruy/ruy/check_macros_test.cc similarity index 98% rename from tensorflow/lite/experimental/ruy/check_macros_test.cc rename to tensorflow/lite/experimental/ruy/ruy/check_macros_test.cc index de56c1d8b8a..1a2a5a238f2 100644 --- a/tensorflow/lite/experimental/ruy/check_macros_test.cc +++ b/tensorflow/lite/experimental/ruy/ruy/check_macros_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/lite/experimental/ruy/check_macros.h" +#include "tensorflow/lite/experimental/ruy/ruy/check_macros.h" #include diff --git a/tensorflow/lite/experimental/ruy/common.h b/tensorflow/lite/experimental/ruy/ruy/common.h similarity index 80% rename from tensorflow/lite/experimental/ruy/common.h rename to tensorflow/lite/experimental/ruy/ruy/common.h index 9c4e50a033a..e52a6ba6976 100644 --- a/tensorflow/lite/experimental/ruy/common.h +++ b/tensorflow/lite/experimental/ruy/ruy/common.h @@ -15,17 +15,17 @@ limitations under the License. // Miscellaneous helpers internal library. -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_COMMON_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_COMMON_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_COMMON_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_COMMON_H_ #include #include -#include "tensorflow/lite/experimental/ruy/check_macros.h" -#include "tensorflow/lite/experimental/ruy/matrix.h" -#include "tensorflow/lite/experimental/ruy/opt_set.h" -#include "tensorflow/lite/experimental/ruy/path.h" -#include "tensorflow/lite/experimental/ruy/platform.h" +#include "tensorflow/lite/experimental/ruy/ruy/check_macros.h" +#include "tensorflow/lite/experimental/ruy/ruy/matrix.h" +#include "tensorflow/lite/experimental/ruy/ruy/opt_set.h" +#include "tensorflow/lite/experimental/ruy/ruy/path.h" +#include "tensorflow/lite/experimental/ruy/ruy/platform.h" #if RUY_OPT_ENABLED(RUY_OPT_PREFETCH_LOAD) #define RUY_PREFETCH_LOAD(X) X @@ -70,4 +70,4 @@ Scalar SymmetricZeroPoint() { } // namespace ruy -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_COMMON_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_COMMON_H_ diff --git a/tensorflow/lite/experimental/ruy/context.cc b/tensorflow/lite/experimental/ruy/ruy/context.cc similarity index 91% rename from tensorflow/lite/experimental/ruy/context.cc rename to tensorflow/lite/experimental/ruy/ruy/context.cc index e3cae69019d..e0d4701645f 100644 --- a/tensorflow/lite/experimental/ruy/context.cc +++ b/tensorflow/lite/experimental/ruy/ruy/context.cc @@ -13,13 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/lite/experimental/ruy/context.h" +#include "tensorflow/lite/experimental/ruy/ruy/context.h" -#include "tensorflow/lite/experimental/ruy/check_macros.h" -#include "tensorflow/lite/experimental/ruy/detect_arm.h" -#include "tensorflow/lite/experimental/ruy/detect_x86.h" -#include "tensorflow/lite/experimental/ruy/have_built_path_for.h" -#include "tensorflow/lite/experimental/ruy/platform.h" +#include "tensorflow/lite/experimental/ruy/ruy/check_macros.h" +#include "tensorflow/lite/experimental/ruy/ruy/detect_arm.h" +#include "tensorflow/lite/experimental/ruy/ruy/detect_x86.h" +#include "tensorflow/lite/experimental/ruy/ruy/have_built_path_for.h" +#include "tensorflow/lite/experimental/ruy/ruy/platform.h" namespace ruy { diff --git a/tensorflow/lite/experimental/ruy/context.h b/tensorflow/lite/experimental/ruy/ruy/context.h similarity index 86% rename from tensorflow/lite/experimental/ruy/context.h rename to tensorflow/lite/experimental/ruy/ruy/context.h index fa8d3b7e727..a2d05a9ba5c 100644 --- a/tensorflow/lite/experimental/ruy/context.h +++ b/tensorflow/lite/experimental/ruy/ruy/context.h @@ -13,19 +13,19 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_CONTEXT_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_CONTEXT_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_CONTEXT_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_CONTEXT_H_ #include #include #include -#include "tensorflow/lite/experimental/ruy/allocator.h" -#include "tensorflow/lite/experimental/ruy/path.h" -#include "tensorflow/lite/experimental/ruy/prepacked_cache.h" -#include "tensorflow/lite/experimental/ruy/thread_pool.h" -#include "tensorflow/lite/experimental/ruy/trace.h" -#include "tensorflow/lite/experimental/ruy/tune.h" +#include "tensorflow/lite/experimental/ruy/ruy/allocator.h" +#include "tensorflow/lite/experimental/ruy/ruy/path.h" +#include "tensorflow/lite/experimental/ruy/ruy/prepacked_cache.h" +#include "tensorflow/lite/experimental/ruy/ruy/thread_pool.h" +#include "tensorflow/lite/experimental/ruy/ruy/trace.h" +#include "tensorflow/lite/experimental/ruy/ruy/tune.h" namespace ruy { @@ -106,4 +106,4 @@ struct Context final { } // end namespace ruy -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_CONTEXT_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_CONTEXT_H_ diff --git a/tensorflow/lite/experimental/ruy/context_test.cc b/tensorflow/lite/experimental/ruy/ruy/context_test.cc similarity index 92% rename from tensorflow/lite/experimental/ruy/context_test.cc rename to tensorflow/lite/experimental/ruy/ruy/context_test.cc index 97d8d52dc67..bddbfcf8c55 100644 --- a/tensorflow/lite/experimental/ruy/context_test.cc +++ b/tensorflow/lite/experimental/ruy/ruy/context_test.cc @@ -13,11 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/lite/experimental/ruy/context.h" +#include "tensorflow/lite/experimental/ruy/ruy/context.h" #include -#include "tensorflow/lite/experimental/ruy/path.h" -#include "tensorflow/lite/experimental/ruy/platform.h" +#include "tensorflow/lite/experimental/ruy/ruy/path.h" +#include "tensorflow/lite/experimental/ruy/ruy/platform.h" namespace ruy { namespace { diff --git a/tensorflow/lite/experimental/ruy/cpu_cache_size.h b/tensorflow/lite/experimental/ruy/ruy/cpu_cache_size.h similarity index 89% rename from tensorflow/lite/experimental/ruy/cpu_cache_size.h rename to tensorflow/lite/experimental/ruy/ruy/cpu_cache_size.h index 16379cccfaa..95ed35ec097 100644 --- a/tensorflow/lite/experimental/ruy/cpu_cache_size.h +++ b/tensorflow/lite/experimental/ruy/ruy/cpu_cache_size.h @@ -13,11 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_CPU_CACHE_SIZE_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_CPU_CACHE_SIZE_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_CPU_CACHE_SIZE_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_CPU_CACHE_SIZE_H_ -#include "tensorflow/lite/experimental/ruy/path.h" -#include "tensorflow/lite/experimental/ruy/platform.h" +#include "tensorflow/lite/experimental/ruy/ruy/path.h" +#include "tensorflow/lite/experimental/ruy/ruy/platform.h" namespace ruy { @@ -78,4 +78,4 @@ inline int SharedDataCacheSize(Path path) { } // namespace ruy -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_CPU_CACHE_SIZE_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_CPU_CACHE_SIZE_H_ diff --git a/tensorflow/lite/experimental/ruy/detect_arm.cc b/tensorflow/lite/experimental/ruy/ruy/detect_arm.cc similarity index 97% rename from tensorflow/lite/experimental/ruy/detect_arm.cc rename to tensorflow/lite/experimental/ruy/ruy/detect_arm.cc index 5940458d82a..8f6d2c9f9fe 100644 --- a/tensorflow/lite/experimental/ruy/detect_arm.cc +++ b/tensorflow/lite/experimental/ruy/ruy/detect_arm.cc @@ -40,7 +40,7 @@ limitations under the License. * file - in actual code for (A) and in a comment for (B). */ -#include "tensorflow/lite/experimental/ruy/detect_arm.h" +#include "tensorflow/lite/experimental/ruy/ruy/detect_arm.h" #if defined __linux__ && defined __aarch64__ #include diff --git a/tensorflow/lite/experimental/ruy/detect_arm.h b/tensorflow/lite/experimental/ruy/ruy/detect_arm.h similarity index 83% rename from tensorflow/lite/experimental/ruy/detect_arm.h rename to tensorflow/lite/experimental/ruy/ruy/detect_arm.h index e843a684396..9a1542d3cce 100644 --- a/tensorflow/lite/experimental/ruy/detect_arm.h +++ b/tensorflow/lite/experimental/ruy/ruy/detect_arm.h @@ -15,8 +15,8 @@ limitations under the License. // Temporary dotprod-detection code until we can rely on getauxval. -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_DETECT_ARM_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_DETECT_ARM_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_DETECT_ARM_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_DETECT_ARM_H_ namespace ruy { @@ -26,4 +26,4 @@ bool DetectDotprod(); } // namespace ruy -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_DETECT_ARM_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_DETECT_ARM_H_ diff --git a/tensorflow/lite/experimental/ruy/detect_x86.cc b/tensorflow/lite/experimental/ruy/ruy/detect_x86.cc similarity index 97% rename from tensorflow/lite/experimental/ruy/detect_x86.cc rename to tensorflow/lite/experimental/ruy/ruy/detect_x86.cc index 3a4c1addaec..113a73c09e3 100644 --- a/tensorflow/lite/experimental/ruy/detect_x86.cc +++ b/tensorflow/lite/experimental/ruy/ruy/detect_x86.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/lite/experimental/ruy/detect_x86.h" +#include "tensorflow/lite/experimental/ruy/ruy/detect_x86.h" #include diff --git a/tensorflow/lite/experimental/ruy/detect_x86.h b/tensorflow/lite/experimental/ruy/ruy/detect_x86.h similarity index 86% rename from tensorflow/lite/experimental/ruy/detect_x86.h rename to tensorflow/lite/experimental/ruy/ruy/detect_x86.h index 0b761de6841..185dabe06a5 100644 --- a/tensorflow/lite/experimental/ruy/detect_x86.h +++ b/tensorflow/lite/experimental/ruy/ruy/detect_x86.h @@ -13,10 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_DETECT_X86_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_DETECT_X86_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_DETECT_X86_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_DETECT_X86_H_ -#include "tensorflow/lite/experimental/ruy/platform.h" +#include "tensorflow/lite/experimental/ruy/ruy/platform.h" namespace ruy { @@ -46,4 +46,4 @@ inline bool DetectCpuAvxVnni() { return false; } } // namespace ruy -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_DETECT_X86_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_DETECT_X86_H_ diff --git a/tensorflow/lite/experimental/ruy/dispatch.h b/tensorflow/lite/experimental/ruy/ruy/dispatch.h similarity index 94% rename from tensorflow/lite/experimental/ruy/dispatch.h rename to tensorflow/lite/experimental/ruy/ruy/dispatch.h index 7938769d3e7..d1e97e29b9c 100644 --- a/tensorflow/lite/experimental/ruy/dispatch.h +++ b/tensorflow/lite/experimental/ruy/ruy/dispatch.h @@ -30,31 +30,31 @@ limitations under the License. // // This file also performs some checking of invariants to catch user errors. -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_DISPATCH_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_DISPATCH_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_DISPATCH_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_DISPATCH_H_ #include #include #include // IWYU pragma: keep #include -#include "tensorflow/lite/experimental/ruy/check_macros.h" -#include "tensorflow/lite/experimental/ruy/common.h" -#include "tensorflow/lite/experimental/ruy/context.h" -#include "tensorflow/lite/experimental/ruy/internal_matrix.h" -#include "tensorflow/lite/experimental/ruy/kernel.h" -#include "tensorflow/lite/experimental/ruy/kernel_common.h" -#include "tensorflow/lite/experimental/ruy/matrix.h" -#include "tensorflow/lite/experimental/ruy/opt_set.h" -#include "tensorflow/lite/experimental/ruy/pack.h" -#include "tensorflow/lite/experimental/ruy/pack_common.h" -#include "tensorflow/lite/experimental/ruy/path.h" -#include "tensorflow/lite/experimental/ruy/profiler/instrumentation.h" -#include "tensorflow/lite/experimental/ruy/side_pair.h" -#include "tensorflow/lite/experimental/ruy/size_util.h" -#include "tensorflow/lite/experimental/ruy/spec.h" -#include "tensorflow/lite/experimental/ruy/trmul.h" -#include "tensorflow/lite/experimental/ruy/trmul_params.h" +#include "tensorflow/lite/experimental/ruy/ruy/check_macros.h" +#include "tensorflow/lite/experimental/ruy/ruy/common.h" +#include "tensorflow/lite/experimental/ruy/ruy/context.h" +#include "tensorflow/lite/experimental/ruy/ruy/internal_matrix.h" +#include "tensorflow/lite/experimental/ruy/ruy/kernel.h" +#include "tensorflow/lite/experimental/ruy/ruy/kernel_common.h" +#include "tensorflow/lite/experimental/ruy/ruy/matrix.h" +#include "tensorflow/lite/experimental/ruy/ruy/opt_set.h" +#include "tensorflow/lite/experimental/ruy/ruy/pack.h" +#include "tensorflow/lite/experimental/ruy/ruy/pack_common.h" +#include "tensorflow/lite/experimental/ruy/ruy/path.h" +#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" +#include "tensorflow/lite/experimental/ruy/ruy/side_pair.h" +#include "tensorflow/lite/experimental/ruy/ruy/size_util.h" +#include "tensorflow/lite/experimental/ruy/ruy/spec.h" +#include "tensorflow/lite/experimental/ruy/ruy/trmul.h" +#include "tensorflow/lite/experimental/ruy/ruy/trmul_params.h" namespace ruy { @@ -479,4 +479,4 @@ void DispatchMul(const Matrix& lhs, const Matrix& rhs, } // namespace ruy -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_DISPATCH_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_DISPATCH_H_ diff --git a/tensorflow/lite/experimental/ruy/example.cc b/tensorflow/lite/experimental/ruy/ruy/example.cc similarity index 98% rename from tensorflow/lite/experimental/ruy/example.cc rename to tensorflow/lite/experimental/ruy/ruy/example.cc index d53672a3a00..5d31d6c2e3e 100644 --- a/tensorflow/lite/experimental/ruy/example.cc +++ b/tensorflow/lite/experimental/ruy/ruy/example.cc @@ -16,7 +16,7 @@ limitations under the License. #include #include -#include "tensorflow/lite/experimental/ruy/ruy.h" +#include "tensorflow/lite/experimental/ruy/ruy/ruy.h" void ExampleMulFloat(ruy::Context *context) { const float lhs_data[] = {1, 2, 3, 4}; diff --git a/tensorflow/lite/experimental/ruy/example_advanced.cc b/tensorflow/lite/experimental/ruy/ruy/example_advanced.cc similarity index 97% rename from tensorflow/lite/experimental/ruy/example_advanced.cc rename to tensorflow/lite/experimental/ruy/ruy/example_advanced.cc index f4415e1cb4b..9e1dd17f86d 100644 --- a/tensorflow/lite/experimental/ruy/example_advanced.cc +++ b/tensorflow/lite/experimental/ruy/ruy/example_advanced.cc @@ -18,7 +18,7 @@ limitations under the License. #include #include -#include "tensorflow/lite/experimental/ruy/ruy_advanced.h" +#include "tensorflow/lite/experimental/ruy/ruy/ruy_advanced.h" // Simple allocator for allocating pre-packed matrices. class SimpleAllocator { diff --git a/tensorflow/lite/experimental/ruy/have_built_path_for.h b/tensorflow/lite/experimental/ruy/ruy/have_built_path_for.h similarity index 76% rename from tensorflow/lite/experimental/ruy/have_built_path_for.h rename to tensorflow/lite/experimental/ruy/ruy/have_built_path_for.h index 7ca0f4d1c40..08651facb7e 100644 --- a/tensorflow/lite/experimental/ruy/have_built_path_for.h +++ b/tensorflow/lite/experimental/ruy/ruy/have_built_path_for.h @@ -13,10 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_HAVE_BUILT_PATH_FOR_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_HAVE_BUILT_PATH_FOR_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_HAVE_BUILT_PATH_FOR_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_HAVE_BUILT_PATH_FOR_H_ -#include "tensorflow/lite/experimental/ruy/platform.h" +#include "tensorflow/lite/experimental/ruy/ruy/platform.h" namespace ruy { @@ -29,4 +29,4 @@ bool HaveBuiltPathForAvxVnni(); } // namespace ruy -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_HAVE_BUILT_PATH_FOR_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_HAVE_BUILT_PATH_FOR_H_ diff --git a/tensorflow/lite/experimental/ruy/have_built_path_for_avx2.cc b/tensorflow/lite/experimental/ruy/ruy/have_built_path_for_avx2.cc similarity index 89% rename from tensorflow/lite/experimental/ruy/have_built_path_for_avx2.cc rename to tensorflow/lite/experimental/ruy/ruy/have_built_path_for_avx2.cc index be694cea228..a9bcfbbbcfb 100644 --- a/tensorflow/lite/experimental/ruy/have_built_path_for_avx2.cc +++ b/tensorflow/lite/experimental/ruy/ruy/have_built_path_for_avx2.cc @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/lite/experimental/ruy/have_built_path_for.h" -#include "tensorflow/lite/experimental/ruy/opt_set.h" +#include "tensorflow/lite/experimental/ruy/ruy/have_built_path_for.h" +#include "tensorflow/lite/experimental/ruy/ruy/opt_set.h" namespace ruy { diff --git a/tensorflow/lite/experimental/ruy/have_built_path_for_avx512.cc b/tensorflow/lite/experimental/ruy/ruy/have_built_path_for_avx512.cc similarity index 89% rename from tensorflow/lite/experimental/ruy/have_built_path_for_avx512.cc rename to tensorflow/lite/experimental/ruy/ruy/have_built_path_for_avx512.cc index ccfea773b15..2b42cba26c9 100644 --- a/tensorflow/lite/experimental/ruy/have_built_path_for_avx512.cc +++ b/tensorflow/lite/experimental/ruy/ruy/have_built_path_for_avx512.cc @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/lite/experimental/ruy/have_built_path_for.h" -#include "tensorflow/lite/experimental/ruy/opt_set.h" +#include "tensorflow/lite/experimental/ruy/ruy/have_built_path_for.h" +#include "tensorflow/lite/experimental/ruy/ruy/opt_set.h" namespace ruy { diff --git a/tensorflow/lite/experimental/ruy/have_built_path_for_avxvnni.cc b/tensorflow/lite/experimental/ruy/ruy/have_built_path_for_avxvnni.cc similarity index 91% rename from tensorflow/lite/experimental/ruy/have_built_path_for_avxvnni.cc rename to tensorflow/lite/experimental/ruy/ruy/have_built_path_for_avxvnni.cc index e2318e67792..42f9cb668df 100644 --- a/tensorflow/lite/experimental/ruy/have_built_path_for_avxvnni.cc +++ b/tensorflow/lite/experimental/ruy/ruy/have_built_path_for_avxvnni.cc @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/lite/experimental/ruy/have_built_path_for.h" -#include "tensorflow/lite/experimental/ruy/opt_set.h" +#include "tensorflow/lite/experimental/ruy/ruy/have_built_path_for.h" +#include "tensorflow/lite/experimental/ruy/ruy/opt_set.h" namespace ruy { diff --git a/tensorflow/lite/experimental/ruy/have_built_path_for_sse42.cc b/tensorflow/lite/experimental/ruy/ruy/have_built_path_for_sse42.cc similarity index 91% rename from tensorflow/lite/experimental/ruy/have_built_path_for_sse42.cc rename to tensorflow/lite/experimental/ruy/ruy/have_built_path_for_sse42.cc index 1be687f6bd7..e7470f54520 100644 --- a/tensorflow/lite/experimental/ruy/have_built_path_for_sse42.cc +++ b/tensorflow/lite/experimental/ruy/ruy/have_built_path_for_sse42.cc @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/lite/experimental/ruy/have_built_path_for.h" -#include "tensorflow/lite/experimental/ruy/opt_set.h" +#include "tensorflow/lite/experimental/ruy/ruy/have_built_path_for.h" +#include "tensorflow/lite/experimental/ruy/ruy/opt_set.h" namespace ruy { diff --git a/tensorflow/lite/experimental/ruy/internal_matrix.h b/tensorflow/lite/experimental/ruy/ruy/internal_matrix.h similarity index 96% rename from tensorflow/lite/experimental/ruy/internal_matrix.h rename to tensorflow/lite/experimental/ruy/ruy/internal_matrix.h index 597af4757fd..cf10adf084d 100644 --- a/tensorflow/lite/experimental/ruy/internal_matrix.h +++ b/tensorflow/lite/experimental/ruy/ruy/internal_matrix.h @@ -87,18 +87,18 @@ limitations under the License. // exists is so that PMatrix is not exposed to users -- we prefer to keep the // internal matrix types hidden, even from "advanced" users. -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_INTERNAL_MATRIX_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_INTERNAL_MATRIX_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_INTERNAL_MATRIX_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_INTERNAL_MATRIX_H_ #include #include #include #include -#include "tensorflow/lite/experimental/ruy/check_macros.h" -#include "tensorflow/lite/experimental/ruy/common.h" -#include "tensorflow/lite/experimental/ruy/matrix.h" -#include "tensorflow/lite/experimental/ruy/size_util.h" +#include "tensorflow/lite/experimental/ruy/ruy/check_macros.h" +#include "tensorflow/lite/experimental/ruy/ruy/common.h" +#include "tensorflow/lite/experimental/ruy/ruy/matrix.h" +#include "tensorflow/lite/experimental/ruy/ruy/size_util.h" namespace ruy { @@ -385,4 +385,4 @@ KernelLayout ToKernelLayout() { } // namespace ruy -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_INTERNAL_MATRIX_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_INTERNAL_MATRIX_H_ diff --git a/tensorflow/lite/experimental/ruy/kernel.h b/tensorflow/lite/experimental/ruy/ruy/kernel.h similarity index 65% rename from tensorflow/lite/experimental/ruy/kernel.h rename to tensorflow/lite/experimental/ruy/ruy/kernel.h index fd470efc5de..dd9a60b8d09 100644 --- a/tensorflow/lite/experimental/ruy/kernel.h +++ b/tensorflow/lite/experimental/ruy/ruy/kernel.h @@ -13,19 +13,19 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_KERNEL_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_KERNEL_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_KERNEL_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_KERNEL_H_ -#include "tensorflow/lite/experimental/ruy/platform.h" +#include "tensorflow/lite/experimental/ruy/ruy/platform.h" // IWYU pragma: begin_exports #if RUY_PLATFORM(NEON) -#include "tensorflow/lite/experimental/ruy/kernel_arm.h" +#include "tensorflow/lite/experimental/ruy/ruy/kernel_arm.h" #elif RUY_PLATFORM(X86) -#include "tensorflow/lite/experimental/ruy/kernel_x86.h" +#include "tensorflow/lite/experimental/ruy/ruy/kernel_x86.h" #else -#include "tensorflow/lite/experimental/ruy/kernel_common.h" +#include "tensorflow/lite/experimental/ruy/ruy/kernel_common.h" #endif // IWYU pragma: end_exports -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_KERNEL_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_KERNEL_H_ diff --git a/tensorflow/lite/experimental/ruy/kernel_arm.h b/tensorflow/lite/experimental/ruy/ruy/kernel_arm.h similarity index 89% rename from tensorflow/lite/experimental/ruy/kernel_arm.h rename to tensorflow/lite/experimental/ruy/ruy/kernel_arm.h index 9493c059eb5..760f0f0b4b5 100644 --- a/tensorflow/lite/experimental/ruy/kernel_arm.h +++ b/tensorflow/lite/experimental/ruy/ruy/kernel_arm.h @@ -13,24 +13,24 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_KERNEL_ARM_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_KERNEL_ARM_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_KERNEL_ARM_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_KERNEL_ARM_H_ #include #include -#include "tensorflow/lite/experimental/ruy/common.h" -#include "tensorflow/lite/experimental/ruy/internal_matrix.h" -#include "tensorflow/lite/experimental/ruy/kernel_common.h" -#include "tensorflow/lite/experimental/ruy/matrix.h" -#include "tensorflow/lite/experimental/ruy/opt_set.h" -#include "tensorflow/lite/experimental/ruy/path.h" -#include "tensorflow/lite/experimental/ruy/platform.h" -#include "tensorflow/lite/experimental/ruy/profiler/instrumentation.h" -#include "tensorflow/lite/experimental/ruy/side_pair.h" -#include "tensorflow/lite/experimental/ruy/size_util.h" -#include "tensorflow/lite/experimental/ruy/spec.h" -#include "tensorflow/lite/experimental/ruy/tune.h" +#include "tensorflow/lite/experimental/ruy/ruy/common.h" +#include "tensorflow/lite/experimental/ruy/ruy/internal_matrix.h" +#include "tensorflow/lite/experimental/ruy/ruy/kernel_common.h" +#include "tensorflow/lite/experimental/ruy/ruy/matrix.h" +#include "tensorflow/lite/experimental/ruy/ruy/opt_set.h" +#include "tensorflow/lite/experimental/ruy/ruy/path.h" +#include "tensorflow/lite/experimental/ruy/ruy/platform.h" +#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" +#include "tensorflow/lite/experimental/ruy/ruy/side_pair.h" +#include "tensorflow/lite/experimental/ruy/ruy/size_util.h" +#include "tensorflow/lite/experimental/ruy/ruy/spec.h" +#include "tensorflow/lite/experimental/ruy/ruy/tune.h" namespace ruy { @@ -208,4 +208,4 @@ struct Kernel -#include "tensorflow/lite/experimental/ruy/common.h" -#include "tensorflow/lite/experimental/ruy/kernel.h" -#include "tensorflow/lite/experimental/ruy/opt_set.h" -#include "tensorflow/lite/experimental/ruy/platform.h" -#include "tensorflow/lite/experimental/ruy/profiler/instrumentation.h" +#include "tensorflow/lite/experimental/ruy/ruy/common.h" +#include "tensorflow/lite/experimental/ruy/ruy/kernel.h" +#include "tensorflow/lite/experimental/ruy/ruy/opt_set.h" +#include "tensorflow/lite/experimental/ruy/ruy/platform.h" +#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" namespace ruy { diff --git a/tensorflow/lite/experimental/ruy/kernel_avx2.cc b/tensorflow/lite/experimental/ruy/ruy/kernel_avx2.cc similarity index 99% rename from tensorflow/lite/experimental/ruy/kernel_avx2.cc rename to tensorflow/lite/experimental/ruy/ruy/kernel_avx2.cc index 783e52b2aee..1113469fd28 100644 --- a/tensorflow/lite/experimental/ruy/kernel_avx2.cc +++ b/tensorflow/lite/experimental/ruy/ruy/kernel_avx2.cc @@ -16,11 +16,11 @@ limitations under the License. #include #include -#include "tensorflow/lite/experimental/ruy/check_macros.h" -#include "tensorflow/lite/experimental/ruy/kernel.h" -#include "tensorflow/lite/experimental/ruy/opt_set.h" -#include "tensorflow/lite/experimental/ruy/platform.h" -#include "tensorflow/lite/experimental/ruy/profiler/instrumentation.h" +#include "tensorflow/lite/experimental/ruy/ruy/check_macros.h" +#include "tensorflow/lite/experimental/ruy/ruy/kernel.h" +#include "tensorflow/lite/experimental/ruy/ruy/opt_set.h" +#include "tensorflow/lite/experimental/ruy/ruy/platform.h" +#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" #if RUY_PLATFORM(AVX2) && RUY_OPT_ENABLED(RUY_OPT_ASM) #include // IWYU pragma: keep diff --git a/tensorflow/lite/experimental/ruy/kernel_avx512.cc b/tensorflow/lite/experimental/ruy/ruy/kernel_avx512.cc similarity index 99% rename from tensorflow/lite/experimental/ruy/kernel_avx512.cc rename to tensorflow/lite/experimental/ruy/ruy/kernel_avx512.cc index 4fe75ad3fdf..e51876fcc02 100644 --- a/tensorflow/lite/experimental/ruy/kernel_avx512.cc +++ b/tensorflow/lite/experimental/ruy/ruy/kernel_avx512.cc @@ -16,11 +16,11 @@ limitations under the License. #include #include -#include "tensorflow/lite/experimental/ruy/check_macros.h" -#include "tensorflow/lite/experimental/ruy/kernel.h" -#include "tensorflow/lite/experimental/ruy/opt_set.h" -#include "tensorflow/lite/experimental/ruy/platform.h" -#include "tensorflow/lite/experimental/ruy/profiler/instrumentation.h" +#include "tensorflow/lite/experimental/ruy/ruy/check_macros.h" +#include "tensorflow/lite/experimental/ruy/ruy/kernel.h" +#include "tensorflow/lite/experimental/ruy/ruy/opt_set.h" +#include "tensorflow/lite/experimental/ruy/ruy/platform.h" +#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" #if RUY_PLATFORM(AVX512) && RUY_OPT_ENABLED(RUY_OPT_ASM) #include // IWYU pragma: keep diff --git a/tensorflow/lite/experimental/ruy/kernel_avxvnni.cc b/tensorflow/lite/experimental/ruy/ruy/kernel_avxvnni.cc similarity index 98% rename from tensorflow/lite/experimental/ruy/kernel_avxvnni.cc rename to tensorflow/lite/experimental/ruy/ruy/kernel_avxvnni.cc index 60fcd8ed652..c868c00957b 100644 --- a/tensorflow/lite/experimental/ruy/kernel_avxvnni.cc +++ b/tensorflow/lite/experimental/ruy/ruy/kernel_avxvnni.cc @@ -16,11 +16,11 @@ limitations under the License. #include #include -#include "tensorflow/lite/experimental/ruy/check_macros.h" -#include "tensorflow/lite/experimental/ruy/kernel.h" -#include "tensorflow/lite/experimental/ruy/opt_set.h" -#include "tensorflow/lite/experimental/ruy/platform.h" -#include "tensorflow/lite/experimental/ruy/profiler/instrumentation.h" +#include "tensorflow/lite/experimental/ruy/ruy/check_macros.h" +#include "tensorflow/lite/experimental/ruy/ruy/kernel.h" +#include "tensorflow/lite/experimental/ruy/ruy/opt_set.h" +#include "tensorflow/lite/experimental/ruy/ruy/platform.h" +#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" #if RUY_PLATFORM(AVX_VNNI) && RUY_OPT_ENABLED(RUY_OPT_ASM) #include // IWYU pragma: keep diff --git a/tensorflow/lite/experimental/ruy/kernel_common.h b/tensorflow/lite/experimental/ruy/ruy/kernel_common.h similarity index 95% rename from tensorflow/lite/experimental/ruy/kernel_common.h rename to tensorflow/lite/experimental/ruy/ruy/kernel_common.h index 179a72b8460..c1721b81869 100644 --- a/tensorflow/lite/experimental/ruy/kernel_common.h +++ b/tensorflow/lite/experimental/ruy/ruy/kernel_common.h @@ -13,25 +13,25 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_KERNEL_COMMON_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_KERNEL_COMMON_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_KERNEL_COMMON_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_KERNEL_COMMON_H_ #include #include #include -#include "tensorflow/lite/experimental/ruy/check_macros.h" -#include "tensorflow/lite/experimental/ruy/common.h" -#include "tensorflow/lite/experimental/ruy/internal_matrix.h" -#include "tensorflow/lite/experimental/ruy/matrix.h" -#include "tensorflow/lite/experimental/ruy/opt_set.h" -#include "tensorflow/lite/experimental/ruy/path.h" -#include "tensorflow/lite/experimental/ruy/platform.h" -#include "tensorflow/lite/experimental/ruy/profiler/instrumentation.h" -#include "tensorflow/lite/experimental/ruy/side_pair.h" -#include "tensorflow/lite/experimental/ruy/size_util.h" -#include "tensorflow/lite/experimental/ruy/spec.h" -#include "tensorflow/lite/experimental/ruy/tune.h" +#include "tensorflow/lite/experimental/ruy/ruy/check_macros.h" +#include "tensorflow/lite/experimental/ruy/ruy/common.h" +#include "tensorflow/lite/experimental/ruy/ruy/internal_matrix.h" +#include "tensorflow/lite/experimental/ruy/ruy/matrix.h" +#include "tensorflow/lite/experimental/ruy/ruy/opt_set.h" +#include "tensorflow/lite/experimental/ruy/ruy/path.h" +#include "tensorflow/lite/experimental/ruy/ruy/platform.h" +#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" +#include "tensorflow/lite/experimental/ruy/ruy/side_pair.h" +#include "tensorflow/lite/experimental/ruy/ruy/size_util.h" +#include "tensorflow/lite/experimental/ruy/ruy/spec.h" +#include "tensorflow/lite/experimental/ruy/ruy/tune.h" namespace ruy { @@ -478,4 +478,4 @@ struct KernelParamsFloat {}; } // namespace ruy -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_KERNEL_COMMON_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_KERNEL_COMMON_H_ diff --git a/tensorflow/lite/experimental/ruy/kernel_sse42.cc b/tensorflow/lite/experimental/ruy/ruy/kernel_sse42.cc similarity index 98% rename from tensorflow/lite/experimental/ruy/kernel_sse42.cc rename to tensorflow/lite/experimental/ruy/ruy/kernel_sse42.cc index c312cb3f641..46a6d045e6a 100644 --- a/tensorflow/lite/experimental/ruy/kernel_sse42.cc +++ b/tensorflow/lite/experimental/ruy/ruy/kernel_sse42.cc @@ -16,11 +16,11 @@ limitations under the License. #include #include -#include "tensorflow/lite/experimental/ruy/check_macros.h" -#include "tensorflow/lite/experimental/ruy/kernel.h" -#include "tensorflow/lite/experimental/ruy/opt_set.h" -#include "tensorflow/lite/experimental/ruy/platform.h" -#include "tensorflow/lite/experimental/ruy/profiler/instrumentation.h" +#include "tensorflow/lite/experimental/ruy/ruy/check_macros.h" +#include "tensorflow/lite/experimental/ruy/ruy/kernel.h" +#include "tensorflow/lite/experimental/ruy/ruy/opt_set.h" +#include "tensorflow/lite/experimental/ruy/ruy/platform.h" +#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" #if RUY_PLATFORM(SSE42) && RUY_OPT_ENABLED(RUY_OPT_ASM) #include // IWYU pragma: keep diff --git a/tensorflow/lite/experimental/ruy/kernel_x86.h b/tensorflow/lite/experimental/ruy/ruy/kernel_x86.h similarity index 92% rename from tensorflow/lite/experimental/ruy/kernel_x86.h rename to tensorflow/lite/experimental/ruy/ruy/kernel_x86.h index 51a684e077b..f79f70ab88c 100644 --- a/tensorflow/lite/experimental/ruy/kernel_x86.h +++ b/tensorflow/lite/experimental/ruy/ruy/kernel_x86.h @@ -13,20 +13,20 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_KERNEL_X86_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_KERNEL_X86_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_KERNEL_X86_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_KERNEL_X86_H_ #include -#include "tensorflow/lite/experimental/ruy/common.h" -#include "tensorflow/lite/experimental/ruy/internal_matrix.h" -#include "tensorflow/lite/experimental/ruy/kernel_common.h" -#include "tensorflow/lite/experimental/ruy/matrix.h" -#include "tensorflow/lite/experimental/ruy/opt_set.h" -#include "tensorflow/lite/experimental/ruy/path.h" -#include "tensorflow/lite/experimental/ruy/platform.h" -#include "tensorflow/lite/experimental/ruy/spec.h" -#include "tensorflow/lite/experimental/ruy/tune.h" +#include "tensorflow/lite/experimental/ruy/ruy/common.h" +#include "tensorflow/lite/experimental/ruy/ruy/internal_matrix.h" +#include "tensorflow/lite/experimental/ruy/ruy/kernel_common.h" +#include "tensorflow/lite/experimental/ruy/ruy/matrix.h" +#include "tensorflow/lite/experimental/ruy/ruy/opt_set.h" +#include "tensorflow/lite/experimental/ruy/ruy/path.h" +#include "tensorflow/lite/experimental/ruy/ruy/platform.h" +#include "tensorflow/lite/experimental/ruy/ruy/spec.h" +#include "tensorflow/lite/experimental/ruy/ruy/tune.h" namespace ruy { @@ -219,4 +219,4 @@ struct Kernel> { } // namespace ruy -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_KERNEL_X86_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_KERNEL_X86_H_ diff --git a/tensorflow/lite/experimental/ruy/matrix.h b/tensorflow/lite/experimental/ruy/ruy/matrix.h similarity index 96% rename from tensorflow/lite/experimental/ruy/matrix.h rename to tensorflow/lite/experimental/ruy/ruy/matrix.h index 978714c353e..a76f32136c6 100644 --- a/tensorflow/lite/experimental/ruy/matrix.h +++ b/tensorflow/lite/experimental/ruy/ruy/matrix.h @@ -13,14 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_MATRIX_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_MATRIX_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_MATRIX_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_MATRIX_H_ #include #include // IWYU pragma: keep #include -#include "tensorflow/lite/experimental/ruy/check_macros.h" +#include "tensorflow/lite/experimental/ruy/ruy/check_macros.h" namespace ruy { @@ -179,4 +179,4 @@ constexpr int FixedKernelLayout::kRows; } // namespace ruy -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_MATRIX_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_MATRIX_H_ diff --git a/tensorflow/lite/experimental/ruy/opt_set.h b/tensorflow/lite/experimental/ruy/ruy/opt_set.h similarity index 90% rename from tensorflow/lite/experimental/ruy/opt_set.h rename to tensorflow/lite/experimental/ruy/ruy/opt_set.h index d082adece9c..fef0107ed01 100644 --- a/tensorflow/lite/experimental/ruy/opt_set.h +++ b/tensorflow/lite/experimental/ruy/ruy/opt_set.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_OPT_SET_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_OPT_SET_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_OPT_SET_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_OPT_SET_H_ // RUY_OPT_SET is a compile-time API that Ruy provides for enabling/disabling // certain optimizations. It should be used by defining that macro on the @@ -48,4 +48,4 @@ limitations under the License. #define RUY_OPT_ENABLED(ruy_opt) ((RUY_OPT_SET & ruy_opt) != 0) -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_OPT_SET_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_OPT_SET_H_ diff --git a/tensorflow/lite/experimental/ruy/pack.h b/tensorflow/lite/experimental/ruy/ruy/pack.h similarity index 91% rename from tensorflow/lite/experimental/ruy/pack.h rename to tensorflow/lite/experimental/ruy/ruy/pack.h index 61008c23605..96040aa1039 100644 --- a/tensorflow/lite/experimental/ruy/pack.h +++ b/tensorflow/lite/experimental/ruy/ruy/pack.h @@ -80,19 +80,19 @@ limitations under the License. // column sums for quantization (and never row sums, since the LHS is // transposed). -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_PACK_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_PACK_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PACK_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PACK_H_ -#include "tensorflow/lite/experimental/ruy/platform.h" +#include "tensorflow/lite/experimental/ruy/ruy/platform.h" // IWYU pragma: begin_exports #if RUY_PLATFORM(NEON) -#include "tensorflow/lite/experimental/ruy/pack_arm.h" +#include "tensorflow/lite/experimental/ruy/ruy/pack_arm.h" #elif RUY_PLATFORM(X86) -#include "tensorflow/lite/experimental/ruy/pack_x86.h" +#include "tensorflow/lite/experimental/ruy/ruy/pack_x86.h" #else -#include "tensorflow/lite/experimental/ruy/pack_common.h" +#include "tensorflow/lite/experimental/ruy/ruy/pack_common.h" #endif // IWYU pragma: end_exports -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_PACK_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PACK_H_ diff --git a/tensorflow/lite/experimental/ruy/pack_arm.cc b/tensorflow/lite/experimental/ruy/ruy/pack_arm.cc similarity index 99% rename from tensorflow/lite/experimental/ruy/pack_arm.cc rename to tensorflow/lite/experimental/ruy/ruy/pack_arm.cc index ec30d0b3b65..52b55a57cc6 100644 --- a/tensorflow/lite/experimental/ruy/pack_arm.cc +++ b/tensorflow/lite/experimental/ruy/ruy/pack_arm.cc @@ -14,11 +14,11 @@ limitations under the License. ==============================================================================*/ #include -#include "tensorflow/lite/experimental/ruy/common.h" -#include "tensorflow/lite/experimental/ruy/opt_set.h" -#include "tensorflow/lite/experimental/ruy/pack.h" -#include "tensorflow/lite/experimental/ruy/platform.h" -#include "tensorflow/lite/experimental/ruy/profiler/instrumentation.h" +#include "tensorflow/lite/experimental/ruy/ruy/common.h" +#include "tensorflow/lite/experimental/ruy/ruy/opt_set.h" +#include "tensorflow/lite/experimental/ruy/ruy/pack.h" +#include "tensorflow/lite/experimental/ruy/ruy/platform.h" +#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" namespace ruy { diff --git a/tensorflow/lite/experimental/ruy/pack_arm.h b/tensorflow/lite/experimental/ruy/ruy/pack_arm.h similarity index 96% rename from tensorflow/lite/experimental/ruy/pack_arm.h rename to tensorflow/lite/experimental/ruy/ruy/pack_arm.h index e2c538a6140..f4691d66fcb 100644 --- a/tensorflow/lite/experimental/ruy/pack_arm.h +++ b/tensorflow/lite/experimental/ruy/ruy/pack_arm.h @@ -80,22 +80,22 @@ limitations under the License. // column sums for quantization (and never row sums, since the LHS is // transposed). -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_PACK_ARM_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_PACK_ARM_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PACK_ARM_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PACK_ARM_H_ #include #include -#include "tensorflow/lite/experimental/ruy/check_macros.h" -#include "tensorflow/lite/experimental/ruy/common.h" -#include "tensorflow/lite/experimental/ruy/internal_matrix.h" -#include "tensorflow/lite/experimental/ruy/matrix.h" -#include "tensorflow/lite/experimental/ruy/opt_set.h" -#include "tensorflow/lite/experimental/ruy/pack_common.h" -#include "tensorflow/lite/experimental/ruy/path.h" -#include "tensorflow/lite/experimental/ruy/platform.h" -#include "tensorflow/lite/experimental/ruy/profiler/instrumentation.h" -#include "tensorflow/lite/experimental/ruy/tune.h" +#include "tensorflow/lite/experimental/ruy/ruy/check_macros.h" +#include "tensorflow/lite/experimental/ruy/ruy/common.h" +#include "tensorflow/lite/experimental/ruy/ruy/internal_matrix.h" +#include "tensorflow/lite/experimental/ruy/ruy/matrix.h" +#include "tensorflow/lite/experimental/ruy/ruy/opt_set.h" +#include "tensorflow/lite/experimental/ruy/ruy/pack_common.h" +#include "tensorflow/lite/experimental/ruy/ruy/path.h" +#include "tensorflow/lite/experimental/ruy/ruy/platform.h" +#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" +#include "tensorflow/lite/experimental/ruy/ruy/tune.h" namespace ruy { @@ -494,4 +494,4 @@ struct PackImpl, float, } // namespace ruy -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_PACK_ARM_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PACK_ARM_H_ diff --git a/tensorflow/lite/experimental/ruy/pack_avx2.cc b/tensorflow/lite/experimental/ruy/ruy/pack_avx2.cc similarity index 98% rename from tensorflow/lite/experimental/ruy/pack_avx2.cc rename to tensorflow/lite/experimental/ruy/ruy/pack_avx2.cc index 417eee6ae46..3575943e50e 100644 --- a/tensorflow/lite/experimental/ruy/pack_avx2.cc +++ b/tensorflow/lite/experimental/ruy/ruy/pack_avx2.cc @@ -16,13 +16,13 @@ limitations under the License. #include #include -#include "tensorflow/lite/experimental/ruy/check_macros.h" -#include "tensorflow/lite/experimental/ruy/matrix.h" -#include "tensorflow/lite/experimental/ruy/opt_set.h" -#include "tensorflow/lite/experimental/ruy/pack.h" -#include "tensorflow/lite/experimental/ruy/path.h" -#include "tensorflow/lite/experimental/ruy/platform.h" -#include "tensorflow/lite/experimental/ruy/profiler/instrumentation.h" +#include "tensorflow/lite/experimental/ruy/ruy/check_macros.h" +#include "tensorflow/lite/experimental/ruy/ruy/matrix.h" +#include "tensorflow/lite/experimental/ruy/ruy/opt_set.h" +#include "tensorflow/lite/experimental/ruy/ruy/pack.h" +#include "tensorflow/lite/experimental/ruy/ruy/path.h" +#include "tensorflow/lite/experimental/ruy/ruy/platform.h" +#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" #if RUY_PLATFORM(AVX2) && RUY_OPT_ENABLED(RUY_OPT_INTRINSICS) #include // IWYU pragma: keep diff --git a/tensorflow/lite/experimental/ruy/pack_avx512.cc b/tensorflow/lite/experimental/ruy/ruy/pack_avx512.cc similarity index 98% rename from tensorflow/lite/experimental/ruy/pack_avx512.cc rename to tensorflow/lite/experimental/ruy/ruy/pack_avx512.cc index beaaf5cddfa..d5636572eed 100644 --- a/tensorflow/lite/experimental/ruy/pack_avx512.cc +++ b/tensorflow/lite/experimental/ruy/ruy/pack_avx512.cc @@ -16,13 +16,13 @@ limitations under the License. #include #include -#include "tensorflow/lite/experimental/ruy/check_macros.h" -#include "tensorflow/lite/experimental/ruy/matrix.h" -#include "tensorflow/lite/experimental/ruy/opt_set.h" -#include "tensorflow/lite/experimental/ruy/pack.h" -#include "tensorflow/lite/experimental/ruy/path.h" -#include "tensorflow/lite/experimental/ruy/platform.h" -#include "tensorflow/lite/experimental/ruy/profiler/instrumentation.h" +#include "tensorflow/lite/experimental/ruy/ruy/check_macros.h" +#include "tensorflow/lite/experimental/ruy/ruy/matrix.h" +#include "tensorflow/lite/experimental/ruy/ruy/opt_set.h" +#include "tensorflow/lite/experimental/ruy/ruy/pack.h" +#include "tensorflow/lite/experimental/ruy/ruy/path.h" +#include "tensorflow/lite/experimental/ruy/ruy/platform.h" +#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" #if RUY_PLATFORM(AVX512) && RUY_OPT_ENABLED(RUY_OPT_INTRINSICS) #include // IWYU pragma: keep diff --git a/tensorflow/lite/experimental/ruy/pack_avxvnni.cc b/tensorflow/lite/experimental/ruy/ruy/pack_avxvnni.cc similarity index 97% rename from tensorflow/lite/experimental/ruy/pack_avxvnni.cc rename to tensorflow/lite/experimental/ruy/ruy/pack_avxvnni.cc index fc892327d73..49b4a1f978c 100644 --- a/tensorflow/lite/experimental/ruy/pack_avxvnni.cc +++ b/tensorflow/lite/experimental/ruy/ruy/pack_avxvnni.cc @@ -16,13 +16,13 @@ limitations under the License. #include #include -#include "tensorflow/lite/experimental/ruy/check_macros.h" -#include "tensorflow/lite/experimental/ruy/matrix.h" -#include "tensorflow/lite/experimental/ruy/opt_set.h" -#include "tensorflow/lite/experimental/ruy/pack.h" -#include "tensorflow/lite/experimental/ruy/path.h" -#include "tensorflow/lite/experimental/ruy/platform.h" -#include "tensorflow/lite/experimental/ruy/profiler/instrumentation.h" +#include "tensorflow/lite/experimental/ruy/ruy/check_macros.h" +#include "tensorflow/lite/experimental/ruy/ruy/matrix.h" +#include "tensorflow/lite/experimental/ruy/ruy/opt_set.h" +#include "tensorflow/lite/experimental/ruy/ruy/pack.h" +#include "tensorflow/lite/experimental/ruy/ruy/path.h" +#include "tensorflow/lite/experimental/ruy/ruy/platform.h" +#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" #if RUY_PLATFORM(AVX_VNNI) && RUY_OPT_ENABLED(RUY_OPT_INTRINSICS) #include // IWYU pragma: keep diff --git a/tensorflow/lite/experimental/ruy/pack_common.h b/tensorflow/lite/experimental/ruy/ruy/pack_common.h similarity index 92% rename from tensorflow/lite/experimental/ruy/pack_common.h rename to tensorflow/lite/experimental/ruy/ruy/pack_common.h index 74960dfbd50..91d47af8a5f 100644 --- a/tensorflow/lite/experimental/ruy/pack_common.h +++ b/tensorflow/lite/experimental/ruy/ruy/pack_common.h @@ -80,20 +80,20 @@ limitations under the License. // column sums for quantization (and never row sums, since the LHS is // transposed). -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_PACK_COMMON_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_PACK_COMMON_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PACK_COMMON_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PACK_COMMON_H_ #include -#include "tensorflow/lite/experimental/ruy/check_macros.h" -#include "tensorflow/lite/experimental/ruy/common.h" -#include "tensorflow/lite/experimental/ruy/internal_matrix.h" -#include "tensorflow/lite/experimental/ruy/matrix.h" -#include "tensorflow/lite/experimental/ruy/opt_set.h" -#include "tensorflow/lite/experimental/ruy/path.h" -#include "tensorflow/lite/experimental/ruy/platform.h" -#include "tensorflow/lite/experimental/ruy/profiler/instrumentation.h" -#include "tensorflow/lite/experimental/ruy/tune.h" +#include "tensorflow/lite/experimental/ruy/ruy/check_macros.h" +#include "tensorflow/lite/experimental/ruy/ruy/common.h" +#include "tensorflow/lite/experimental/ruy/ruy/internal_matrix.h" +#include "tensorflow/lite/experimental/ruy/ruy/matrix.h" +#include "tensorflow/lite/experimental/ruy/ruy/opt_set.h" +#include "tensorflow/lite/experimental/ruy/ruy/path.h" +#include "tensorflow/lite/experimental/ruy/ruy/platform.h" +#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" +#include "tensorflow/lite/experimental/ruy/ruy/tune.h" namespace ruy { @@ -243,4 +243,4 @@ void RunPack(Tuning tuning, const DMatrix& src_matrix, PMatrix* packed_matrix, } // namespace ruy -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_PACK_COMMON_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PACK_COMMON_H_ diff --git a/tensorflow/lite/experimental/ruy/pack_sse42.cc b/tensorflow/lite/experimental/ruy/ruy/pack_sse42.cc similarity index 97% rename from tensorflow/lite/experimental/ruy/pack_sse42.cc rename to tensorflow/lite/experimental/ruy/ruy/pack_sse42.cc index 9be7b8d0bc1..ecd1cf83c6d 100644 --- a/tensorflow/lite/experimental/ruy/pack_sse42.cc +++ b/tensorflow/lite/experimental/ruy/ruy/pack_sse42.cc @@ -16,13 +16,13 @@ limitations under the License. #include #include -#include "tensorflow/lite/experimental/ruy/check_macros.h" -#include "tensorflow/lite/experimental/ruy/matrix.h" -#include "tensorflow/lite/experimental/ruy/opt_set.h" -#include "tensorflow/lite/experimental/ruy/pack.h" -#include "tensorflow/lite/experimental/ruy/path.h" -#include "tensorflow/lite/experimental/ruy/platform.h" -#include "tensorflow/lite/experimental/ruy/profiler/instrumentation.h" +#include "tensorflow/lite/experimental/ruy/ruy/check_macros.h" +#include "tensorflow/lite/experimental/ruy/ruy/matrix.h" +#include "tensorflow/lite/experimental/ruy/ruy/opt_set.h" +#include "tensorflow/lite/experimental/ruy/ruy/pack.h" +#include "tensorflow/lite/experimental/ruy/ruy/path.h" +#include "tensorflow/lite/experimental/ruy/ruy/platform.h" +#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" #if RUY_PLATFORM(SSE42) && RUY_OPT_ENABLED(RUY_OPT_INTRINSICS) #include // IWYU pragma: keep diff --git a/tensorflow/lite/experimental/ruy/pack_x86.h b/tensorflow/lite/experimental/ruy/ruy/pack_x86.h similarity index 96% rename from tensorflow/lite/experimental/ruy/pack_x86.h rename to tensorflow/lite/experimental/ruy/ruy/pack_x86.h index 7ac27141ca2..8bdc88e5763 100644 --- a/tensorflow/lite/experimental/ruy/pack_x86.h +++ b/tensorflow/lite/experimental/ruy/ruy/pack_x86.h @@ -80,23 +80,23 @@ limitations under the License. // column sums for quantization (and never row sums, since the LHS is // transposed). -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_PACK_X86_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_PACK_X86_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PACK_X86_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PACK_X86_H_ #include #include #include -#include "tensorflow/lite/experimental/ruy/check_macros.h" -#include "tensorflow/lite/experimental/ruy/common.h" -#include "tensorflow/lite/experimental/ruy/internal_matrix.h" -#include "tensorflow/lite/experimental/ruy/matrix.h" -#include "tensorflow/lite/experimental/ruy/opt_set.h" -#include "tensorflow/lite/experimental/ruy/pack_common.h" -#include "tensorflow/lite/experimental/ruy/path.h" -#include "tensorflow/lite/experimental/ruy/platform.h" -#include "tensorflow/lite/experimental/ruy/profiler/instrumentation.h" -#include "tensorflow/lite/experimental/ruy/tune.h" +#include "tensorflow/lite/experimental/ruy/ruy/check_macros.h" +#include "tensorflow/lite/experimental/ruy/ruy/common.h" +#include "tensorflow/lite/experimental/ruy/ruy/internal_matrix.h" +#include "tensorflow/lite/experimental/ruy/ruy/matrix.h" +#include "tensorflow/lite/experimental/ruy/ruy/opt_set.h" +#include "tensorflow/lite/experimental/ruy/ruy/pack_common.h" +#include "tensorflow/lite/experimental/ruy/ruy/path.h" +#include "tensorflow/lite/experimental/ruy/ruy/platform.h" +#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" +#include "tensorflow/lite/experimental/ruy/ruy/tune.h" namespace ruy { @@ -458,4 +458,4 @@ struct PackImpl, } // namespace ruy -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_PACK_X86_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PACK_X86_H_ diff --git a/tensorflow/lite/experimental/ruy/path.h b/tensorflow/lite/experimental/ruy/ruy/path.h similarity index 95% rename from tensorflow/lite/experimental/ruy/path.h rename to tensorflow/lite/experimental/ruy/ruy/path.h index d0c7095dbef..5973b8040a7 100644 --- a/tensorflow/lite/experimental/ruy/path.h +++ b/tensorflow/lite/experimental/ruy/ruy/path.h @@ -13,13 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_PATH_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_PATH_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PATH_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PATH_H_ #include -#include "tensorflow/lite/experimental/ruy/platform.h" -#include "tensorflow/lite/experimental/ruy/size_util.h" +#include "tensorflow/lite/experimental/ruy/ruy/platform.h" +#include "tensorflow/lite/experimental/ruy/ruy/size_util.h" namespace ruy { @@ -159,4 +159,4 @@ constexpr Path kAllPaths = Path::kReference | Path::kStandardCpp; } // namespace ruy -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_PATH_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PATH_H_ diff --git a/tensorflow/lite/experimental/ruy/platform.h b/tensorflow/lite/experimental/ruy/ruy/platform.h similarity index 96% rename from tensorflow/lite/experimental/ruy/platform.h rename to tensorflow/lite/experimental/ruy/ruy/platform.h index d86c9576e5c..d6e86e6a792 100644 --- a/tensorflow/lite/experimental/ruy/platform.h +++ b/tensorflow/lite/experimental/ruy/ruy/platform.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_PLATFORM_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_PLATFORM_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PLATFORM_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PLATFORM_H_ #ifdef __ANDROID_NDK__ #include @@ -153,4 +153,4 @@ limitations under the License. #define RUY_DONOTUSEDIRECTLY_EMSCRIPTEN 0 #endif -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_PLATFORM_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PLATFORM_H_ diff --git a/tensorflow/lite/experimental/ruy/pmu.cc b/tensorflow/lite/experimental/ruy/ruy/pmu.cc similarity index 98% rename from tensorflow/lite/experimental/ruy/pmu.cc rename to tensorflow/lite/experimental/ruy/ruy/pmu.cc index 86c137bbf6a..6405aa15e6a 100644 --- a/tensorflow/lite/experimental/ruy/pmu.cc +++ b/tensorflow/lite/experimental/ruy/ruy/pmu.cc @@ -13,9 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/lite/experimental/ruy/pmu.h" +#include "tensorflow/lite/experimental/ruy/ruy/pmu.h" -#include "tensorflow/lite/experimental/ruy/check_macros.h" +#include "tensorflow/lite/experimental/ruy/ruy/check_macros.h" #ifdef __linux__ #include diff --git a/tensorflow/lite/experimental/ruy/pmu.h b/tensorflow/lite/experimental/ruy/ruy/pmu.h similarity index 87% rename from tensorflow/lite/experimental/ruy/pmu.h rename to tensorflow/lite/experimental/ruy/ruy/pmu.h index 03f0cb7d878..721c1d5f1cc 100644 --- a/tensorflow/lite/experimental/ruy/pmu.h +++ b/tensorflow/lite/experimental/ruy/ruy/pmu.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_PMU_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_PMU_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PMU_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PMU_H_ namespace ruy { @@ -41,4 +41,4 @@ class PmuEvents { } // namespace ruy -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_PMU_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PMU_H_ diff --git a/tensorflow/lite/experimental/ruy/prepack.h b/tensorflow/lite/experimental/ruy/ruy/prepack.h similarity index 80% rename from tensorflow/lite/experimental/ruy/prepack.h rename to tensorflow/lite/experimental/ruy/ruy/prepack.h index 0f2b6c4d2b4..794b8df7b4d 100644 --- a/tensorflow/lite/experimental/ruy/prepack.h +++ b/tensorflow/lite/experimental/ruy/ruy/prepack.h @@ -15,24 +15,24 @@ limitations under the License. // Implementation of low-level pre-packing API. -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_PREPACK_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_PREPACK_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PREPACK_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PREPACK_H_ #include #include -#include "tensorflow/lite/experimental/ruy/check_macros.h" -#include "tensorflow/lite/experimental/ruy/context.h" -#include "tensorflow/lite/experimental/ruy/dispatch.h" -#include "tensorflow/lite/experimental/ruy/internal_matrix.h" -#include "tensorflow/lite/experimental/ruy/matrix.h" -#include "tensorflow/lite/experimental/ruy/path.h" -#include "tensorflow/lite/experimental/ruy/profiler/instrumentation.h" -#include "tensorflow/lite/experimental/ruy/side_pair.h" -#include "tensorflow/lite/experimental/ruy/spec.h" -#include "tensorflow/lite/experimental/ruy/trmul.h" -#include "tensorflow/lite/experimental/ruy/trmul_params.h" -#include "tensorflow/lite/experimental/ruy/tune.h" +#include "tensorflow/lite/experimental/ruy/ruy/check_macros.h" +#include "tensorflow/lite/experimental/ruy/ruy/context.h" +#include "tensorflow/lite/experimental/ruy/ruy/dispatch.h" +#include "tensorflow/lite/experimental/ruy/ruy/internal_matrix.h" +#include "tensorflow/lite/experimental/ruy/ruy/matrix.h" +#include "tensorflow/lite/experimental/ruy/ruy/path.h" +#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" +#include "tensorflow/lite/experimental/ruy/ruy/side_pair.h" +#include "tensorflow/lite/experimental/ruy/ruy/spec.h" +#include "tensorflow/lite/experimental/ruy/ruy/trmul.h" +#include "tensorflow/lite/experimental/ruy/ruy/trmul_params.h" +#include "tensorflow/lite/experimental/ruy/ruy/tune.h" namespace ruy { @@ -105,4 +105,4 @@ void MulWithPrepackedInternal(const Matrix& lhs, } // namespace ruy -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_PREPACK_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PREPACK_H_ diff --git a/tensorflow/lite/experimental/ruy/prepacked_cache.cc b/tensorflow/lite/experimental/ruy/ruy/prepacked_cache.cc similarity index 87% rename from tensorflow/lite/experimental/ruy/prepacked_cache.cc rename to tensorflow/lite/experimental/ruy/ruy/prepacked_cache.cc index c3d0405d583..da683020169 100644 --- a/tensorflow/lite/experimental/ruy/prepacked_cache.cc +++ b/tensorflow/lite/experimental/ruy/ruy/prepacked_cache.cc @@ -13,10 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/lite/experimental/ruy/prepacked_cache.h" +#include "tensorflow/lite/experimental/ruy/ruy/prepacked_cache.h" -#include "tensorflow/lite/experimental/ruy/matrix.h" -#include "tensorflow/lite/experimental/ruy/profiler/instrumentation.h" +#include + +#include "tensorflow/lite/experimental/ruy/ruy/matrix.h" +#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" namespace ruy { @@ -30,7 +32,9 @@ CacheIterator PrepackedCache::FindAndUpdate(const CacheKey &key) { const TimePoint time = CacheNow(); itr->second.second = time; } - return itr; + // std::move() is required in the MSVC STL when NDEBUG is not set, and has no + // effect in libc++. + return std::move(itr); // NOLINT } void PrepackedCache::Insert(const CacheKey &key, diff --git a/tensorflow/lite/experimental/ruy/prepacked_cache.h b/tensorflow/lite/experimental/ruy/ruy/prepacked_cache.h similarity index 92% rename from tensorflow/lite/experimental/ruy/prepacked_cache.h rename to tensorflow/lite/experimental/ruy/ruy/prepacked_cache.h index 5d38ddbbb0a..f2ee15559c7 100644 --- a/tensorflow/lite/experimental/ruy/prepacked_cache.h +++ b/tensorflow/lite/experimental/ruy/ruy/prepacked_cache.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_PREPACKED_CACHE_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_PREPACKED_CACHE_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PREPACKED_CACHE_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PREPACKED_CACHE_H_ #include #include @@ -22,9 +22,9 @@ limitations under the License. #include #include -#include "tensorflow/lite/experimental/ruy/allocator.h" -#include "tensorflow/lite/experimental/ruy/matrix.h" -#include "tensorflow/lite/experimental/ruy/time.h" +#include "tensorflow/lite/experimental/ruy/ruy/allocator.h" +#include "tensorflow/lite/experimental/ruy/ruy/matrix.h" +#include "tensorflow/lite/experimental/ruy/ruy/time.h" namespace ruy { @@ -127,4 +127,4 @@ class PrepackedCache { } // namespace ruy -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_PREPACKED_CACHE_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PREPACKED_CACHE_H_ diff --git a/tensorflow/lite/experimental/ruy/prepacked_cache_test.cc b/tensorflow/lite/experimental/ruy/ruy/prepacked_cache_test.cc similarity index 97% rename from tensorflow/lite/experimental/ruy/prepacked_cache_test.cc rename to tensorflow/lite/experimental/ruy/ruy/prepacked_cache_test.cc index fc5d70f25a4..453190a3b88 100644 --- a/tensorflow/lite/experimental/ruy/prepacked_cache_test.cc +++ b/tensorflow/lite/experimental/ruy/ruy/prepacked_cache_test.cc @@ -13,13 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/lite/experimental/ruy/prepacked_cache.h" +#include "tensorflow/lite/experimental/ruy/ruy/prepacked_cache.h" #include // NOLINT(build/c++11) #include -#include "tensorflow/lite/experimental/ruy/ruy.h" -#include "tensorflow/lite/experimental/ruy/time.h" +#include "tensorflow/lite/experimental/ruy/ruy/ruy.h" +#include "tensorflow/lite/experimental/ruy/ruy/time.h" namespace ruy { namespace { diff --git a/tensorflow/lite/experimental/ruy/profiler/BUILD b/tensorflow/lite/experimental/ruy/ruy/profiler/BUILD similarity index 100% rename from tensorflow/lite/experimental/ruy/profiler/BUILD rename to tensorflow/lite/experimental/ruy/ruy/profiler/BUILD diff --git a/tensorflow/lite/experimental/ruy/profiler/README.md b/tensorflow/lite/experimental/ruy/ruy/profiler/README.md similarity index 98% rename from tensorflow/lite/experimental/ruy/profiler/README.md rename to tensorflow/lite/experimental/ruy/ruy/profiler/README.md index 28cc55020e5..8d7902566b3 100644 --- a/tensorflow/lite/experimental/ruy/profiler/README.md +++ b/tensorflow/lite/experimental/ruy/ruy/profiler/README.md @@ -133,7 +133,7 @@ But also the following advantages: The philosophy underlying this profiler is that software performance depends on software engineers profiling often, and a key factor limiting that in practice is the difficulty or cumbersome aspects of profiling with more serious profilers -such as Linux's "perf", espectially in embedded/mobile development: multiple +such as Linux's "perf", especially in embedded/mobile development: multiple command lines are involved to copy symbol files to devices, retrieve profile data from the device, etc. In that context, it is useful to make profiling as easy as benchmarking, even on embedded targets, even if the price to pay for diff --git a/tensorflow/lite/experimental/ruy/profiler/instrumentation.cc b/tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.cc similarity index 97% rename from tensorflow/lite/experimental/ruy/profiler/instrumentation.cc rename to tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.cc index bad6a22d3b3..b7c330c04bd 100644 --- a/tensorflow/lite/experimental/ruy/profiler/instrumentation.cc +++ b/tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/lite/experimental/ruy/profiler/instrumentation.h" +#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" #ifdef RUY_PROFILER diff --git a/tensorflow/lite/experimental/ruy/profiler/instrumentation.h b/tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h similarity index 96% rename from tensorflow/lite/experimental/ruy/profiler/instrumentation.h rename to tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h index cb0e70297d7..a9046d465af 100644 --- a/tensorflow/lite/experimental/ruy/profiler/instrumentation.h +++ b/tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_PROFILER_INSTRUMENTATION_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_PROFILER_INSTRUMENTATION_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PROFILER_INSTRUMENTATION_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PROFILER_INSTRUMENTATION_H_ #ifdef RUY_PROFILER #include @@ -200,4 +200,4 @@ class ScopeLabel { } // namespace profiler } // namespace ruy -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_PROFILER_INSTRUMENTATION_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PROFILER_INSTRUMENTATION_H_ diff --git a/tensorflow/lite/experimental/ruy/profiler/profiler.cc b/tensorflow/lite/experimental/ruy/ruy/profiler/profiler.cc similarity index 93% rename from tensorflow/lite/experimental/ruy/profiler/profiler.cc rename to tensorflow/lite/experimental/ruy/ruy/profiler/profiler.cc index d192ba36f3a..c5ff598ee2b 100644 --- a/tensorflow/lite/experimental/ruy/profiler/profiler.cc +++ b/tensorflow/lite/experimental/ruy/ruy/profiler/profiler.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/lite/experimental/ruy/profiler/profiler.h" +#include "tensorflow/lite/experimental/ruy/ruy/profiler/profiler.h" #ifdef RUY_PROFILER #include @@ -24,8 +24,8 @@ limitations under the License. #include #endif -#include "tensorflow/lite/experimental/ruy/profiler/instrumentation.h" -#include "tensorflow/lite/experimental/ruy/profiler/treeview.h" +#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" +#include "tensorflow/lite/experimental/ruy/ruy/profiler/treeview.h" namespace ruy { namespace profiler { diff --git a/tensorflow/lite/experimental/ruy/profiler/profiler.h b/tensorflow/lite/experimental/ruy/ruy/profiler/profiler.h similarity index 89% rename from tensorflow/lite/experimental/ruy/profiler/profiler.h rename to tensorflow/lite/experimental/ruy/ruy/profiler/profiler.h index 7166c910d97..19ef0deba0c 100644 --- a/tensorflow/lite/experimental/ruy/profiler/profiler.h +++ b/tensorflow/lite/experimental/ruy/ruy/profiler/profiler.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_PROFILER_PROFILER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_PROFILER_PROFILER_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PROFILER_PROFILER_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PROFILER_PROFILER_H_ #include @@ -25,8 +25,8 @@ limitations under the License. #include #endif -#include "tensorflow/lite/experimental/ruy/profiler/instrumentation.h" -#include "tensorflow/lite/experimental/ruy/profiler/treeview.h" +#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" +#include "tensorflow/lite/experimental/ruy/ruy/profiler/treeview.h" namespace ruy { namespace profiler { @@ -103,4 +103,4 @@ struct ScopeProfile { } // namespace profiler } // namespace ruy -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_PROFILER_PROFILER_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PROFILER_PROFILER_H_ diff --git a/tensorflow/lite/experimental/ruy/profiler/test.cc b/tensorflow/lite/experimental/ruy/ruy/profiler/test.cc similarity index 96% rename from tensorflow/lite/experimental/ruy/profiler/test.cc rename to tensorflow/lite/experimental/ruy/ruy/profiler/test.cc index 9e4f1734920..feab967c87c 100644 --- a/tensorflow/lite/experimental/ruy/profiler/test.cc +++ b/tensorflow/lite/experimental/ruy/ruy/profiler/test.cc @@ -18,9 +18,9 @@ limitations under the License. #include #include -#include "tensorflow/lite/experimental/ruy/profiler/profiler.h" -#include "tensorflow/lite/experimental/ruy/profiler/test_instrumented_library.h" -#include "tensorflow/lite/experimental/ruy/profiler/treeview.h" +#include "tensorflow/lite/experimental/ruy/ruy/profiler/profiler.h" +#include "tensorflow/lite/experimental/ruy/ruy/profiler/test_instrumented_library.h" +#include "tensorflow/lite/experimental/ruy/ruy/profiler/treeview.h" namespace ruy { namespace profiler { diff --git a/tensorflow/lite/experimental/ruy/profiler/test_instrumented_library.cc b/tensorflow/lite/experimental/ruy/ruy/profiler/test_instrumented_library.cc similarity index 96% rename from tensorflow/lite/experimental/ruy/profiler/test_instrumented_library.cc rename to tensorflow/lite/experimental/ruy/ruy/profiler/test_instrumented_library.cc index 822563c814d..e9b5929c9b7 100644 --- a/tensorflow/lite/experimental/ruy/profiler/test_instrumented_library.cc +++ b/tensorflow/lite/experimental/ruy/ruy/profiler/test_instrumented_library.cc @@ -15,7 +15,7 @@ limitations under the License. #include -#include "tensorflow/lite/experimental/ruy/profiler/instrumentation.h" +#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" namespace { diff --git a/tensorflow/lite/experimental/ruy/profiler/test_instrumented_library.h b/tensorflow/lite/experimental/ruy/ruy/profiler/test_instrumented_library.h similarity index 68% rename from tensorflow/lite/experimental/ruy/profiler/test_instrumented_library.h rename to tensorflow/lite/experimental/ruy/ruy/profiler/test_instrumented_library.h index 1272f5b1c21..d6a80a09042 100644 --- a/tensorflow/lite/experimental/ruy/profiler/test_instrumented_library.h +++ b/tensorflow/lite/experimental/ruy/ruy/profiler/test_instrumented_library.h @@ -13,11 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_PROFILER_TEST_INSTRUMENTED_LIBRARY_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_PROFILER_TEST_INSTRUMENTED_LIBRARY_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PROFILER_TEST_INSTRUMENTED_LIBRARY_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PROFILER_TEST_INSTRUMENTED_LIBRARY_H_ -#include "tensorflow/lite/experimental/ruy/profiler/instrumentation.h" +#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" void MergeSort(int size, int* data); -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_PROFILER_TEST_INSTRUMENTED_LIBRARY_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PROFILER_TEST_INSTRUMENTED_LIBRARY_H_ diff --git a/tensorflow/lite/experimental/ruy/profiler/treeview.cc b/tensorflow/lite/experimental/ruy/ruy/profiler/treeview.cc similarity index 99% rename from tensorflow/lite/experimental/ruy/profiler/treeview.cc rename to tensorflow/lite/experimental/ruy/ruy/profiler/treeview.cc index 8bf969ee33d..256d2a1106c 100644 --- a/tensorflow/lite/experimental/ruy/profiler/treeview.cc +++ b/tensorflow/lite/experimental/ruy/ruy/profiler/treeview.cc @@ -15,7 +15,7 @@ limitations under the License. #ifdef RUY_PROFILER -#include "tensorflow/lite/experimental/ruy/profiler/treeview.h" +#include "tensorflow/lite/experimental/ruy/ruy/profiler/treeview.h" #include #include diff --git a/tensorflow/lite/experimental/ruy/profiler/treeview.h b/tensorflow/lite/experimental/ruy/ruy/profiler/treeview.h similarity index 94% rename from tensorflow/lite/experimental/ruy/profiler/treeview.h rename to tensorflow/lite/experimental/ruy/ruy/profiler/treeview.h index b833e7b08c4..7f48af5ece0 100644 --- a/tensorflow/lite/experimental/ruy/profiler/treeview.h +++ b/tensorflow/lite/experimental/ruy/ruy/profiler/treeview.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_PROFILER_TREEVIEW_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_PROFILER_TREEVIEW_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PROFILER_TREEVIEW_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PROFILER_TREEVIEW_H_ #ifdef RUY_PROFILER @@ -23,7 +23,7 @@ limitations under the License. #include #include -#include "tensorflow/lite/experimental/ruy/profiler/instrumentation.h" +#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" namespace ruy { namespace profiler { @@ -127,4 +127,4 @@ void CollapseNodesMatchingFormatted(const TreeView& treeview_in, int depth, #endif // RUY_PROFILER -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_PROFILER_TREEVIEW_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PROFILER_TREEVIEW_H_ diff --git a/tensorflow/lite/experimental/ruy/ruy.h b/tensorflow/lite/experimental/ruy/ruy/ruy.h similarity index 74% rename from tensorflow/lite/experimental/ruy/ruy.h rename to tensorflow/lite/experimental/ruy/ruy/ruy.h index 436b1af94a1..783c410cf82 100644 --- a/tensorflow/lite/experimental/ruy/ruy.h +++ b/tensorflow/lite/experimental/ruy/ruy/ruy.h @@ -15,14 +15,14 @@ limitations under the License. // This is the only Ruy header that users should #include. -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_RUY_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_RUY_H_ -#include "tensorflow/lite/experimental/ruy/context.h" -#include "tensorflow/lite/experimental/ruy/dispatch.h" -#include "tensorflow/lite/experimental/ruy/matrix.h" -#include "tensorflow/lite/experimental/ruy/path.h" -#include "tensorflow/lite/experimental/ruy/spec.h" +#include "tensorflow/lite/experimental/ruy/ruy/context.h" +#include "tensorflow/lite/experimental/ruy/ruy/dispatch.h" +#include "tensorflow/lite/experimental/ruy/ruy/matrix.h" +#include "tensorflow/lite/experimental/ruy/ruy/path.h" +#include "tensorflow/lite/experimental/ruy/ruy/spec.h" namespace ruy { @@ -39,4 +39,4 @@ void Mul(const Matrix& lhs, const Matrix& rhs, } // namespace ruy -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_RUY_H_ diff --git a/tensorflow/lite/experimental/ruy/ruy_advanced.h b/tensorflow/lite/experimental/ruy/ruy/ruy_advanced.h similarity index 84% rename from tensorflow/lite/experimental/ruy/ruy_advanced.h rename to tensorflow/lite/experimental/ruy/ruy/ruy_advanced.h index 68748198f3e..0b24636ef06 100644 --- a/tensorflow/lite/experimental/ruy/ruy_advanced.h +++ b/tensorflow/lite/experimental/ruy/ruy/ruy_advanced.h @@ -13,17 +13,17 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_ADVANCED_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_ADVANCED_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_RUY_ADVANCED_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_RUY_ADVANCED_H_ #include #include -#include "tensorflow/lite/experimental/ruy/context.h" -#include "tensorflow/lite/experimental/ruy/matrix.h" -#include "tensorflow/lite/experimental/ruy/path.h" -#include "tensorflow/lite/experimental/ruy/prepack.h" -#include "tensorflow/lite/experimental/ruy/side_pair.h" +#include "tensorflow/lite/experimental/ruy/ruy/context.h" +#include "tensorflow/lite/experimental/ruy/ruy/matrix.h" +#include "tensorflow/lite/experimental/ruy/ruy/path.h" +#include "tensorflow/lite/experimental/ruy/ruy/prepack.h" +#include "tensorflow/lite/experimental/ruy/ruy/side_pair.h" namespace ruy { @@ -66,4 +66,4 @@ void MulWithPrepacked(const Matrix& lhs, } // namespace ruy -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_ADVANCED_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_RUY_ADVANCED_H_ diff --git a/tensorflow/lite/experimental/ruy/ruy_test.bzl b/tensorflow/lite/experimental/ruy/ruy/ruy_test.bzl similarity index 100% rename from tensorflow/lite/experimental/ruy/ruy_test.bzl rename to tensorflow/lite/experimental/ruy/ruy/ruy_test.bzl diff --git a/tensorflow/lite/experimental/ruy/ruy_test_ext.bzl b/tensorflow/lite/experimental/ruy/ruy/ruy_test_ext.bzl similarity index 100% rename from tensorflow/lite/experimental/ruy/ruy_test_ext.bzl rename to tensorflow/lite/experimental/ruy/ruy/ruy_test_ext.bzl diff --git a/tensorflow/lite/experimental/ruy/side_pair.h b/tensorflow/lite/experimental/ruy/ruy/side_pair.h similarity index 88% rename from tensorflow/lite/experimental/ruy/side_pair.h rename to tensorflow/lite/experimental/ruy/ruy/side_pair.h index 56ac16c85d1..a3210e27a53 100644 --- a/tensorflow/lite/experimental/ruy/side_pair.h +++ b/tensorflow/lite/experimental/ruy/ruy/side_pair.h @@ -13,10 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_SIDE_PAIR_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_SIDE_PAIR_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_SIDE_PAIR_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_SIDE_PAIR_H_ -#include "tensorflow/lite/experimental/ruy/check_macros.h" +#include "tensorflow/lite/experimental/ruy/ruy/check_macros.h" namespace ruy { @@ -61,4 +61,4 @@ class SidePair final { } // namespace ruy -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_SIDE_PAIR_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_SIDE_PAIR_H_ diff --git a/tensorflow/lite/experimental/ruy/size_util.h b/tensorflow/lite/experimental/ruy/ruy/size_util.h similarity index 90% rename from tensorflow/lite/experimental/ruy/size_util.h rename to tensorflow/lite/experimental/ruy/ruy/size_util.h index 5cfde0d48d7..56dd095de85 100644 --- a/tensorflow/lite/experimental/ruy/size_util.h +++ b/tensorflow/lite/experimental/ruy/ruy/size_util.h @@ -13,12 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_SIZE_UTIL_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_SIZE_UTIL_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_SIZE_UTIL_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_SIZE_UTIL_H_ #include -#include "tensorflow/lite/experimental/ruy/check_macros.h" +#include "tensorflow/lite/experimental/ruy/ruy/check_macros.h" #ifdef _WIN32 #include @@ -90,4 +90,4 @@ Integer round_up_pot(Integer value, Modulo modulo) { } // namespace ruy -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_SIZE_UTIL_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_SIZE_UTIL_H_ diff --git a/tensorflow/lite/experimental/ruy/size_util_test.cc b/tensorflow/lite/experimental/ruy/ruy/size_util_test.cc similarity index 98% rename from tensorflow/lite/experimental/ruy/size_util_test.cc rename to tensorflow/lite/experimental/ruy/ruy/size_util_test.cc index 48605dcd5ad..442c31958cc 100644 --- a/tensorflow/lite/experimental/ruy/size_util_test.cc +++ b/tensorflow/lite/experimental/ruy/ruy/size_util_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/lite/experimental/ruy/size_util.h" +#include "tensorflow/lite/experimental/ruy/ruy/size_util.h" #include #include diff --git a/tensorflow/lite/experimental/ruy/spec.h b/tensorflow/lite/experimental/ruy/ruy/spec.h similarity index 95% rename from tensorflow/lite/experimental/ruy/spec.h rename to tensorflow/lite/experimental/ruy/ruy/spec.h index 3f856e301ca..584d90ea047 100644 --- a/tensorflow/lite/experimental/ruy/spec.h +++ b/tensorflow/lite/experimental/ruy/ruy/spec.h @@ -13,14 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_SPEC_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_SPEC_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_SPEC_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_SPEC_H_ #include #include -#include "tensorflow/lite/experimental/ruy/cpu_cache_size.h" -#include "tensorflow/lite/experimental/ruy/matrix.h" +#include "tensorflow/lite/experimental/ruy/ruy/cpu_cache_size.h" +#include "tensorflow/lite/experimental/ruy/ruy/matrix.h" namespace ruy { @@ -115,4 +115,4 @@ struct BasicSpec { } // namespace ruy -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_SPEC_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_SPEC_H_ diff --git a/tensorflow/lite/experimental/ruy/test.h b/tensorflow/lite/experimental/ruy/ruy/test.h similarity index 99% rename from tensorflow/lite/experimental/ruy/test.h rename to tensorflow/lite/experimental/ruy/ruy/test.h index a7b2ff483b2..305b5a844fa 100644 --- a/tensorflow/lite/experimental/ruy/test.h +++ b/tensorflow/lite/experimental/ruy/ruy/test.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_TEST_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_TEST_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_TEST_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_TEST_H_ #include @@ -37,13 +37,13 @@ limitations under the License. #include #include // IWYU pragma: export -#include "tensorflow/lite/experimental/ruy/matrix.h" // IWYU pragma: export -#include "tensorflow/lite/experimental/ruy/platform.h" -#include "tensorflow/lite/experimental/ruy/pmu.h" -#include "tensorflow/lite/experimental/ruy/ruy.h" -#include "tensorflow/lite/experimental/ruy/ruy_advanced.h" -#include "tensorflow/lite/experimental/ruy/spec.h" // IWYU pragma: export -#include "tensorflow/lite/experimental/ruy/time.h" +#include "tensorflow/lite/experimental/ruy/ruy/matrix.h" // IWYU pragma: export +#include "tensorflow/lite/experimental/ruy/ruy/platform.h" +#include "tensorflow/lite/experimental/ruy/ruy/pmu.h" +#include "tensorflow/lite/experimental/ruy/ruy/ruy.h" +#include "tensorflow/lite/experimental/ruy/ruy/ruy_advanced.h" +#include "tensorflow/lite/experimental/ruy/ruy/spec.h" // IWYU pragma: export +#include "tensorflow/lite/experimental/ruy/ruy/time.h" #ifdef RUY_TEST_EXTERNAL_PATHS #define EIGEN_USE_THREADS @@ -55,7 +55,7 @@ limitations under the License. #endif #ifdef RUY_PROFILER -#include "tensorflow/lite/experimental/ruy/profiler/profiler.h" +#include "tensorflow/lite/experimental/ruy/ruy/profiler/profiler.h" #endif namespace ruy { @@ -2122,4 +2122,4 @@ void TestLinearAllOrders(int rows, int depth, int cols) { } // namespace ruy -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_TEST_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_TEST_H_ diff --git a/tensorflow/lite/experimental/ruy/test_fast.cc b/tensorflow/lite/experimental/ruy/ruy/test_fast.cc similarity index 98% rename from tensorflow/lite/experimental/ruy/test_fast.cc rename to tensorflow/lite/experimental/ruy/ruy/test_fast.cc index 58d69da1524..6b7026530ac 100644 --- a/tensorflow/lite/experimental/ruy/test_fast.cc +++ b/tensorflow/lite/experimental/ruy/ruy/test_fast.cc @@ -17,7 +17,7 @@ limitations under the License. #include -#include "tensorflow/lite/experimental/ruy/test.h" +#include "tensorflow/lite/experimental/ruy/ruy/test.h" namespace ruy { diff --git a/tensorflow/lite/experimental/ruy/test_slow.cc b/tensorflow/lite/experimental/ruy/ruy/test_slow.cc similarity index 97% rename from tensorflow/lite/experimental/ruy/test_slow.cc rename to tensorflow/lite/experimental/ruy/ruy/test_slow.cc index 4faa628b67f..7e7292cd503 100644 --- a/tensorflow/lite/experimental/ruy/test_slow.cc +++ b/tensorflow/lite/experimental/ruy/ruy/test_slow.cc @@ -15,7 +15,7 @@ limitations under the License. // This test contains more expensive test cases. -#include "tensorflow/lite/experimental/ruy/test.h" +#include "tensorflow/lite/experimental/ruy/ruy/test.h" namespace ruy { diff --git a/tensorflow/lite/experimental/ruy/test_special_specs.cc b/tensorflow/lite/experimental/ruy/ruy/test_special_specs.cc similarity index 99% rename from tensorflow/lite/experimental/ruy/test_special_specs.cc rename to tensorflow/lite/experimental/ruy/ruy/test_special_specs.cc index bcdd5da8e59..6f5a88c833a 100644 --- a/tensorflow/lite/experimental/ruy/test_special_specs.cc +++ b/tensorflow/lite/experimental/ruy/ruy/test_special_specs.cc @@ -15,7 +15,7 @@ limitations under the License. // This test covers non-basic specs. -#include "tensorflow/lite/experimental/ruy/test.h" +#include "tensorflow/lite/experimental/ruy/ruy/test.h" namespace ruy { diff --git a/tensorflow/lite/experimental/ruy/thread_pool.cc b/tensorflow/lite/experimental/ruy/ruy/thread_pool.cc similarity index 97% rename from tensorflow/lite/experimental/ruy/thread_pool.cc rename to tensorflow/lite/experimental/ruy/ruy/thread_pool.cc index 0e7130f8734..eb86a1fbf38 100644 --- a/tensorflow/lite/experimental/ruy/thread_pool.cc +++ b/tensorflow/lite/experimental/ruy/ruy/thread_pool.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/lite/experimental/ruy/thread_pool.h" +#include "tensorflow/lite/experimental/ruy/ruy/thread_pool.h" #include #include // NOLINT(build/c++11) @@ -24,8 +24,8 @@ limitations under the License. #include // NOLINT(build/c++11) #include // NOLINT(build/c++11) -#include "tensorflow/lite/experimental/ruy/check_macros.h" -#include "tensorflow/lite/experimental/ruy/wait.h" +#include "tensorflow/lite/experimental/ruy/ruy/check_macros.h" +#include "tensorflow/lite/experimental/ruy/ruy/wait.h" namespace ruy { diff --git a/tensorflow/lite/experimental/ruy/thread_pool.h b/tensorflow/lite/experimental/ruy/ruy/thread_pool.h similarity index 93% rename from tensorflow/lite/experimental/ruy/thread_pool.h rename to tensorflow/lite/experimental/ruy/ruy/thread_pool.h index 179f5d41f43..5504bd80614 100644 --- a/tensorflow/lite/experimental/ruy/thread_pool.h +++ b/tensorflow/lite/experimental/ruy/ruy/thread_pool.h @@ -16,12 +16,12 @@ limitations under the License. // This file is a fork of gemmlowp's multi_thread_gemm.h, under Apache 2.0 // license. -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_THREAD_POOL_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_THREAD_POOL_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_THREAD_POOL_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_THREAD_POOL_H_ #include -#include "tensorflow/lite/experimental/ruy/blocking_counter.h" +#include "tensorflow/lite/experimental/ruy/ruy/blocking_counter.h" namespace ruy { @@ -99,4 +99,4 @@ class ThreadPool { } // namespace ruy -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_THREAD_POOL_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_THREAD_POOL_H_ diff --git a/tensorflow/lite/experimental/ruy/time.h b/tensorflow/lite/experimental/ruy/ruy/time.h similarity index 93% rename from tensorflow/lite/experimental/ruy/time.h rename to tensorflow/lite/experimental/ruy/ruy/time.h index d96ed3409e0..9dba75eb4c5 100644 --- a/tensorflow/lite/experimental/ruy/time.h +++ b/tensorflow/lite/experimental/ruy/ruy/time.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_TIME_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_TIME_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_TIME_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_TIME_H_ #include // NOLINT(build/c++11) #include // IWYU pragma: keep @@ -78,4 +78,4 @@ inline TimePoint CoarseNow() { } // namespace ruy -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_TIME_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_TIME_H_ diff --git a/tensorflow/lite/experimental/ruy/trace.cc b/tensorflow/lite/experimental/ruy/ruy/trace.cc similarity index 97% rename from tensorflow/lite/experimental/ruy/trace.cc rename to tensorflow/lite/experimental/ruy/ruy/trace.cc index 83b31103d42..806f6ec2cf2 100644 --- a/tensorflow/lite/experimental/ruy/trace.cc +++ b/tensorflow/lite/experimental/ruy/ruy/trace.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/lite/experimental/ruy/trace.h" +#include "tensorflow/lite/experimental/ruy/ruy/trace.h" #include #include // IWYU pragma: keep @@ -22,9 +22,9 @@ limitations under the License. #include #include -#include "tensorflow/lite/experimental/ruy/check_macros.h" -#include "tensorflow/lite/experimental/ruy/side_pair.h" -#include "tensorflow/lite/experimental/ruy/time.h" +#include "tensorflow/lite/experimental/ruy/ruy/check_macros.h" +#include "tensorflow/lite/experimental/ruy/ruy/side_pair.h" +#include "tensorflow/lite/experimental/ruy/ruy/time.h" namespace ruy { diff --git a/tensorflow/lite/experimental/ruy/trace.h b/tensorflow/lite/experimental/ruy/ruy/trace.h similarity index 89% rename from tensorflow/lite/experimental/ruy/trace.h rename to tensorflow/lite/experimental/ruy/ruy/trace.h index 87be2e9b675..6680438c124 100644 --- a/tensorflow/lite/experimental/ruy/trace.h +++ b/tensorflow/lite/experimental/ruy/ruy/trace.h @@ -13,13 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_TRACE_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_TRACE_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_TRACE_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_TRACE_H_ #include -#include "tensorflow/lite/experimental/ruy/block_map.h" -#include "tensorflow/lite/experimental/ruy/side_pair.h" +#include "tensorflow/lite/experimental/ruy/ruy/block_map.h" +#include "tensorflow/lite/experimental/ruy/ruy/side_pair.h" namespace ruy { @@ -70,4 +70,4 @@ inline void TraceRecordEnd(Trace*) {} } // namespace ruy -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_TRACE_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_TRACE_H_ diff --git a/tensorflow/lite/experimental/ruy/trmul.cc b/tensorflow/lite/experimental/ruy/ruy/trmul.cc similarity index 94% rename from tensorflow/lite/experimental/ruy/trmul.cc rename to tensorflow/lite/experimental/ruy/ruy/trmul.cc index 783d8d08b9f..c3e15a9d628 100644 --- a/tensorflow/lite/experimental/ruy/trmul.cc +++ b/tensorflow/lite/experimental/ruy/ruy/trmul.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/lite/experimental/ruy/trmul.h" +#include "tensorflow/lite/experimental/ruy/ruy/trmul.h" #include #include @@ -21,20 +21,20 @@ limitations under the License. #include #include -#include "tensorflow/lite/experimental/ruy/allocator.h" -#include "tensorflow/lite/experimental/ruy/block_map.h" -#include "tensorflow/lite/experimental/ruy/check_macros.h" -#include "tensorflow/lite/experimental/ruy/common.h" -#include "tensorflow/lite/experimental/ruy/internal_matrix.h" -#include "tensorflow/lite/experimental/ruy/matrix.h" -#include "tensorflow/lite/experimental/ruy/opt_set.h" -#include "tensorflow/lite/experimental/ruy/profiler/instrumentation.h" -#include "tensorflow/lite/experimental/ruy/side_pair.h" -#include "tensorflow/lite/experimental/ruy/size_util.h" -#include "tensorflow/lite/experimental/ruy/spec.h" -#include "tensorflow/lite/experimental/ruy/thread_pool.h" -#include "tensorflow/lite/experimental/ruy/trace.h" -#include "tensorflow/lite/experimental/ruy/tune.h" +#include "tensorflow/lite/experimental/ruy/ruy/allocator.h" +#include "tensorflow/lite/experimental/ruy/ruy/block_map.h" +#include "tensorflow/lite/experimental/ruy/ruy/check_macros.h" +#include "tensorflow/lite/experimental/ruy/ruy/common.h" +#include "tensorflow/lite/experimental/ruy/ruy/internal_matrix.h" +#include "tensorflow/lite/experimental/ruy/ruy/matrix.h" +#include "tensorflow/lite/experimental/ruy/ruy/opt_set.h" +#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" +#include "tensorflow/lite/experimental/ruy/ruy/side_pair.h" +#include "tensorflow/lite/experimental/ruy/ruy/size_util.h" +#include "tensorflow/lite/experimental/ruy/ruy/spec.h" +#include "tensorflow/lite/experimental/ruy/ruy/thread_pool.h" +#include "tensorflow/lite/experimental/ruy/ruy/trace.h" +#include "tensorflow/lite/experimental/ruy/ruy/tune.h" namespace ruy { diff --git a/tensorflow/lite/experimental/ruy/trmul.h b/tensorflow/lite/experimental/ruy/ruy/trmul.h similarity index 82% rename from tensorflow/lite/experimental/ruy/trmul.h rename to tensorflow/lite/experimental/ruy/ruy/trmul.h index 6f7d7ba4590..9786b7f6180 100644 --- a/tensorflow/lite/experimental/ruy/trmul.h +++ b/tensorflow/lite/experimental/ruy/ruy/trmul.h @@ -23,11 +23,11 @@ limitations under the License. // That is why TrMul is nicer to implement, allowing for a more symmetric // treatment of LHS and RHS. -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_TRMUL_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_TRMUL_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_TRMUL_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_TRMUL_H_ -#include "tensorflow/lite/experimental/ruy/context.h" -#include "tensorflow/lite/experimental/ruy/trmul_params.h" +#include "tensorflow/lite/experimental/ruy/ruy/context.h" +#include "tensorflow/lite/experimental/ruy/ruy/trmul_params.h" namespace ruy { @@ -35,4 +35,4 @@ void TrMul(TrMulParams* params, Context* context); } // namespace ruy -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_TRMUL_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_TRMUL_H_ diff --git a/tensorflow/lite/experimental/ruy/trmul_params.h b/tensorflow/lite/experimental/ruy/ruy/trmul_params.h similarity index 84% rename from tensorflow/lite/experimental/ruy/trmul_params.h rename to tensorflow/lite/experimental/ruy/ruy/trmul_params.h index 060dd9c6c18..c694f16b938 100644 --- a/tensorflow/lite/experimental/ruy/trmul_params.h +++ b/tensorflow/lite/experimental/ruy/ruy/trmul_params.h @@ -13,12 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_TRMUL_PARAMS_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_TRMUL_PARAMS_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_TRMUL_PARAMS_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_TRMUL_PARAMS_H_ -#include "tensorflow/lite/experimental/ruy/internal_matrix.h" -#include "tensorflow/lite/experimental/ruy/side_pair.h" -#include "tensorflow/lite/experimental/ruy/tune.h" +#include "tensorflow/lite/experimental/ruy/ruy/internal_matrix.h" +#include "tensorflow/lite/experimental/ruy/ruy/side_pair.h" +#include "tensorflow/lite/experimental/ruy/ruy/tune.h" namespace ruy { @@ -64,4 +64,4 @@ struct TrMulParams { } // namespace ruy -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_TRMUL_PARAMS_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_TRMUL_PARAMS_H_ diff --git a/tensorflow/lite/experimental/ruy/tune.cc b/tensorflow/lite/experimental/ruy/ruy/tune.cc similarity index 98% rename from tensorflow/lite/experimental/ruy/tune.cc rename to tensorflow/lite/experimental/ruy/ruy/tune.cc index 3249b5b211c..63fa0338d6d 100644 --- a/tensorflow/lite/experimental/ruy/tune.cc +++ b/tensorflow/lite/experimental/ruy/ruy/tune.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/lite/experimental/ruy/tune.h" +#include "tensorflow/lite/experimental/ruy/ruy/tune.h" #include #include diff --git a/tensorflow/lite/experimental/ruy/tune.h b/tensorflow/lite/experimental/ruy/ruy/tune.h similarity index 95% rename from tensorflow/lite/experimental/ruy/tune.h rename to tensorflow/lite/experimental/ruy/ruy/tune.h index be38ca3fab0..3471604e37a 100644 --- a/tensorflow/lite/experimental/ruy/tune.h +++ b/tensorflow/lite/experimental/ruy/ruy/tune.h @@ -69,12 +69,12 @@ limitations under the License. // nano-benchmark. // * Maybe using getcpu in conjunction with the nano-benchmark to cache // per-CPU-id nano-benchmark results. -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_TUNE_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_TUNE_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_TUNE_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_TUNE_H_ -#include "tensorflow/lite/experimental/ruy/opt_set.h" -#include "tensorflow/lite/experimental/ruy/platform.h" -#include "tensorflow/lite/experimental/ruy/time.h" +#include "tensorflow/lite/experimental/ruy/ruy/opt_set.h" +#include "tensorflow/lite/experimental/ruy/ruy/platform.h" +#include "tensorflow/lite/experimental/ruy/ruy/time.h" // Tuning only implemented on NEON_64 at the moment (see assembly code // in the nano-benchmark) and not on Apple (some Apple CPUs produce incorrect @@ -160,4 +160,4 @@ class TuningResolver { } // namespace ruy -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_TUNE_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_TUNE_H_ diff --git a/tensorflow/lite/experimental/ruy/tune_test.cc b/tensorflow/lite/experimental/ruy/ruy/tune_test.cc similarity index 96% rename from tensorflow/lite/experimental/ruy/tune_test.cc rename to tensorflow/lite/experimental/ruy/ruy/tune_test.cc index 051c34910b6..0b00e645195 100644 --- a/tensorflow/lite/experimental/ruy/tune_test.cc +++ b/tensorflow/lite/experimental/ruy/ruy/tune_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/lite/experimental/ruy/tune.h" +#include "tensorflow/lite/experimental/ruy/ruy/tune.h" #include // NOLINT(build/c++11) #include // NOLINT(build/c++11) diff --git a/tensorflow/lite/experimental/ruy/tune_tool.cc b/tensorflow/lite/experimental/ruy/ruy/tune_tool.cc similarity index 96% rename from tensorflow/lite/experimental/ruy/tune_tool.cc rename to tensorflow/lite/experimental/ruy/ruy/tune_tool.cc index bda0a0af93d..04cfa6d6b89 100644 --- a/tensorflow/lite/experimental/ruy/tune_tool.cc +++ b/tensorflow/lite/experimental/ruy/ruy/tune_tool.cc @@ -20,7 +20,7 @@ limitations under the License. #include #include // NOLINT(build/c++11) -#include "tensorflow/lite/experimental/ruy/tune.h" +#include "tensorflow/lite/experimental/ruy/ruy/tune.h" #ifdef _WIN32 #define getpid() 0 diff --git a/tensorflow/lite/experimental/ruy/wait.cc b/tensorflow/lite/experimental/ruy/ruy/wait.cc similarity index 98% rename from tensorflow/lite/experimental/ruy/wait.cc rename to tensorflow/lite/experimental/ruy/ruy/wait.cc index 04a5848fb44..7d91b6ebce6 100644 --- a/tensorflow/lite/experimental/ruy/wait.cc +++ b/tensorflow/lite/experimental/ruy/ruy/wait.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/lite/experimental/ruy/wait.h" +#include "tensorflow/lite/experimental/ruy/ruy/wait.h" #include // NOLINT(build/c++11) diff --git a/tensorflow/lite/experimental/ruy/wait.h b/tensorflow/lite/experimental/ruy/ruy/wait.h similarity index 93% rename from tensorflow/lite/experimental/ruy/wait.h rename to tensorflow/lite/experimental/ruy/ruy/wait.h index 0d06a4c2748..a3cd26282af 100644 --- a/tensorflow/lite/experimental/ruy/wait.h +++ b/tensorflow/lite/experimental/ruy/ruy/wait.h @@ -13,14 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_WAIT_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_WAIT_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_WAIT_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_WAIT_H_ #include // NOLINT(build/c++11) #include #include // NOLINT(build/c++11) -#include "tensorflow/lite/experimental/ruy/time.h" +#include "tensorflow/lite/experimental/ruy/ruy/time.h" namespace ruy { @@ -70,4 +70,4 @@ void Wait(const std::function& condition, } // namespace ruy -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_WAIT_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_WAIT_H_ diff --git a/tensorflow/lite/experimental/ruy/wait_test.cc b/tensorflow/lite/experimental/ruy/ruy/wait_test.cc similarity index 97% rename from tensorflow/lite/experimental/ruy/wait_test.cc rename to tensorflow/lite/experimental/ruy/ruy/wait_test.cc index 71e9d1d5b35..b1b7558583d 100644 --- a/tensorflow/lite/experimental/ruy/wait_test.cc +++ b/tensorflow/lite/experimental/ruy/ruy/wait_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/lite/experimental/ruy/wait.h" +#include "tensorflow/lite/experimental/ruy/ruy/wait.h" #include #include // NOLINT(build/c++11) @@ -21,7 +21,7 @@ limitations under the License. #include // NOLINT(build/c++11) #include -#include "tensorflow/lite/experimental/ruy/platform.h" +#include "tensorflow/lite/experimental/ruy/ruy/platform.h" namespace ruy { namespace { diff --git a/tensorflow/lite/experimental/support/metadata/java/src/java/org/tensorflow/lite/support/metadata/MetadataExtractor.java b/tensorflow/lite/experimental/support/metadata/java/src/java/org/tensorflow/lite/support/metadata/MetadataExtractor.java index 3bd60edfef6..054ea0e9730 100644 --- a/tensorflow/lite/experimental/support/metadata/java/src/java/org/tensorflow/lite/support/metadata/MetadataExtractor.java +++ b/tensorflow/lite/experimental/support/metadata/java/src/java/org/tensorflow/lite/support/metadata/MetadataExtractor.java @@ -25,6 +25,7 @@ import org.checkerframework.checker.nullness.qual.Nullable; import org.tensorflow.lite.DataType; import org.tensorflow.lite.Tensor.QuantizationParams; import org.tensorflow.lite.schema.Tensor; +import org.tensorflow.lite.support.metadata.schema.ModelMetadata; import org.tensorflow.lite.support.metadata.schema.TensorMetadata; /** @@ -96,6 +97,11 @@ public class MetadataExtractor { zipFile = createZipFile(buffer); } + /** Returns {@code true} if the model has metadata. Otherwise, returns {@code false}. */ + public Boolean hasMetadata() { + return metadataInfo != null; + } + /** * Gets the packed associated file with the specified {@code fileName}. * @@ -154,6 +160,16 @@ public class MetadataExtractor { return modelInfo.getInputTensorType(inputIndex); } + /** + * Gets the root handler for the model metadata. + * + * @throws IllegalStateException if this model does not contain model metadata + */ + public ModelMetadata getModelMetadata() { + assertMetadataInfo(); + return metadataInfo.getModelMetadata(); + } + /** Gets the count of output tensors in the model. */ public int getOutputTensorCount() { return modelInfo.getOutputTensorCount(); diff --git a/tensorflow/lite/experimental/support/metadata/java/src/java/org/tensorflow/lite/support/metadata/ModelMetadataInfo.java b/tensorflow/lite/experimental/support/metadata/java/src/java/org/tensorflow/lite/support/metadata/ModelMetadataInfo.java index 6a6419393d5..ad13a3050af 100644 --- a/tensorflow/lite/experimental/support/metadata/java/src/java/org/tensorflow/lite/support/metadata/ModelMetadataInfo.java +++ b/tensorflow/lite/experimental/support/metadata/java/src/java/org/tensorflow/lite/support/metadata/ModelMetadataInfo.java @@ -29,6 +29,9 @@ import org.tensorflow.lite.support.metadata.schema.TensorMetadata; /** Extracts model metadata information out of TFLite metadata FlatBuffer. */ final class ModelMetadataInfo { + /** The root handler for the model metadata. */ + private final ModelMetadata modelMetadata; + /** Metadata array of input tensors. */ private final List inputsMetadata; @@ -45,7 +48,7 @@ final class ModelMetadataInfo { ModelMetadataInfo(ByteBuffer buffer) { checkNotNull(buffer, "Metadata flatbuffer cannot be null."); - ModelMetadata modelMetadata = ModelMetadata.getRootAsModelMetadata(buffer); + modelMetadata = ModelMetadata.getRootAsModelMetadata(buffer); checkArgument( modelMetadata.subgraphMetadataLength() > 0, "The metadata flatbuffer does not contain any subgraph metadata."); @@ -73,6 +76,11 @@ final class ModelMetadataInfo { return inputsMetadata.get(inputIndex); } + /** Gets the root handler for the model metadata. */ + ModelMetadata getModelMetadata() { + return modelMetadata; + } + /** Gets the count of output tensors with metadata in the metadata FlatBuffer. */ int getOutputTensorCount() { return outputsMetadata.size(); diff --git a/tensorflow/lite/experimental/writer/BUILD b/tensorflow/lite/experimental/writer/BUILD index 34d5c68c490..2792cc7ae2f 100644 --- a/tensorflow/lite/experimental/writer/BUILD +++ b/tensorflow/lite/experimental/writer/BUILD @@ -32,6 +32,7 @@ cc_library( "//tensorflow/lite:builtin_op_data", "//tensorflow/lite:framework", "//tensorflow/lite:schema_fbs_version", + "//tensorflow/lite/c:common", "//tensorflow/lite/schema:schema_fbs_with_reflection", ], ) @@ -63,7 +64,9 @@ cc_test( deps = [ ":writer_lib", "//tensorflow/lite:framework", + "//tensorflow/lite/c:common", "//tensorflow/lite/kernels:builtin_ops", + "//tensorflow/lite/schema:schema_fbs", "//tensorflow/lite/testing:util", "@com_google_googletest//:gtest", ], diff --git a/tensorflow/lite/experimental/writer/writer_lib.cc b/tensorflow/lite/experimental/writer/writer_lib.cc index eb5b5cb1088..a88544e95b1 100644 --- a/tensorflow/lite/experimental/writer/writer_lib.cc +++ b/tensorflow/lite/experimental/writer/writer_lib.cc @@ -17,8 +17,10 @@ limitations under the License. #include #include #include +#include #include "tensorflow/lite/builtin_op_data.h" +#include "tensorflow/lite/c/common.h" #include "tensorflow/lite/context_util.h" #include "tensorflow/lite/core/subgraph.h" #include "tensorflow/lite/experimental/writer/enum_mapping.h" @@ -50,7 +52,7 @@ SubgraphWriter::ExportOperators(flatbuffers::FlatBufferBuilder* fbb) { std::vector operator_to_opcode; // TODO(aselle): Augment this once we put execution plan in schema. operator_to_opcode.resize(subgraph_->nodes_size(), -1); - for (int op_index : subgraph_->execution_plan()) { + for (int op_index : execution_plan_) { const auto* node_and_registration = subgraph_->node_and_registration(op_index); const TfLiteRegistration* registration = &node_and_registration->second; @@ -63,7 +65,7 @@ SubgraphWriter::ExportOperators(flatbuffers::FlatBufferBuilder* fbb) { } } // second pass serialize operators - for (int op_index : subgraph_->execution_plan()) { + for (int op_index : execution_plan_) { const auto* node_and_registration = subgraph_->node_and_registration(op_index); const TfLiteNode& node = node_and_registration->first; @@ -255,10 +257,8 @@ TfLiteStatus SubgraphWriter::GetBuffer(std::unique_ptr* out, std::vector> subgraphs_as_vector; { // subgraph specific stuff auto tensors = ExportTensors(&builder); - std::vector written_inputs = - RemapTensorIndicesToWritten(subgraph_->inputs()); - std::vector written_outputs = - RemapTensorIndicesToWritten(subgraph_->outputs()); + std::vector written_inputs = RemapTensorIndicesToWritten(inputs_); + std::vector written_outputs = RemapTensorIndicesToWritten(outputs_); auto inputs = ExportVector(&builder, written_inputs); auto outputs = ExportVector(&builder, written_outputs); @@ -309,4 +309,63 @@ TfLiteStatus SubgraphWriter::RegisterCustomWriter( return kTfLiteOk; } +TfLiteStatus SubgraphWriter::CheckInputOutput( + const std::vector& inputs, const std::vector& outputs, + const std::vector& execution_plan) { + std::unordered_set known_tensors(inputs.begin(), inputs.end()); + // Scan execution plan and confirm input tensors are known before each node + // executes. Then append output tensors to known tensors. + for (int op_index : execution_plan) { + const auto* node_and_registration = + subgraph_->node_and_registration(op_index); + const TfLiteNode& node = node_and_registration->first; + for (int tensor_index : TfLiteIntArrayView(node.inputs)) { + if (TfLiteTensor* tensor = subgraph_->tensor(tensor_index)) { + // Skip constant tensors. + if (tensor->allocation_type == kTfLiteMmapRo) { + continue; + } + } + + if (known_tensors.find(tensor_index) == known_tensors.end()) { + subgraph_->context()->ReportError( + subgraph_->context(), + "Node (%d) uses an input (%d) that is not provided.", op_index, + tensor_index); + return kTfLiteError; + } + } + TfLiteIntArrayView outputs(node.outputs); + known_tensors.insert(outputs.begin(), outputs.end()); + } + + // Check if outputs are known tensors or constants. + for (int tensor_index : outputs) { + if (TfLiteTensor* tensor = subgraph_->tensor(tensor_index)) { + // Skip constant tensors. + if (tensor->allocation_type == kTfLiteMmapRo) { + continue; + } + } + + if (known_tensors.find(tensor_index) == known_tensors.end()) { + subgraph_->context()->ReportError( + subgraph_->context(), + "Output (%d) is not produced by the execution plan.", tensor_index); + return kTfLiteError; + } + } + return kTfLiteOk; +} + +TfLiteStatus SubgraphWriter::SetCustomInputOutput( + const std::vector& inputs, const std::vector& outputs, + const std::vector& execution_plan) { + TF_LITE_ENSURE_STATUS(CheckInputOutput(inputs, outputs, execution_plan)); + inputs_ = inputs; + outputs_ = outputs; + execution_plan_ = execution_plan; + return kTfLiteOk; +} + } // namespace tflite diff --git a/tensorflow/lite/experimental/writer/writer_lib.h b/tensorflow/lite/experimental/writer/writer_lib.h index cc2b8fcf174..f7816dcc33e 100644 --- a/tensorflow/lite/experimental/writer/writer_lib.h +++ b/tensorflow/lite/experimental/writer/writer_lib.h @@ -30,6 +30,7 @@ limitations under the License. #include #include "tensorflow/lite/builtin_op_data.h" +#include "tensorflow/lite/c/common.h" #include "tensorflow/lite/context_util.h" #include "tensorflow/lite/core/subgraph.h" #include "tensorflow/lite/experimental/writer/enum_mapping.h" @@ -47,9 +48,13 @@ class SubgraphWriter { flatbuffers::Offset>* output_options, CustomOptionsFormat* custom_options_format); - // Construct an subgraph writer for the specified `subgraph`. Then, - // a uses .Write() or .GetBuffer(...) to extract the data. - explicit SubgraphWriter(Subgraph* subgraph) : subgraph_(subgraph) { + // Construct a subgraph writer for the specified `subgraph`. Then, use + // .Write() or .GetBuffer(...) to extract the data. + explicit SubgraphWriter(Subgraph* subgraph) + : subgraph_(subgraph), + inputs_(subgraph->inputs()), + outputs_(subgraph->outputs()), + execution_plan_(subgraph->execution_plan()) { buffers_.push_back(std::make_pair(nullptr, 0)); } @@ -65,6 +70,11 @@ class SubgraphWriter { void SetUnusedTensors(const std::set& unused_tensors) { unused_tensors_ = unused_tensors; } + // Sets custom inputs, outputs, and execution_plan so that a portion of the + // subgraph is written to the buffer instead of the whole subgraph. + TfLiteStatus SetCustomInputOutput(const std::vector& inputs, + const std::vector& outputs, + const std::vector& execution_plan); private: template @@ -84,6 +94,12 @@ class SubgraphWriter { template std::vector RemapTensorIndicesToWritten(const T& input); + // Checks if given `input`, `output`, and `execution_plan` represents a valid + // model within the Subgraph. + TfLiteStatus CheckInputOutput(const std::vector& inputs, + const std::vector& outputs, + const std::vector& execution_plan); + int GetOpCodeForBuiltin(int builtin_op_index) { // auto it = builtin_op_to_opcode_.find(builtin_op_index); std::pair result = @@ -107,6 +123,12 @@ class SubgraphWriter { // The subgraph we are writing Subgraph* subgraph_; + // Input tensor indices to be written. + std::vector inputs_; + // Output tensor indices to be written. + std::vector outputs_; + // Order of nodes to be written. + std::vector execution_plan_; // Keep track of byte buffers std::vector> buffers_; // List of op codes and mappings from builtin or custom op to opcode diff --git a/tensorflow/lite/experimental/writer/writer_lib_test.cc b/tensorflow/lite/experimental/writer/writer_lib_test.cc index 0b5520f9bee..41cca88ead7 100644 --- a/tensorflow/lite/experimental/writer/writer_lib_test.cc +++ b/tensorflow/lite/experimental/writer/writer_lib_test.cc @@ -14,10 +14,13 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/lite/experimental/writer/writer_lib.h" + #include +#include "tensorflow/lite/c/common.h" #include "tensorflow/lite/interpreter.h" #include "tensorflow/lite/kernels/register.h" #include "tensorflow/lite/model.h" +#include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/testing/util.h" namespace tflite { @@ -55,6 +58,102 @@ TEST(Writer, FloatModelTest) { CHECK_EQ(new_interpreter->AllocateTensors(), kTfLiteOk); } +// Tests writing only a portion of the subgraph. +TEST(Writer, CustomInputOutputTest) { + Interpreter interpreter; + interpreter.AddTensors(4); + constexpr float kFoo[] = {1, 2, 3}; + interpreter.SetTensorParametersReadWrite(0, kTfLiteFloat32, "a", {3}, + TfLiteQuantization()); + interpreter.SetTensorParametersReadOnly( + 1, kTfLiteFloat32, "b", {3}, TfLiteQuantization(), + reinterpret_cast(kFoo), sizeof(kFoo)); + interpreter.SetTensorParametersReadWrite(2, kTfLiteFloat32, "c", {3}, + TfLiteQuantization()); + interpreter.SetTensorParametersReadWrite(3, kTfLiteFloat32, "d", {3}, + TfLiteQuantization()); + interpreter.SetInputs({0, 1}); + interpreter.SetOutputs({3}); + + // Add two ops: Add and Relu + const char* initial_data = ""; + tflite::ops::builtin::BuiltinOpResolver resolver; + TfLiteAddParams* builtin_data = + reinterpret_cast(malloc(sizeof(TfLiteAddParams))); + builtin_data->activation = kTfLiteActNone; + const TfLiteRegistration* reg = resolver.FindOp(BuiltinOperator_ADD, 1); + interpreter.AddNodeWithParameters({0, 1}, {2}, initial_data, 0, + reinterpret_cast(builtin_data), reg); + + const TfLiteRegistration* reg2 = resolver.FindOp(BuiltinOperator_RELU, 1); + interpreter.AddNodeWithParameters({2}, {3}, nullptr, 0, nullptr, reg2); + + // Only write the second op. + SubgraphWriter writer(&interpreter.primary_subgraph()); + EXPECT_EQ(writer.SetCustomInputOutput(/*inputs=*/{2}, /*outputs=*/{3}, + /*execution_plan=*/{1}), + kTfLiteOk); + writer.SetUnusedTensors({0, 1}); + writer.Write("/tmp/test_custom.tflite"); + + std::unique_ptr model = + FlatBufferModel::BuildFromFile("/tmp/test_custom.tflite"); + InterpreterBuilder builder(*model, resolver); + std::unique_ptr new_interpreter; + builder(&new_interpreter); + ASSERT_EQ(new_interpreter->AllocateTensors(), kTfLiteOk); +} + +TEST(Writer, CustomInputOutputErrorCasesTest) { + Interpreter interpreter; + interpreter.AddTensors(5); + constexpr float kFoo[] = {1, 2, 3}; + interpreter.SetTensorParametersReadWrite(0, kTfLiteFloat32, "a", {3}, + TfLiteQuantization()); + interpreter.SetTensorParametersReadOnly( + 1, kTfLiteFloat32, "b", {3}, TfLiteQuantization(), + reinterpret_cast(kFoo), sizeof(kFoo)); + interpreter.SetTensorParametersReadWrite(2, kTfLiteFloat32, "c", {3}, + TfLiteQuantization()); + interpreter.SetTensorParametersReadWrite(3, kTfLiteFloat32, "d", {3}, + TfLiteQuantization()); + interpreter.SetTensorParametersReadWrite(4, kTfLiteFloat32, "e", {3}, + TfLiteQuantization()); + interpreter.SetInputs({0, 1}); + interpreter.SetOutputs({4}); + + // Add three ops. + const char* initial_data = ""; + tflite::ops::builtin::BuiltinOpResolver resolver; + TfLiteAddParams* builtin_data = + reinterpret_cast(malloc(sizeof(TfLiteAddParams))); + builtin_data->activation = kTfLiteActNone; + const TfLiteRegistration* reg = resolver.FindOp(BuiltinOperator_ADD, 1); + interpreter.AddNodeWithParameters({0, 1}, {2}, initial_data, 0, + reinterpret_cast(builtin_data), reg); + + const TfLiteRegistration* reg2 = resolver.FindOp(BuiltinOperator_RELU, 1); + interpreter.AddNodeWithParameters({2}, {3}, nullptr, 0, nullptr, reg2); + + const TfLiteRegistration* reg3 = resolver.FindOp(BuiltinOperator_RELU6, 1); + interpreter.AddNodeWithParameters({3}, {4}, nullptr, 0, nullptr, reg3); + + SubgraphWriter writer(&interpreter.primary_subgraph()); + + // Test wrong input. + EXPECT_EQ(writer.SetCustomInputOutput(/*inputs=*/{2}, /*outputs=*/{3}, + /*execution_plan=*/{0, 1}), + kTfLiteError); + // Test wrong output. + EXPECT_EQ(writer.SetCustomInputOutput(/*inputs=*/{0, 1}, /*outputs=*/{4}, + /*execution_plan=*/{0, 1}), + kTfLiteError); + // Test a valid case. + EXPECT_EQ(writer.SetCustomInputOutput(/*inputs=*/{0, 1}, /*outputs=*/{3}, + /*execution_plan=*/{0, 1}), + kTfLiteOk); +} + TEST(Writer, PerTensorQuantizedModelTest) { Interpreter interpreter; interpreter.AddTensors(3); diff --git a/tensorflow/lite/g3doc/convert/python_api.md b/tensorflow/lite/g3doc/convert/python_api.md index ef9bdf2c9ef..ba86eac25fd 100644 --- a/tensorflow/lite/g3doc/convert/python_api.md +++ b/tensorflow/lite/g3doc/convert/python_api.md @@ -171,7 +171,7 @@ TensorFlow Lite metadata provides a standard for model descriptions. The metadata is an important source of knowledge about what the model does and its input / output information. This makes it easier for other developers to understand the best practices and for code generators to create platform -specific wrapper code. For more infomation, please refer to the +specific wrapper code. For more information, please refer to the [TensorFlow Lite Metadata](metadata.md) section. ## Installing TensorFlow @@ -192,7 +192,7 @@ either install the nightly build with [Docker](https://www.tensorflow.org/install/docker), or [build the pip package from source](https://www.tensorflow.org/install/source). -### Custom ops in the experimenal new converter +### Custom ops in the experimental new converter There is a behavior change in how models containing [custom ops](https://www.tensorflow.org/lite/guide/ops_custom) (those for which diff --git a/tensorflow/lite/g3doc/guide/index.md b/tensorflow/lite/g3doc/guide/index.md index bb658237bba..6419b3b55de 100644 --- a/tensorflow/lite/g3doc/guide/index.md +++ b/tensorflow/lite/g3doc/guide/index.md @@ -28,9 +28,10 @@ improve: TensorFlow Lite works with a huge range of devices, from tiny microcontrollers to powerful mobile phones. -Key Point: The TensorFlow Lite binary is smaller than 300KB when all supported -operators are linked, and less than 200KB when using only the operators needed -for supporting the common image classification models InceptionV3 and MobileNet. +Key Point: The TensorFlow Lite binary is ~1MB when all 125+ supported operators +are linked (for 32-bit ARM builds), and less than 300KB when using only the +operators needed for supporting the common image classification models +InceptionV3 and MobileNet. ## Get started diff --git a/tensorflow/lite/g3doc/performance/best_practices.md b/tensorflow/lite/g3doc/performance/best_practices.md index 56093e63722..32f5ef485aa 100644 --- a/tensorflow/lite/g3doc/performance/best_practices.md +++ b/tensorflow/lite/g3doc/performance/best_practices.md @@ -52,7 +52,7 @@ operator is executed. Check out our Model optimization aims to create smaller models that are generally faster and more energy efficient, so that they can be deployed on mobile devices. There are -multiple optimization techniques suppored by TensorFlow Lite, such as +multiple optimization techniques supported by TensorFlow Lite, such as quantization. Check out our [model optimization docs](model_optimization.md) for details. diff --git a/tensorflow/lite/interpreter_builder.cc b/tensorflow/lite/interpreter_builder.cc index ef8f5a8773a..5d7807cd291 100644 --- a/tensorflow/lite/interpreter_builder.cc +++ b/tensorflow/lite/interpreter_builder.cc @@ -97,23 +97,25 @@ TfLiteStatus ParseSparseIndexVector(const DimensionMetadata* src, const char* kEmptyTensorName = ""; -// Normally we'd use ABSL_HAVE_ATTRIBUTE_WEAK and ABSL_ATTRIBUTE_WEAK, but -// we avoid the absl dependency for binary size reasons. -#ifdef __has_attribute -#define TFLITE_HAS_ATTRIBUTE(x) __has_attribute(x) -#else -#define TFLITE_HAS_ATTRIBUTE(x) 0 -#endif +#if TFLITE_HAS_ATTRIBUTE_WEAK +// Using weak symbols to create a delegate allows automatic injection of the +// delegate simply by adding it as a dependency. -#if TFLITE_HAS_ATTRIBUTE(weak) || (defined(__GNUC__) && !defined(__clang__)) -// Using weak symbols for the flex delegate allows automatic injection of the -// delegate simply by adding it as a dependency. See also the strong override in +// For flex delegate, see also the strong override in // lite/delegates/flex/delegate.cc. -__attribute__((weak)) Interpreter::TfLiteDelegatePtr AcquireFlexDelegate() { +TFLITE_ATTRIBUTE_WEAK Interpreter::TfLiteDelegatePtr AcquireFlexDelegate() { + return Interpreter::TfLiteDelegatePtr(nullptr, [](TfLiteDelegate*) {}); +} + +// For XNNPACK delegate, see also the strong override in +// lite/enable_xnnpack_delegate.cc. +TFLITE_ATTRIBUTE_WEAK Interpreter::TfLiteDelegatePtr AcquireXNNPACKDelegate( + int num_threads) { return Interpreter::TfLiteDelegatePtr(nullptr, [](TfLiteDelegate*) {}); } #else Interpreter::TfLiteDelegatePtr (*AcquireFlexDelegate)() = nullptr; +Interpreter::TfLiteDelegatePtr (*AcquireXNNPACKDelegate)(int) = nullptr; #endif namespace impl { @@ -415,6 +417,7 @@ TfLiteStatus InterpreterBuilder::ParseTensors( return kEmptyTensorName; }; + num_fp32_tensors_ = 0; for (int i = 0; i < tensors->size(); ++i) { const auto* tensor = tensors->Get(i); std::vector dims = FlatBufferIntArrayToVector(tensor->shape()); @@ -425,6 +428,9 @@ TfLiteStatus InterpreterBuilder::ParseTensors( status = kTfLiteError; continue; } + if (type == kTfLiteFloat32) { + ++num_fp32_tensors_; + } auto get_readonly_data = [&](const char** buffer_data, size_t* buffer_size) { // TODO(aselle): Check what happens if we have an unspecified size @@ -507,12 +513,23 @@ TfLiteStatus InterpreterBuilder::ParseTensors( return status; } -TfLiteStatus InterpreterBuilder::ApplyDelegates(Interpreter* interpreter) { - // Apply Flex delegate if applicable. - if (!has_flex_op_ || AcquireFlexDelegate == nullptr) { - return kTfLiteOk; - } else if (auto flex_delegate = AcquireFlexDelegate()) { - return interpreter->ModifyGraphWithDelegate(std::move(flex_delegate)); +TfLiteStatus InterpreterBuilder::ApplyDelegates(Interpreter* interpreter, + int num_threads) { + // First, apply XNNPACK delegate if applicable. + if (AcquireXNNPACKDelegate && num_fp32_tensors_ > 0) { + if (auto xnnpack_delegate = AcquireXNNPACKDelegate(num_threads)) { + // The execution will fall back to default implementation if the XNNPACK + // delegate fails to be applied. Therefore, we ignore the return status + // here and let it fall through the rest of the code. + interpreter->ModifyGraphWithDelegate(std::move(xnnpack_delegate)); + } + } + + // Secondly, apply Flex delegate if applicable. + if (has_flex_op_ && AcquireFlexDelegate) { + if (auto flex_delegate = AcquireFlexDelegate()) { + return interpreter->ModifyGraphWithDelegate(std::move(flex_delegate)); + } } return kTfLiteOk; @@ -625,7 +642,7 @@ TfLiteStatus InterpreterBuilder::operator()( modified_subgraph->SetVariables(std::move(variables)); } - if (ApplyDelegates(interpreter->get()) != kTfLiteOk) + if (ApplyDelegates(interpreter->get(), num_threads) != kTfLiteOk) return cleanup_and_error(); return kTfLiteOk; diff --git a/tensorflow/lite/interpreter_builder.h b/tensorflow/lite/interpreter_builder.h index 1d150d6f1d4..1b8ae5a8e68 100644 --- a/tensorflow/lite/interpreter_builder.h +++ b/tensorflow/lite/interpreter_builder.h @@ -78,7 +78,7 @@ class InterpreterBuilder { const flatbuffers::Vector>* buffers, const flatbuffers::Vector>* tensors, Subgraph* subgraph); - TfLiteStatus ApplyDelegates(Interpreter* interpreter); + TfLiteStatus ApplyDelegates(Interpreter* interpreter, int num_threads); TfLiteStatus ParseQuantization(const QuantizationParameters* src_quantization, TfLiteQuantization* quantization, const std::vector& dims); @@ -95,6 +95,7 @@ class InterpreterBuilder { const Allocation* allocation_ = nullptr; bool has_flex_op_ = false; + int num_fp32_tensors_ = 0; }; } // namespace impl diff --git a/tensorflow/lite/java/demo/app/build.gradle b/tensorflow/lite/java/demo/app/build.gradle index fca18430fa5..41bbf38fedb 100644 --- a/tensorflow/lite/java/demo/app/build.gradle +++ b/tensorflow/lite/java/demo/app/build.gradle @@ -53,8 +53,8 @@ dependencies { implementation 'com.android.support:support-v13:25.2.0' // Build off of nightly TensorFlow Lite - implementation 'org.tensorflow:tensorflow-lite:0.0.0-nightly' - implementation 'org.tensorflow:tensorflow-lite-gpu:0.0.0-nightly' + implementation('org.tensorflow:tensorflow-lite:0.0.0-nightly') { changing = true } + implementation('org.tensorflow:tensorflow-lite-gpu:0.0.0-nightly') { changing = true } // Use local TensorFlow library // implementation 'org.tensorflow:tensorflow-lite-local:0.0.0' } diff --git a/tensorflow/lite/kernels/BUILD b/tensorflow/lite/kernels/BUILD index 872d3c0822b..28eefb2895f 100644 --- a/tensorflow/lite/kernels/BUILD +++ b/tensorflow/lite/kernels/BUILD @@ -274,7 +274,7 @@ cc_library( # For now this unconditionally depends on both ruy and gemmlowp. # See the comment inside class CpuBackendContext on the # gemmlowp_context_ and ruy_context_ members. - "//tensorflow/lite/experimental/ruy:context", + "//tensorflow/lite/experimental/ruy/ruy:context", "@gemmlowp", "//tensorflow/lite:external_cpu_backend_context", ], @@ -295,8 +295,8 @@ cc_library( # We only need to depend on gemmlowp when tflite_with_ruy # is false, but putting these dependencies in a select() seems to # defeat copybara's rewriting rules. - "//tensorflow/lite/experimental/ruy:context", - "//tensorflow/lite/experimental/ruy:thread_pool", + "//tensorflow/lite/experimental/ruy/ruy:context", + "//tensorflow/lite/experimental/ruy/ruy:thread_pool", "@gemmlowp", ], ) @@ -334,9 +334,9 @@ cc_library( ":cpu_backend_threadpool", # Depend on ruy regardless of `tflite_with_ruy`. See the comment in # cpu_backend_gemm.h about why ruy is the generic path. - "//tensorflow/lite/experimental/ruy", - "//tensorflow/lite/experimental/ruy:path", - "//tensorflow/lite/experimental/ruy/profiler:instrumentation", + "//tensorflow/lite/experimental/ruy/ruy", + "//tensorflow/lite/experimental/ruy/ruy:path", + "//tensorflow/lite/experimental/ruy/ruy/profiler:instrumentation", # We only need to depend on gemmlowp and Eigen when tflite_with_ruy # is false, but putting these dependencies in a select() seems to # defeat copybara's rewriting rules. @@ -355,7 +355,7 @@ cc_test( "@com_google_googletest//:gtest", # ruy's reference path provides the reference implementation # that this test compares against. - "//tensorflow/lite/experimental/ruy", + "//tensorflow/lite/experimental/ruy/ruy", ], ) @@ -379,8 +379,8 @@ cc_library( copts = tflite_copts() + micro_copts(), deps = [ "//tensorflow/lite/c:common", + "//tensorflow/lite/kernels/internal:cppmath", "//tensorflow/lite/kernels/internal:quantization_util", - "//tensorflow/lite/kernels/internal:round", "@flatbuffers", ], ) @@ -596,7 +596,7 @@ cc_library( "//tensorflow/lite:context", "//tensorflow/lite/c:common", "//tensorflow/lite/experimental/kernels:hashtable_op_kernels", - "//tensorflow/lite/experimental/ruy/profiler:instrumentation", + "//tensorflow/lite/experimental/ruy/ruy/profiler:instrumentation", "//tensorflow/lite/kernels/internal:kernel_utils", "//tensorflow/lite/kernels/internal:tensor", "//third_party/fft2d:fft2d_headers", @@ -613,7 +613,7 @@ cc_library( ":cpu_backend_context", ":op_macros", "//tensorflow/lite/c:common", - "//tensorflow/lite/experimental/ruy/profiler:instrumentation", + "//tensorflow/lite/experimental/ruy/ruy/profiler:instrumentation", "//tensorflow/lite/kernels/internal:common", "//tensorflow/lite/kernels/internal:compatibility", "//tensorflow/lite/kernels/internal:kernel_utils", diff --git a/tensorflow/lite/kernels/activations.cc b/tensorflow/lite/kernels/activations.cc index eece297ffea..9fa503a9189 100644 --- a/tensorflow/lite/kernels/activations.cc +++ b/tensorflow/lite/kernels/activations.cc @@ -542,7 +542,12 @@ TfLiteStatus SoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); const TfLiteTensor* input = GetInput(context, node, 0); TfLiteTensor* output = GetOutput(context, node, 0); - TF_LITE_ENSURE_EQ(context, input->type, output->type); + if (output->type == kTfLiteInt16) { + TF_LITE_ENSURE(context, + input->type == kTfLiteInt8 || input->type == kTfLiteUInt8); + } else { + TF_LITE_ENSURE_EQ(context, input->type, output->type); + } TF_LITE_ENSURE(context, NumDimensions(input) >= 1); @@ -923,12 +928,12 @@ TfLiteStatus SoftmaxFloat(TfLiteContext* context, const TfLiteTensor* input, return kTfLiteOk; } -template +template TfLiteStatus SoftmaxQuantized(TfLiteContext* context, const TfLiteTensor* input, TfLiteTensor* output, SoftmaxOpData* data) { optimized_ops::Softmax(data->params, GetTensorShape(input), - GetTensorData(input), GetTensorShape(output), - GetTensorData(output)); + GetTensorData(input), GetTensorShape(output), + GetTensorData(output)); return kTfLiteOk; } @@ -944,16 +949,41 @@ TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) { return SoftmaxFloat(context, input, output, params); } case kTfLiteUInt8: { - return SoftmaxQuantized(context, input, output, data); + switch (output->type) { + case kTfLiteUInt8: + return SoftmaxQuantized(context, input, output, + data); + case kTfLiteInt16: + return SoftmaxQuantized(context, input, output, + data); + default: + TF_LITE_KERNEL_LOG(context, + "Only uint8_t and int16_t outputs are supported " + "with uint8_t inputs currently, got %s.", + TfLiteTypeGetName(output->type)); + return kTfLiteError; + } } case kTfLiteInt8: { - return SoftmaxQuantized(context, input, output, data); + switch (output->type) { + case kTfLiteInt8: + return SoftmaxQuantized(context, input, output, data); + case kTfLiteInt16: + return SoftmaxQuantized(context, input, output, + data); + default: + TF_LITE_KERNEL_LOG(context, + "Only int8_t and int16_t outputs are supported " + "with int8_t inputs currently, got %s.", + TfLiteTypeGetName(output->type)); + return kTfLiteError; + } } default: TF_LITE_KERNEL_LOG( context, - "Only float32, uint8_t and Int8_t are supported currently, got %s.", + "Only float32, uint8_t and int8_t are supported currently, got %s.", TfLiteTypeGetName(input->type)); return kTfLiteError; } diff --git a/tensorflow/lite/kernels/activations_test.cc b/tensorflow/lite/kernels/activations_test.cc index 781d3d03eb2..9f4f0b02795 100644 --- a/tensorflow/lite/kernels/activations_test.cc +++ b/tensorflow/lite/kernels/activations_test.cc @@ -78,14 +78,17 @@ class BaseActivationsOpModel : public SingleOpModel { } // A dedicated constructor for SOFTMAX, which does some options. - BaseActivationsOpModel(float softmax_beta, TensorData input) { + BaseActivationsOpModel(float softmax_beta, TensorData input, + TensorType output_type) { input_ = AddInput(input); - if (input.type == TensorType_UINT8) { - output_ = AddOutput({input.type, {}, 0, 0, 1. / 256}); - } else if (input.type == TensorType_INT8) { + if (output_type == TensorType_UINT8) { + output_ = AddOutput({TensorType_UINT8, {}, 0, 0, 1. / 256}); + } else if (output_type == TensorType_INT8) { output_ = AddOutput({TensorType_INT8, {}, 0, 0, 1. / 256, -128}); + } else if (output_type == TensorType_INT16) { + output_ = AddOutput({TensorType_INT16, {}, 0, 0, 1. / 32768, -16384}); } else { - output_ = AddOutput({input.type, {}}); + output_ = AddOutput({output_type, {}}); } SetBuiltinOp(BuiltinOperator_SOFTMAX, BuiltinOptions_SoftmaxOptions, CreateSoftmaxOptions(builder_, softmax_beta).Union()); @@ -919,8 +922,8 @@ TEST_P(LogisticOpTest, SigmoidInt16) { } TEST(FloatActivationsOpTest, Softmax4D) { - FloatActivationsOpModel m(0.1, - /*input=*/{TensorType_FLOAT32, {1, 2, 1, 4}}); + FloatActivationsOpModel m(0.1f, {TensorType_FLOAT32, {1, 2, 1, 4}}, + TensorType_FLOAT32); m.SetInput({ 0, -6, 2, 4, // depth = 0 3, -2, 10, 1, // depth = 1 @@ -932,8 +935,8 @@ TEST(FloatActivationsOpTest, Softmax4D) { }))); // Same input, but a different shape. - FloatActivationsOpModel m2(0.1, - /*input=*/{TensorType_FLOAT32, {4, 1, 1, 2}}); + FloatActivationsOpModel m2(0.1f, {TensorType_FLOAT32, {4, 1, 1, 2}}, + TensorType_FLOAT32); m2.SetInput({ 0, -6, // 2, 4, // @@ -950,9 +953,8 @@ TEST(FloatActivationsOpTest, Softmax4D) { } TEST(QuantizedActivationsOpTest, Softmax4DUint8) { - QuantizedActivationsOpModel m( - 0.1, - /*input=*/{TensorType_UINT8, {1, 2, 1, 4}, -10, 10}); + QuantizedActivationsOpModel m(0.1f, {TensorType_UINT8, {1, 2, 1, 4}, -10, 10}, + TensorType_UINT8); m.SetInput({ 0, -6, 2, 4, // depth = 0 3, -2, 10, 1, // depth = 1 @@ -968,8 +970,7 @@ TEST(QuantizedActivationsOpTest, Softmax4DUint8) { // Same input, but a different shape. QuantizedActivationsOpModel m2( - 0.1, - /*input=*/{TensorType_UINT8, {4, 1, 1, 2}, -10, 10}); + 0.1f, {TensorType_UINT8, {4, 1, 1, 2}, -10, 10}, TensorType_UINT8); m2.SetInput({ 0, -6, // 2, 4, // @@ -988,11 +989,48 @@ TEST(QuantizedActivationsOpTest, Softmax4DUint8) { kQuantizedTolerance))); } +TEST(QuantizedActivationsOpTest, Softmax4DUint8Int16) { + QuantizedActivationsOpModel m(0.1f, {TensorType_UINT8, {1, 2, 1, 4}, -10, 10}, + TensorType_INT16); + m.SetInput({ + 0, -6, 2, 4, // depth = 0 + 3, -2, 10, 1, // depth = 1 + }); + m.Invoke(); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear( + { + .23463, .12877, .28658, .35003, // + .22528, .13664, .45365, .18443, // + }, + kQuantizedTolerance))); + + // Same input, but a different shape. + QuantizedActivationsOpModel m2( + 0.1f, {TensorType_UINT8, {4, 1, 1, 2}, -10, 10}, TensorType_INT16); + m2.SetInput({ + 0, -6, // + 2, 4, // + 3, -2, // + 10, 1, // + }); + m2.Invoke(); + EXPECT_THAT(m2.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear( + { + 0.645656, 0.354344, // + 0.450166, 0.549834, // + 0.622459, 0.377541, // + 0.710949, 0.28905, // + }, + kQuantizedTolerance))); +} + // Test quantized softmax with int8 input and output. With the same input as in // QuantizedActivationsOpTest.Softmax1D, the dequantized output is identical. TEST(QuantizedActivationsOpTest, Softmax1DInt8) { - QuantizedActivationsOpModel m(0.1, - /*input=*/{TensorType_INT8, {8}, -10, 10}); + QuantizedActivationsOpModel m(0.1, {TensorType_INT8, {8}, -10, 10}, + TensorType_INT8); m.SetInput({0, -6, 2, 4, 3, -2, 10, 1}); m.Invoke(); EXPECT_THAT( @@ -1002,11 +1040,26 @@ TEST(QuantizedActivationsOpTest, Softmax1DInt8) { kQuantizedTolerance))); } +// Test quantized softmax with int8 input and int16 output. With the same input +// as in QuantizedActivationsOpTest.Softmax1D, the dequantized output is +// identical. +TEST(QuantizedActivationsOpTest, Softmax1DInt8Int16) { + QuantizedActivationsOpModel m(0.1f, {TensorType_INT8, {8}, -10, 10}, + TensorType_INT16); + m.SetInput({0, -6, 2, 4, 3, -2, 10, 1}); + m.Invoke(); + EXPECT_THAT( + m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear({0.09766, 0.05469, 0.12109, 0.14453, + 0.13281, 0.07813, 0.26563, 0.10938}, + kQuantizedTolerance))); +} + // Test quantized softmax with int8 input and output. With the same input as in // QuantizedActivationsOpTest.Softmax2D, the dequantized output is identical. TEST(QuantizedActivationsOpTest, Softmax2DInt8) { - QuantizedActivationsOpModel m(0.1, - /*input=*/{TensorType_INT8, {2, 4}, -10, 10}); + QuantizedActivationsOpModel m(0.1f, {TensorType_INT8, {2, 4}, -10, 10}, + TensorType_INT8); m.SetInput({ 0, -6, 2, 4, // 3, -2, 10, 1, // @@ -1021,8 +1074,87 @@ TEST(QuantizedActivationsOpTest, Softmax2DInt8) { kQuantizedTolerance))); // Same input, but a different shape. - QuantizedActivationsOpModel m2(0.1, - /*input=*/{TensorType_INT8, {4, 2}, -10, 10}); + QuantizedActivationsOpModel m2(0.1f, {TensorType_INT8, {4, 2}, -10, 10}, + TensorType_INT8); + m2.SetInput({ + 0, -6, // + 2, 4, // + 3, -2, // + 10, 1, // + }); + m2.Invoke(); + EXPECT_THAT(m2.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear( + { + 0.645656, 0.354344, // + 0.450166, 0.549834, // + 0.622459, 0.377541, // + 0.710949, 0.28905, // + }, + kQuantizedTolerance))); +} + +// Test quantized softmax with int8 input and int16 output. With the same input +// as in QuantizedActivationsOpTest.Softmax2D, the dequantized output is +// identical. +TEST(QuantizedActivationsOpTest, Softmax2DInt8Int16) { + QuantizedActivationsOpModel m(0.1f, {TensorType_INT8, {2, 4}, -10, 10}, + TensorType_INT16); + m.SetInput({ + 0, -6, 2, 4, // + 3, -2, 10, 1, // + }); + m.Invoke(); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear( + { + .23463, .12877, .28658, .35003, // + .22528, .13664, .45365, .18443, // + }, + kQuantizedTolerance))); + + // Same input, but a different shape. + QuantizedActivationsOpModel m2(0.1f, {TensorType_INT8, {4, 2}, -10, 10}, + TensorType_INT16); + m2.SetInput({ + 0, -6, // + 2, 4, // + 3, -2, // + 10, 1, // + }); + m2.Invoke(); + EXPECT_THAT(m2.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear( + { + 0.645656, 0.354344, // + 0.450166, 0.549834, // + 0.622459, 0.377541, // + 0.710949, 0.28905, // + }, + kQuantizedTolerance))); +} + +// Test quantized softmax with int8 input and output. With the same input as in +// QuantizedActivationsOpTest.Softmax3D, the dequantized output is identical. +TEST(QuantizedActivationsOpTest, Softmax3DInt8) { + QuantizedActivationsOpModel m(0.1f, {TensorType_INT8, {1, 2, 4}, -10, 10}, + TensorType_INT8); + m.SetInput({ + 0, -6, 2, 4, // depth = 0 + 3, -2, 10, 1, // depth = 1 + }); + m.Invoke(); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear( + { + .23463, .12877, .28658, .35003, // + .22528, .13664, .45365, .18443, // + }, + kQuantizedTolerance))); + + // Same input, but a different shape. + QuantizedActivationsOpModel m2(0.1f, {TensorType_INT8, {4, 1, 2}, -10, 10}, + TensorType_INT8); m2.SetInput({ 0, -6, // 2, 4, // @@ -1043,15 +1175,57 @@ TEST(QuantizedActivationsOpTest, Softmax2DInt8) { // Test quantized softmax with int8 input and output. With the same input as in // QuantizedActivationsOpTest.Softmax3D, the dequantized output is identical. -TEST(QuantizedActivationsOpTest, Softmax3DInt8) { - QuantizedActivationsOpModel m( - 0.1, - /*input=*/{TensorType_INT8, {1, 2, 4}, -10, 10}); +TEST(QuantizedActivationsOpTest, Softmax3DInt8Int16) { + QuantizedActivationsOpModel m(0.1f, {TensorType_INT8, {1, 2, 4}, -10, 10}, + TensorType_INT16); m.SetInput({ 0, -6, 2, 4, // depth = 0 3, -2, 10, 1, // depth = 1 }); m.Invoke(); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear( + { + .23463, .12877, .28658, .35003, // + .22528, .13664, .45365, .18443, // + }, + kQuantizedTolerance))); + + // Same input, but a different shape. + QuantizedActivationsOpModel m2(0.1f, {TensorType_INT8, {4, 1, 2}, -10, 10}, + TensorType_INT16); + m2.SetInput({ + 0, -6, // + 2, 4, // + 3, -2, // + 10, 1, // + }); + m2.Invoke(); + EXPECT_THAT(m2.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear( + { + 0.645656, 0.354344, // + 0.450166, 0.549834, // + 0.622459, 0.377541, // + 0.710949, 0.28905, // + }, + kQuantizedTolerance))); +} + +// Test quantized softmax with int8 input and output. With the same input as in +// QuantizedActivationsOpTest.Softmax4D, the dequantized output is identical. +TEST(QuantizedActivationsOpTest, Softmax4DInt8) { + QuantizedActivationsOpModel m(0.1f, {TensorType_INT8, {1, 2, 1, 4}, -10, 10}, + TensorType_INT8); + m.SetInput({ + 0, -6, 2, 4, // depth = 0 + 3, -2, 10, 1, // depth = 1 + }); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({ + -68, -95, -54, -38, // + -70, -93, -12, -81, // + })); EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear( { @@ -1061,9 +1235,8 @@ TEST(QuantizedActivationsOpTest, Softmax3DInt8) { kQuantizedTolerance))); // Same input, but a different shape. - QuantizedActivationsOpModel m2( - 0.1, - /*input=*/{TensorType_INT8, {4, 1, 2}, -10, 10}); + QuantizedActivationsOpModel m2(0.1f, {TensorType_INT8, {4, 1, 1, 2}, -10, 10}, + TensorType_INT8); m2.SetInput({ 0, -6, // 2, 4, // @@ -1084,20 +1257,15 @@ TEST(QuantizedActivationsOpTest, Softmax3DInt8) { // Test quantized softmax with int8 input and output. With the same input as in // QuantizedActivationsOpTest.Softmax4D, the dequantized output is identical. -TEST(QuantizedActivationsOpTest, Softmax4DInt8) { - QuantizedActivationsOpModel m( - 0.1, - /*input=*/{TensorType_INT8, {1, 2, 1, 4}, -10, 10}); +TEST(QuantizedActivationsOpTest, Softmax4DInt8Int16) { + QuantizedActivationsOpModel m(0.1f, {TensorType_INT8, {1, 2, 1, 4}, -10, 10}, + TensorType_INT16); m.SetInput({ 0, -6, 2, 4, // depth = 0 3, -2, 10, 1, // depth = 1 }); m.Invoke(); - EXPECT_THAT(m.GetOutput(), ElementsAreArray({ - -68, -95, -54, -38, // - -70, -93, -12, -81, // - })); - EXPECT_THAT(m.GetDequantizedOutput(), + EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear( { .23463, .12877, .28658, .35003, // @@ -1106,9 +1274,8 @@ TEST(QuantizedActivationsOpTest, Softmax4DInt8) { kQuantizedTolerance))); // Same input, but a different shape. - QuantizedActivationsOpModel m2( - 0.1, - /*input=*/{TensorType_INT8, {4, 1, 1, 2}, -10, 10}); + QuantizedActivationsOpModel m2(0.1f, {TensorType_INT8, {4, 1, 1, 2}, -10, 10}, + TensorType_INT16); m2.SetInput({ 0, -6, // 2, 4, // @@ -1116,7 +1283,7 @@ TEST(QuantizedActivationsOpTest, Softmax4DInt8) { 10, 1, // }); m2.Invoke(); - EXPECT_THAT(m2.GetDequantizedOutput(), + EXPECT_THAT(m2.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear( { 0.645656, 0.354344, // @@ -1128,8 +1295,8 @@ TEST(QuantizedActivationsOpTest, Softmax4DInt8) { } TEST(FloatActivationsOpTest, Softmax3D) { - FloatActivationsOpModel m(0.1, - /*input=*/{TensorType_FLOAT32, {1, 2, 4}}); + FloatActivationsOpModel m(0.1f, {TensorType_FLOAT32, {1, 2, 4}}, + TensorType_FLOAT32); m.SetInput({ 0, -6, 2, 4, // depth = 0 3, -2, 10, 1, // depth = 1 @@ -1141,8 +1308,8 @@ TEST(FloatActivationsOpTest, Softmax3D) { }))); // Same input, but a different shape. - FloatActivationsOpModel m2(0.1, - /*input=*/{TensorType_FLOAT32, {4, 1, 2}}); + FloatActivationsOpModel m2(0.1f, {TensorType_FLOAT32, {4, 1, 2}}, + TensorType_FLOAT32); m2.SetInput({ 0, -6, // 2, 4, // @@ -1159,9 +1326,8 @@ TEST(FloatActivationsOpTest, Softmax3D) { } TEST(QuantizedActivationsOpTest, Softmax3DUint8) { - QuantizedActivationsOpModel m( - 0.1, - /*input=*/{TensorType_UINT8, {1, 2, 4}, -10, 10}); + QuantizedActivationsOpModel m(0.1f, {TensorType_UINT8, {1, 2, 4}, -10, 10}, + TensorType_UINT8); m.SetInput({ 0, -6, 2, 4, // depth = 0 3, -2, 10, 1, // depth = 1 @@ -1176,9 +1342,8 @@ TEST(QuantizedActivationsOpTest, Softmax3DUint8) { kQuantizedTolerance))); // Same input, but a different shape. - QuantizedActivationsOpModel m2( - 0.1, - /*input=*/{TensorType_UINT8, {4, 1, 2}, -10, 10}); + QuantizedActivationsOpModel m2(0.1f, {TensorType_UINT8, {4, 1, 2}, -10, 10}, + TensorType_UINT8); m2.SetInput({ 0, -6, // 2, 4, // @@ -1197,9 +1362,46 @@ TEST(QuantizedActivationsOpTest, Softmax3DUint8) { kQuantizedTolerance))); } +TEST(QuantizedActivationsOpTest, Softmax3DUint8Int16) { + QuantizedActivationsOpModel m(0.1f, {TensorType_UINT8, {1, 2, 4}, -10, 10}, + TensorType_INT16); + m.SetInput({ + 0, -6, 2, 4, // depth = 0 + 3, -2, 10, 1, // depth = 1 + }); + m.Invoke(); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear( + { + .23463, .12877, .28658, .35003, // + .22528, .13664, .45365, .18443, // + }, + kQuantizedTolerance))); + + // Same input, but a different shape. + QuantizedActivationsOpModel m2(0.1f, {TensorType_UINT8, {4, 1, 2}, -10, 10}, + TensorType_INT16); + m2.SetInput({ + 0, -6, // + 2, 4, // + 3, -2, // + 10, 1, // + }); + m2.Invoke(); + EXPECT_THAT(m2.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear( + { + 0.645656, 0.354344, // + 0.450166, 0.549834, // + 0.622459, 0.377541, // + 0.710949, 0.28905, // + }, + kQuantizedTolerance))); +} + TEST(FloatActivationsOpTest, Softmax1D) { - FloatActivationsOpModel m(0.1, - /*input=*/{TensorType_FLOAT32, {8}}); + FloatActivationsOpModel m(0.1f, {TensorType_FLOAT32, {8}}, + TensorType_FLOAT32); m.SetInput({0, -6, 2, 4, 3, -2, 10, 1}); m.Invoke(); EXPECT_THAT( @@ -1209,8 +1411,8 @@ TEST(FloatActivationsOpTest, Softmax1D) { } TEST(QuantizedActivationsOpTest, Softmax1DUint8) { - QuantizedActivationsOpModel m(0.1, - /*input=*/{TensorType_UINT8, {8}, -10, 10}); + QuantizedActivationsOpModel m(0.1f, {TensorType_UINT8, {8}, -10, 10}, + TensorType_UINT8); m.SetInput({0, -6, 2, 4, 3, -2, 10, 1}); m.Invoke(); EXPECT_THAT( @@ -1220,9 +1422,21 @@ TEST(QuantizedActivationsOpTest, Softmax1DUint8) { kQuantizedTolerance))); } +TEST(QuantizedActivationsOpTest, Softmax1DUint8Int16) { + QuantizedActivationsOpModel m(0.1f, {TensorType_UINT8, {8}, -10, 10}, + TensorType_INT16); + m.SetInput({0, -6, 2, 4, 3, -2, 10, 1}); + m.Invoke(); + EXPECT_THAT( + m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear({0.09766, 0.05469, 0.12109, 0.14453, + 0.13281, 0.07813, 0.26563, 0.10938}, + kQuantizedTolerance))); +} + TEST(FloatActivationsOpTest, Softmax2D) { - FloatActivationsOpModel m(0.1, - /*input=*/{TensorType_FLOAT32, {2, 4}}); + FloatActivationsOpModel m(0.1f, {TensorType_FLOAT32, {2, 4}}, + TensorType_FLOAT32); m.SetInput({ 0, -6, 2, 4, // 3, -2, 10, 1, // @@ -1234,8 +1448,8 @@ TEST(FloatActivationsOpTest, Softmax2D) { }))); // Same input, but a different shape. - FloatActivationsOpModel m2(0.1, - /*input=*/{TensorType_FLOAT32, {4, 2}}); + FloatActivationsOpModel m2(0.1f, {TensorType_FLOAT32, {4, 2}}, + TensorType_FLOAT32); m2.SetInput({ 0, -6, // 2, 4, // @@ -1252,8 +1466,8 @@ TEST(FloatActivationsOpTest, Softmax2D) { } TEST(FloatActivationsOpTest, Softmax2DMultithreading) { - FloatActivationsOpModel m(0.1, - /*input=*/{TensorType_FLOAT32, {16, 4}}); + FloatActivationsOpModel m(0.1f, {TensorType_FLOAT32, {16, 4}}, + TensorType_FLOAT32); m.SetInput({ 0, -6, 2, 4, // Thread 1. 3, -2, 10, 1, // @@ -1294,8 +1508,8 @@ TEST(FloatActivationsOpTest, Softmax2DMultithreading) { }))); // Same input, but a different shape. - FloatActivationsOpModel m2(0.1, - /*input=*/{TensorType_FLOAT32, {16, 2}}); + FloatActivationsOpModel m2(0.1f, {TensorType_FLOAT32, {16, 2}}, + TensorType_FLOAT32); m2.SetInput({ 0, -6, // Thread 1 2, 4, // @@ -1337,8 +1551,8 @@ TEST(FloatActivationsOpTest, Softmax2DMultithreading) { } TEST(QuantizedActivationsOpTest, Softmax2DUint8) { - QuantizedActivationsOpModel m(0.1, - /*input=*/{TensorType_UINT8, {2, 4}, -10, 10}); + QuantizedActivationsOpModel m(0.1f, {TensorType_UINT8, {2, 4}, -10, 10}, + TensorType_UINT8); m.SetInput({ 0, -6, 2, 4, // 3, -2, 10, 1, // @@ -1353,8 +1567,8 @@ TEST(QuantizedActivationsOpTest, Softmax2DUint8) { kQuantizedTolerance))); // Same input, but a different shape. - QuantizedActivationsOpModel m2(0.1, - /*input=*/{TensorType_UINT8, {4, 2}, -10, 10}); + QuantizedActivationsOpModel m2(0.1f, {TensorType_UINT8, {4, 2}, -10, 10}, + TensorType_UINT8); m2.SetInput({ 0, -6, // 2, 4, // @@ -1373,6 +1587,43 @@ TEST(QuantizedActivationsOpTest, Softmax2DUint8) { kQuantizedTolerance))); } +TEST(QuantizedActivationsOpTest, Softmax2DUint8Int16) { + QuantizedActivationsOpModel m(0.1f, {TensorType_UINT8, {2, 4}, -10, 10}, + TensorType_INT16); + m.SetInput({ + 0, -6, 2, 4, // + 3, -2, 10, 1, // + }); + m.Invoke(); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear( + { + .23463, .12877, .28658, .35003, // + .22528, .13664, .45365, .18443, // + }, + kQuantizedTolerance))); + + // Same input, but a different shape. + QuantizedActivationsOpModel m2(0.1f, {TensorType_UINT8, {4, 2}, -10, 10}, + TensorType_INT16); + m2.SetInput({ + 0, -6, // + 2, 4, // + 3, -2, // + 10, 1, // + }); + m2.Invoke(); + EXPECT_THAT(m2.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear( + { + 0.645656, 0.354344, // + 0.450166, 0.549834, // + 0.622459, 0.377541, // + 0.710949, 0.28905, // + }, + kQuantizedTolerance))); +} + // This contains the same test values as the Softmax test, but reference answer // generated via the following snippet of python: // logits1 = tf.constant([[0, -6, 2, 4],[3, -2, 10, 1]], dtype=tf.float32) diff --git a/tensorflow/lite/kernels/cpu_backend_context.cc b/tensorflow/lite/kernels/cpu_backend_context.cc index 6ad73fabf10..dfeea5d0a64 100644 --- a/tensorflow/lite/kernels/cpu_backend_context.cc +++ b/tensorflow/lite/kernels/cpu_backend_context.cc @@ -16,7 +16,7 @@ limitations under the License. #include "tensorflow/lite/kernels/cpu_backend_context.h" #include "public/gemmlowp.h" -#include "tensorflow/lite/experimental/ruy/context.h" +#include "tensorflow/lite/experimental/ruy/ruy/context.h" #include "tensorflow/lite/kernels/op_macros.h" namespace { diff --git a/tensorflow/lite/kernels/cpu_backend_context.h b/tensorflow/lite/kernels/cpu_backend_context.h index 2d3d76deaea..eafae75fc47 100644 --- a/tensorflow/lite/kernels/cpu_backend_context.h +++ b/tensorflow/lite/kernels/cpu_backend_context.h @@ -19,7 +19,7 @@ limitations under the License. #include #include "public/gemmlowp.h" -#include "tensorflow/lite/experimental/ruy/context.h" +#include "tensorflow/lite/experimental/ruy/ruy/context.h" #include "tensorflow/lite/external_cpu_backend_context.h" namespace tflite { diff --git a/tensorflow/lite/kernels/cpu_backend_gemm_custom_gemv.h b/tensorflow/lite/kernels/cpu_backend_gemm_custom_gemv.h index b19d5bc990b..6fde100a4bf 100644 --- a/tensorflow/lite/kernels/cpu_backend_gemm_custom_gemv.h +++ b/tensorflow/lite/kernels/cpu_backend_gemm_custom_gemv.h @@ -35,7 +35,7 @@ limitations under the License. #include #include -#include "tensorflow/lite/experimental/ruy/profiler/instrumentation.h" +#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" #include "tensorflow/lite/kernels/cpu_backend_context.h" #include "tensorflow/lite/kernels/cpu_backend_gemm_params.h" #include "tensorflow/lite/kernels/cpu_backend_threadpool.h" diff --git a/tensorflow/lite/kernels/cpu_backend_gemm_gemmlowp.h b/tensorflow/lite/kernels/cpu_backend_gemm_gemmlowp.h index a73149c50fa..253c035688f 100644 --- a/tensorflow/lite/kernels/cpu_backend_gemm_gemmlowp.h +++ b/tensorflow/lite/kernels/cpu_backend_gemm_gemmlowp.h @@ -22,7 +22,7 @@ limitations under the License. #include #include "public/gemmlowp.h" -#include "tensorflow/lite/experimental/ruy/ruy.h" +#include "tensorflow/lite/experimental/ruy/ruy/ruy.h" #include "tensorflow/lite/kernels/cpu_backend_context.h" #include "tensorflow/lite/kernels/cpu_backend_gemm_params.h" #include "tensorflow/lite/kernels/cpu_backend_gemm_ruy.h" diff --git a/tensorflow/lite/kernels/cpu_backend_gemm_ruy.h b/tensorflow/lite/kernels/cpu_backend_gemm_ruy.h index 4e1158bc0cc..c02dce2b773 100644 --- a/tensorflow/lite/kernels/cpu_backend_gemm_ruy.h +++ b/tensorflow/lite/kernels/cpu_backend_gemm_ruy.h @@ -16,8 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_RUY_H_ #define TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_RUY_H_ -#include "tensorflow/lite/experimental/ruy/path.h" -#include "tensorflow/lite/experimental/ruy/ruy.h" +#include "tensorflow/lite/experimental/ruy/ruy/path.h" +#include "tensorflow/lite/experimental/ruy/ruy/ruy.h" #include "tensorflow/lite/kernels/cpu_backend_context.h" #include "tensorflow/lite/kernels/cpu_backend_gemm_params.h" diff --git a/tensorflow/lite/kernels/cpu_backend_gemm_test.cc b/tensorflow/lite/kernels/cpu_backend_gemm_test.cc index d545b80f97f..d26df809c97 100644 --- a/tensorflow/lite/kernels/cpu_backend_gemm_test.cc +++ b/tensorflow/lite/kernels/cpu_backend_gemm_test.cc @@ -25,7 +25,7 @@ limitations under the License. #include #include -#include "tensorflow/lite/experimental/ruy/ruy.h" +#include "tensorflow/lite/experimental/ruy/ruy/ruy.h" #include "tensorflow/lite/kernels/cpu_backend_context.h" #include "tensorflow/lite/kernels/cpu_backend_gemm_params.h" diff --git a/tensorflow/lite/kernels/cpu_backend_threadpool.h b/tensorflow/lite/kernels/cpu_backend_threadpool.h index d1e1d14c3c2..b924826a07c 100644 --- a/tensorflow/lite/kernels/cpu_backend_threadpool.h +++ b/tensorflow/lite/kernels/cpu_backend_threadpool.h @@ -20,8 +20,8 @@ limitations under the License. #include "tensorflow/lite/kernels/internal/compatibility.h" #ifdef TFLITE_WITH_RUY -#include "tensorflow/lite/experimental/ruy/context.h" -#include "tensorflow/lite/experimental/ruy/thread_pool.h" +#include "tensorflow/lite/experimental/ruy/ruy/context.h" +#include "tensorflow/lite/experimental/ruy/ruy/thread_pool.h" #else #include "public/gemmlowp.h" #endif diff --git a/tensorflow/lite/kernels/internal/BUILD b/tensorflow/lite/kernels/internal/BUILD index e7612e39c71..952073ef02a 100644 --- a/tensorflow/lite/kernels/internal/BUILD +++ b/tensorflow/lite/kernels/internal/BUILD @@ -243,13 +243,13 @@ cc_library( ":strided_slice_logic", ":types", ":reference_base", - ":round", + ":cppmath", ":tensor", ":tensor_utils", ":transpose_utils", "//third_party/eigen3", "@gemmlowp//:fixedpoint", - "//tensorflow/lite/experimental/ruy/profiler:instrumentation", + "//tensorflow/lite/experimental/ruy/ruy/profiler:instrumentation", "//tensorflow/lite/c:common", "//tensorflow/lite/kernels:cpu_backend_context", "//tensorflow/lite/kernels:cpu_backend_threadpool", @@ -294,14 +294,14 @@ cc_library( ":types", ":legacy_types", ":legacy_reference_base", - ":round", + ":cppmath", "//third_party/eigen3", "@gemmlowp", "//tensorflow/lite/c:common", "//tensorflow/lite/kernels:cpu_backend_context", "//tensorflow/lite/kernels:cpu_backend_threadpool", "//tensorflow/lite/kernels:cpu_backend_gemm", - "//tensorflow/lite/experimental/ruy/profiler:instrumentation", + "//tensorflow/lite/experimental/ruy/ruy/profiler:instrumentation", ] + select({ ":haswell": tflite_deps_intel, ":ios_x86_64": tflite_deps_intel, @@ -346,9 +346,10 @@ cc_test( ) cc_library( - name = "round", + name = "cppmath", srcs = [], - hdrs = ["round.h"], + hdrs = ["cppmath.h"], + build_for_embedded = True, copts = tflite_copts(), ) @@ -359,7 +360,7 @@ cc_library( copts = tflite_copts() + micro_copts(), deps = [ ":compatibility", - ":round", + ":cppmath", ":types", "//tensorflow/lite/kernels:op_macros", ], @@ -467,7 +468,7 @@ cc_library( ":common", ":compatibility", ":quantization_util", - ":round", + ":cppmath", ":strided_slice_logic", ":tensor", ":tensor_utils", @@ -476,7 +477,7 @@ cc_library( "//third_party/eigen3", "//tensorflow/lite/c:common", "//tensorflow/lite/kernels:op_macros", - "//tensorflow/lite/experimental/ruy/profiler:instrumentation", + "//tensorflow/lite/experimental/ruy/ruy/profiler:instrumentation", "//tensorflow/lite/tools/optimize/sparsity:format_converter", ] + select({ ":haswell": tflite_deps_intel, @@ -532,7 +533,7 @@ cc_library( ":common", ":compatibility", ":quantization_util", - ":round", + ":cppmath", ":strided_slice_logic", ":legacy_types", ":tensor", @@ -541,7 +542,7 @@ cc_library( "@gemmlowp", "//tensorflow/lite/c:common", "//tensorflow/lite/kernels:op_macros", - "//tensorflow/lite/experimental/ruy/profiler:instrumentation", + "//tensorflow/lite/experimental/ruy/ruy/profiler:instrumentation", "//tensorflow/lite/tools/optimize/sparsity:format_converter", ] + select({ ":haswell": tflite_deps_intel, @@ -600,7 +601,7 @@ cc_library( deps = [ ":common", ":compatibility", - ":round", + ":cppmath", "//tensorflow/lite:minimal_logging", "//tensorflow/lite/c:common", "//tensorflow/lite/kernels:cpu_backend_context", @@ -621,12 +622,12 @@ cc_library( deps = [ ":common", ":compatibility", + ":cppmath", ":cpu_check", ":portable_tensor_utils", - ":round", "//tensorflow/lite/c:common", - "//tensorflow/lite/experimental/ruy", - "//tensorflow/lite/experimental/ruy:detect_arm", + "//tensorflow/lite/experimental/ruy/ruy", + "//tensorflow/lite/experimental/ruy/ruy:detect_arm", "//tensorflow/lite/kernels:cpu_backend_context", "//tensorflow/lite/kernels:cpu_backend_gemm", ], @@ -821,7 +822,7 @@ cc_test( ":reference_base", ":test_util", ":types", - "//tensorflow/lite/experimental/ruy:context", + "//tensorflow/lite/experimental/ruy/ruy:context", "//tensorflow/lite/kernels:cpu_backend_context", "@com_google_absl//absl/strings", "@com_google_googletest//:gtest_main", diff --git a/tensorflow/lite/kernels/internal/round.h b/tensorflow/lite/kernels/internal/cppmath.h similarity index 54% rename from tensorflow/lite/kernels/internal/round.h rename to tensorflow/lite/kernels/internal/cppmath.h index d102d379339..611a8d2588a 100644 --- a/tensorflow/lite/kernels/internal/round.h +++ b/tensorflow/lite/kernels/internal/cppmath.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -12,29 +12,28 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_ROUND_H_ -#define TENSORFLOW_LITE_KERNELS_INTERNAL_ROUND_H_ +#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_CPPMATH_H_ +#define TENSORFLOW_LITE_KERNELS_INTERNAL_CPPMATH_H_ #include namespace tflite { -// TODO(aselle): See if we can do this only on jdk. Also mikecase, check -// if you need this for java host build. -#if defined(TF_LITE_USE_GLOBAL_ROUND) || \ +#if defined(TF_LITE_USE_GLOBAL_CMATH_FUNCTIONS) || \ (defined(__ANDROID__) && !defined(__NDK_MAJOR__)) || defined(ARDUINO) -template -inline float TfLiteRound(const float x) { - return ::round(x); -} -inline double TfLiteRound(const double x) { return ::round(x); } +#define TF_LITE_GLOBAL_STD_PREFIX #else -template -inline T TfLiteRound(const T x) { - return std::round(x); -} +#define TF_LITE_GLOBAL_STD_PREFIX std #endif +#define DECLARE_STD_GLOBAL_SWITCH1(tf_name, std_name) \ + template \ + inline T tf_name(const T x) { \ + return TF_LITE_GLOBAL_STD_PREFIX::std_name(x); \ + } + +DECLARE_STD_GLOBAL_SWITCH1(TfLiteRound, round); + } // namespace tflite -#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_ROUND_H_ +#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_CPPMATH_H_ diff --git a/tensorflow/lite/kernels/internal/depthwiseconv_quantized_test.cc b/tensorflow/lite/kernels/internal/depthwiseconv_quantized_test.cc index b35a66d30f2..4f8ceb33595 100644 --- a/tensorflow/lite/kernels/internal/depthwiseconv_quantized_test.cc +++ b/tensorflow/lite/kernels/internal/depthwiseconv_quantized_test.cc @@ -24,7 +24,7 @@ limitations under the License. #include #include -#include "tensorflow/lite/experimental/ruy/context.h" +#include "tensorflow/lite/experimental/ruy/ruy/context.h" #include "tensorflow/lite/kernels/cpu_backend_context.h" #include "tensorflow/lite/kernels/internal/optimized/cpu_check.h" #include "tensorflow/lite/kernels/internal/optimized/depthwiseconv_3x3_filter_common.h" diff --git a/tensorflow/lite/kernels/internal/optimized/batch_matmul.h b/tensorflow/lite/kernels/internal/optimized/batch_matmul.h index 03cef848026..cd1241ed225 100644 --- a/tensorflow/lite/kernels/internal/optimized/batch_matmul.h +++ b/tensorflow/lite/kernels/internal/optimized/batch_matmul.h @@ -19,7 +19,6 @@ limitations under the License. #include "tensorflow/lite/kernels/cpu_backend_gemm.h" #include "tensorflow/lite/kernels/cpu_backend_gemm_params.h" #include "tensorflow/lite/kernels/internal/common.h" -#include "tensorflow/lite/kernels/internal/round.h" #include "tensorflow/lite/kernels/internal/types.h" namespace tflite { diff --git a/tensorflow/lite/kernels/internal/optimized/depthwiseconv_3x3_filter_common.h b/tensorflow/lite/kernels/internal/optimized/depthwiseconv_3x3_filter_common.h index 171475df107..2768344696d 100644 --- a/tensorflow/lite/kernels/internal/optimized/depthwiseconv_3x3_filter_common.h +++ b/tensorflow/lite/kernels/internal/optimized/depthwiseconv_3x3_filter_common.h @@ -15,7 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_3X3_FILTER_COMMON_H_ #define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_3X3_FILTER_COMMON_H_ -#include "tensorflow/lite/experimental/ruy/profiler/instrumentation.h" +#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" #include "tensorflow/lite/kernels/internal/optimized/cpu_check.h" #include "tensorflow/lite/kernels/internal/reference/depthwiseconv_uint8.h" #include "tensorflow/lite/kernels/internal/types.h" diff --git a/tensorflow/lite/kernels/internal/optimized/depthwiseconv_float.h b/tensorflow/lite/kernels/internal/optimized/depthwiseconv_float.h index 54de6304ccc..af763377763 100644 --- a/tensorflow/lite/kernels/internal/optimized/depthwiseconv_float.h +++ b/tensorflow/lite/kernels/internal/optimized/depthwiseconv_float.h @@ -15,7 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_FLOAT_H_ #define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_FLOAT_H_ -#include "tensorflow/lite/experimental/ruy/profiler/instrumentation.h" +#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" #include "tensorflow/lite/kernels/internal/optimized/cpu_check.h" #include "tensorflow/lite/kernels/internal/types.h" diff --git a/tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8.h b/tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8.h index a758929a25b..1b86d91fb42 100644 --- a/tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8.h +++ b/tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8.h @@ -17,7 +17,7 @@ limitations under the License. #include -#include "tensorflow/lite/experimental/ruy/profiler/instrumentation.h" +#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" #include "tensorflow/lite/kernels/internal/optimized/cpu_check.h" #include "tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h" #include "tensorflow/lite/kernels/internal/reference/depthwiseconv_uint8.h" diff --git a/tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h b/tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h index 6fd101d1ca6..293fd4248f2 100644 --- a/tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h +++ b/tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h @@ -17,7 +17,7 @@ limitations under the License. #include -#include "tensorflow/lite/experimental/ruy/profiler/instrumentation.h" +#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" #include "tensorflow/lite/kernels/internal/optimized/cpu_check.h" #include "tensorflow/lite/kernels/internal/optimized/depthwiseconv_3x3_filter_common.h" #include "tensorflow/lite/kernels/internal/reference/depthwiseconv_uint8.h" diff --git a/tensorflow/lite/kernels/internal/optimized/im2col_utils.h b/tensorflow/lite/kernels/internal/optimized/im2col_utils.h index fcf9272689f..e3a9b9acdc6 100644 --- a/tensorflow/lite/kernels/internal/optimized/im2col_utils.h +++ b/tensorflow/lite/kernels/internal/optimized/im2col_utils.h @@ -15,7 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_IM2COL_UTILS_H_ #define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_IM2COL_UTILS_H_ -#include "tensorflow/lite/experimental/ruy/profiler/instrumentation.h" +#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" #include "tensorflow/lite/kernels/internal/types.h" namespace tflite { diff --git a/tensorflow/lite/kernels/internal/optimized/integer_ops/add.h b/tensorflow/lite/kernels/internal/optimized/integer_ops/add.h index c4537bbd3a5..8db98cf1bdc 100644 --- a/tensorflow/lite/kernels/internal/optimized/integer_ops/add.h +++ b/tensorflow/lite/kernels/internal/optimized/integer_ops/add.h @@ -15,7 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_INTEGER_OPS_ADD_H_ #define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_INTEGER_OPS_ADD_H_ -#include "tensorflow/lite/experimental/ruy/profiler/instrumentation.h" +#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" #include "tensorflow/lite/kernels/internal/common.h" #include "tensorflow/lite/kernels/internal/optimized/cpu_check.h" #include "tensorflow/lite/kernels/internal/reference/integer_ops/add.h" diff --git a/tensorflow/lite/kernels/internal/optimized/integer_ops/conv.h b/tensorflow/lite/kernels/internal/optimized/integer_ops/conv.h index 9f967070413..6c1abaeff82 100644 --- a/tensorflow/lite/kernels/internal/optimized/integer_ops/conv.h +++ b/tensorflow/lite/kernels/internal/optimized/integer_ops/conv.h @@ -15,7 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_INTEGER_OPS_CONV_H_ #define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_INTEGER_OPS_CONV_H_ -#include "tensorflow/lite/experimental/ruy/profiler/instrumentation.h" +#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" #include "tensorflow/lite/kernels/cpu_backend_context.h" #include "tensorflow/lite/kernels/cpu_backend_gemm.h" #include "tensorflow/lite/kernels/cpu_backend_gemm_params.h" diff --git a/tensorflow/lite/kernels/internal/optimized/integer_ops/depthwise_conv.h b/tensorflow/lite/kernels/internal/optimized/integer_ops/depthwise_conv.h index 2d0568fa4c8..d44cfabe3c3 100644 --- a/tensorflow/lite/kernels/internal/optimized/integer_ops/depthwise_conv.h +++ b/tensorflow/lite/kernels/internal/optimized/integer_ops/depthwise_conv.h @@ -15,7 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_INTEGER_OPS_DEPTHWISE_CONV_H_ #define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_INTEGER_OPS_DEPTHWISE_CONV_H_ -#include "tensorflow/lite/experimental/ruy/profiler/instrumentation.h" +#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" #include "tensorflow/lite/kernels/cpu_backend_context.h" #include "tensorflow/lite/kernels/cpu_backend_threadpool.h" #include "tensorflow/lite/kernels/internal/optimized/cpu_check.h" diff --git a/tensorflow/lite/kernels/internal/optimized/integer_ops/depthwise_conv_3x3_filter.h b/tensorflow/lite/kernels/internal/optimized/integer_ops/depthwise_conv_3x3_filter.h index 1efe6c7e0fd..97039e2e462 100644 --- a/tensorflow/lite/kernels/internal/optimized/integer_ops/depthwise_conv_3x3_filter.h +++ b/tensorflow/lite/kernels/internal/optimized/integer_ops/depthwise_conv_3x3_filter.h @@ -17,7 +17,7 @@ limitations under the License. #include -#include "tensorflow/lite/experimental/ruy/profiler/instrumentation.h" +#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" #include "tensorflow/lite/kernels/internal/optimized/cpu_check.h" #include "tensorflow/lite/kernels/internal/optimized/depthwiseconv_3x3_filter_common.h" #include "tensorflow/lite/kernels/internal/types.h" diff --git a/tensorflow/lite/kernels/internal/optimized/integer_ops/depthwise_conv_hybrid.h b/tensorflow/lite/kernels/internal/optimized/integer_ops/depthwise_conv_hybrid.h index ef02bf194d9..153a2252f39 100644 --- a/tensorflow/lite/kernels/internal/optimized/integer_ops/depthwise_conv_hybrid.h +++ b/tensorflow/lite/kernels/internal/optimized/integer_ops/depthwise_conv_hybrid.h @@ -15,7 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_INTEGER_OPS_DEPTHWISE_CONV_HYBRID_H_ #define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_INTEGER_OPS_DEPTHWISE_CONV_HYBRID_H_ -#include "tensorflow/lite/experimental/ruy/profiler/instrumentation.h" +#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" #include "tensorflow/lite/kernels/cpu_backend_context.h" #include "tensorflow/lite/kernels/cpu_backend_threadpool.h" #include "tensorflow/lite/kernels/internal/optimized/cpu_check.h" diff --git a/tensorflow/lite/kernels/internal/optimized/integer_ops/depthwise_conv_hybrid_3x3_filter.h b/tensorflow/lite/kernels/internal/optimized/integer_ops/depthwise_conv_hybrid_3x3_filter.h index a1e5cd7796e..fa96ce94a6e 100644 --- a/tensorflow/lite/kernels/internal/optimized/integer_ops/depthwise_conv_hybrid_3x3_filter.h +++ b/tensorflow/lite/kernels/internal/optimized/integer_ops/depthwise_conv_hybrid_3x3_filter.h @@ -17,7 +17,7 @@ limitations under the License. #include -#include "tensorflow/lite/experimental/ruy/profiler/instrumentation.h" +#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" #include "tensorflow/lite/kernels/internal/optimized/cpu_check.h" #include "tensorflow/lite/kernels/internal/optimized/depthwiseconv_3x3_filter_common.h" #include "tensorflow/lite/kernels/internal/types.h" diff --git a/tensorflow/lite/kernels/internal/optimized/integer_ops/fully_connected.h b/tensorflow/lite/kernels/internal/optimized/integer_ops/fully_connected.h index 2e01cba5d87..fdd3135097b 100644 --- a/tensorflow/lite/kernels/internal/optimized/integer_ops/fully_connected.h +++ b/tensorflow/lite/kernels/internal/optimized/integer_ops/fully_connected.h @@ -15,7 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_INTEGER_OPS_FULLY_CONNECTED_H_ #define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_INTEGER_OPS_FULLY_CONNECTED_H_ -#include "tensorflow/lite/experimental/ruy/profiler/instrumentation.h" +#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" #include "tensorflow/lite/kernels/cpu_backend_context.h" #include "tensorflow/lite/kernels/cpu_backend_gemm.h" #include "tensorflow/lite/kernels/cpu_backend_threadpool.h" diff --git a/tensorflow/lite/kernels/internal/optimized/integer_ops/mul.h b/tensorflow/lite/kernels/internal/optimized/integer_ops/mul.h index eb84cc2e9fa..952415593a5 100644 --- a/tensorflow/lite/kernels/internal/optimized/integer_ops/mul.h +++ b/tensorflow/lite/kernels/internal/optimized/integer_ops/mul.h @@ -15,7 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_INTEGER_OPS_MUL_H_ #define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_INTEGER_OPS_MUL_H_ -#include "tensorflow/lite/experimental/ruy/profiler/instrumentation.h" +#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" #include "tensorflow/lite/kernels/internal/common.h" #include "tensorflow/lite/kernels/internal/optimized/cpu_check.h" #include "tensorflow/lite/kernels/internal/reference/integer_ops/mul.h" diff --git a/tensorflow/lite/kernels/internal/optimized/integer_ops/pooling.h b/tensorflow/lite/kernels/internal/optimized/integer_ops/pooling.h index 3a6bdd2d031..fb4642e7f0d 100644 --- a/tensorflow/lite/kernels/internal/optimized/integer_ops/pooling.h +++ b/tensorflow/lite/kernels/internal/optimized/integer_ops/pooling.h @@ -28,13 +28,13 @@ limitations under the License. #include #include "fixedpoint/fixedpoint.h" -#include "tensorflow/lite/experimental/ruy/profiler/instrumentation.h" +#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" +#include "tensorflow/lite/kernels/internal/cppmath.h" #include "tensorflow/lite/kernels/internal/optimized/cpu_check.h" #include "tensorflow/lite/kernels/internal/optimized/im2col_utils.h" #include "tensorflow/lite/kernels/internal/quantization_util.h" #include "tensorflow/lite/kernels/internal/reference/integer_ops/pooling.h" #include "tensorflow/lite/kernels/internal/reference/reference_ops.h" -#include "tensorflow/lite/kernels/internal/round.h" #include "tensorflow/lite/kernels/internal/strided_slice_logic.h" #include "tensorflow/lite/kernels/internal/tensor_utils.h" #include "tensorflow/lite/kernels/internal/types.h" diff --git a/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.cc b/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.cc index cfe5ab10fb2..86e2f9fa96a 100644 --- a/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.cc +++ b/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.cc @@ -23,16 +23,16 @@ limitations under the License. #include #include -#include "tensorflow/lite/experimental/ruy/detect_arm.h" -#include "tensorflow/lite/experimental/ruy/ruy.h" +#include "tensorflow/lite/experimental/ruy/ruy/detect_arm.h" +#include "tensorflow/lite/experimental/ruy/ruy/ruy.h" #include "tensorflow/lite/kernels/cpu_backend_context.h" #include "tensorflow/lite/kernels/cpu_backend_gemm.h" #include "tensorflow/lite/kernels/cpu_backend_gemm_params.h" #include "tensorflow/lite/kernels/internal/common.h" #include "tensorflow/lite/kernels/internal/compatibility.h" +#include "tensorflow/lite/kernels/internal/cppmath.h" #include "tensorflow/lite/kernels/internal/optimized/cpu_check.h" #include "tensorflow/lite/kernels/internal/optimized/neon_tensor_utils_impl.h" -#include "tensorflow/lite/kernels/internal/round.h" #ifdef USE_NEON diff --git a/tensorflow/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/lite/kernels/internal/optimized/optimized_ops.h index 604ec5b4eff..34341f0e881 100644 --- a/tensorflow/lite/kernels/internal/optimized/optimized_ops.h +++ b/tensorflow/lite/kernels/internal/optimized/optimized_ops.h @@ -27,6 +27,7 @@ limitations under the License. #include #include +#include "tensorflow/lite/kernels/internal/common.h" #include "tensorflow/lite/kernels/internal/compatibility.h" #include "tensorflow/lite/kernels/internal/reference/add.h" @@ -38,16 +39,16 @@ limitations under the License. #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "fixedpoint/fixedpoint.h" #include "tensorflow/lite/c/common.h" -#include "tensorflow/lite/experimental/ruy/profiler/instrumentation.h" +#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" #include "tensorflow/lite/kernels/cpu_backend_context.h" #include "tensorflow/lite/kernels/cpu_backend_gemm.h" #include "tensorflow/lite/kernels/cpu_backend_gemm_params.h" #include "tensorflow/lite/kernels/cpu_backend_threadpool.h" +#include "tensorflow/lite/kernels/internal/cppmath.h" #include "tensorflow/lite/kernels/internal/optimized/cpu_check.h" #include "tensorflow/lite/kernels/internal/optimized/im2col_utils.h" #include "tensorflow/lite/kernels/internal/quantization_util.h" #include "tensorflow/lite/kernels/internal/reference/reference_ops.h" -#include "tensorflow/lite/kernels/internal/round.h" #include "tensorflow/lite/kernels/internal/strided_slice_logic.h" #include "tensorflow/lite/kernels/internal/tensor.h" #include "tensorflow/lite/kernels/internal/tensor_utils.h" @@ -1129,6 +1130,48 @@ inline void Mean(const tflite::MeanParams& op_params, } } +template +inline bool MeanGeneral(const T* input_data, const int* input_dims, + const int input_num_dims, T* output_data, + const int* output_dims, const int output_num_dims, + const int* axis, const int num_axis_dimensions, + bool keep_dims, int* temp_index, int* resolved_axis, + U* temp_sum) { + return reference_ops::Mean(input_data, input_dims, input_num_dims, + output_data, output_dims, output_num_dims, axis, + num_axis_dimensions, keep_dims, temp_index, + resolved_axis, temp_sum); +} + +template <> +inline bool MeanGeneral( + const float* input_data, const int* input_dims, const int input_num_dims, + float* output_data, const int* output_dims, const int output_num_dims, + const int* axis, const int num_axis_dimensions, bool keep_dims, + int* temp_index, int* resolved_axis, float* temp_sum) { + // Handle reduce_mean for the last dimensions. + if (num_axis_dimensions == 1 && axis[0] == (input_num_dims - 1)) { + ruy::profiler::ScopeLabel label("MeanLastDim/Float"); + int output_size = 1; + for (int i = 0; i < input_num_dims - 1; ++i) { + output_size *= input_dims[i]; + } + const int last_input_dim = input_dims[axis[0]]; + + // TODO(b/152563685): Consider use eigen to cover more general cases. + const MatrixMap in_mat(input_data, last_input_dim, + output_size); + VectorMap out(output_data, output_size, 1); + out = (in_mat.array().colwise().sum()) / static_cast(last_input_dim); + return true; + } + + return reference_ops::Mean(input_data, input_dims, input_num_dims, + output_data, output_dims, output_num_dims, axis, + num_axis_dimensions, keep_dims, temp_index, + resolved_axis, temp_sum); +} + inline void Conv(const ConvParams& params, const RuntimeShape& input_shape, const float* input_data, const RuntimeShape& filter_shape, const float* filter_data, const RuntimeShape& bias_shape, @@ -4080,20 +4123,20 @@ inline void PopulateSoftmaxLookupTable(SoftmaxParams* data, float input_scale, } } -template +template inline void Softmax(const SoftmaxParams& params, - const RuntimeShape& input_shape, const T* input_data, - const RuntimeShape& output_shape, T* output_data) { + const RuntimeShape& input_shape, const In* input_data, + const RuntimeShape& output_shape, Out* output_data) { const int trailing_dim = input_shape.DimensionsCount() - 1; const int excluding_last_dim = MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape); const int last_dim = MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim); - const int32_t clamp_max = std::numeric_limits::max(); - const int32_t clamp_min = std::numeric_limits::min(); + const int32_t clamp_max = std::numeric_limits::max(); + const int32_t clamp_min = std::numeric_limits::min(); for (int i = 0; i < excluding_last_dim; ++i) { - int32_t max_val = std::numeric_limits::min(); + int32_t max_val = std::numeric_limits::min(); // Find max quantized value. for (int j = 0; j < last_dim; ++j) { max_val = std::max(max_val, static_cast(input_data[j])); @@ -4112,8 +4155,8 @@ inline void Softmax(const SoftmaxParams& params, for (int j = 0; j < last_dim; ++j) { const float prob_rescaled = table_offset[input_data[j]] * inv_sum_exp; const int32_t prob_quantized = - QuantizeSoftmaxOutput(prob_rescaled, params.zero_point); - output_data[j] = static_cast( + QuantizeSoftmaxOutput(prob_rescaled, params.zero_point); + output_data[j] = static_cast( std::max(std::min(clamp_max, prob_quantized), clamp_min)); } input_data += last_dim; diff --git a/tensorflow/lite/kernels/internal/optimized/sparse_ops/fully_connected.h b/tensorflow/lite/kernels/internal/optimized/sparse_ops/fully_connected.h index b7a2360f890..f7e54e144ce 100644 --- a/tensorflow/lite/kernels/internal/optimized/sparse_ops/fully_connected.h +++ b/tensorflow/lite/kernels/internal/optimized/sparse_ops/fully_connected.h @@ -17,8 +17,8 @@ limitations under the License. #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/kernels/internal/common.h" +#include "tensorflow/lite/kernels/internal/cppmath.h" #include "tensorflow/lite/kernels/internal/quantization_util.h" -#include "tensorflow/lite/kernels/internal/round.h" #include "tensorflow/lite/kernels/internal/types.h" namespace tflite { diff --git a/tensorflow/lite/kernels/internal/quantization_util.cc b/tensorflow/lite/kernels/internal/quantization_util.cc index 8e28361f1f4..60e3054056d 100644 --- a/tensorflow/lite/kernels/internal/quantization_util.cc +++ b/tensorflow/lite/kernels/internal/quantization_util.cc @@ -13,13 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/lite/kernels/internal/quantization_util.h" + #include #include #include #include "tensorflow/lite/kernels/internal/compatibility.h" -#include "tensorflow/lite/kernels/internal/quantization_util.h" -#include "tensorflow/lite/kernels/internal/round.h" +#include "tensorflow/lite/kernels/internal/cppmath.h" namespace tflite { diff --git a/tensorflow/lite/kernels/internal/quantization_util.h b/tensorflow/lite/kernels/internal/quantization_util.h index d380725257e..0ee914b0689 100644 --- a/tensorflow/lite/kernels/internal/quantization_util.h +++ b/tensorflow/lite/kernels/internal/quantization_util.h @@ -20,7 +20,7 @@ limitations under the License. #include #include "tensorflow/lite/kernels/internal/compatibility.h" -#include "tensorflow/lite/kernels/internal/round.h" +#include "tensorflow/lite/kernels/internal/cppmath.h" #include "tensorflow/lite/kernels/internal/types.h" namespace tflite { diff --git a/tensorflow/lite/kernels/internal/reference/batch_matmul.h b/tensorflow/lite/kernels/internal/reference/batch_matmul.h index 4fe84aa3388..2a6b6d6f0f5 100644 --- a/tensorflow/lite/kernels/internal/reference/batch_matmul.h +++ b/tensorflow/lite/kernels/internal/reference/batch_matmul.h @@ -17,7 +17,6 @@ limitations under the License. #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/kernels/internal/common.h" -#include "tensorflow/lite/kernels/internal/round.h" #include "tensorflow/lite/kernels/internal/types.h" namespace tflite { diff --git a/tensorflow/lite/kernels/internal/reference/concatenation.h b/tensorflow/lite/kernels/internal/reference/concatenation.h index b511826969b..958fe3ea249 100644 --- a/tensorflow/lite/kernels/internal/reference/concatenation.h +++ b/tensorflow/lite/kernels/internal/reference/concatenation.h @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/lite/kernels/internal/common.h" #include "tensorflow/lite/kernels/internal/compatibility.h" -#include "tensorflow/lite/kernels/internal/round.h" +#include "tensorflow/lite/kernels/internal/cppmath.h" #include "tensorflow/lite/kernels/internal/types.h" namespace tflite { diff --git a/tensorflow/lite/kernels/internal/reference/dequantize.h b/tensorflow/lite/kernels/internal/reference/dequantize.h index 6bedcba1044..286c9310799 100644 --- a/tensorflow/lite/kernels/internal/reference/dequantize.h +++ b/tensorflow/lite/kernels/internal/reference/dequantize.h @@ -43,25 +43,6 @@ inline void Dequantize(const tflite::DequantizationParams& op_params, } } -// Dequantizes into an integer with rounding. -template -inline void DequantizeInteger(const tflite::DequantizationParams& op_params, - const RuntimeShape& input_shape, - const InputT* input_data, - const RuntimeShape& output_shape, - OutputT* output_data) { - int32 zero_point = op_params.zero_point; - const double scale = op_params.scale; - const int flat_size = MatchingFlatSize(input_shape, output_shape); - - for (int i = 0; i < flat_size; i++) { - const int32 val = input_data[i]; - const OutputT result = - static_cast(round(scale * (val - zero_point))); - output_data[i] = result; - } -} - // Dequantizes per-channel quantized tensor to float. template inline void PerChannelDequantize( diff --git a/tensorflow/lite/kernels/internal/reference/fully_connected.h b/tensorflow/lite/kernels/internal/reference/fully_connected.h index 51c1deff969..fa59e1df370 100644 --- a/tensorflow/lite/kernels/internal/reference/fully_connected.h +++ b/tensorflow/lite/kernels/internal/reference/fully_connected.h @@ -16,8 +16,8 @@ limitations under the License. #define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_FULLY_CONNECTED_H_ #include "tensorflow/lite/kernels/internal/common.h" +#include "tensorflow/lite/kernels/internal/cppmath.h" #include "tensorflow/lite/kernels/internal/quantization_util.h" -#include "tensorflow/lite/kernels/internal/round.h" #include "tensorflow/lite/kernels/internal/types.h" namespace tflite { diff --git a/tensorflow/lite/kernels/internal/reference/integer_ops/mul.h b/tensorflow/lite/kernels/internal/reference/integer_ops/mul.h index 0ee45dc3aae..20571110005 100644 --- a/tensorflow/lite/kernels/internal/reference/integer_ops/mul.h +++ b/tensorflow/lite/kernels/internal/reference/integer_ops/mul.h @@ -16,7 +16,7 @@ limitations under the License. #define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_MUL_H_ #include "fixedpoint/fixedpoint.h" -#include "tensorflow/lite/experimental/ruy/profiler/instrumentation.h" +#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" #include "tensorflow/lite/kernels/internal/common.h" namespace tflite { diff --git a/tensorflow/lite/kernels/internal/reference/logistic.h b/tensorflow/lite/kernels/internal/reference/logistic.h index d53ca4c057c..8aba51896df 100644 --- a/tensorflow/lite/kernels/internal/reference/logistic.h +++ b/tensorflow/lite/kernels/internal/reference/logistic.h @@ -19,8 +19,8 @@ limitations under the License. #include "fixedpoint/fixedpoint.h" #include "tensorflow/lite/kernels/internal/common.h" +#include "tensorflow/lite/kernels/internal/cppmath.h" #include "tensorflow/lite/kernels/internal/quantization_util.h" -#include "tensorflow/lite/kernels/internal/round.h" #include "tensorflow/lite/kernels/internal/types.h" #include "tensorflow/lite/kernels/op_macros.h" diff --git a/tensorflow/lite/kernels/internal/reference/pooling.h b/tensorflow/lite/kernels/internal/reference/pooling.h index 2cb23472f29..a03359cda82 100644 --- a/tensorflow/lite/kernels/internal/reference/pooling.h +++ b/tensorflow/lite/kernels/internal/reference/pooling.h @@ -16,8 +16,8 @@ limitations under the License. #define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_POOLING_H_ #include "tensorflow/lite/kernels/internal/common.h" +#include "tensorflow/lite/kernels/internal/cppmath.h" #include "tensorflow/lite/kernels/internal/quantization_util.h" -#include "tensorflow/lite/kernels/internal/round.h" #include "tensorflow/lite/kernels/internal/types.h" namespace tflite { diff --git a/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.cc b/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.cc index 045d7bdcb73..9c58415d6dc 100644 --- a/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.cc +++ b/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.cc @@ -24,8 +24,8 @@ limitations under the License. #include "tensorflow/lite/kernels/cpu_backend_context.h" #include "tensorflow/lite/kernels/internal/common.h" #include "tensorflow/lite/kernels/internal/compatibility.h" +#include "tensorflow/lite/kernels/internal/cppmath.h" #include "tensorflow/lite/kernels/internal/reference/portable_tensor_utils_impl.h" -#include "tensorflow/lite/kernels/internal/round.h" #if defined(_MSC_VER) #define __restrict__ __restrict diff --git a/tensorflow/lite/kernels/internal/reference/quantize.h b/tensorflow/lite/kernels/internal/reference/quantize.h index b1ad71ea6cd..58d19c0a14c 100644 --- a/tensorflow/lite/kernels/internal/reference/quantize.h +++ b/tensorflow/lite/kernels/internal/reference/quantize.h @@ -16,7 +16,7 @@ limitations under the License. #define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_QUANTIZE_H_ #include "tensorflow/lite/kernels/internal/common.h" -#include "tensorflow/lite/kernels/internal/round.h" +#include "tensorflow/lite/kernels/internal/cppmath.h" #include "tensorflow/lite/kernels/internal/types.h" namespace tflite { diff --git a/tensorflow/lite/kernels/internal/reference/reduce.h b/tensorflow/lite/kernels/internal/reference/reduce.h index 8bfe66cdd48..46448b2a646 100644 --- a/tensorflow/lite/kernels/internal/reference/reduce.h +++ b/tensorflow/lite/kernels/internal/reference/reduce.h @@ -15,8 +15,9 @@ limitations under the License. #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_REDUCE_H_ #define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_REDUCE_H_ -#include "tensorflow/lite/experimental/ruy/profiler/instrumentation.h" +#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" #include "tensorflow/lite/kernels/internal/common.h" +#include "tensorflow/lite/kernels/internal/cppmath.h" #include "tensorflow/lite/kernels/internal/quantization_util.h" #include "tensorflow/lite/kernels/internal/types.h" @@ -371,7 +372,7 @@ inline bool QuantizedMeanOrSum(const T* input_data, int32 input_zero_point, -input_zero_point * scale * num_elements_in_axis + 0.5f; for (size_t idx = 0; idx < num_outputs; ++idx) { const U value = - static_cast(std::round(temp_sum[idx] * scale + bias)) + + static_cast(TfLiteRound(temp_sum[idx] * scale + bias)) + output_zero_point; output_data[idx] = static_cast(value); } @@ -381,7 +382,7 @@ inline bool QuantizedMeanOrSum(const T* input_data, int32 input_zero_point, float float_mean = static_cast(temp_sum[idx]) / static_cast(num_elements_in_axis); float result = - std::min(std::round(float_mean * scale + bias) + output_zero_point, + std::min(TfLiteRound(float_mean * scale + bias) + output_zero_point, static_cast(std::numeric_limits::max())); result = std::max(result, static_cast(std::numeric_limits::min())); diff --git a/tensorflow/lite/kernels/internal/reference/reference_ops.h b/tensorflow/lite/kernels/internal/reference/reference_ops.h index a3430c72594..cc450ef3912 100644 --- a/tensorflow/lite/kernels/internal/reference/reference_ops.h +++ b/tensorflow/lite/kernels/internal/reference/reference_ops.h @@ -29,7 +29,7 @@ limitations under the License. #include "third_party/eigen3/Eigen/Core" #include "fixedpoint/fixedpoint.h" #include "tensorflow/lite/c/common.h" -#include "tensorflow/lite/experimental/ruy/profiler/instrumentation.h" +#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" #include "tensorflow/lite/kernels/internal/common.h" #include "tensorflow/lite/kernels/internal/quantization_util.h" #include "tensorflow/lite/kernels/internal/reference/add.h" @@ -57,7 +57,6 @@ limitations under the License. #include "tensorflow/lite/kernels/internal/reference/softmax.h" #include "tensorflow/lite/kernels/internal/reference/strided_slice.h" #include "tensorflow/lite/kernels/internal/reference/sub.h" -#include "tensorflow/lite/kernels/internal/round.h" #include "tensorflow/lite/kernels/internal/strided_slice_logic.h" #include "tensorflow/lite/kernels/internal/tensor.h" #include "tensorflow/lite/kernels/internal/types.h" diff --git a/tensorflow/lite/kernels/internal/reference/requantize.h b/tensorflow/lite/kernels/internal/reference/requantize.h index fc24166cce7..8233be9ebae 100644 --- a/tensorflow/lite/kernels/internal/reference/requantize.h +++ b/tensorflow/lite/kernels/internal/reference/requantize.h @@ -15,7 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_REQUANTIZE_H_ #define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_REQUANTIZE_H_ -#include "tensorflow/lite/experimental/ruy/profiler/instrumentation.h" +#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" #include "tensorflow/lite/kernels/internal/common.h" #include "tensorflow/lite/kernels/internal/types.h" diff --git a/tensorflow/lite/kernels/internal/reference/softmax.h b/tensorflow/lite/kernels/internal/reference/softmax.h index 7c59b133bf3..3f19878e6a6 100644 --- a/tensorflow/lite/kernels/internal/reference/softmax.h +++ b/tensorflow/lite/kernels/internal/reference/softmax.h @@ -19,8 +19,8 @@ limitations under the License. #include "fixedpoint/fixedpoint.h" #include "tensorflow/lite/kernels/internal/common.h" +#include "tensorflow/lite/kernels/internal/cppmath.h" #include "tensorflow/lite/kernels/internal/quantization_util.h" -#include "tensorflow/lite/kernels/internal/round.h" #include "tensorflow/lite/kernels/internal/types.h" #include "tensorflow/lite/kernels/op_macros.h" diff --git a/tensorflow/lite/kernels/internal/reference/sub.h b/tensorflow/lite/kernels/internal/reference/sub.h index ae48491c04e..a9ed3a675fd 100644 --- a/tensorflow/lite/kernels/internal/reference/sub.h +++ b/tensorflow/lite/kernels/internal/reference/sub.h @@ -16,7 +16,7 @@ limitations under the License. #define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_SUB_H_ #include "fixedpoint/fixedpoint.h" -#include "tensorflow/lite/experimental/ruy/profiler/instrumentation.h" +#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" #include "tensorflow/lite/kernels/internal/common.h" namespace tflite { diff --git a/tensorflow/lite/kernels/kernel_util.cc b/tensorflow/lite/kernels/kernel_util.cc index 49700dc8d12..aa735ee0c43 100644 --- a/tensorflow/lite/kernels/kernel_util.cc +++ b/tensorflow/lite/kernels/kernel_util.cc @@ -18,8 +18,8 @@ limitations under the License. #include #include +#include "tensorflow/lite/kernels/internal/cppmath.h" #include "tensorflow/lite/kernels/internal/quantization_util.h" -#include "tensorflow/lite/kernels/internal/round.h" namespace tflite { diff --git a/tensorflow/lite/kernels/lstm_eval.cc b/tensorflow/lite/kernels/lstm_eval.cc index 454a223440e..9cc146ae8bd 100644 --- a/tensorflow/lite/kernels/lstm_eval.cc +++ b/tensorflow/lite/kernels/lstm_eval.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/common.h" -#include "tensorflow/lite/experimental/ruy/profiler/instrumentation.h" +#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" #include "tensorflow/lite/kernels/cpu_backend_context.h" #include "tensorflow/lite/kernels/internal/compatibility.h" #include "tensorflow/lite/kernels/internal/kernel_utils.h" diff --git a/tensorflow/lite/kernels/reduce.cc b/tensorflow/lite/kernels/reduce.cc index a61f5f4dac7..f0222a08fe3 100644 --- a/tensorflow/lite/kernels/reduce.cc +++ b/tensorflow/lite/kernels/reduce.cc @@ -359,7 +359,7 @@ TfLiteStatus EvalMean(TfLiteContext* context, TfLiteNode* node) { } else { TF_LITE_ENSURE( context, - reference_ops::Mean( + optimized_ops::MeanGeneral( GetTensorData(op_context.input), op_context.input->dims->data, op_context.input->dims->size, GetTensorData(op_context.output), diff --git a/tensorflow/lite/kernels/reduce_test.cc b/tensorflow/lite/kernels/reduce_test.cc index 58cd9b9076c..ddbd5106063 100644 --- a/tensorflow/lite/kernels/reduce_test.cc +++ b/tensorflow/lite/kernels/reduce_test.cc @@ -354,6 +354,24 @@ TEST(DynamicFloatMeanOpTest, NotKeepDims) { EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({12, 13}))); } +TEST(DynamicFloatMeanOpTest, ReduceOnLastDimNotKeepDims) { + std::vector data = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, + 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, + 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}; + MeanOpDynamicModel m({TensorType_FLOAT32, {4, 3, 2}}, + {TensorType_FLOAT32, {2}}, {TensorType_INT32, {1}}, + false); + std::vector axis = {2}; + m.SetAxis(axis); + m.SetInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4, 3})); + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray(ArrayFloatNear({1.5, 3.5, 5.5, 7.5, 9.5, 11.5, 13.5, + 15.5, 17.5, 19.5, 21.5, 23.5}))); +} + TEST(DynamicFloatMeanOpTest, KeepDims) { std::vector data = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, diff --git a/tensorflow/lite/kernels/rfft2d.cc b/tensorflow/lite/kernels/rfft2d.cc index a06b66735f6..c0554c5e39b 100644 --- a/tensorflow/lite/kernels/rfft2d.cc +++ b/tensorflow/lite/kernels/rfft2d.cc @@ -16,7 +16,7 @@ limitations under the License. #include "third_party/fft2d/fft2d.h" #include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/common.h" -#include "tensorflow/lite/experimental/ruy/profiler/instrumentation.h" +#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" #include "tensorflow/lite/kernels/internal/tensor.h" #include "tensorflow/lite/kernels/kernel_util.h" #include "tensorflow/lite/kernels/op_macros.h" diff --git a/tensorflow/lite/micro/BUILD b/tensorflow/lite/micro/BUILD index 3aebedcf498..bd4a0b5d152 100644 --- a/tensorflow/lite/micro/BUILD +++ b/tensorflow/lite/micro/BUILD @@ -46,11 +46,7 @@ cc_library( "test_helpers.h", ], build_for_embedded = True, - copts = [ - "-Werror", - "-Wdouble-promotion", - "-Wsign-compare", - ], + copts = micro_copts(), deps = [ ":micro_compatibility", ":micro_string", diff --git a/tensorflow/lite/micro/examples/micro_speech/README.md b/tensorflow/lite/micro/examples/micro_speech/README.md index 368f8056e5e..4b75041e0c8 100644 --- a/tensorflow/lite/micro/examples/micro_speech/README.md +++ b/tensorflow/lite/micro/examples/micro_speech/README.md @@ -397,16 +397,23 @@ The following instructions will help you build and deploy the sample to the [NXP FRDM K66F](https://www.nxp.com/design/development-boards/freedom-development-boards/mcu-boards/freedom-development-platform-for-kinetis-k66-k65-and-k26-mcus:FRDM-K66F) using [ARM Mbed](https://github.com/ARMmbed/mbed-cli). -1. Download [the TensorFlow source code](https://github.com/tensorflow/tensorflow). -2. Follow instructions from [mbed website](https://os.mbed.com/docs/mbed-os/v5.13/tools/installation-and-setup.html) to setup and install mbed CLI. +1. Download + [the TensorFlow source code](https://github.com/tensorflow/tensorflow). +2. Follow instructions from + [mbed website](https://os.mbed.com/docs/mbed-os/v5.13/tools/installation-and-setup.html) + to setup and install mbed CLI. 3. Compile TensorFlow with the following command to generate mbed project: ``` make -f tensorflow/lite/micro/tools/make/Makefile TARGET=mbed TAGS="nxp_k66f" generate_micro_speech_mbed_project ``` -4. Go to the location of the generated project. The generated project is usually - in `tensorflow/lite/micro/tools/make/gen/mbed_cortex-m4/prj/micro_speech/mbed` + +4. Go to the location of the generated project. The generated project is + usually in + `tensorflow/lite/micro/tools/make/gen/mbed_cortex-m4/prj/micro_speech/mbed` + 5. Create a mbed project using the generated files: `mbed new .` + 6. Change the project setting to use C++ 11 rather than C++ 14 using: ``` @@ -415,13 +422,15 @@ using [ARM Mbed](https://github.com/ARMmbed/mbed-cli). for line in fileinput.input(filename, inplace=True): print line.replace("\"-std=gnu++14\"","\"-std=c++11\", \"-fpermissive\"")' ``` + 7. To compile project, use the following command: ``` mbed compile --target K66F --toolchain GCC_ARM --profile release ``` -8. For some mbed compliers, you may get compile error in mbed_rtc_time.cpp. - Go to `mbed-os/platform/mbed_rtc_time.h` and comment line 32 and line 37: + +8. For some mbed compilers, you may get compile error in mbed_rtc_time.cpp. Go + to `mbed-os/platform/mbed_rtc_time.h` and comment line 32 and line 37: ``` //#if !defined(__GNUC__) || defined(__CC_ARM) || defined(__clang__) @@ -431,25 +440,35 @@ using [ARM Mbed](https://github.com/ARMmbed/mbed-cli). }; //#endif ``` -9. Look at helpful resources from NXP website such as [NXP FRDM-K66F User guide](https://www.nxp.com/docs/en/user-guide/FRDMK66FUG.pdf) and [NXP FRDM-K66F Getting Started](https://www.nxp.com/document/guide/get-started-with-the-frdm-k66f:NGS-FRDM-K66F) + +9. Look at helpful resources from NXP website such as + [NXP FRDM-K66F User guide](https://www.nxp.com/docs/en/user-guide/FRDMK66FUG.pdf) + and + [NXP FRDM-K66F Getting Started](https://www.nxp.com/document/guide/get-started-with-the-frdm-k66f:NGS-FRDM-K66F) to understand information about the board. + 10. Connect the USB cable to the micro USB port. When the Ethernet port is facing towards you, the micro USB port is left of the Ethernet port. -11. To compile and flash in a single step, add the `--flash` option: + +11. To compile and flash in a single step, add the `--flash` option: ``` mbed compile --target K66F --toolchain GCC_ARM --profile release --flash ``` + 12. Disconnect USB cable from the device to power down the device and connect back the power cable to start running the model. -13. Connect to serial port with baud rate of 9600 and correct serial device - to view the output from the MCU. In linux, you can run the following screen + +13. Connect to serial port with baud rate of 9600 and correct serial device to + view the output from the MCU. In linux, you can run the following screen command if the serial device is `/dev/ttyACM0`: ``` sudo screen /dev/ttyACM0 9600 ``` + 14. Saying "Yes" will print "Yes" and "No" will print "No" on the serial port. + 15. A loopback path from microphone to headset jack is enabled. Headset jack is in black color. If there is no output on the serial port, you can connect headphone to headphone port to check if audio loopback path is working. diff --git a/tensorflow/lite/micro/examples/person_detection/README.md b/tensorflow/lite/micro/examples/person_detection/README.md index 12c6b7b9b9f..5ee7bda9914 100644 --- a/tensorflow/lite/micro/examples/person_detection/README.md +++ b/tensorflow/lite/micro/examples/person_detection/README.md @@ -202,7 +202,7 @@ The next steps assume that the * The `IDF_PATH` environment variable is set * `idf.py` and Xtensa-esp32 tools (e.g. `xtensa-esp32-elf-gcc`) are in `$PATH` -* `esp32-camera` should be downloaded in `comopnents/` dir of example as +* `esp32-camera` should be downloaded in `components/` dir of example as explained in `Building the example`(below) ### Generate the examples diff --git a/tensorflow/lite/micro/examples/person_detection/esp/README_ESP.md b/tensorflow/lite/micro/examples/person_detection/esp/README_ESP.md index 35e974d985a..78a7561d5b5 100644 --- a/tensorflow/lite/micro/examples/person_detection/esp/README_ESP.md +++ b/tensorflow/lite/micro/examples/person_detection/esp/README_ESP.md @@ -16,7 +16,7 @@ The next steps assume that the [IDF environment variables are set](https://docs.espressif.com/projects/esp-idf/en/latest/get-started/index.html#step-4-set-up-the-environment-variables) : * The `IDF_PATH` environment variable is set. * `idf.py` and Xtensa-esp32 tools (e.g., `xtensa-esp32-elf-gcc`) are in `$PATH`. * `esp32-camera` should be -downloaded in `comopnents/` dir of example as explained in `Build the +downloaded in `components/` dir of example as explained in `Build the example`(below) ## Build the example diff --git a/tensorflow/lite/micro/kernels/BUILD b/tensorflow/lite/micro/kernels/BUILD index 288f603a1f0..a0ffa342008 100644 --- a/tensorflow/lite/micro/kernels/BUILD +++ b/tensorflow/lite/micro/kernels/BUILD @@ -61,6 +61,7 @@ cc_library( "//tensorflow/lite/kernels:op_macros", "//tensorflow/lite/kernels:padding", "//tensorflow/lite/kernels/internal:common", + "//tensorflow/lite/kernels/internal:compatibility", "//tensorflow/lite/kernels/internal:quantization_util", "//tensorflow/lite/kernels/internal:reference_base", "//tensorflow/lite/kernels/internal:tensor", @@ -133,6 +134,7 @@ cc_library( "//tensorflow/lite/kernels:op_macros", "//tensorflow/lite/kernels:padding", "//tensorflow/lite/kernels/internal:common", + "//tensorflow/lite/kernels/internal:compatibility", "//tensorflow/lite/kernels/internal:quantization_util", "//tensorflow/lite/kernels/internal:reference_base", "//tensorflow/lite/kernels/internal:tensor", @@ -504,7 +506,10 @@ tflite_micro_cc_test( cc_library( name = "activation_utils", hdrs = ["activation_utils.h"], - deps = ["//tensorflow/lite/c:common"], + deps = [ + "//tensorflow/lite/c:common", + "//tensorflow/lite/kernels/internal:cppmath", + ], ) tflite_micro_cc_test( diff --git a/tensorflow/lite/micro/kernels/activation_utils.h b/tensorflow/lite/micro/kernels/activation_utils.h index 04fd2bd88a9..7525bc93b0a 100644 --- a/tensorflow/lite/micro/kernels/activation_utils.h +++ b/tensorflow/lite/micro/kernels/activation_utils.h @@ -16,9 +16,11 @@ limitations under the License. #ifndef TENSORFLOW_LITE_MICRO_KERNELS_ACTIVATION_UTILS_H_ #define TENSORFLOW_LITE_MICRO_KERNELS_ACTIVATION_UTILS_H_ +#include #include #include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/kernels/internal/cppmath.h" namespace tflite { namespace ops { @@ -30,11 +32,11 @@ inline float ActivationValFloat(TfLiteFusedActivation act, float a) { case kTfLiteActNone: return a; case kTfLiteActRelu: - return std::fmax(0.0f, a); + return std::max(0.0f, a); case kTfLiteActRelu1: - return std::fmax(-1.0f, std::fmin(a, 1.0f)); + return std::max(-1.0f, std::min(a, 1.0f)); case kTfLiteActRelu6: - return std::fmax(0.0f, std::fmin(a, 6.0f)); + return std::max(0.0f, std::min(a, 6.0f)); case kTfLiteActTanh: return std::tanh(a); case kTfLiteActSignBit: diff --git a/tensorflow/lite/micro/kernels/circular_buffer.cc b/tensorflow/lite/micro/kernels/circular_buffer.cc index 6b024696faa..e4bf91d9095 100644 --- a/tensorflow/lite/micro/kernels/circular_buffer.cc +++ b/tensorflow/lite/micro/kernels/circular_buffer.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/kernels/internal/compatibility.h" #include "tensorflow/lite/kernels/internal/quantization_util.h" #include "tensorflow/lite/kernels/internal/reference/integer_ops/add.h" #include "tensorflow/lite/kernels/internal/reference/process_broadcast_shapes.h" @@ -73,6 +74,8 @@ OpData op_data_array[kMaxOpDataSize]; } // namespace +void Free(TfLiteContext* context, void* buffer) { op_data_counter = 0; } + TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { const TfLiteTensor* input = GetInput(context, node, kInputTensor); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); @@ -93,6 +96,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // TODO(b/132070898): Use statically slotted OpData structures until a // scratch memory API is ready. + TFLITE_DCHECK_LE(op_data_counter, kMaxOpDataSize); OpData* op_data = &op_data_array[op_data_counter++]; // The last circular buffer layer (length 5) simply accumulates outputs, and // does not run periodically. @@ -156,6 +160,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteRegistration* Register_CIRCULAR_BUFFER() { static TfLiteRegistration r = {}; + r.free = circular_buffer::Free; r.prepare = circular_buffer::Prepare; r.invoke = circular_buffer::Eval; return &r; diff --git a/tensorflow/lite/micro/kernels/cmsis-nn/conv.cc b/tensorflow/lite/micro/kernels/cmsis-nn/conv.cc index 2e2efec127e..8bc1f5351cb 100644 --- a/tensorflow/lite/micro/kernels/cmsis-nn/conv.cc +++ b/tensorflow/lite/micro/kernels/cmsis-nn/conv.cc @@ -24,7 +24,6 @@ limitations under the License. #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" #include "tensorflow/lite/kernels/kernel_util.h" #include "tensorflow/lite/kernels/padding.h" -#include "tensorflow/lite/micro/kernels/cmsis-nn/scratch_buffer.h" namespace tflite { namespace ops { @@ -111,12 +110,56 @@ TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node, } void* Init(TfLiteContext* context, const char* buffer, size_t length) { - return nullptr; + void* raw; + context->AllocatePersistentBuffer(context, sizeof(int), &raw); + return raw; } void Free(TfLiteContext* context, void* buffer) {} TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { +#if defined(__ARM_FEATURE_DSP) + OpData data; + int32_t buf_size; + + auto* params = reinterpret_cast(node->builtin_data); + + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + const TfLiteTensor* filter = GetInput(context, node, kFilterTensor); + const TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + RuntimeShape input_shape = GetTensorShape(input); + + const int input_depth = input_shape.Dims(3); + const int input_width = input->dims->data[2]; + const int input_height = input->dims->data[1]; + const int filter_width = filter->dims->data[2]; + const int filter_height = filter->dims->data[1]; + const int output_width = output->dims->data[2]; + const int output_height = output->dims->data[1]; + + int* buffer_idx = reinterpret_cast(node->user_data); + + TF_LITE_ENSURE_STATUS(CalculateOpData( + context, node, params, input_width, input_height, filter_width, + filter_height, output_width, output_height, input->type, &data)); + + if (data.padding.width == 0 && data.padding.height == 0 && + (input_depth % 4 == 0) && params->stride_width == 1 && + params->stride_height == 1 && filter_width == 1 && filter_height == 1) { + buf_size = arm_convolve_1x1_s8_fast_get_buffer_size(input_depth); + } else { + buf_size = arm_convolve_s8_get_buffer_size(input_depth, filter_width, + filter_height); + } + + node->user_data = buffer_idx; + if (buf_size > 0) { + context->RequestScratchBufferInArena(context, buf_size, buffer_idx); + } else { + *buffer_idx = -1; + } +#endif return kTfLiteOk; } @@ -200,15 +243,16 @@ TfLiteStatus EvalQuantizedPerChannel( const int output_width = output_shape.Dims(2); int16_t* buf = nullptr; + auto* buffer_idx = reinterpret_cast(node->user_data); + if (*buffer_idx > -1) { + void* raw = context->GetScratchBuffer(context, *buffer_idx); + buf = reinterpret_cast(raw); + } + if (op_params.padding_values.width == 0 && op_params.padding_values.height == 0 && (input_depth % 4 == 0) && - (output_depth % 2 == 0) && op_params.stride_width == 1 && - op_params.stride_height == 1 && filter_width == 1 && filter_height == 1) { - const int32_t buf_size = - arm_convolve_1x1_s8_fast_get_buffer_size(input_depth); - if (get_cmsis_scratch_buffer(context, &buf, buf_size) != kTfLiteOk) { - return kTfLiteError; - } + op_params.stride_width == 1 && op_params.stride_height == 1 && + filter_width == 1 && filter_height == 1) { if (arm_convolve_1x1_s8_fast( GetTensorData(input), input_width, input_height, input_depth, batches, GetTensorData(filter), output_depth, @@ -221,12 +265,26 @@ TfLiteStatus EvalQuantizedPerChannel( output_height, buf) != ARM_MATH_SUCCESS) { return kTfLiteError; } - } else { - const int32_t buf_size = arm_convolve_s8_get_buffer_size( + + } else if (output_height == 1 && input_height == 1 && filter_height == 1 && + (output_width % 4 == 0) && batches == 1) { + const int32_t buf_size = arm_convolve_1_x_n_s8_get_buffer_size( input_depth, filter_width, filter_height); if (get_cmsis_scratch_buffer(context, &buf, buf_size) != kTfLiteOk) { return kTfLiteError; } + if (arm_convolve_1_x_n_s8( + GetTensorData(input), input_width, input_depth, batches, + GetTensorData(filter), output_depth, filter_width, + op_params.padding_values.width, op_params.stride_width, + GetTensorData(bias), GetTensorData(output), + data->per_channel_output_shift, data->per_channel_output_multiplier, + op_params.output_offset, op_params.input_offset, + output_activation_min, output_activation_max, output_width, + buf) != ARM_MATH_SUCCESS) { + return kTfLiteError; + } + } else { if (arm_convolve_s8( GetTensorData(input), input_width, input_height, input_depth, batches, GetTensorData(filter), output_depth, diff --git a/tensorflow/lite/micro/kernels/cmsis-nn/depthwise_conv.cc b/tensorflow/lite/micro/kernels/cmsis-nn/depthwise_conv.cc index f5543b85cb9..e4be31d12ed 100644 --- a/tensorflow/lite/micro/kernels/cmsis-nn/depthwise_conv.cc +++ b/tensorflow/lite/micro/kernels/cmsis-nn/depthwise_conv.cc @@ -25,7 +25,6 @@ limitations under the License. #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" #include "tensorflow/lite/kernels/kernel_util.h" #include "tensorflow/lite/kernels/padding.h" -#include "tensorflow/lite/micro/kernels/cmsis-nn/scratch_buffer.h" namespace tflite { namespace ops { @@ -99,12 +98,41 @@ TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node, } // namespace void* Init(TfLiteContext* context, const char* buffer, size_t length) { - return nullptr; + void* raw; + context->AllocatePersistentBuffer(context, sizeof(int), &raw); + return raw; } void Free(TfLiteContext* context, void* buffer) {} TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { +#if defined(__ARM_FEATURE_DSP) + auto* params = + reinterpret_cast(node->builtin_data); + + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + const TfLiteTensor* filter = GetInput(context, node, kFilterTensor); + + const int filter_width = SizeOfDimension(filter, 2); + const int filter_height = SizeOfDimension(filter, 1); + + RuntimeShape input_shape = GetTensorShape(input); + const int input_depth = input_shape.Dims(3); + + int* buffer_idx = reinterpret_cast(node->user_data); + + *buffer_idx = -1; + node->user_data = buffer_idx; + + if (params->depth_multiplier == 1) { + const int32_t buf_size = arm_depthwise_conv_s8_opt_get_buffer_size( + input_depth, filter_width, filter_height); + + if (buf_size > 0) { + context->RequestScratchBufferInArena(context, buf_size, buffer_idx); + } + } +#endif return kTfLiteOk; } @@ -174,10 +202,12 @@ TfLiteStatus EvalQuantizedPerChannel(TfLiteContext* context, TfLiteNode* node, if (op_params.depth_multiplier == 1) { int16_t* buf = nullptr; - const int32_t buf_size = arm_depthwise_conv_s8_opt_get_buffer_size( - input_depth, filter_width, filter_height); - TF_LITE_ENSURE_OK(context, - get_cmsis_scratch_buffer(context, &buf, buf_size)); + auto* buffer_idx = reinterpret_cast(node->user_data); + if (*buffer_idx > -1) { + void* raw = context->GetScratchBuffer(context, *buffer_idx); + buf = reinterpret_cast(raw); + } + TF_LITE_ENSURE_EQ( context, arm_depthwise_conv_s8_opt( diff --git a/tensorflow/lite/micro/kernels/cmsis-nn/fully_connected.cc b/tensorflow/lite/micro/kernels/cmsis-nn/fully_connected.cc index 20980d726c6..88e32ba5d8c 100644 --- a/tensorflow/lite/micro/kernels/cmsis-nn/fully_connected.cc +++ b/tensorflow/lite/micro/kernels/cmsis-nn/fully_connected.cc @@ -23,7 +23,6 @@ limitations under the License. #include "tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h" #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" #include "tensorflow/lite/kernels/kernel_util.h" -#include "tensorflow/lite/micro/kernels/cmsis-nn/scratch_buffer.h" namespace tflite { namespace ops { @@ -73,14 +72,32 @@ TfLiteStatus CalculateOpData(TfLiteContext* context, } // namespace void* Init(TfLiteContext* context, const char* buffer, size_t length) { - return nullptr; + void* raw; + context->AllocatePersistentBuffer(context, sizeof(int), &raw); + return raw; } void Free(TfLiteContext* context, void* buffer) {} TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { - // todo: call AllocateTemporaryTensor() instead of using - // get_cmsis_scratch_buffer() +#if defined(__ARM_FEATURE_DSP) + const TfLiteTensor* filter = GetInput(context, node, kWeightsTensor); + + RuntimeShape filter_shape = GetTensorShape(filter); + const int filter_dim_count = filter_shape.DimensionsCount(); + const int accum_depth = filter_shape.Dims(filter_dim_count - 1); + + const int32_t buf_size = arm_fully_connected_s8_get_buffer_size(accum_depth); + + int* buffer_idx = reinterpret_cast(node->user_data); + + node->user_data = buffer_idx; + if (buf_size > 0) { + context->RequestScratchBufferInArena(context, buf_size, buffer_idx); + } else { + *buffer_idx = -1; + } +#endif return kTfLiteOk; } @@ -97,9 +114,14 @@ TfLiteStatus EvalQuantizedInt8(TfLiteContext* context, TfLiteNode* node, const int accum_depth = filter_shape.Dims(filter_dim_count - 1); #if defined(__ARM_FEATURE_DSP) - const int32_t buf_size = arm_fully_connected_s8_get_buffer_size(accum_depth); int16_t* buf = nullptr; - TF_LITE_ENSURE_OK(context, get_cmsis_scratch_buffer(context, &buf, buf_size)); + + auto* buffer_idx = reinterpret_cast(node->user_data); + if (*buffer_idx > -1) { + void* raw = context->GetScratchBuffer(context, *buffer_idx); + buf = reinterpret_cast(raw); + } + TF_LITE_ENSURE_EQ( context, arm_fully_connected_s8( diff --git a/tensorflow/lite/micro/kernels/cmsis-nn/pooling.cc b/tensorflow/lite/micro/kernels/cmsis-nn/pooling.cc index 54dcf64118e..74cf10f5a73 100644 --- a/tensorflow/lite/micro/kernels/cmsis-nn/pooling.cc +++ b/tensorflow/lite/micro/kernels/cmsis-nn/pooling.cc @@ -16,7 +16,6 @@ limitations under the License. // These are headers from the ARM CMSIS-NN library. #include "arm_nnfunctions.h" // NOLINT -#include "scratch_buffer.h" // NOLINT #include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/kernels/internal/reference/integer_ops/pooling.h" #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" @@ -128,10 +127,13 @@ TfLiteStatus AverageEvalInt8(TfLiteContext* context, const TfLiteNode* node, const int padding_width = data->padding.width; int16_t* scratch_buffer = nullptr; - int32_t buffer_size = arm_avgpool_s8_get_buffer_size(output_width, depth); - TF_LITE_ENSURE_OK( - context, get_cmsis_scratch_buffer(context, &scratch_buffer, buffer_size)); + auto* buffer_idx = reinterpret_cast(node->user_data); + + if (*buffer_idx > -1) { + void* raw = context->GetScratchBuffer(context, *buffer_idx); + scratch_buffer = reinterpret_cast(raw); + } TF_LITE_ENSURE_EQ( context, @@ -207,12 +209,39 @@ void MaxEvalQuantizedUInt8(TfLiteContext* context, TfLiteNode* node, } // namespace void* Init(TfLiteContext* context, const char* buffer, size_t length) { - return nullptr; + void* raw; + context->AllocatePersistentBuffer(context, sizeof(int), &raw); + return raw; } void Free(TfLiteContext* context, void* buffer) {} TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { +#if defined(__ARM_FEATURE_DSP) + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + const TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + RuntimeShape input_shape = GetTensorShape(input); + TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); + + RuntimeShape output_shape = GetTensorShape(output); + TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4); + + const int depth = MatchingDim(input_shape, 3, output_shape, 3); + const int output_width = output_shape.Dims(2); + + const int32_t buffer_size = + arm_avgpool_s8_get_buffer_size(output_width, depth); + + int* buffer_idx = reinterpret_cast(node->user_data); + + node->user_data = buffer_idx; + if (buffer_size > 0) { + context->RequestScratchBufferInArena(context, buffer_size, buffer_idx); + } else { + *buffer_idx = -1; + } +#endif return kTfLiteOk; } diff --git a/tensorflow/lite/micro/kernels/cmsis-nn/scratch_buffer.cc b/tensorflow/lite/micro/kernels/cmsis-nn/scratch_buffer.cc deleted file mode 100644 index e15a1416aeb..00000000000 --- a/tensorflow/lite/micro/kernels/cmsis-nn/scratch_buffer.cc +++ /dev/null @@ -1,36 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "scratch_buffer.h" - -// todo: remove this function once context->AllocateTemporaryTensor() is -// implemented. - -// This buffer is used by CMSIS-NN optimized operator implementations. -// SCRATCH_BUFFER_BYTES bytes is chosen empirically. It needs to be large -// enough to hold the biggest buffer needed by all CMSIS-NN operators in the -// network. -// note: buffer must be 32-bit aligned for SIMD -#define SCRATCH_BUFFER_BYTES 13000 - -TfLiteStatus get_cmsis_scratch_buffer(TfLiteContext* context, int16_t** buf, - int32_t buf_size_bytes) { - __attribute__((aligned( - 4))) static int16_t cmsis_scratch_buffer[SCRATCH_BUFFER_BYTES / 2] = {0}; - - TF_LITE_ENSURE(context, buf_size_bytes <= SCRATCH_BUFFER_BYTES); - *buf = cmsis_scratch_buffer; - return kTfLiteOk; -} diff --git a/tensorflow/lite/micro/kernels/cmsis-nn/softmax.cc b/tensorflow/lite/micro/kernels/cmsis-nn/softmax.cc index 8cfa5413ad1..108f0cfbf4c 100644 --- a/tensorflow/lite/micro/kernels/cmsis-nn/softmax.cc +++ b/tensorflow/lite/micro/kernels/cmsis-nn/softmax.cc @@ -25,35 +25,37 @@ namespace micro { namespace activations { namespace { -struct OpData { - int32_t input_multiplier = 0; - int input_left_shift = 0; - int32_t input_range_radius = 0; - int diff_min = 0; -}; - -TfLiteStatus CalculateSoftmaxOpData(TfLiteContext* context, +TfLiteStatus CalculateSoftmaxParams(TfLiteContext* context, const TfLiteTensor* input, TfLiteTensor* output, const TfLiteSoftmaxParams* params, - OpData* data) { + SoftmaxParams* op_data) { if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8) { if (input->type == kTfLiteUInt8) { + TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteUInt8); TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0); } else { + TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteInt8); + TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteInt8); TF_LITE_ENSURE_EQ(context, output->params.zero_point, -128); } - TF_LITE_ENSURE(context, (output->params.scale == 1.f / 256) || (output->params.scale == 1.f / 255)); static const int kScaledDiffIntegerBits = 5; + int input_left_shift; tflite::PreprocessSoftmaxScaling( params->beta, input->params.scale, kScaledDiffIntegerBits, - &data->input_multiplier, &data->input_left_shift); - data->diff_min = -1.0 * tflite::CalculateInputRadius( - kScaledDiffIntegerBits, data->input_left_shift); + &op_data->input_multiplier, &input_left_shift); + op_data->input_left_shift = input_left_shift; + op_data->diff_min = + -1.0 * tflite::CalculateInputRadius(kScaledDiffIntegerBits, + op_data->input_left_shift); + } else { + TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteFloat32); + TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteFloat32); + op_data->beta = static_cast(params->beta); } return kTfLiteOk; } @@ -75,26 +77,19 @@ TfLiteStatus SoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) { return kTfLiteOk; } -// Takes a 4D tensor and perform softmax along the forth dimension. +// Takes a tensor and performs softmax along the last dimension. void SoftmaxFloat(const TfLiteTensor* input, TfLiteTensor* output, - TfLiteSoftmaxParams* params) { - SoftmaxParams op_params; - op_params.beta = params->beta; + const SoftmaxParams& op_data) { tflite::reference_ops::Softmax( - op_params, GetTensorShape(input), GetTensorData(input), + op_data, GetTensorShape(input), GetTensorData(input), GetTensorShape(output), GetTensorData(output)); } void SoftmaxQuantized(const TfLiteTensor* input, TfLiteTensor* output, - TfLiteSoftmaxParams* params, OpData* data) { - SoftmaxParams op_params; - op_params.input_multiplier = data->input_multiplier; - op_params.input_left_shift = data->input_left_shift; - op_params.diff_min = data->diff_min; - + const SoftmaxParams& op_data) { if (input->type == kTfLiteUInt8) { tflite::reference_ops::Softmax( - op_params, GetTensorShape(input), GetTensorData(input), + op_data, GetTensorShape(input), GetTensorData(input), GetTensorShape(output), GetTensorData(output)); } else { const unsigned int num_dims = NumDimensions(input); @@ -106,30 +101,29 @@ void SoftmaxQuantized(const TfLiteTensor* input, TfLiteTensor* output, MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim); arm_softmax_s8(GetTensorData(input), outer_size, depth, - op_params.input_multiplier, op_params.input_left_shift, - op_params.diff_min, GetTensorData(output)); + op_data.input_multiplier, op_data.input_left_shift, + op_data.diff_min, GetTensorData(output)); } } TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) { - auto* params = reinterpret_cast(node->builtin_data); + auto* params = static_cast(node->builtin_data); const TfLiteTensor* input = GetInput(context, node, 0); TfLiteTensor* output = GetOutput(context, node, 0); - OpData local_data_object; - OpData* data = &local_data_object; + SoftmaxParams op_data; TF_LITE_ENSURE_STATUS( - CalculateSoftmaxOpData(context, input, output, params, data)); + CalculateSoftmaxParams(context, input, output, params, &op_data)); switch (input->type) { case kTfLiteFloat32: { - SoftmaxFloat(input, output, params); + SoftmaxFloat(input, output, op_data); return kTfLiteOk; } - case kTfLiteUInt8: - case kTfLiteInt8: { - SoftmaxQuantized(input, output, params, data); + case kTfLiteInt8: + case kTfLiteUInt8: { + SoftmaxQuantized(input, output, params, op_data); return kTfLiteOk; } default: diff --git a/tensorflow/lite/micro/kernels/dequantize.cc b/tensorflow/lite/micro/kernels/dequantize.cc index 21c34a4f8fd..7f25493dab4 100644 --- a/tensorflow/lite/micro/kernels/dequantize.cc +++ b/tensorflow/lite/micro/kernels/dequantize.cc @@ -49,11 +49,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { const TfLiteTensor* input = GetInput(context, node, 0); TfLiteTensor* output = GetOutput(context, node, 0); - tflite::DequantizationParams op_params; - op_params.zero_point = input->params.zero_point; - op_params.scale = static_cast(input->params.scale); - if (output->type == kTfLiteFloat32) { + tflite::DequantizationParams op_params; + op_params.zero_point = input->params.zero_point; + op_params.scale = static_cast(input->params.scale); switch (input->type) { case kTfLiteUInt8: reference_ops::Dequantize( @@ -77,24 +76,24 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { return kTfLiteError; } } else if (output->type == kTfLiteInt32) { + int32_t output_multiplier; + int output_shift; + const double effective_output_scale = + static_cast(input->params.scale) / + static_cast(output->params.scale); + QuantizeMultiplier(effective_output_scale, &output_multiplier, + &output_shift); + int flat_size = + MatchingFlatSize(GetTensorShape(input), GetTensorShape(output)); switch (input->type) { - // TODO(b/148749335): DequantizeInteger and Requantize are hacks here. case kTfLiteInt16: { - reference_ops::DequantizeInteger( - op_params, GetTensorShape(input), GetTensorData(input), - GetTensorShape(output), GetTensorData(output)); + reference_ops::Requantize( + GetTensorData(input), flat_size, output_multiplier, + output_shift, input->params.zero_point, output->params.zero_point, + GetTensorData(output)); break; } case kTfLiteInt8: { - int32_t output_multiplier; - int output_shift; - const double effective_output_scale = - static_cast(input->params.scale) / - static_cast(output->params.scale); - QuantizeMultiplier(effective_output_scale, &output_multiplier, - &output_shift); - int flat_size = - MatchingFlatSize(GetTensorShape(input), GetTensorShape(output)); reference_ops::Requantize( GetTensorData(input), flat_size, output_multiplier, output_shift, input->params.zero_point, output->params.zero_point, diff --git a/tensorflow/lite/micro/kernels/dequantize_test.cc b/tensorflow/lite/micro/kernels/dequantize_test.cc index b1c8f8499fe..5831791248c 100644 --- a/tensorflow/lite/micro/kernels/dequantize_test.cc +++ b/tensorflow/lite/micro/kernels/dequantize_test.cc @@ -189,4 +189,22 @@ TF_LITE_MICRO_TEST(DequantizeOpTestInt8ToInt32) { golden, output_scale, output_zero_point, output); } +TF_LITE_MICRO_TEST(DequantizeOpTestInt16ToInt32) { + const int length = 10; + const int dims[] = {2, 5, 2}; + const float input_float[] = {-63.5, -63, -62.5, -62, -61.5, + 62, 62.5, 63, 63.5, 64}; + const int32_t golden[] = {-630, -625, -620, -615, -610, + 625, 630, 635, 640, 645}; + const float input_scale = 0.5f; + const int input_zero_point = -1; + const float output_scale = 0.1f; + const int output_zero_point = 5; + int16_t input_quantized[length]; + int32_t output[length]; + tflite::testing::TestDequantizeToInt32( + dims, input_float, input_quantized, input_scale, input_zero_point, dims, + golden, output_scale, output_zero_point, output); +} + TF_LITE_MICRO_TESTS_END diff --git a/tensorflow/lite/micro/kernels/quantize.cc b/tensorflow/lite/micro/kernels/quantize.cc index 8ad69ce157b..e394a227a98 100644 --- a/tensorflow/lite/micro/kernels/quantize.cc +++ b/tensorflow/lite/micro/kernels/quantize.cc @@ -16,8 +16,10 @@ limitations under the License. #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/kernels/internal/quantization_util.h" +#include "tensorflow/lite/kernels/internal/reference/requantize.h" #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" #include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/micro/micro_utils.h" namespace tflite { namespace ops { @@ -82,13 +84,19 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { return kTfLiteError; } } else if (input->type == kTfLiteInt16) { + size_t size = ElementCount(*input->dims); + int32_t output_multiplier; + int output_shift; + double effective_scale = + static_cast(input->params.scale / output->params.scale); switch (output->type) { case kTfLiteInt8: - reference_ops::AffineQuantize( - op_params, GetTensorShape(input), GetTensorData(input), - GetTensorShape(output), GetTensorData(output)); + QuantizeMultiplier(effective_scale, &output_multiplier, &output_shift); + reference_ops::Requantize( + GetTensorData(input), size, output_multiplier, + output_shift, input->params.zero_point, output->params.zero_point, + GetTensorData(output)); break; - default: TF_LITE_KERNEL_LOG(context, "Input %s, output %s not supported.", TfLiteTypeGetName(input->type), diff --git a/tensorflow/lite/micro/kernels/quantize_test.cc b/tensorflow/lite/micro/kernels/quantize_test.cc index 869801afff8..359abbd73db 100644 --- a/tensorflow/lite/micro/kernels/quantize_test.cc +++ b/tensorflow/lite/micro/kernels/quantize_test.cc @@ -25,37 +25,15 @@ namespace testing { namespace { template -void TestQuantize(const int* input_dims_data, const float* input_data, - const int* output_dims_data, const float* golden, - T* golden_quantized, float scale, int zero_point, - T* output_data) { - TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data); - TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data); - const int output_dims_count = ElementCount(*output_dims); - - ::tflite::ops::micro::AllOpsResolver resolver; - - TfLiteTensor output_tensor = CreateQuantizedTensor( - output_data, output_dims, scale, zero_point, "output_tensor"); - - TfLiteAffineQuantization quant; - float scales[] = {1, scale}; - int zero_points[] = {1, zero_point}; - quant.scale = FloatArrayFromFloats(scales); - quant.zero_point = IntArrayFromInts(zero_points); - output_tensor.quantization = {kTfLiteAffineQuantization, &quant}; - - // 1 input, 1 output. - constexpr int tensors_size = 2; - TfLiteTensor tensors[tensors_size] = { - CreateFloatTensor(input_data, input_dims, "input_tensor"), - output_tensor, - }; - +void ValidateQuantizeGoldens(TfLiteTensor* tensors, int tensors_size, + const float* golden, T* golden_quantized, + float scale, int zero_point, int output_len, + T* output_data) { TfLiteContext context; PopulateContext(tensors, tensors_size, micro_test::reporter, &context); // Version 1 of quantize supports int8 and uint8 quantization. + ::tflite::ops::micro::AllOpsResolver resolver; const TfLiteRegistration* registration = resolver.FindOp(tflite::BuiltinOperator_QUANTIZE, 1); @@ -96,13 +74,77 @@ void TestQuantize(const int* input_dims_data, const float* input_data, } // Use reference quantization from test utils to compare against op output. - AsymmetricQuantize(golden, golden_quantized, output_dims_count, scale, - zero_point); - for (int i = 0; i < output_dims_count; ++i) { + AsymmetricQuantize(golden, golden_quantized, output_len, scale, zero_point); + for (int i = 0; i < output_len; ++i) { TF_LITE_MICRO_EXPECT_EQ(golden_quantized[i], output_data[i]); } } +template +void TestQuantizeFloat(const int* input_dims_data, const float* input_data, + const int* output_dims_data, const float* golden, + T* golden_quantized, const float scale, + const int zero_point, T* output_data) { + TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data); + TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data); + const int output_dims_count = ElementCount(*output_dims); + + TfLiteTensor output_tensor = CreateQuantizedTensor( + output_data, output_dims, scale, zero_point, "output_tensor"); + + TfLiteAffineQuantization quant; + float scales[] = {1, scale}; + int zero_points[] = {1, zero_point}; + quant.scale = FloatArrayFromFloats(scales); + quant.zero_point = IntArrayFromInts(zero_points); + output_tensor.quantization = {kTfLiteAffineQuantization, &quant}; + + // 1 input, 1 output. + constexpr int tensors_size = 2; + TfLiteTensor tensors[tensors_size] = { + CreateFloatTensor(input_data, input_dims, "input_tensor"), + output_tensor, + }; + + ValidateQuantizeGoldens(tensors, tensors_size, golden, golden_quantized, + scale, zero_point, output_dims_count, output_data); +} + +template +void TestQuantizeInt16(const int* input_dims_data, const float* input_data, + int16_t* input_quantized, const float input_scale, + const int input_zero_point, const int* output_dims_data, + const float* golden, T* golden_quantized, + const float output_scale, const int output_zero_point, + T* output_data) { + TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data); + TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data); + const int output_dims_count = ElementCount(*output_dims); + + TfLiteTensor output_tensor = + CreateQuantizedTensor(output_data, output_dims, output_scale, + output_zero_point, "output_tensor"); + + TfLiteAffineQuantization quant; + float scales[] = {1, output_scale}; + int zero_points[] = {1, output_zero_point}; + quant.scale = FloatArrayFromFloats(scales); + quant.zero_point = IntArrayFromInts(zero_points); + output_tensor.quantization = {kTfLiteAffineQuantization, &quant}; + + // 1 input, 1 output. + constexpr int tensors_size = 2; + TfLiteTensor tensors[tensors_size] = { + CreateQuantizedTensor(input_data, input_quantized, input_dims, + input_scale, input_zero_point, "input_tensor"), + output_tensor, + }; + + ValidateQuantizeGoldens(tensors, tensors_size, golden, golden_quantized, + output_scale, output_zero_point, output_dims_count, + output_data); +} + } // namespace } // namespace testing } // namespace tflite @@ -118,8 +160,8 @@ TF_LITE_MICRO_TEST(QuantizeOpTestUint8) { const int zero_point = 127; uint8_t output[length]; uint8_t values_quantized[length]; - tflite::testing::TestQuantize(dims, values, dims, values, values_quantized, - scale, zero_point, output); + tflite::testing::TestQuantizeFloat( + dims, values, dims, values, values_quantized, scale, zero_point, output); } TF_LITE_MICRO_TEST(QuantizeOpTestUint8NoScale) { @@ -131,8 +173,8 @@ TF_LITE_MICRO_TEST(QuantizeOpTestUint8NoScale) { const int zero_point = 127; uint8_t output[length]; uint8_t values_quantized[length]; - tflite::testing::TestQuantize(dims, values, dims, values, values_quantized, - scale, zero_point, output); + tflite::testing::TestQuantizeFloat( + dims, values, dims, values, values_quantized, scale, zero_point, output); } TF_LITE_MICRO_TEST(QuantizeOpTestInt8) { @@ -144,8 +186,8 @@ TF_LITE_MICRO_TEST(QuantizeOpTestInt8) { const int zero_point = -1; uint8_t output[length]; uint8_t values_quantized[length]; - tflite::testing::TestQuantize(dims, values, dims, values, values_quantized, - scale, zero_point, output); + tflite::testing::TestQuantizeFloat( + dims, values, dims, values, values_quantized, scale, zero_point, output); } TF_LITE_MICRO_TEST(QuantizeOpTestInt8NoScale) { @@ -157,8 +199,24 @@ TF_LITE_MICRO_TEST(QuantizeOpTestInt8NoScale) { const int zero_point = 0; uint8_t output[length]; uint8_t values_quantized[length]; - tflite::testing::TestQuantize(dims, values, dims, values, values_quantized, - scale, zero_point, output); + tflite::testing::TestQuantizeFloat( + dims, values, dims, values, values_quantized, scale, zero_point, output); +} + +TF_LITE_MICRO_TEST(QuantizeOpTestInt16toInt8) { + const int length = 10; + const int dims[] = {2, 2, 5}; + const float values[] = {-64, -62, -60, -58, -56, 54, 56, 58, 60, 62}; + const float input_scale = 2.f; + const int input_zero_point = 0; + const float output_scale = 0.5; + const int output_zero_point = 0; + int8_t output_quantized[length]; + int16_t input_quantized[length]; + tflite::testing::TestQuantizeInt16(dims, values, input_quantized, input_scale, + input_zero_point, dims, values, + output_quantized, output_scale, + output_zero_point, output_quantized); } TF_LITE_MICRO_TESTS_END diff --git a/tensorflow/lite/micro/kernels/softmax.cc b/tensorflow/lite/micro/kernels/softmax.cc index c213d6646ed..1c95d6db8fb 100644 --- a/tensorflow/lite/micro/kernels/softmax.cc +++ b/tensorflow/lite/micro/kernels/softmax.cc @@ -29,27 +29,23 @@ namespace micro { namespace activations { namespace { -struct OpData { - int32_t input_multiplier = 0; - int input_left_shift = 0; - int32_t input_range_radius = 0; - int diff_min = 0; -}; - -TfLiteStatus CalculateSoftmaxOpData(TfLiteContext* context, +TfLiteStatus CalculateSoftmaxParams(TfLiteContext* context, const TfLiteTensor* input, TfLiteTensor* output, const TfLiteSoftmaxParams* params, - OpData* data) { + SoftmaxParams* op_data) { if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8) { if (input->type == kTfLiteUInt8) { + TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteUInt8); TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0); } else { + TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteInt8); if (output->type == kTfLiteInt16) { TF_LITE_ENSURE_EQ(context, output->params.zero_point, -32768); // NOTE: Current int16 softmax output does not require symmetric scaling // - so no need to verify scale here. } else { + TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteInt8); TF_LITE_ENSURE_EQ(context, output->params.zero_point, -128); TF_LITE_ENSURE(context, output->params.scale == 1.f / 256); } @@ -57,12 +53,19 @@ TfLiteStatus CalculateSoftmaxOpData(TfLiteContext* context, static const int kScaledDiffIntegerBits = 5; + int input_left_shift; tflite::PreprocessSoftmaxScaling( static_cast(params->beta), static_cast(input->params.scale), kScaledDiffIntegerBits, - &data->input_multiplier, &data->input_left_shift); - data->diff_min = -1.0 * tflite::CalculateInputRadius( - kScaledDiffIntegerBits, data->input_left_shift); + &op_data->input_multiplier, &input_left_shift); + op_data->input_left_shift = input_left_shift; + op_data->diff_min = + -1.0 * tflite::CalculateInputRadius(kScaledDiffIntegerBits, + op_data->input_left_shift); + } else { + TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteFloat32); + TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteFloat32); + op_data->beta = static_cast(params->beta); } return kTfLiteOk; } @@ -86,56 +89,49 @@ TfLiteStatus SoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) { // Takes a tensor and performs softmax along the last dimension. void SoftmaxFloat(const TfLiteTensor* input, TfLiteTensor* output, - TfLiteSoftmaxParams* params) { - SoftmaxParams op_params; - op_params.beta = static_cast(params->beta); + const SoftmaxParams& op_data) { tflite::reference_ops::Softmax( - op_params, GetTensorShape(input), GetTensorData(input), + op_data, GetTensorShape(input), GetTensorData(input), GetTensorShape(output), GetTensorData(output)); } void SoftmaxQuantized(const TfLiteTensor* input, TfLiteTensor* output, - TfLiteSoftmaxParams* params, OpData* data) { - SoftmaxParams op_params; - op_params.input_multiplier = data->input_multiplier; - op_params.input_left_shift = data->input_left_shift; - op_params.diff_min = data->diff_min; + const SoftmaxParams& op_data) { if (input->type == kTfLiteUInt8) { tflite::reference_ops::Softmax( - op_params, GetTensorShape(input), GetTensorData(input), + op_data, GetTensorShape(input), GetTensorData(input), GetTensorShape(output), GetTensorData(output)); } else { if (output->type == kTfLiteInt16) { tflite::reference_ops::Softmax( - op_params, GetTensorShape(input), GetTensorData(input), + op_data, GetTensorShape(input), GetTensorData(input), GetTensorShape(output), GetTensorData(output)); } else { tflite::reference_ops::Softmax( - op_params, GetTensorShape(input), GetTensorData(input), + op_data, GetTensorShape(input), GetTensorData(input), GetTensorShape(output), GetTensorData(output)); } } } TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) { - auto* params = reinterpret_cast(node->builtin_data); + auto* params = static_cast(node->builtin_data); const TfLiteTensor* input = GetInput(context, node, 0); TfLiteTensor* output = GetOutput(context, node, 0); - OpData local_data_object; - OpData* data = &local_data_object; + SoftmaxParams op_data; TF_LITE_ENSURE_STATUS( - CalculateSoftmaxOpData(context, input, output, params, data)); + CalculateSoftmaxParams(context, input, output, params, &op_data)); switch (input->type) { case kTfLiteFloat32: { - SoftmaxFloat(input, output, params); + SoftmaxFloat(input, output, op_data); return kTfLiteOk; } case kTfLiteInt8: case kTfLiteUInt8: { - SoftmaxQuantized(input, output, params, data); + SoftmaxQuantized(input, output, op_data); return kTfLiteOk; } default: @@ -149,11 +145,14 @@ TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) { } // namespace activations TfLiteRegistration* Register_SOFTMAX() { - static TfLiteRegistration r = {}; - r.init = activations::Init; - r.free = activations::Free; - r.prepare = activations::SoftmaxPrepare; - r.invoke = activations::SoftmaxEval; + static TfLiteRegistration r = {activations::Init, + activations::Free, + activations::SoftmaxPrepare, + activations::SoftmaxEval, + nullptr, + 0, + nullptr, + 0}; return &r; } diff --git a/tensorflow/lite/micro/kernels/xtensa_hifi/activations.cc b/tensorflow/lite/micro/kernels/xtensa_hifi/activations.cc new file mode 100644 index 00000000000..d22912959ae --- /dev/null +++ b/tensorflow/lite/micro/kernels/xtensa_hifi/activations.cc @@ -0,0 +1,241 @@ +/****************************************************************************** + * Copyright (C) 2019 Cadence Design Systems, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files (the + * "Software"), to use this Software with Cadence processor cores only and + * not with any other processors and platforms, subject to + * the following conditions: + * + * The above copyright notice and this permission notice shall be included + * in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + ******************************************************************************/ + +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/kernels/internal/common.h" +#include "tensorflow/lite/kernels/internal/quantization_util.h" +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/op_macros.h" +#include "tensorflow/lite/micro/micro_utils.h" +#include "xtensa_tf_micro_common.h" + +namespace tflite { +namespace ops { +namespace micro { +namespace activations { + +constexpr int kInputTensor = 0; +constexpr int kOutputTensor = 0; + +template +inline void ReluQuantized(int32_t lower, const RuntimeShape& input_shape, + const Q* input_data, const RuntimeShape& output_shape, + Q* output_data) { + const int flat_size = MatchingFlatSize(input_shape, output_shape); + for (int i = 0; i < flat_size; ++i) { + const Q val = input_data[i]; + const Q clamped = val < lower ? lower : val; + output_data[i] = clamped; + } +} + +inline void ReluFloat(const RuntimeShape& input_shape, const float* input_data, + const RuntimeShape& output_shape, float* output_data) { + const int flat_size = MatchingFlatSize(input_shape, output_shape); + for (int i = 0; i < flat_size; ++i) { + const float val = input_data[i]; + const float lower = 0.0f; + const float clamped = val < lower ? lower : val; + output_data[i] = clamped; + } +} + +inline void Relu6Float(const RuntimeShape& input_shape, const float* input_data, + const RuntimeShape& output_shape, float* output_data) { + const int flat_size = MatchingFlatSize(input_shape, output_shape); + for (int i = 0; i < flat_size; ++i) { + const float val = input_data[i]; + const float upper = 6.0f; + const float lower = 0.0f; + const float clamped = val > upper ? upper : val < lower ? lower : val; + output_data[i] = clamped; + } +} + +template +inline void Relu6Quantized(Q lower, Q upper, const RuntimeShape& input_shape, + const Q* input_data, + const RuntimeShape& output_shape, Q* output_data) { + const int flat_size = MatchingFlatSize(input_shape, output_shape); + for (int i = 0; i < flat_size; ++i) { + const Q val = input_data[i]; + const Q clamped = val > upper ? upper : val < lower ? lower : val; + output_data[i] = clamped; + } +} + +TfLiteStatus ReluPrepare(TfLiteContext* context, TfLiteNode* node) { + return kTfLiteOk; +} + +TfLiteStatus ReluEval(TfLiteContext* context, TfLiteNode* node) { + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + switch (input->type) { + case kTfLiteFloat32: { + int err; + const float* inp_data_ptr; + float* out_data_ptr; + const RuntimeShape& input_shape = GetTensorShape(input); + const RuntimeShape& output_shape = GetTensorShape(output); + const int flat_size = MatchingFlatSize(input_shape, output_shape); + + inp_data_ptr = GetTensorData(input); + out_data_ptr = GetTensorData(output); + + const float f32_pos_inf = 0x7F800000; + err = xa_nn_vec_relu_f32_f32(out_data_ptr, inp_data_ptr, f32_pos_inf, + flat_size); + + CHECK_ERR_HIFI_NNLIB_KER(err, "xa_nn_vec_relu1_f32_f32 failed"); + return kTfLiteOk; + } + case kTfLiteInt8: { + ReluQuantized(input->params.zero_point, GetTensorShape(input), + GetTensorData(input), + GetTensorShape(output), + GetTensorData(output)); + return kTfLiteOk; + } + case kTfLiteUInt8: { + int err; + const uint8_t* inp_data_ptr; + uint8_t* out_data_ptr; + const RuntimeShape& input_shape = GetTensorShape(input); + const RuntimeShape& output_shape = GetTensorShape(output); + const int flat_size = MatchingFlatSize(input_shape, output_shape); + + inp_data_ptr = GetTensorData(input); + out_data_ptr = GetTensorData(output); + + err = xa_nn_vec_activation_min_max_asym8_asym8( + out_data_ptr, inp_data_ptr, 0, 255, flat_size); // Is 255 right? + + CHECK_ERR_HIFI_NNLIB_KER(err, "xa_nn_vec_activation_min_max_8_8 failed"); + return kTfLiteOk; + } + default: { + TF_LITE_KERNEL_LOG(context, "Only float32 is supported currently, got %s", + TfLiteTypeGetName(input->type)); + return kTfLiteError; + } + } +} + +TfLiteStatus Relu6Prepare(TfLiteContext* context, TfLiteNode* node) { + return kTfLiteOk; +} + +TfLiteStatus Relu6Eval(TfLiteContext* context, TfLiteNode* node) { + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + switch (input->type) { + case kTfLiteFloat32: { + int err; + const float* inp_data_ptr; + float* out_data_ptr; + const RuntimeShape& input_shape = GetTensorShape(input); + const RuntimeShape& output_shape = GetTensorShape(output); + const int flat_size = MatchingFlatSize(input_shape, output_shape); + + inp_data_ptr = GetTensorData(input); + out_data_ptr = GetTensorData(output); + + err = xa_nn_vec_relu6_f32_f32(out_data_ptr, inp_data_ptr, flat_size); + + CHECK_ERR_HIFI_NNLIB_KER(err, "xa_nn_vec_relu1_f32_f32 failed"); + return kTfLiteOk; + } + case kTfLiteInt8: { + const int8_t six = FloatToAsymmetricQuantizedInt8( + 6.0f, input->params.scale, input->params.zero_point); + const int8_t zero = input->params.zero_point; + Relu6Quantized( + zero, six, GetTensorShape(input), GetTensorData(input), + GetTensorShape(output), GetTensorData(output)); + return kTfLiteOk; + } + case kTfLiteUInt8: { + const uint8_t six = FloatToAsymmetricQuantizedUInt8( + 6.0f, input->params.scale, input->params.zero_point); + const uint8_t zero = input->params.zero_point; + int err; + const uint8_t* inp_data_ptr; + uint8_t* out_data_ptr; + const RuntimeShape& input_shape = GetTensorShape(input); + const RuntimeShape& output_shape = GetTensorShape(output); + const int flat_size = MatchingFlatSize(input_shape, output_shape); + + inp_data_ptr = GetTensorData(input); + out_data_ptr = GetTensorData(output); + + err = xa_nn_vec_activation_min_max_asym8_asym8(out_data_ptr, inp_data_ptr, + zero, six, flat_size); + + CHECK_ERR_HIFI_NNLIB_KER(err, "xa_nn_vec_activation_min_max_8_8 failed"); + return kTfLiteOk; + } + default: { + TF_LITE_KERNEL_LOG(context, "Only float32 is supported currently, got %s", + TfLiteTypeGetName(input->type)); + return kTfLiteError; + } + } +} + +} // namespace activations + +TfLiteRegistration* Register_RELU() { + static TfLiteRegistration r = {}; + r.prepare = activations::ReluPrepare; + r.invoke = activations::ReluEval; + return &r; +} + +TfLiteRegistration* Register_RELU6() { + static TfLiteRegistration r = {}; + r.prepare = activations::Relu6Prepare; + r.invoke = activations::Relu6Eval; + return &r; +} + +} // namespace micro +} // namespace ops +} // namespace tflite diff --git a/tensorflow/lite/micro/kernels/xtensa_hifi/conv.cc b/tensorflow/lite/micro/kernels/xtensa_hifi/conv.cc new file mode 100755 index 00000000000..9536d5ec0e2 --- /dev/null +++ b/tensorflow/lite/micro/kernels/xtensa_hifi/conv.cc @@ -0,0 +1,549 @@ +/****************************************************************************** + * Copyright (C) 2019 Cadence Design Systems, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files (the + * "Software"), to use this Software with Cadence processor cores only and + * not with any other processors and platforms, subject to + * the following conditions: + * + * The above copyright notice and this permission notice shall be included + * in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + ******************************************************************************/ + +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/kernels/internal/reference/conv.h" + +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/kernels/internal/common.h" +#include "tensorflow/lite/kernels/internal/quantization_util.h" +#include "tensorflow/lite/kernels/internal/reference/integer_ops/conv.h" +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/padding.h" +#include "xtensa_tf_micro_common.h" + +namespace tflite { +namespace ops { +namespace micro { +namespace conv { + +constexpr int kInputTensor = 0; +constexpr int kFilterTensor = 1; +constexpr int kBiasTensor = 2; +constexpr int kOutputTensor = 0; +constexpr int kMaxChannels = 256; + +// Conv is quantized along dimension 0: +// https://www.tensorflow.org/lite/performance/quantization_spec +constexpr int kConvQuantizedDimension = 0; + +// This file has 2 implementation of Conv. + +struct OpData { + TfLitePaddingValues padding; + // The scaling factor from input to output (aka the 'real multiplier') can + // be represented as a fixed point multiplier plus a left shift. + int32_t output_multiplier; + int output_shift; + + // Per channel output multiplier and shift. + // (b/141139247): Allocate these dynamically when possible. + int32_t per_channel_output_multiplier[kMaxChannels]; + int32_t per_channel_output_shift[kMaxChannels]; + + // The range of the fused activation layer. For example for kNone and + // uint8_t these would be 0 and 255. + int32_t output_activation_min; + int32_t output_activation_max; +}; + +inline PaddingType RuntimePaddingType(TfLitePadding padding) { + switch (padding) { + case TfLitePadding::kTfLitePaddingSame: + return PaddingType::kSame; + case TfLitePadding::kTfLitePaddingValid: + return PaddingType::kValid; + case TfLitePadding::kTfLitePaddingUnknown: + default: + return PaddingType::kNone; + } +} + +TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node, + TfLiteConvParams* params, int width, int height, + int filter_width, int filter_height, int out_width, + int out_height, const TfLiteType data_type, + OpData* data) { + bool has_bias = node->inputs->size == 3; + // Check number of inputs/outputs + TF_LITE_ENSURE(context, has_bias || node->inputs->size == 2); + TF_LITE_ENSURE_EQ(context, node->outputs->size, 1); + + // Matching GetWindowedOutputSize in TensorFlow. + auto padding = params->padding; + data->padding = ComputePaddingHeightWidth( + params->stride_height, params->stride_width, + params->dilation_height_factor, params->dilation_width_factor, height, + width, filter_height, filter_width, padding, &out_height, &out_width); + + // Note that quantized inference requires that all tensors have their + // parameters set. This is usually done during quantized training. + if (data_type != kTfLiteFloat32) { + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + const TfLiteTensor* filter = GetInput(context, node, kFilterTensor); + const TfLiteTensor* bias = + GetOptionalInputTensor(context, node, kBiasTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + int output_channels = filter->dims->data[kConvQuantizedDimension]; + + TF_LITE_ENSURE_STATUS(tflite::PopulateConvolutionQuantizationParams( + context, input, filter, bias, output, params->activation, + &data->output_multiplier, &data->output_shift, + &data->output_activation_min, &data->output_activation_max, + data->per_channel_output_multiplier, + reinterpret_cast(data->per_channel_output_shift), + output_channels)); + } + return kTfLiteOk; +} + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + return nullptr; +} + +void Free(TfLiteContext* context, void* buffer) {} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + return kTfLiteOk; +} + +TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node, + TfLiteConvParams* params, OpData* data, + const TfLiteTensor* input, + const TfLiteTensor* filter, const TfLiteTensor* bias, + TfLiteTensor* im2col, TfLiteTensor* hwcn_weights, + TfLiteTensor* output) { + const int32_t input_offset = -input->params.zero_point; + const int32_t filter_offset = -filter->params.zero_point; + const int32_t output_offset = output->params.zero_point; + + if ((params->dilation_width_factor == 1) && + (params->dilation_height_factor == 1)) { + const uint8 *input_data, *filter_data; + const int32_t* bias_data; + uint8* output_data; + const RuntimeShape& input_shape = GetTensorShape(input); + const RuntimeShape& filter_shape = GetTensorShape(filter); + const RuntimeShape& output_shape = GetTensorShape(output); + const RuntimeShape& bias_shape = GetTensorShape(bias); + + input_data = GetTensorData(input); + filter_data = GetTensorData(filter); + bias_data = GetTensorData(bias); + output_data = GetTensorData(output); + + const int stride_width = params->stride_width; + const int stride_height = params->stride_height; + const int dilation_width_factor = 1; + const int dilation_height_factor = 1; + const int pad_width = data->padding.width; + const int pad_height = data->padding.height; + const int32 output_activation_min = data->output_activation_min; + const int32 output_activation_max = data->output_activation_max; + const int32 output_multiplier = data->output_multiplier; + const int output_shift = -data->output_shift; + TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4); + TFLITE_DCHECK_LE(output_activation_min, output_activation_max); + + const int batches = MatchingDim(input_shape, 0, output_shape, 0); + const int input_depth = MatchingDim(input_shape, 3, filter_shape, 3); + const int output_depth = MatchingDim(filter_shape, 0, output_shape, 3); + if (bias_data) { + TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth); + } + const int input_height = input_shape.Dims(1); + const int input_width = input_shape.Dims(2); + const int filter_height = filter_shape.Dims(1); + const int filter_width = filter_shape.Dims(2); + const int output_height = output_shape.Dims(1); + const int output_width = output_shape.Dims(2); + const int filter_depth = filter_shape.Dims(3); + + int err, output_data_format = 0; + void* p_scratch; + uint8 *p_filter, *p_out_scratch; + // Calculate filter_depth_padded as next near multiple of 4 + int filter_depth_padded = (filter_depth + 3) & (~3); + int out_length = output_height * output_width * output_depth; + int required_scratch, input_precision = PREC_ASYM8; + int h, w, c; + + required_scratch = xa_nn_conv2d_std_getsize( + input_height, input_depth, filter_height, filter_width, stride_height, + pad_height, output_height, input_precision); + + if (required_scratch <= 0) { + TF_LITE_KERNEL_LOG(context, + "conv2d_std_asym8: xa_nn_conv2d_std_getsize failed"); + return kTfLiteError; + } + + ALLOCATE_XTENSA_NNLIB_SCRATCH_MEM; + p_scratch = xtensa_nnlib_scratch_buf; + + p_filter = (uint8*)p_scratch; + p_out_scratch = + (p_filter + + ALIGNED_SIZE((sizeof(uint8_t) * filter_height * filter_width * + filter_depth_padded * output_depth), + 8)); + required_scratch += + ALIGNED_SIZE((sizeof(uint8_t) * filter_height * filter_width * + filter_depth_padded * output_depth), + 8); + p_scratch = + (uint8*)(p_out_scratch + ALIGNED_SIZE(sizeof(uint8_t) * out_length, 8)); + required_scratch += ALIGNED_SIZE(sizeof(uint8_t) * out_length, 8); + + if (required_scratch > (int)XTENSA_NNLIB_MAX_SCRATCH_SIZE) { + TF_LITE_KERNEL_LOG(context, + "conv2d_std_asym8: insufficient scratch memory"); + return kTfLiteError; + } + + // Padding filter coefficients depthwise + for (h = 0; h < filter_height * filter_width * output_depth; h++) { + for (c = 0; c < filter_depth; c++) { + p_filter[h * filter_depth_padded + c] = + filter_data[h * filter_depth + c]; + } + for (c = input_depth; c < filter_depth_padded; c++) { + p_filter[h * filter_depth_padded + c] = + -filter_offset; // filter_depth[h*input_depth + c]; + } + } + + for (int batch = 0; batch < batches; ++batch) { + uint8* p_out_temp; + p_out_temp = (uint8*)&p_out_scratch[0]; + p_out_temp = (uint8*)ALIGN_PTR(p_out_temp, 8); + + err = xa_nn_conv2d_std_asym8xasym8( + p_out_temp, + &input_data[batch * input_height * input_width * input_depth], + p_filter, // filter_data, + bias_data, input_height, input_width, input_depth, filter_height, + filter_width, output_depth, stride_width, stride_height, pad_width, + pad_height, output_height, output_width, input_offset, filter_offset, + output_multiplier, output_shift, output_offset, output_data_format, + p_scratch); + + CHECK_ERR_HIFI_NNLIB_KER( + err, "conv2d_std_asym8: xa_nn_conv2d_std_asym8xasym8 failed"); + + for (int i = 0; i < out_length; i++) { + uint8* p_temp; + p_temp = &output_data[batch * out_length]; + + ACTIVATION_MIN_MAX_ASYM8(p_temp[i], p_out_temp[i], + output_activation_min, output_activation_max) + } + } + } else { + ConvParams op_params; + op_params.padding_type = RuntimePaddingType(params->padding); + op_params.padding_values.width = data->padding.width; + op_params.padding_values.height = data->padding.height; + op_params.stride_width = params->stride_width; + op_params.stride_height = params->stride_height; + op_params.dilation_width_factor = params->dilation_width_factor; + op_params.dilation_height_factor = params->dilation_height_factor; + op_params.input_offset = input_offset; + op_params.weights_offset = filter_offset; + op_params.output_offset = output_offset; + op_params.output_multiplier = data->output_multiplier; + op_params.output_shift = -data->output_shift; + op_params.quantized_activation_min = data->output_activation_min; + op_params.quantized_activation_max = data->output_activation_max; + reference_ops::Conv(op_params, GetTensorShape(input), + GetTensorData(input), GetTensorShape(filter), + GetTensorData(filter), GetTensorShape(bias), + GetTensorData(bias), GetTensorShape(output), + GetTensorData(output), GetTensorShape(im2col), + GetTensorData(im2col), nullptr); + } + return kTfLiteOk; +} + +void EvalQuantizedPerChannel(TfLiteContext* context, TfLiteNode* node, + TfLiteConvParams* params, OpData* data, + const TfLiteTensor* input, + const TfLiteTensor* filter, + const TfLiteTensor* bias, TfLiteTensor* output, + TfLiteTensor* im2col) { + ConvParams op_params; + op_params.input_offset = -input->params.zero_point; + op_params.output_offset = output->params.zero_point; + op_params.stride_height = params->stride_height; + op_params.stride_width = params->stride_width; + op_params.dilation_height_factor = params->dilation_height_factor; + op_params.dilation_width_factor = params->dilation_width_factor; + op_params.padding_values.height = data->padding.height; + op_params.padding_values.width = data->padding.width; + op_params.quantized_activation_min = data->output_activation_min; + op_params.quantized_activation_max = data->output_activation_max; + + reference_integer_ops::ConvPerChannel( + op_params, data->per_channel_output_multiplier, + data->per_channel_output_shift, GetTensorShape(input), + GetTensorData(input), GetTensorShape(filter), + GetTensorData(filter), GetTensorShape(bias), + GetTensorData(bias), GetTensorShape(output), + GetTensorData(output)); +} + +TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node, + TfLiteConvParams* params, OpData* data, + const TfLiteTensor* input, const TfLiteTensor* filter, + const TfLiteTensor* bias, TfLiteTensor* im2col, + TfLiteTensor* hwcn_weights, TfLiteTensor* output) { + float output_activation_min, output_activation_max; + CalculateActivationRange(params->activation, &output_activation_min, + &output_activation_max); + + if ((params->dilation_width_factor == 1) && + (params->dilation_height_factor == 1)) { + const float *input_data, *filter_data; + const float* bias_data; + float* output_data; + const RuntimeShape& input_shape = GetTensorShape(input); + const RuntimeShape& filter_shape = GetTensorShape(filter); + const RuntimeShape& output_shape = GetTensorShape(output); + const RuntimeShape& bias_shape = GetTensorShape(bias); + + input_data = GetTensorData(input); + filter_data = GetTensorData(filter); + bias_data = GetTensorData(bias); + output_data = GetTensorData(output); + + const int stride_width = params->stride_width; + const int stride_height = params->stride_height; + const int dilation_width_factor = 1; + const int dilation_height_factor = 1; + const int pad_width = data->padding.width; + const int pad_height = data->padding.height; + TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4); + + const int batches = MatchingDim(input_shape, 0, output_shape, 0); + const int input_depth = MatchingDim(input_shape, 3, filter_shape, 3); + const int output_depth = MatchingDim(filter_shape, 0, output_shape, 3); + if (bias_data) { + TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth); + } + const int input_height = input_shape.Dims(1); + const int input_width = input_shape.Dims(2); + const int filter_height = filter_shape.Dims(1); + const int filter_width = filter_shape.Dims(2); + const int output_height = output_shape.Dims(1); + const int output_width = output_shape.Dims(2); + const int filter_depth = filter_shape.Dims(3); + int err, output_data_format = 0; + void* p_scratch; + float *p_filter, *p_out_scratch; + // Calculate filter_depth_padded as next near multiple of 2 + int filter_depth_padded = (filter_depth + 1) & (~1); + int out_length = output_height * output_width * output_depth; + int required_scratch, input_precision = PREC_F32; + int h, w, c; + + required_scratch = xa_nn_conv2d_std_getsize( + input_height, input_depth, filter_height, filter_width, stride_height, + pad_height, output_height, input_precision); + + if (required_scratch <= 0) { + TF_LITE_KERNEL_LOG(context, + "conv2d_std_f32: xa_nn_conv2d_std_getsize failed"); + return kTfLiteError; + } + + ALLOCATE_XTENSA_NNLIB_SCRATCH_MEM; + p_scratch = xtensa_nnlib_scratch_buf; + + p_filter = (float*)p_scratch; + p_out_scratch = + (float*)((uint8_t*)p_filter + + ALIGNED_SIZE((sizeof(float) * filter_height * filter_width * + filter_depth_padded * output_depth), + 8)); + required_scratch += + ALIGNED_SIZE((sizeof(float) * filter_height * filter_width * + filter_depth_padded * output_depth), + 8); + p_scratch = (float*)((uint8_t*)p_out_scratch + + ALIGNED_SIZE(sizeof(float) * out_length, 8)); + required_scratch += ALIGNED_SIZE(sizeof(float) * out_length, 8); + + if (required_scratch > (int)XTENSA_NNLIB_MAX_SCRATCH_SIZE) { + TF_LITE_KERNEL_LOG(context, + "conv2d_std_f32: insufficient scratch memory"); + return kTfLiteError; + } + + // Padding filter coefficients depthwise + for (h = 0; h < filter_height * filter_width * output_depth; h++) { + for (c = 0; c < filter_depth; c++) { + p_filter[h * filter_depth_padded + c] = + filter_data[h * filter_depth + c]; + } + for (c = input_depth; c < filter_depth_padded; c++) { + p_filter[h * filter_depth_padded + c] = 0; + } + } + + for (int batch = 0; batch < batches; ++batch) { + float* p_out_temp; + p_out_temp = (float*)&p_out_scratch[0]; + p_out_temp = (float*)ALIGN_PTR(p_out_temp, 8); + + err = xa_nn_conv2d_std_f32( + p_out_temp, + &input_data[batch * input_height * input_width * input_depth], + p_filter, bias_data, input_height, input_width, input_depth, + filter_height, filter_width, output_depth, stride_width, + stride_height, pad_width, pad_height, output_height, output_width, + output_data_format, p_scratch); + + CHECK_ERR_HIFI_NNLIB_KER( + err, "conv2d_std_f32: xa_nn_conv2d_std_f32xf32 failed"); + + for (int i = 0; i < out_length; i++) { + float* p_temp; + p_temp = &output_data[batch * out_length]; + ACTIVATION_MIN_MAX(float, p_temp[i], p_out_temp[i], + output_activation_min, output_activation_max) + } + } + } else { + ConvParams op_params; + op_params.padding_type = RuntimePaddingType(params->padding); + op_params.padding_values.width = data->padding.width; + op_params.padding_values.height = data->padding.height; + op_params.stride_width = params->stride_width; + op_params.stride_height = params->stride_height; + op_params.dilation_width_factor = params->dilation_width_factor; + op_params.dilation_height_factor = params->dilation_height_factor; + op_params.float_activation_min = output_activation_min; + op_params.float_activation_max = output_activation_max; + + reference_ops::Conv(op_params, GetTensorShape(input), + GetTensorData(input), GetTensorShape(filter), + GetTensorData(filter), GetTensorShape(bias), + GetTensorData(bias), GetTensorShape(output), + GetTensorData(output), GetTensorShape(im2col), + GetTensorData(im2col)); + } + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); + + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + const TfLiteTensor* filter = GetInput(context, node, kFilterTensor); + const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor); + + int input_width = input->dims->data[2]; + int input_height = input->dims->data[1]; + int filter_width = filter->dims->data[2]; + int filter_height = filter->dims->data[1]; + int output_width = output->dims->data[2]; + int output_height = output->dims->data[1]; + + OpData data; + + // All per-channel quantized tensors need valid zero point and scale arrays. + if (input->type == kTfLiteInt8) { + TF_LITE_ENSURE_EQ(context, filter->quantization.type, + kTfLiteAffineQuantization); + + const auto* affine_quantization = + reinterpret_cast( + filter->quantization.params); + TF_LITE_ENSURE(context, affine_quantization); + TF_LITE_ENSURE(context, affine_quantization->scale); + TF_LITE_ENSURE(context, affine_quantization->zero_point); + + TF_LITE_ENSURE(context, + affine_quantization->scale->size == 1 || + affine_quantization->scale->size == + filter->dims->data[kConvQuantizedDimension]); + TF_LITE_ENSURE_EQ(context, affine_quantization->scale->size, + affine_quantization->zero_point->size); + } + + TF_LITE_ENSURE_STATUS(CalculateOpData( + context, node, params, input_width, input_height, filter_width, + filter_height, output_width, output_height, input->type, &data)); + + switch (input->type) { // Already know in/out types are same. + case kTfLiteFloat32: + EvalFloat(context, node, params, &data, input, filter, bias, nullptr, + nullptr, output); + break; + case kTfLiteInt8: + EvalQuantizedPerChannel(context, node, params, &data, input, filter, bias, + output, nullptr); + break; + case kTfLiteUInt8: + EvalQuantized(context, node, params, &data, input, filter, bias, nullptr, + nullptr, output); + break; + default: + TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.", + TfLiteTypeGetName(input->type), input->type); + return kTfLiteError; + } + return kTfLiteOk; +} + +} // namespace conv + +TfLiteRegistration* Register_CONV_2D() { + static TfLiteRegistration r = {}; + r.prepare = conv::Prepare; + r.invoke = conv::Eval; + return &r; +} + +} // namespace micro +} // namespace ops +} // namespace tflite diff --git a/tensorflow/lite/micro/kernels/xtensa_hifi/depthwise_conv.cc b/tensorflow/lite/micro/kernels/xtensa_hifi/depthwise_conv.cc new file mode 100755 index 00000000000..fdb66d6cbc4 --- /dev/null +++ b/tensorflow/lite/micro/kernels/xtensa_hifi/depthwise_conv.cc @@ -0,0 +1,560 @@ +/****************************************************************************** + * Copyright (C) 2019 Cadence Design Systems, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files (the + * "Software"), to use this Software with Cadence processor cores only and + * not with any other processors and platforms, subject to + * the following conditions: + * + * The above copyright notice and this permission notice shall be included + * in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + ******************************************************************************/ + +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/kernels/internal/reference/integer_ops/depthwise_conv.h" + +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/kernels/internal/common.h" +#include "tensorflow/lite/kernels/internal/quantization_util.h" +#include "tensorflow/lite/kernels/internal/reference/depthwiseconv_float.h" +#include "tensorflow/lite/kernels/internal/reference/depthwiseconv_uint8.h" +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/padding.h" +#include "xtensa_tf_micro_common.h" + +namespace tflite { +namespace ops { +namespace micro { +namespace depthwise_conv { +namespace { + +constexpr int kInputTensor = 0; +constexpr int kFilterTensor = 1; +constexpr int kBiasTensor = 2; +constexpr int kOutputTensor = 0; +constexpr int kMaxChannels = 256; + +// Depthwise conv is quantized along dimension 3: +// https://www.tensorflow.org/lite/performance/quantization_spec +constexpr int kDepthwiseConvQuantizedDimension = 3; + +struct OpData { + TfLitePaddingValues padding; + // The scaling factor from input to output (aka the 'real multiplier') can + // be represented as a fixed point multiplier plus a left shift. + int32_t output_multiplier; + int output_shift; + + // Per channel output multiplier and shift. + // (b/141139247): Allocate these dynamically when possible. + int32_t per_channel_output_multiplier[kMaxChannels]; + int32_t per_channel_output_shift[kMaxChannels]; + + // The range of the fused activation layer. For example for kNone and + // uint8_t these would be 0 and 255. + int32_t output_activation_min; + int32_t output_activation_max; +}; + +TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node, + TfLiteDepthwiseConvParams* params, int width, + int height, int filter_width, int filter_height, + const TfLiteType data_type, OpData* data) { + bool has_bias = node->inputs->size == 3; + // Check number of inputs/outputs + TF_LITE_ENSURE(context, has_bias || node->inputs->size == 2); + TF_LITE_ENSURE_EQ(context, node->outputs->size, 1); + + int unused_output_height, unused_output_width; + data->padding = ComputePaddingHeightWidth( + params->stride_height, params->stride_width, 1, 1, height, width, + filter_height, filter_width, params->padding, &unused_output_height, + &unused_output_width); + + // Note that quantized inference requires that all tensors have their + // parameters set. This is usually done during quantized training. + if (data_type != kTfLiteFloat32) { + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + const TfLiteTensor* filter = GetInput(context, node, kFilterTensor); + const TfLiteTensor* bias = + GetOptionalInputTensor(context, node, kBiasTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + int num_channels = filter->dims->data[kDepthwiseConvQuantizedDimension]; + + TF_LITE_ENSURE_STATUS(tflite::PopulateConvolutionQuantizationParams( + context, input, filter, bias, output, params->activation, + &data->output_multiplier, &data->output_shift, + &data->output_activation_min, &data->output_activation_max, + data->per_channel_output_multiplier, + reinterpret_cast(data->per_channel_output_shift), num_channels)); + } + return kTfLiteOk; +} + +} // namespace + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + return nullptr; +} + +void Free(TfLiteContext* context, void* buffer) {} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + return kTfLiteOk; +} + +TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node, + TfLiteDepthwiseConvParams* params, OpData* data, + const TfLiteTensor* input, const TfLiteTensor* filter, + const TfLiteTensor* bias, TfLiteTensor* output) { + float output_activation_min, output_activation_max; + CalculateActivationRange(params->activation, &output_activation_min, + &output_activation_max); + + if ((params->dilation_width_factor == 1) && + (params->dilation_height_factor == 1)) { + const float *input_data, *filter_data, *bias_data; + float* output_data; + const RuntimeShape& input_shape = GetTensorShape(input); + const RuntimeShape& filter_shape = GetTensorShape(filter); + const RuntimeShape& output_shape = GetTensorShape(output); + const RuntimeShape& bias_shape = GetTensorShape(bias); + + input_data = GetTensorData(input); + filter_data = GetTensorData(filter); + bias_data = GetTensorData(bias); + output_data = GetTensorData(output); + + const int stride_width = params->stride_width; + const int stride_height = params->stride_height; + const int dilation_width_factor = 1; + const int dilation_height_factor = 1; + // const int dilation_width_factor = params->dilation_width_factor;; + // const int dilation_height_factor = params->dilation_height_factor; + const int pad_width = data->padding.width; + const int pad_height = data->padding.height; + const int depth_multiplier = params->depth_multiplier; + TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4); + + const int batches = MatchingDim(input_shape, 0, output_shape, 0); + const int output_depth = MatchingDim(filter_shape, 3, output_shape, 3); + const int input_height = input_shape.Dims(1); + const int input_width = input_shape.Dims(2); + const int input_depth = input_shape.Dims(3); + const int filter_height = filter_shape.Dims(1); + const int filter_width = filter_shape.Dims(2); + const int output_height = output_shape.Dims(1); + const int output_width = output_shape.Dims(2); + const int filter_depth = filter_shape.Dims(3); + TFLITE_DCHECK_EQ(output_depth, input_depth * depth_multiplier); + TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth); + + int32_t err, input_data_format = 0, output_data_format = 0; + void* p_scratch; + float* p_filter; + int filter_depth_padded, filter_size_padded, required_scratch; + int input_precision = PREC_F32; + int h, c, i; + + ALLOCATE_XTENSA_NNLIB_SCRATCH_MEM; + p_scratch = xtensa_nnlib_scratch_buf; + + filter_depth_padded = (filter_depth + 1) & (~1); + filter_size_padded = filter_height * filter_width * filter_depth_padded; + + required_scratch = xa_nn_conv2d_depthwise_getsize( + input_height, input_width, input_depth, filter_height, filter_width, + depth_multiplier, stride_width, stride_height, pad_width, pad_height, + output_height, output_width, input_precision, input_data_format); + + if (required_scratch <= 0) { + TF_LITE_KERNEL_LOG( + context, "DepthwiseConvFloat: xa_nn_conv2d_depthwise_getsize failed"); + return kTfLiteError; + } + + required_scratch += ALIGNED_SIZE(sizeof(float) * filter_size_padded, 8); + if (required_scratch > (int)XTENSA_NNLIB_MAX_SCRATCH_SIZE) { + TF_LITE_KERNEL_LOG(context, + "DepthwiseConvFloat: insufficient scratch memory"); + return kTfLiteError; + } + + p_filter = (float*)p_scratch; + p_scratch = (void*)((uint8_t*)p_filter + + ALIGNED_SIZE(sizeof(float) * filter_size_padded, 8)); + + for (h = 0; h < filter_height * filter_width; h++) { + for (c = 0; c < filter_depth; c++) { + p_filter[h * filter_depth_padded + c] = + filter_data[h * filter_depth + c]; + } + for (c = filter_depth; c < filter_depth_padded; c++) { + p_filter[h * filter_depth_padded + c] = 0; + } + } + + for (i = 0; i < batches; i++) { + err = xa_nn_conv2d_depthwise_f32( + &output_data[i * output_height * output_width * output_depth], + p_filter, // filter_data, + &input_data[i * input_height * input_width * input_depth], bias_data, + input_height, input_width, input_depth, filter_height, filter_width, + depth_multiplier, stride_width, stride_height, pad_width, pad_height, + output_height, output_width, input_data_format, output_data_format, + p_scratch); + + CHECK_ERR_HIFI_NNLIB_KER( + err, "DepthwiseConvFloat: xa_nn_conv2d_depthwise_f32 failed"); + } + + // pre loop for activation_min_max to handle alignment + int out_length = batches * output_height * output_width * output_depth; + uint32 p_unalign_val = (uint32)output_data, p_align_val; + p_align_val = (p_unalign_val + 7) & (~7); + + int pre_loop_count = p_align_val - p_unalign_val; + pre_loop_count = MIN(pre_loop_count, out_length); + + for (i = 0; i < pre_loop_count; i++) { + ACTIVATION_MIN_MAX(float, output_data[i], output_data[i], + output_activation_min, output_activation_max) + } + + out_length = out_length - pre_loop_count; + + if (out_length) { + err = xa_nn_vec_activation_min_max_f32_f32( + &output_data[i], &output_data[i], output_activation_min, + output_activation_max, out_length); + + CHECK_ERR_HIFI_NNLIB_KER( + err, + "DepthwiseConvFloat: xa_nn_vec_activation_min_max_f32_f32 failed"); + } + } else { + tflite::DepthwiseParams op_params; + // Padding type is ignored, but still set. + op_params.padding_type = PaddingType::kSame; + op_params.padding_values.width = data->padding.width; + op_params.padding_values.height = data->padding.height; + op_params.stride_width = params->stride_width; + op_params.stride_height = params->stride_height; + op_params.dilation_width_factor = params->dilation_width_factor; + op_params.dilation_height_factor = params->dilation_height_factor; + op_params.depth_multiplier = params->depth_multiplier; + op_params.float_activation_min = output_activation_min; + op_params.float_activation_max = output_activation_max; + + tflite::reference_ops::DepthwiseConv( + op_params, GetTensorShape(input), GetTensorData(input), + GetTensorShape(filter), GetTensorData(filter), + GetTensorShape(bias), GetTensorData(bias), + GetTensorShape(output), GetTensorData(output)); + } + return kTfLiteOk; +} + +void EvalQuantizedPerChannel(TfLiteContext* context, TfLiteNode* node, + TfLiteDepthwiseConvParams* params, OpData* data, + const TfLiteTensor* input, + const TfLiteTensor* filter, + const TfLiteTensor* bias, TfLiteTensor* output) { + DepthwiseParams op_params; + op_params.padding_type = PaddingType::kSame; + op_params.padding_values.width = data->padding.width; + op_params.padding_values.height = data->padding.height; + op_params.stride_width = params->stride_width; + op_params.stride_height = params->stride_height; + op_params.dilation_width_factor = params->dilation_width_factor; + op_params.dilation_height_factor = params->dilation_height_factor; + op_params.depth_multiplier = params->depth_multiplier; + op_params.input_offset = -input->params.zero_point; + op_params.weights_offset = 0; + op_params.output_offset = output->params.zero_point; + // (b/130439627): Use calculated value for clamping. + op_params.quantized_activation_min = std::numeric_limits::min(); + op_params.quantized_activation_max = std::numeric_limits::max(); + + reference_integer_ops::DepthwiseConvPerChannel( + op_params, data->per_channel_output_multiplier, + data->per_channel_output_shift, GetTensorShape(input), + GetTensorData(input), GetTensorShape(filter), + GetTensorData(filter), GetTensorShape(bias), + GetTensorData(bias), GetTensorShape(output), + GetTensorData(output)); +} + +TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node, + TfLiteDepthwiseConvParams* params, OpData* data, + const TfLiteTensor* input, + const TfLiteTensor* filter, const TfLiteTensor* bias, + TfLiteTensor* output) { + const int32_t input_offset = -input->params.zero_point; + const int32_t filter_offset = -filter->params.zero_point; + const int32_t output_offset = output->params.zero_point; + + if ((params->dilation_width_factor == 1) && + (params->dilation_height_factor == 1)) { + const uint8 *input_data, *filter_data; + const int32_t* bias_data; + uint8* output_data; + const RuntimeShape& input_shape = GetTensorShape(input); + const RuntimeShape& filter_shape = GetTensorShape(filter); + const RuntimeShape& output_shape = GetTensorShape(output); + const RuntimeShape& bias_shape = GetTensorShape(bias); + + input_data = GetTensorData(input); + filter_data = GetTensorData(filter); + bias_data = GetTensorData(bias); + output_data = GetTensorData(output); + + const int stride_width = params->stride_width; + const int stride_height = params->stride_height; + const int dilation_width_factor = 1; + const int dilation_height_factor = 1; + // const int dilation_width_factor = params->dilation_width_factor; + // const int dilation_height_factor = params->dilation_height_factor; + const int pad_width = data->padding.width; + const int pad_height = data->padding.height; + const int depth_multiplier = params->depth_multiplier; + const int32 output_activation_min = data->output_activation_min; + const int32 output_activation_max = data->output_activation_max; + const int32 output_multiplier = data->output_multiplier; + // Legacy ops used mixed left and right shifts. Now all are +ve-means-left. + const int output_shift = -data->output_shift; + TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4); + + TFLITE_DCHECK_LE(output_activation_min, output_activation_max); + const int batches = MatchingDim(input_shape, 0, output_shape, 0); + const int output_depth = MatchingDim(filter_shape, 3, output_shape, 3); + const int input_height = input_shape.Dims(1); + const int input_width = input_shape.Dims(2); + const int input_depth = input_shape.Dims(3); + const int filter_height = filter_shape.Dims(1); + const int filter_width = filter_shape.Dims(2); + const int output_height = output_shape.Dims(1); + const int output_width = output_shape.Dims(2); + const int filter_depth = filter_shape.Dims(3); + TFLITE_DCHECK_EQ(output_depth, input_depth * depth_multiplier); + TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth); + + int32_t err, i, input_data_format = 0, output_data_format = 0; + void* p_scratch; + uint8* p_filter; + int filter_depth_padded, filter_size_padded, required_scratch; + int input_precision = PREC_ASYM8; + int h, c; + + ALLOCATE_XTENSA_NNLIB_SCRATCH_MEM; + p_scratch = xtensa_nnlib_scratch_buf; + + required_scratch = xa_nn_conv2d_depthwise_getsize( + input_height, input_width, input_depth, filter_height, filter_width, + depth_multiplier, stride_width, stride_height, pad_width, pad_height, + output_height, output_width, input_precision, input_data_format); + + if (required_scratch <= 0) { + TF_LITE_KERNEL_LOG( + context, "DepthwiseConvAsym8: xa_nn_conv2d_depthwise_getsize failed"); + return kTfLiteError; + } + + filter_depth_padded = (filter_depth + 3) & (~3); + filter_size_padded = filter_height * filter_width * filter_depth_padded; + required_scratch += ALIGNED_SIZE(sizeof(uint8_t) * filter_size_padded, 8); + + if (required_scratch > (int)XTENSA_NNLIB_MAX_SCRATCH_SIZE) { + TF_LITE_KERNEL_LOG(context, + "DepthwiseConvAsym8: insufficient scratch memory"); + return kTfLiteError; + } + + p_filter = (uint8*)p_scratch; + p_scratch = (void*)(p_filter + + ALIGNED_SIZE(sizeof(uint8_t) * filter_size_padded, 8)); + + for (h = 0; h < filter_height * filter_width; h++) { + for (c = 0; c < filter_depth; c++) { + p_filter[h * filter_depth_padded + c] = + filter_data[h * filter_depth + c]; + } + for (c = filter_depth; c < filter_depth_padded; c++) { + p_filter[h * filter_depth_padded + c] = -filter_offset; + } + } + + for (i = 0; i < batches; i++) { + err = xa_nn_conv2d_depthwise_asym8xasym8( + &output_data[i * output_height * output_width * output_depth], + p_filter, // filter_data, + &input_data[i * input_height * input_width * input_depth], bias_data, + input_height, input_width, input_depth, filter_height, filter_width, + depth_multiplier, stride_width, stride_height, pad_width, pad_height, + output_height, output_width, input_offset, filter_offset, + output_multiplier, output_shift, output_offset, input_data_format, + output_data_format, p_scratch); + + CHECK_ERR_HIFI_NNLIB_KER( + err, "DepthwiseConvAsym8: xa_nn_conv2d_depthwise_asym8xasym8 failed"); + } + + // pre loop for activation_min_max to handle alignment + int out_length = batches * output_height * output_width * output_depth; + uint32 p_unalign_val = (uint32)output_data, p_align_val; + p_align_val = (p_unalign_val + 7) & (~7); + + int pre_loop_count = p_align_val - p_unalign_val; + pre_loop_count = MIN(pre_loop_count, out_length); + + for (i = 0; i < pre_loop_count; i++) { + ACTIVATION_MIN_MAX_ASYM8(output_data[i], output_data[i], + output_activation_min, output_activation_max) + } + + out_length = out_length - pre_loop_count; + + if (out_length > 0) { + err = xa_nn_vec_activation_min_max_asym8_asym8( + &output_data[i], &output_data[i], output_activation_min, + output_activation_max, out_length); + + CHECK_ERR_HIFI_NNLIB_KER( + err, + "DepthwiseConvAsym8: xa_nn_vec_activation_min_max_asym8_asym8 " + "failed"); + } + } else { + tflite::DepthwiseParams op_params; + // Padding type is ignored, but still set. + op_params.padding_type = PaddingType::kSame; + op_params.padding_values.width = data->padding.width; + op_params.padding_values.height = data->padding.height; + op_params.stride_width = params->stride_width; + op_params.stride_height = params->stride_height; + op_params.dilation_width_factor = params->dilation_width_factor; + op_params.dilation_height_factor = params->dilation_height_factor; + op_params.depth_multiplier = params->depth_multiplier; + op_params.quantized_activation_min = data->output_activation_min; + op_params.quantized_activation_max = data->output_activation_max; + op_params.input_offset = input_offset; + op_params.weights_offset = filter_offset; + op_params.output_offset = output_offset; + op_params.output_multiplier = data->output_multiplier; + // Legacy ops used mixed left and right shifts. Now all are +ve-means-left. + op_params.output_shift = -data->output_shift; + + tflite::reference_ops::DepthwiseConv( + op_params, GetTensorShape(input), GetTensorData(input), + GetTensorShape(filter), GetTensorData(filter), + GetTensorShape(bias), GetTensorData(bias), + GetTensorShape(output), GetTensorData(output)); + } + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + auto* params = + reinterpret_cast(node->builtin_data); + + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + const TfLiteTensor* filter = GetInput(context, node, kFilterTensor); + const TfLiteTensor* bias = + (NumInputs(node) == 3) ? GetInput(context, node, kBiasTensor) : nullptr; + + const TfLiteType data_type = input->type; + int width = SizeOfDimension(input, 2); + int height = SizeOfDimension(input, 1); + int filter_width = SizeOfDimension(filter, 2); + int filter_height = SizeOfDimension(filter, 1); + + OpData data; + + // All per-channel quantized tensors need valid zero point and scale arrays. + if (input->type == kTfLiteInt8) { + TF_LITE_ENSURE_EQ(context, filter->quantization.type, + kTfLiteAffineQuantization); + + const auto* affine_quantization = + reinterpret_cast( + filter->quantization.params); + TF_LITE_ENSURE(context, affine_quantization); + TF_LITE_ENSURE(context, affine_quantization->scale); + TF_LITE_ENSURE(context, affine_quantization->zero_point); + TF_LITE_ENSURE( + context, affine_quantization->scale->size == 1 || + affine_quantization->scale->size == + filter->dims->data[kDepthwiseConvQuantizedDimension]); + TF_LITE_ENSURE_EQ(context, affine_quantization->scale->size, + affine_quantization->zero_point->size); + } + + TF_LITE_ENSURE_STATUS(CalculateOpData(context, node, params, width, height, + filter_width, filter_height, data_type, + &data)); + + // (aselle): Consider whether float conv and quantized conv should be + // separate ops to avoid dispatch overhead here. + switch (input->type) { // Already know in/out types are same. + case kTfLiteFloat32: + EvalFloat(context, node, params, &data, input, filter, bias, output); + break; + case kTfLiteInt8: + EvalQuantizedPerChannel(context, node, params, &data, input, filter, bias, + output); + break; + case kTfLiteUInt8: + EvalQuantized(context, node, params, &data, input, filter, bias, output); + break; + default: + TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.", + TfLiteTypeGetName(input->type), input->type); + return kTfLiteError; + } + return kTfLiteOk; +} + +} // namespace depthwise_conv + +TfLiteRegistration* Register_DEPTHWISE_CONV_2D() { + static TfLiteRegistration r = {}; + r.init = depthwise_conv::Init; + r.free = depthwise_conv::Free; + r.prepare = depthwise_conv::Prepare; + r.invoke = depthwise_conv::Eval; + return &r; +} + +} // namespace micro +} // namespace ops +} // namespace tflite diff --git a/tensorflow/lite/micro/kernels/xtensa_hifi/floor.cc b/tensorflow/lite/micro/kernels/xtensa_hifi/floor.cc new file mode 100644 index 00000000000..5358ce1780a --- /dev/null +++ b/tensorflow/lite/micro/kernels/xtensa_hifi/floor.cc @@ -0,0 +1,81 @@ +/****************************************************************************** + * Copyright (C) 2019 Cadence Design Systems, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files (the + * "Software"), to use this Software with Cadence processor cores only and + * not with any other processors and platforms, subject to + * the following conditions: + * + * The above copyright notice and this permission notice shall be included + * in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + ******************************************************************************/ + +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/kernels/internal/reference/floor.h" + +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "xtensa_tf_micro_common.h" + +namespace tflite { +namespace ops { +namespace micro { +namespace floor { + +constexpr int kInputTensor = 0; +constexpr int kOutputTensor = 0; + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + int err; + const float* inp_data_ptr; + float* out_data_ptr; + const RuntimeShape& input_shape = GetTensorShape(input); + const RuntimeShape& output_shape = GetTensorShape(output); + const int flat_size = MatchingFlatSize(input_shape, output_shape); + + inp_data_ptr = GetTensorData(input); + out_data_ptr = GetTensorData(output); + + err = xa_nn_elm_floor_f32_f32(out_data_ptr, inp_data_ptr, flat_size); + + CHECK_ERR_HIFI_NNLIB_KER(err, "xa_nn_elm_floor_f32_f32 failed"); + return kTfLiteOk; +} +} // namespace floor + +TfLiteRegistration* Register_FLOOR() { + static TfLiteRegistration r = {}; + r.invoke = floor::Eval; + return &r; +} + +} // namespace micro +} // namespace ops +} // namespace tflite diff --git a/tensorflow/lite/micro/kernels/xtensa_hifi/fully_connected.cc b/tensorflow/lite/micro/kernels/xtensa_hifi/fully_connected.cc new file mode 100644 index 00000000000..1a576cd7e9c --- /dev/null +++ b/tensorflow/lite/micro/kernels/xtensa_hifi/fully_connected.cc @@ -0,0 +1,277 @@ +/****************************************************************************** + * Copyright (C) 2019 Cadence Design Systems, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files (the + * "Software"), to use this Software with Cadence processor cores only and + * not with any other processors and platforms, subject to + * the following conditions: + * + * The above copyright notice and this permission notice shall be included + * in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + ******************************************************************************/ + +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/kernels/internal/reference/fully_connected.h" + +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/kernels/internal/common.h" +#include "tensorflow/lite/kernels/internal/quantization_util.h" +#include "tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h" +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "xtensa_tf_micro_common.h" + +namespace tflite { +namespace ops { +namespace micro { +namespace fully_connected { +namespace { + +struct OpData { + // The scaling factor from input to output (aka the 'real multiplier') can + // be represented as a fixed point multiplier plus a left shift. + int32_t output_multiplier; + int output_shift; + // The range of the fused activation layer. For example for kNone and + // uint8_t these would be 0 and 255. + int32_t output_activation_min; + int32_t output_activation_max; + // The index of the temporary tensor where the quantized inputs are cached. + int input_quantized_index; +}; + +constexpr int kInputTensor = 0; +constexpr int kWeightsTensor = 1; +constexpr int kBiasTensor = 2; +constexpr int kOutputTensor = 0; + +TfLiteStatus CalculateOpData(TfLiteContext* context, + TfLiteFullyConnectedParams* params, + TfLiteType data_type, const TfLiteTensor* input, + const TfLiteTensor* filter, + const TfLiteTensor* bias, TfLiteTensor* output, + OpData* data) { + TfLiteStatus status = kTfLiteOk; + if (data_type != kTfLiteFloat32) { + double real_multiplier = 0.0; + TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler( + context, input, filter, bias, output, &real_multiplier)); + int exponent; + QuantizeMultiplier(real_multiplier, &data->output_multiplier, &exponent); + data->output_shift = -exponent; + TF_LITE_ENSURE_STATUS(CalculateActivationRangeQuantized( + context, params->activation, output, &data->output_activation_min, + &data->output_activation_max)); + } + return status; +} + +} // namespace + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + return nullptr; +} + +void Free(TfLiteContext* context, void* buffer) {} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + return kTfLiteOk; +} + +TfLiteStatus EvalQuantizedInt8(TfLiteContext* context, TfLiteNode* node, + TfLiteFullyConnectedParams* params, OpData* data, + const TfLiteTensor* input, + const TfLiteTensor* filter, + const TfLiteTensor* bias, TfLiteTensor* output) { + FullyConnectedParams op_params; + op_params.input_offset = -input->params.zero_point; + op_params.weights_offset = -filter->params.zero_point; + op_params.output_offset = output->params.zero_point; + op_params.output_multiplier = data->output_multiplier; + // (b/138810107): Figure out whether output shift should be inverted + op_params.output_shift = -data->output_shift; + op_params.quantized_activation_min = data->output_activation_min; + op_params.quantized_activation_max = data->output_activation_max; + + reference_integer_ops::FullyConnected( + op_params, GetTensorShape(input), GetTensorData(input), + GetTensorShape(filter), GetTensorData(filter), + GetTensorShape(bias), GetTensorData(bias), + GetTensorShape(output), GetTensorData(output)); + return kTfLiteOk; +} + +TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node, + TfLiteFullyConnectedParams* params, OpData* data, + const TfLiteTensor* input, + const TfLiteTensor* filter, const TfLiteTensor* bias, + TfLiteTensor* output) { + const int32_t input_offset = -input->params.zero_point; + const int32_t filter_offset = -filter->params.zero_point; + const int32_t output_offset = output->params.zero_point; + + tflite::FullyConnectedParams op_params; + op_params.input_offset = input_offset; + op_params.weights_offset = filter_offset; + op_params.output_offset = output_offset; + op_params.output_multiplier = data->output_multiplier; + // Legacy ops used mixed left and right shifts. Now all are +ve-means-left. + op_params.output_shift = -data->output_shift; + op_params.quantized_activation_min = data->output_activation_min; + op_params.quantized_activation_max = data->output_activation_max; + +#define TF_LITE_FULLY_CONNECTED(output_data_type) \ + reference_ops::FullyConnected( \ + op_params, GetTensorShape(input), GetTensorData(input), \ + GetTensorShape(filter), GetTensorData(filter), \ + GetTensorShape(bias), GetTensorData(bias), \ + GetTensorShape(output), GetTensorData(output)) + switch (output->type) { + case kTfLiteUInt8: { + int ret, b, weight_depth, out_depth, batches; + uint8_t* p_out = GetTensorData(output); + weight_depth = GetTensorShape(filter).Dims( + GetTensorShape(filter).DimensionsCount() - 1); + out_depth = GetTensorShape(output).Dims( + GetTensorShape(output).DimensionsCount() - 1); + batches = FlatSizeSkipDim(GetTensorShape(output), + GetTensorShape(output).DimensionsCount() - 1); + for (b = 0; b < batches; b++) { + ret = xa_nn_fully_connected_asym8xasym8_asym8( + (GetTensorData(output) + b * out_depth), + GetTensorData(filter), + (GetTensorData(input) + b * weight_depth), + GetTensorData(bias), weight_depth, out_depth, + op_params.input_offset, op_params.weights_offset, + op_params.output_multiplier, op_params.output_shift, + op_params.output_offset); + CHECK_ERR_HIFI_NNLIB_KER( + ret, "xa_nn_fully_connected_asym8xasym8_asym8 failed"); + } + for (int i = 0; i < batches * out_depth; i++) { + ACTIVATION_MIN_MAX_ASYM8(p_out[i], p_out[i], + data->output_activation_min, + data->output_activation_max) + } + break; + } + case kTfLiteInt16: + TF_LITE_FULLY_CONNECTED(int16_t); + break; + default: + TF_LITE_KERNEL_LOG( + context, + "Quantized FullyConnected expects output data type uint8 or int16"); + return kTfLiteError; + } + + return kTfLiteOk; +} + +TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node, + TfLiteFullyConnectedParams* params, OpData* data, + const TfLiteTensor* input, const TfLiteTensor* filter, + const TfLiteTensor* bias, TfLiteTensor* output) { + float output_activation_min, output_activation_max; + CalculateActivationRange(params->activation, &output_activation_min, + &output_activation_max); + tflite::FullyConnectedParams op_params; + op_params.float_activation_min = output_activation_min; + op_params.float_activation_max = output_activation_max; + int ret, b, weight_depth, out_depth, batches; + weight_depth = + GetTensorShape(filter).Dims(GetTensorShape(filter).DimensionsCount() - 1); + out_depth = + GetTensorShape(output).Dims(GetTensorShape(output).DimensionsCount() - 1); + batches = FlatSizeSkipDim(GetTensorShape(output), + GetTensorShape(output).DimensionsCount() - 1); + + for (b = 0; b < batches; b++) { + ret = xa_nn_fully_connected_f32( + (GetTensorData(output) + b * out_depth), + GetTensorData(filter), + (GetTensorData(input) + b * weight_depth), + GetTensorData(bias), weight_depth, out_depth); + CHECK_ERR_HIFI_NNLIB_KER(ret, "xa_nn_fully_connected_f32 failed."); + } + float* p_out = GetTensorData(output); + for (int i = 0; i < batches * out_depth; i++) { + ACTIVATION_MIN_MAX(float, p_out[i], p_out[i], output_activation_min, + output_activation_max) + } + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + auto* params = + reinterpret_cast(node->builtin_data); + + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + const TfLiteTensor* filter = GetInput(context, node, kWeightsTensor); + const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + TfLiteType data_type = input->type; + OpData local_data_object; + OpData* data = &local_data_object; + TF_LITE_ENSURE_STATUS(CalculateOpData(context, params, data_type, input, + filter, bias, output, data)); + + switch (filter->type) { // Already know in/out types are same. + case kTfLiteFloat32: + return EvalFloat(context, node, params, data, input, filter, bias, + output); + case kTfLiteInt8: + return EvalQuantizedInt8(context, node, params, data, input, filter, bias, + output); + + case kTfLiteUInt8: + return EvalQuantized(context, node, params, data, input, filter, bias, + output); + + default: + TF_LITE_KERNEL_LOG(context, "Type %d not currently supported.", + filter->type); + return kTfLiteError; + } + return kTfLiteOk; +} + +} // namespace fully_connected + +TfLiteRegistration* Register_FULLY_CONNECTED() { + static TfLiteRegistration r = {}; + r.init = fully_connected::Init; + r.free = fully_connected::Free; + r.prepare = fully_connected::Prepare; + r.invoke = fully_connected::Eval; + return &r; +} + +} // namespace micro +} // namespace ops +} // namespace tflite diff --git a/tensorflow/lite/micro/kernels/xtensa_hifi/logistic.cc b/tensorflow/lite/micro/kernels/xtensa_hifi/logistic.cc new file mode 100644 index 00000000000..80a4f922409 --- /dev/null +++ b/tensorflow/lite/micro/kernels/xtensa_hifi/logistic.cc @@ -0,0 +1,125 @@ +/****************************************************************************** + * Copyright (C) 2019 Cadence Design Systems, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files (the + * "Software"), to use this Software with Cadence processor cores only and + * not with any other processors and platforms, subject to + * the following conditions: + * + * The above copyright notice and this permission notice shall be included + * in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + ******************************************************************************/ + +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/kernels/internal/reference/logistic.h" + +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/kernels/internal/common.h" +#include "tensorflow/lite/kernels/internal/quantization_util.h" +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/op_macros.h" +#include "xtensa_tf_micro_common.h" + +namespace tflite { +namespace ops { +namespace micro { +namespace activations { + +constexpr int kInputTensor = 0; +constexpr int kOutputTensor = 0; + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + if (input->type == kTfLiteFloat32) { + switch (output->type) { + case kTfLiteFloat32: { + int err; + const float* inp_data_ptr; + float* out_data_ptr; + const RuntimeShape& input_shape = GetTensorShape(input); + const RuntimeShape& output_shape = GetTensorShape(output); + const int flat_size = MatchingFlatSize(input_shape, output_shape); + + inp_data_ptr = GetTensorData(input); + out_data_ptr = GetTensorData(output); + + err = xa_nn_vec_sigmoid_f32_f32(out_data_ptr, inp_data_ptr, flat_size); + + CHECK_ERR_HIFI_NNLIB_KER(err, "xa_nn_vec_sigmoid_f32_f32 failed"); + return kTfLiteOk; + } + default: + TF_LITE_KERNEL_LOG(context, "Input %s, output %s not supported.", + TfLiteTypeGetName(input->type), + TfLiteTypeGetName(output->type)); + return kTfLiteError; + } + } else if (input->type == kTfLiteInt8) { + switch (output->type) { + case kTfLiteInt8: { + reference_ops::Logistic( + GetTensorShape(input), GetTensorData(input), + input->params.scale, input->params.zero_point, + GetTensorShape(output), GetTensorData(output), + output->params.scale, output->params.zero_point); + return kTfLiteOk; + } + default: + TF_LITE_KERNEL_LOG(context, "Input %s, output %s not supported.", + TfLiteTypeGetName(input->type), + TfLiteTypeGetName(output->type)); + return kTfLiteError; + } + } else { + // (b/141211002): Also support other data types once we have supported + // temporary tensors in TFLM. + TF_LITE_KERNEL_LOG(context, "Input %s, output %s not supported.", + TfLiteTypeGetName(input->type), + TfLiteTypeGetName(output->type)); + return kTfLiteError; + } + return kTfLiteOk; +} + +} // namespace activations + +TfLiteRegistration* Register_LOGISTIC() { + static TfLiteRegistration r = {}; + r.prepare = activations::Prepare; + r.invoke = activations::Eval; + return &r; +} +} // namespace micro +} // namespace ops +} // namespace tflite diff --git a/tensorflow/lite/micro/kernels/xtensa_hifi/pooling.cc b/tensorflow/lite/micro/kernels/xtensa_hifi/pooling.cc new file mode 100755 index 00000000000..53d7d5f2031 --- /dev/null +++ b/tensorflow/lite/micro/kernels/xtensa_hifi/pooling.cc @@ -0,0 +1,580 @@ +/****************************************************************************** + * Copyright (C) 2019 Cadence Design Systems, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files (the + * "Software"), to use this Software with Cadence processor cores only and + * not with any other processors and platforms, subject to + * the following conditions: + * + * The above copyright notice and this permission notice shall be included + * in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + ******************************************************************************/ + +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/kernels/internal/reference/pooling.h" + +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/kernels/internal/reference/integer_ops/pooling.h" +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/padding.h" +#include "xtensa_tf_micro_common.h" + +namespace tflite { +namespace ops { +namespace micro { +namespace pooling { + +namespace { + +constexpr int kInputTensor = 0; +constexpr int kOutputTensor = 0; + +struct OpData { + TfLitePaddingValues padding; +}; + +TfLiteStatus CalculateOpData(const TfLiteContext* context, + const TfLitePoolParams* params, + const TfLiteTensor* input, + const TfLiteTensor* output, OpData* data) { + // input: batch, height, width, channel + int height = SizeOfDimension(input, 1); + int width = SizeOfDimension(input, 2); + + int out_height, out_width; + + data->padding = ComputePaddingHeightWidth( + params->stride_height, params->stride_width, + /*dilation_rate_height=*/1, + /*dilation_rate_width=*/1, height, width, params->filter_height, + params->filter_width, params->padding, &out_height, &out_width); + + return kTfLiteOk; +} + +TfLiteStatus AverageEvalFloat(TfLiteContext* context, const TfLiteNode* node, + const TfLitePoolParams* params, + const OpData* data, const TfLiteTensor* input, + TfLiteTensor* output) { + float activation_min, activation_max; + CalculateActivationRange(params->activation, &activation_min, + &activation_max); + + const int stride_height = params->stride_height; + const int stride_width = params->stride_width; + const int pad_width = data->padding.width; + const int pad_height = data->padding.height; + const int kernel_height = params->filter_height; + const int kernel_width = params->filter_width; + + const RuntimeShape& input_shape = GetTensorShape(input); + const RuntimeShape& output_shape = GetTensorShape(output); + TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4); + const int batches = MatchingDim(input_shape, 0, output_shape, 0); + const int depth = MatchingDim(input_shape, 3, output_shape, 3); + const int input_height = input_shape.Dims(1); + const int input_width = input_shape.Dims(2); + const int output_height = output_shape.Dims(1); + const int output_width = output_shape.Dims(2); + + const float* inp_data_ptr; + float* out_data_ptr; + int inp_data_format = 0, out_data_format = 0, out_length; + int inp_precision = PREC_F32, out_precision = PREC_F32; + void* p_scratch; + int err, required_scratch = 0; + + ALLOCATE_XTENSA_NNLIB_SCRATCH_MEM; + p_scratch = (void*)xtensa_nnlib_scratch_buf; + + required_scratch = xa_nn_avgpool_getsize( + depth, inp_precision, out_precision, input_height, input_width, + kernel_height, kernel_width, + stride_width, // x_stride, + stride_height, // y_stride, + pad_width, // x_padding, + pad_height, // y_padding, + output_height, output_width, inp_data_format, out_data_format); + + if (required_scratch <= 0) { + TF_LITE_KERNEL_LOG(context, + "AveragepoolFloat: xa_nn_avgpool_getsize failed"); + return kTfLiteError; + } + + if (required_scratch > (int)XTENSA_NNLIB_MAX_SCRATCH_SIZE) { + TF_LITE_KERNEL_LOG(context, + "AveragepoolFloat: insufficient scratch memory"); + return kTfLiteError; + } + + inp_data_ptr = GetTensorData(input); + out_data_ptr = GetTensorData(output); + + for (int batch = 0; batch < batches; ++batch) { + err = xa_nn_avgpool_f32( + &out_data_ptr[output_height * output_width * depth * batch], + &inp_data_ptr[output_height * output_width * depth * batch], + input_height, input_width, depth, kernel_height, kernel_width, + stride_width, stride_height, pad_width, pad_height, output_height, + output_width, inp_data_format, out_data_format, p_scratch); + + CHECK_ERR_HIFI_NNLIB_KER(err, "AveragepoolFloat: xa_nn_avgpool_f32 failed"); + } + + out_length = batches * output_height * output_width * depth; + uint32 p_unalign_val = (uint32)out_data_ptr, p_align_val; + p_align_val = (p_unalign_val + 7) & (~7); + + // pre loop for activation_min_max + int pre_loop_count = p_align_val - p_unalign_val; + pre_loop_count = MIN(pre_loop_count, out_length); + + for (int i = 0; i < pre_loop_count; i++) { + ACTIVATION_MIN_MAX(float, out_data_ptr[i], out_data_ptr[i], activation_min, + activation_max) + } + + out_length = out_length - pre_loop_count; + + if (out_length) { + err = xa_nn_vec_activation_min_max_f32_f32( + out_data_ptr, out_data_ptr, activation_min, activation_max, out_length); + + CHECK_ERR_HIFI_NNLIB_KER( + err, "AveragepoolFloat: xa_nn_vec_activation_min_max_f32_f32 failed"); + } + return kTfLiteOk; +} + +TfLiteStatus AverageEvalQuantized(TfLiteContext* context, + const TfLiteNode* node, + const TfLitePoolParams* params, + const OpData* data, const TfLiteTensor* input, + TfLiteTensor* output) { + TFLITE_DCHECK(input->type == kTfLiteUInt8 || input->type == kTfLiteInt8); + + int32_t activation_min, activation_max; + (void)CalculateActivationRangeQuantized(context, params->activation, output, + &activation_min, &activation_max); + + if (input->type == kTfLiteUInt8) { + const int stride_height = params->stride_height; + const int stride_width = params->stride_width; + const int pad_width = data->padding.width; + const int pad_height = data->padding.height; + const int kernel_height = params->filter_height; + const int kernel_width = params->filter_width; + + const RuntimeShape& input_shape = GetTensorShape(input); + const RuntimeShape& output_shape = GetTensorShape(output); + TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4); + const int batches = MatchingDim(input_shape, 0, output_shape, 0); + const int depth = MatchingDim(input_shape, 3, output_shape, 3); + const int input_height = input_shape.Dims(1); + const int input_width = input_shape.Dims(2); + const int output_height = output_shape.Dims(1); + const int output_width = output_shape.Dims(2); + + const uint8* inp_data_ptr; + uint8* out_data_ptr; + int inp_data_format = 0, out_data_format = 0, out_length; + int inp_precision = PREC_ASYM8, out_precision = PREC_ASYM8; + void* p_scratch; + int err, required_scratch = 0; + + ALLOCATE_XTENSA_NNLIB_SCRATCH_MEM; + p_scratch = (void*)xtensa_nnlib_scratch_buf; + + required_scratch = xa_nn_avgpool_getsize( + depth, inp_precision, out_precision, input_height, input_width, + kernel_height, kernel_width, + stride_width, // x_stride, + stride_height, // y_stride, + pad_width, // x_padding, + pad_height, // y_padding, + output_height, output_width, inp_data_format, out_data_format); + + if (required_scratch <= 0) { + TF_LITE_KERNEL_LOG(context, + "AveragepoolAsym8: xa_nn_avgpool_getsize failed"); + return kTfLiteError; + } + + if (required_scratch > (int)XTENSA_NNLIB_MAX_SCRATCH_SIZE) { + TF_LITE_KERNEL_LOG(context, + "AveragepoolAsym8: insufficient scratch memory"); + return kTfLiteError; + } + + inp_data_ptr = GetTensorData(input); + out_data_ptr = GetTensorData(output); + + for (int batch = 0; batch < batches; ++batch) { + err = xa_nn_avgpool_asym8( + &out_data_ptr[output_height * output_width * depth * batch], + &inp_data_ptr[output_height * output_width * depth * batch], + input_height, input_width, depth, kernel_height, kernel_width, + stride_width, stride_height, pad_width, pad_height, output_height, + output_width, inp_data_format, out_data_format, p_scratch); + + CHECK_ERR_HIFI_NNLIB_KER(err, + "AveragepoolAsym8: xa_nn_avgpool_asym8 failed"); + } + + out_length = batches * output_height * output_width * depth; + uint32 p_unalign_val = (uint32)out_data_ptr, p_align_val; + p_align_val = (p_unalign_val + 7) & (~7); + + // pre loop for activation_min_max + int pre_loop_count = p_align_val - p_unalign_val; + pre_loop_count = MIN(pre_loop_count, out_length); + + for (int i = 0; i < pre_loop_count; i++) { + ACTIVATION_MIN_MAX_ASYM8(out_data_ptr[i], out_data_ptr[i], activation_min, + activation_max) + } + + out_length = out_length - pre_loop_count; + + if (out_length > 0) { + err = xa_nn_vec_activation_min_max_asym8_asym8( + out_data_ptr, out_data_ptr, activation_min, activation_max, + out_length); + + CHECK_ERR_HIFI_NNLIB_KER( + err, + "AveragepoolAsym8: xa_nn_vec_activation_min_max_asym8_asym8 failed"); + } + } else { + PoolParams op_params; + op_params.stride_height = params->stride_height; + op_params.stride_width = params->stride_width; + op_params.filter_height = params->filter_height; + op_params.filter_width = params->filter_width; + op_params.padding_values.height = data->padding.height; + op_params.padding_values.width = data->padding.width; + op_params.quantized_activation_min = activation_min; + op_params.quantized_activation_max = activation_max; + reference_integer_ops::AveragePool( + op_params, GetTensorShape(input), GetTensorData(input), + GetTensorShape(output), GetTensorData(output)); + } + return kTfLiteOk; +} + +TfLiteStatus MaxEvalFloat(TfLiteContext* context, TfLiteNode* node, + TfLitePoolParams* params, OpData* data, + const TfLiteTensor* input, TfLiteTensor* output) { + float activation_min, activation_max; + CalculateActivationRange(params->activation, &activation_min, + &activation_max); + + const int stride_height = params->stride_height; + const int stride_width = params->stride_width; + const int pad_width = data->padding.width; + const int pad_height = data->padding.height; + const int kernel_height = params->filter_height; + const int kernel_width = params->filter_width; + + const RuntimeShape& input_shape = GetTensorShape(input); + const RuntimeShape& output_shape = GetTensorShape(output); + TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4); + const int batches = MatchingDim(input_shape, 0, output_shape, 0); + const int depth = MatchingDim(input_shape, 3, output_shape, 3); + const int input_height = input_shape.Dims(1); + const int input_width = input_shape.Dims(2); + const int output_height = output_shape.Dims(1); + const int output_width = output_shape.Dims(2); + + const float* inp_data_ptr; + float* out_data_ptr; + int inp_data_format = 0, out_data_format = 0, out_length; + int inp_precision = PREC_F32, out_precision = PREC_F32; + void* p_scratch; + int err, required_scratch = 0; + + ALLOCATE_XTENSA_NNLIB_SCRATCH_MEM; + p_scratch = (void*)xtensa_nnlib_scratch_buf; + + required_scratch = xa_nn_maxpool_getsize( + depth, inp_precision, out_precision, input_height, input_width, + kernel_height, kernel_width, + stride_width, // x_stride, + stride_height, // y_stride, + pad_width, // x_padding, + pad_height, // y_padding, + output_height, output_width, inp_data_format, out_data_format); + + if (required_scratch <= 0) { + TF_LITE_KERNEL_LOG(context, "MaxpoolFloat: xa_nn_maxpool_getsize failed"); + return kTfLiteError; + } + + if (required_scratch > (int)XTENSA_NNLIB_MAX_SCRATCH_SIZE) { + TF_LITE_KERNEL_LOG(context, "MaxpoolFloat: insufficient scratch memory"); + return kTfLiteError; + } + + inp_data_ptr = GetTensorData(input); + out_data_ptr = GetTensorData(output); + + for (int batch = 0; batch < batches; ++batch) { + err = xa_nn_maxpool_f32( + &out_data_ptr[output_height * output_width * depth * batch], + &inp_data_ptr[output_height * output_width * depth * batch], + input_height, input_width, depth, kernel_height, kernel_width, + stride_width, stride_height, pad_width, pad_height, output_height, + output_width, inp_data_format, out_data_format, p_scratch); + + CHECK_ERR_HIFI_NNLIB_KER(err, "MaxpoolFloat: xa_nn_maxpool_f32 failed"); + } + + out_length = batches * output_height * output_width * depth; + uint32 p_unalign_val = (uint32)out_data_ptr, p_align_val; + p_align_val = (p_unalign_val + 7) & (~7); + + // pre loop for activation_min_max + int pre_loop_count = p_align_val - p_unalign_val; + pre_loop_count = MIN(pre_loop_count, out_length); + + for (int i = 0; i < pre_loop_count; i++) { + ACTIVATION_MIN_MAX(float, out_data_ptr[i], out_data_ptr[i], activation_min, + activation_max) + } + + out_length = out_length - pre_loop_count; + + if (out_length > 0) { + err = xa_nn_vec_activation_min_max_f32_f32( + out_data_ptr, out_data_ptr, activation_min, activation_max, out_length); + + CHECK_ERR_HIFI_NNLIB_KER( + err, "MaxpoolFloat: xa_nn_vec_activation_min_max_f32_f32 failed"); + } + return kTfLiteOk; +} + +TfLiteStatus MaxEvalQuantized(TfLiteContext* context, TfLiteNode* node, + TfLitePoolParams* params, OpData* data, + const TfLiteTensor* input, TfLiteTensor* output) { + TFLITE_DCHECK(input->type == kTfLiteUInt8 || input->type == kTfLiteInt8); + + int32_t activation_min, activation_max; + (void)CalculateActivationRangeQuantized(context, params->activation, output, + &activation_min, &activation_max); + + if (input->type == kTfLiteUInt8) { + const int stride_height = params->stride_height; + const int stride_width = params->stride_width; + const int pad_width = data->padding.width; + const int pad_height = data->padding.height; + const int kernel_height = params->filter_height; + const int kernel_width = params->filter_width; + + const RuntimeShape& input_shape = GetTensorShape(input); + const RuntimeShape& output_shape = GetTensorShape(output); + TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4); + const int batches = MatchingDim(input_shape, 0, output_shape, 0); + const int depth = MatchingDim(input_shape, 3, output_shape, 3); + const int input_height = input_shape.Dims(1); + const int input_width = input_shape.Dims(2); + const int output_height = output_shape.Dims(1); + const int output_width = output_shape.Dims(2); + + const uint8* inp_data_ptr; + uint8* out_data_ptr; + int inp_data_format = 0, out_data_format = 0, out_length; + int inp_precision = PREC_ASYM8, out_precision = PREC_ASYM8; + void* p_scratch; + int err, required_scratch = 0; + + ALLOCATE_XTENSA_NNLIB_SCRATCH_MEM; + p_scratch = (void*)xtensa_nnlib_scratch_buf; + + required_scratch = xa_nn_maxpool_getsize( + depth, inp_precision, out_precision, input_height, input_width, + kernel_height, kernel_width, + stride_width, // x_stride, + stride_height, // y_stride, + pad_width, // x_padding, + pad_height, // y_padding, + output_height, output_width, inp_data_format, out_data_format); + + if (required_scratch <= 0) { + TF_LITE_KERNEL_LOG(context, "MaxpoolAsym8: xa_nn_maxpool_getsize failed"); + return kTfLiteError; + } + + if (required_scratch > (int)XTENSA_NNLIB_MAX_SCRATCH_SIZE) { + TF_LITE_KERNEL_LOG(context, "MaxpoolAsym8: insufficient scratch memory"); + return kTfLiteError; + } + + inp_data_ptr = GetTensorData(input); + out_data_ptr = GetTensorData(output); + + for (int batch = 0; batch < batches; ++batch) { + err = xa_nn_maxpool_asym8( + &out_data_ptr[output_height * output_width * depth * batch], + &inp_data_ptr[output_height * output_width * depth * batch], + input_height, input_width, depth, kernel_height, kernel_width, + stride_width, stride_height, pad_width, pad_height, output_height, + output_width, inp_data_format, out_data_format, p_scratch); + + CHECK_ERR_HIFI_NNLIB_KER(err, "MaxpoolAsym8: xa_nn_maxpool_asym8 failed"); + } + + out_length = batches * output_height * output_width * depth; + uint32 p_unalign_val = (uint32)out_data_ptr, p_align_val; + p_align_val = (p_unalign_val + 7) & (~7); + + // pre loop for activation_min_max + int pre_loop_count = p_align_val - p_unalign_val; + pre_loop_count = MIN(pre_loop_count, out_length); + + for (int i = 0; i < pre_loop_count; i++) { + ACTIVATION_MIN_MAX_ASYM8(out_data_ptr[i], out_data_ptr[i], activation_min, + activation_max) + } + + out_length = out_length - pre_loop_count; + + if (out_length > 0) { + err = xa_nn_vec_activation_min_max_asym8_asym8( + out_data_ptr, out_data_ptr, activation_min, activation_max, + out_length); + + CHECK_ERR_HIFI_NNLIB_KER( + err, "MaxpoolAsym8: xa_nn_vec_activation_min_max_asym8_asym8 failed"); + } + } else { + tflite::PoolParams op_params; + op_params.stride_height = params->stride_height; + op_params.stride_width = params->stride_width; + op_params.filter_height = params->filter_height; + op_params.filter_width = params->filter_width; + op_params.padding_values.height = data->padding.height; + op_params.padding_values.width = data->padding.width; + op_params.quantized_activation_min = activation_min; + op_params.quantized_activation_max = activation_max; + reference_integer_ops::MaxPool( + op_params, GetTensorShape(input), GetTensorData(input), + GetTensorShape(output), GetTensorData(output)); + } + return kTfLiteOk; +} + +} // namespace + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + return nullptr; +} + +void Free(TfLiteContext* context, void* buffer) {} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + return kTfLiteOk; +} + +TfLiteStatus AverageEval(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); + OpData data; + + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + TF_LITE_ENSURE_STATUS(CalculateOpData(context, params, input, output, &data)); + + // Inputs and outputs share the same type, guarenteed by the converter. + switch (input->type) { + case kTfLiteFloat32: + AverageEvalFloat(context, node, params, &data, input, output); + break; + case kTfLiteUInt8: + case kTfLiteInt8: + AverageEvalQuantized(context, node, params, &data, input, output); + break; + default: + TF_LITE_KERNEL_LOG(context, "Input type %s is not currently supported", + TfLiteTypeGetName(input->type)); + return kTfLiteError; + } + return kTfLiteOk; +} + +TfLiteStatus MaxEval(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); + OpData data; + + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + TF_LITE_ENSURE_STATUS(CalculateOpData(context, params, input, output, &data)); + + switch (input->type) { + case kTfLiteFloat32: + MaxEvalFloat(context, node, params, &data, input, output); + break; + case kTfLiteUInt8: + case kTfLiteInt8: + MaxEvalQuantized(context, node, params, &data, input, output); + break; + default: + TF_LITE_KERNEL_LOG(context, "Type %s not currently supported.", + TfLiteTypeGetName(input->type)); + return kTfLiteError; + } + return kTfLiteOk; +} + +} // namespace pooling + +TfLiteRegistration* Register_AVERAGE_POOL_2D() { + static TfLiteRegistration r = {}; + r.init = pooling::Init; + r.free = pooling::Free; + r.prepare = pooling::Prepare; + r.invoke = pooling::AverageEval; + return &r; +} + +TfLiteRegistration* Register_MAX_POOL_2D() { + static TfLiteRegistration r = {}; + r.init = pooling::Init; + r.free = pooling::Free; + r.prepare = pooling::Prepare; + r.invoke = pooling::MaxEval; + return &r; +} + +} // namespace micro +} // namespace ops +} // namespace tflite diff --git a/tensorflow/lite/micro/kernels/xtensa_hifi/softmax.cc b/tensorflow/lite/micro/kernels/xtensa_hifi/softmax.cc new file mode 100755 index 00000000000..3a6a0957785 --- /dev/null +++ b/tensorflow/lite/micro/kernels/xtensa_hifi/softmax.cc @@ -0,0 +1,230 @@ +/****************************************************************************** + * Copyright (C) 2019 Cadence Design Systems, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files (the + * "Software"), to use this Software with Cadence processor cores only and + * not with any other processors and platforms, subject to + * the following conditions: + * + * The above copyright notice and this permission notice shall be included + * in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + ******************************************************************************/ + +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/kernels/internal/reference/softmax.h" + +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/kernels/internal/common.h" +#include "tensorflow/lite/kernels/internal/quantization_util.h" +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/op_macros.h" +#include "xtensa_tf_micro_common.h" +namespace tflite { +namespace ops { +namespace micro { +namespace activations { +namespace { + +TfLiteStatus CalculateSoftmaxParams(TfLiteContext* context, + const TfLiteTensor* input, + TfLiteTensor* output, + const TfLiteSoftmaxParams* params, + SoftmaxParams* op_data) { + if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8) { + if (input->type == kTfLiteUInt8) { + TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteUInt8); + TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0); + } else { + TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteInt8); + if (output->type == kTfLiteInt16) { + TF_LITE_ENSURE_EQ(context, output->params.zero_point, -32768); + // NOTE: Current int16 softmax output does not require symmetric scaling + // - so no need to verify scale here. + } else { + TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteInt8); + TF_LITE_ENSURE_EQ(context, output->params.zero_point, -128); + TF_LITE_ENSURE(context, output->params.scale == 1.f / 256); + } + } + + static const int kScaledDiffIntegerBits = 5; + + int input_left_shift; + tflite::PreprocessSoftmaxScaling( + static_cast(params->beta), + static_cast(input->params.scale), kScaledDiffIntegerBits, + &op_data->input_multiplier, &input_left_shift); + op_data->input_left_shift = input_left_shift; + op_data->diff_min = + -1.0 * tflite::CalculateInputRadius(kScaledDiffIntegerBits, + op_data->input_left_shift); + } else { + TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteFloat32); + TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteFloat32); + op_data->beta = static_cast(params->beta); + } + return kTfLiteOk; +} + +} // namespace + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + return nullptr; +} + +void Free(TfLiteContext* context, void* buffer) {} + +TfLiteStatus SoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + const TfLiteTensor* input = GetInput(context, node, 0); + TF_LITE_ENSURE(context, NumDimensions(input) >= 1); + + return kTfLiteOk; +} + +// Takes a tensor and performs softmax along the last dimension. +TfLiteStatus SoftmaxFloat(TfLiteContext* context, const TfLiteTensor* input, + TfLiteTensor* output, const SoftmaxParams& op_data) { + const RuntimeShape& input_shape = GetTensorShape(input); + const float* input_data = GetTensorData(input); + const RuntimeShape& output_shape = GetTensorShape(output); + float* output_data = GetTensorData(output); + const int trailing_dim = input_shape.DimensionsCount() - 1; + const int outer_size = + MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape); + const int depth = + MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim); + + ALLOCATE_XTENSA_NNLIB_SCRATCH_MEM; + float* p_scratch = (float*)xtensa_nnlib_scratch_buf; + + if (depth * sizeof(float) > XTENSA_NNLIB_MAX_SCRATCH_SIZE) { + TF_LITE_KERNEL_LOG(context, "Softmax: insufficient scratch memory"); + return kTfLiteError; + } + + for (int i = 0; i < outer_size; ++i) { + for (int c = 0; c < depth; ++c) { + p_scratch[c] = + input_data[i * depth + c] * static_cast(op_data.beta); + } + + int err = + xa_nn_vec_softmax_f32_f32(&output_data[i * depth], p_scratch, depth); + CHECK_ERR_HIFI_NNLIB_KER(err, "xa_nn_vec_softmax_f32_f32 failed"); + } + return kTfLiteOk; +} + +TfLiteStatus SoftmaxQuantized(TfLiteContext* context, const TfLiteTensor* input, + TfLiteTensor* output, + const SoftmaxParams& op_data) { + if (input->type == kTfLiteUInt8) { + const RuntimeShape& input_shape = GetTensorShape(input); + const uint8_t* input_data = GetTensorData(input); + const RuntimeShape& output_shape = GetTensorShape(output); + uint8_t* output_data = GetTensorData(output); + const int trailing_dim = input_shape.DimensionsCount() - 1; + const int outer_size = + MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape); + const int depth = + MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim); + + ALLOCATE_XTENSA_NNLIB_SCRATCH_MEM; + void* p_scratch = (void*)xtensa_nnlib_scratch_buf; + + if (get_softmax_scratch_size(PREC_ASYM8, PREC_ASYM8, depth) > + XTENSA_NNLIB_MAX_SCRATCH_SIZE) { + TF_LITE_KERNEL_LOG(context, "Softmax: insufficient scratch memory"); + return kTfLiteError; + } + + for (int i = 0; i < outer_size; ++i) { + int err = xa_nn_vec_softmax_asym8_asym8( + &output_data[i * depth], &input_data[i * depth], op_data.diff_min, + op_data.input_left_shift, op_data.input_multiplier, depth, p_scratch); + CHECK_ERR_HIFI_NNLIB_KER(err, "xa_nn_vec_softmax_asym8_asym8 failed"); + } + } else { + if (output->type == kTfLiteInt16) { + tflite::reference_ops::Softmax( + op_data, GetTensorShape(input), GetTensorData(input), + GetTensorShape(output), GetTensorData(output)); + } else { + tflite::reference_ops::Softmax( + op_data, GetTensorShape(input), GetTensorData(input), + GetTensorShape(output), GetTensorData(output)); + } + } + return kTfLiteOk; +} + +TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) { + auto* params = static_cast(node->builtin_data); + + const TfLiteTensor* input = GetInput(context, node, 0); + TfLiteTensor* output = GetOutput(context, node, 0); + + SoftmaxParams op_data; + TF_LITE_ENSURE_STATUS( + CalculateSoftmaxParams(context, input, output, params, &op_data)); + + switch (input->type) { + case kTfLiteFloat32: { + return SoftmaxFloat(context, input, output, op_data); + } + case kTfLiteInt8: + case kTfLiteUInt8: { + return SoftmaxQuantized(context, input, output, op_data); + } + default: + TF_LITE_KERNEL_LOG( + context, + "Only float32, uint8_t and int8_t input supported currently, got %d.", + input->type); + return kTfLiteError; + } +} +} // namespace activations + +TfLiteRegistration* Register_SOFTMAX() { + static TfLiteRegistration r = {activations::Init, + activations::Free, + activations::SoftmaxPrepare, + activations::SoftmaxEval, + nullptr, + 0, + nullptr, + 0}; + return &r; +} + +} // namespace micro +} // namespace ops +} // namespace tflite diff --git a/tensorflow/lite/micro/kernels/xtensa_hifi/svdf.cc b/tensorflow/lite/micro/kernels/xtensa_hifi/svdf.cc new file mode 100644 index 00000000000..e0fa6db5b0b --- /dev/null +++ b/tensorflow/lite/micro/kernels/xtensa_hifi/svdf.cc @@ -0,0 +1,579 @@ +/****************************************************************************** + * Copyright (C) 2019 Cadence Design Systems, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files (the + * "Software"), to use this Software with Cadence processor cores only and + * not with any other processors and platforms, subject to + * the following conditions: + * + * The above copyright notice and this permission notice shall be included + * in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + ******************************************************************************/ + +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/kernels/internal/common.h" +#include "tensorflow/lite/kernels/internal/quantization_util.h" +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/op_macros.h" +#include "tensorflow/lite/micro/kernels/activation_utils.h" +#include "tensorflow/lite/micro/micro_utils.h" +#include "xtensa_tf_micro_common.h" + +namespace tflite { +namespace ops { +namespace micro { +namespace svdf { +namespace { + +// These constants represent constants specific to the hotword "OK G" model. +// They exist until (b/132070898) is fixed. +constexpr int kScratchTensorMaxSize = 64; + +struct OpData { + int32 effective_scale_1_a; + int32 effective_scale_2_a; + // b versions of each scale are kept at int since the numbers are just the + // shift value - typically between [-32, 32]. + int effective_scale_1_b; + int effective_scale_2_b; +}; + +/** + * This version of SVDF is specific to TFLite Micro. It contains the following + * differences between the TFLite version: + * + * 1.) Scratch tensor allocation - scratch tensors must be known ahead of time + * for the Micro interpreter. + * 2.) Output dimensions - the TFLite version determines output size and runtime + * and resizes the output tensor. Micro runtime does not support tensor + * resizing. + */ + +static inline TfLiteStatus ApplyTimeWeightsBiasAndActivation( + TfLiteContext* context, int batch_size, int memory_size, int num_filters, + int num_units, int rank, const TfLiteTensor* weights_time, + const TfLiteTensor* bias, TfLiteFusedActivation activation, + TfLiteTensor* activation_state, TfLiteTensor* scratch, + TfLiteTensor* output) { + float* scratch_bias = GetTensorData(scratch); + if (bias) { + const float* bias_data = GetTensorData(bias); + for (int j = 0; j < num_units; ++j) { + scratch_bias[j] = *bias_data++; + } + } else { + for (int j = 0; j < num_units; ++j) { + scratch_bias[j] = 0.0f; + } + } + int err = 0; + for (int b = 0; b < batch_size; ++b) { + const float* weights_time_vec = GetTensorData(weights_time); + const float* mat_ptr = + GetTensorData(activation_state) + b * memory_size * num_filters; + float* output_ptr_batch = GetTensorData(output) + b * num_units; + for (int j = 0; j < num_units; j++) { + err = xa_nn_matXvec_f32xf32_f32( + output_ptr_batch, mat_ptr, NULL, weights_time_vec, NULL, scratch_bias, + 1, memory_size * rank, 0, memory_size * rank, 0); + CHECK_ERR_HIFI_NNLIB_KER(err, "xa_nn_vec_matXvec_f32xf32_f32 failed"); + output_ptr_batch++; + mat_ptr += memory_size * rank; + weights_time_vec += memory_size * rank; + } + } + + // Apply activation. + for (int b = 0; b < batch_size; ++b) { + float* output_ptr_batch = GetTensorData(output) + b * num_units; + for (int i = 0; i < num_units; ++i) { + *output_ptr_batch = ActivationValFloat(activation, *output_ptr_batch); + ++output_ptr_batch; + } + } + + // Left shift the activation_state to make room for next cycle's activation. + // (alanchiao): explore collapsing this into a single loop. + for (int b = 0; b < batch_size; ++b) { + float* state_ptr_batch = + GetTensorData(activation_state) + b * memory_size * num_filters; + for (int f = 0; f < num_filters; ++f) { + // Shift the vector left: + float* batch_ptr = state_ptr_batch; + float* batch_start = state_ptr_batch + 1; + float* batch_end = state_ptr_batch + memory_size; + while (batch_start != batch_end) { + *batch_ptr++ = *batch_start++; + } + state_ptr_batch[memory_size - 1] = 0.0f; + state_ptr_batch += memory_size; + } + } + return kTfLiteOk; +} + +inline TfLiteStatus EvalFloatSVDF( + TfLiteContext* context, TfLiteNode* node, const TfLiteTensor* input, + const TfLiteTensor* weights_feature, const TfLiteTensor* weights_time, + const TfLiteTensor* bias, const TfLiteSVDFParams* params, + TfLiteTensor* scratch, TfLiteTensor* activation_state, + TfLiteTensor* output) { + const int rank = params->rank; + const int batch_size = input->dims->data[0]; + const int input_size = input->dims->data[1]; + const int num_filters = weights_feature->dims->data[0]; + const int num_units = num_filters / rank; + const int memory_size = weights_time->dims->data[1]; + + // Clear the activation (activation_state's leftmost column). + // (ghodrat): Add a test which initialize activation_state with invalid + // values in leftmost column and make sure it passes. + for (int b = 0; b < batch_size; ++b) { + float* state_ptr_batch = + GetTensorData(activation_state) + b * memory_size * num_filters; + } + + // Compute conv1d(inputs, weights_feature). + // The activation_state's rightmost column is used to save current cycle + // activation. This is achieved by starting at + // GetTensorData(activation_state)[memory_size - 1] and having the + // stride equal to memory_size. + + const float* matrix = GetTensorData(weights_feature); + const float* vector = GetTensorData(input); + float* out_scratch = GetTensorData(scratch); + /* NNLib matXvec needs a bias buffer, so using output buffer to + avoid need for extra memory, output buffer size is batch * num_units, + batch is at least 1 so we use size num_units of it */ + float* bias_scratch = GetTensorData(output); + float* result = &GetTensorData(activation_state)[memory_size - 1]; + float* result_in_batch = result; + + for (int i = 0; i < num_units; i++) bias_scratch[i] = 0.0f; + + int err = 0; + for (int i = 0; i < batch_size; i++) { + /* We are using output buffer for bias (it is needed by NNLib kernel, + so only num_units size is guaranteed, so introduced rank loop and + calling matXvec for num_units rows */ + for (int j = 0; j < rank; j++) { + err = xa_nn_matXvec_f32xf32_f32( + &out_scratch[j * num_units], &matrix[j * input_size * num_units], + NULL, &vector[i * input_size], NULL, bias_scratch, num_units, + input_size, 0, input_size, 0); + CHECK_ERR_HIFI_NNLIB_KER(err, "xa_nn_vec_matXvec_f32xf32_f32 failed"); + } + for (int j = 0; j < num_filters; ++j) { + *result_in_batch = out_scratch[j]; + result_in_batch += memory_size; + } + } + + return ApplyTimeWeightsBiasAndActivation( + context, batch_size, memory_size, num_filters, num_units, rank, + weights_time, bias, params->activation, activation_state, scratch, + output); +} + +void EvalIntegerSVDF( + TfLiteContext* context, TfLiteNode* node, const TfLiteTensor* input_tensor, + const TfLiteTensor* weights_feature_tensor, + const TfLiteTensor* weights_time_tensor, const TfLiteTensor* bias_tensor, + const TfLiteSVDFParams* params, TfLiteTensor* activation_state_tensor, + TfLiteTensor* output_tensor, int32_t scale_1_a, int scale_1_b, + int32_t scale_2_a, int scale_2_b, int32_t input_zp, int32_t output_zp) { + const int n_rank = params->rank; + const int n_batch = input_tensor->dims->data[0]; + const int n_input = input_tensor->dims->data[1]; + const int n_filter = weights_feature_tensor->dims->data[0]; + const int n_unit = n_filter / n_rank; + const int n_memory = weights_time_tensor->dims->data[1]; + + // (b/132070898): Move these temp variables to the new scratch buffer API + // when ready. + int32_t scratch_tensor[kScratchTensorMaxSize]; + int32_t scratch_output_tensor[kScratchTensorMaxSize]; + + // Rewrite last bit of state. + { + for (int b = 0; b < n_batch; ++b) { + int16_t* state_ptr_batch = + GetTensorData(activation_state_tensor) + + b * n_memory * n_filter; + for (int c = 0; c < n_filter; ++c) { + int16_t* state_ptr = state_ptr_batch + c * n_memory; + state_ptr[n_memory - 1] = 0; + } + } + } + + // Feature matmul. + { + int16_t* state = GetTensorData(activation_state_tensor); + const int8_t* input = GetTensorData(input_tensor); + const int8_t* weight_feature = + GetTensorData(weights_feature_tensor); + const int32_t output_max = std::numeric_limits::max(); + const int32_t output_min = std::numeric_limits::min(); + int16_t* result_in_batch = state + (n_memory - 1); + for (int b = 0; b < n_batch; b++) { + const int8_t* matrix_ptr = weight_feature; + for (int r = 0; r < n_filter; r++) { + int32_t dot_prod = 0; + const int8_t* vector_in_batch = input + b * n_input; + for (int c = 0; c < n_input; c++) { + dot_prod += *matrix_ptr++ * (*vector_in_batch++ - input_zp); + } + dot_prod = + MultiplyByQuantizedMultiplier(dot_prod, scale_1_a, scale_1_b); + dot_prod = std::min(std::max(output_min, dot_prod), output_max); + *result_in_batch = dot_prod; + result_in_batch += n_memory; + } + } + } + + // Time. + { + for (int b = 0; b < n_batch; ++b) { + int32_t* scratch_ptr_batch = scratch_tensor + b * n_filter; + + // Perform batched vector dot product: + const int16_t* vector1_ptr = GetTensorData(weights_time_tensor); + const int16_t* vector2_ptr = + GetTensorData(activation_state_tensor) + + b * n_memory * n_filter; + + for (int i = 0; i < n_filter; i++) { + *scratch_ptr_batch = 0; + for (int j = 0; j < n_memory; j++) { + *scratch_ptr_batch += *vector1_ptr++ * *vector2_ptr++; + } + scratch_ptr_batch++; + } + } + } + + // Reduce, add bias, rescale, activation. + { + // Add bias. + if (bias_tensor) { + // Vector batch assign: + const int32_t* bias_data = GetTensorData(bias_tensor); + for (int i = 0; i < n_batch; ++i) { + int32_t* output_ptr = scratch_output_tensor + i * n_unit; + const int32_t* bias_ptr = bias_data; + for (int j = 0; j < n_unit; ++j) { + *output_ptr++ = *bias_ptr++; + } + } + } else { + int32_t* output_ptr = scratch_output_tensor; + for (int i = 0; i < n_batch * n_unit; ++i) { + *output_ptr++ = 0; + } + } + + // Reduce. + for (int b = 0; b < n_batch; ++b) { + int32_t* output_temp_ptr = scratch_output_tensor + b * n_unit; + int32_t* scratch_ptr_batch = scratch_tensor + b * n_filter; + + // Reduction sum vector + for (int i = 0; i < n_unit; ++i) { + for (int j = 0; j < n_rank; ++j) { + output_temp_ptr[i] += *scratch_ptr_batch++; + } + } + } + + // Rescale. + const int32_t output_max = std::numeric_limits::max(); + const int32_t output_min = std::numeric_limits::min(); + for (int i = 0; i < n_batch * n_unit; ++i) { + int32_t x1 = scratch_output_tensor[i]; + int32_t x2 = MultiplyByQuantizedMultiplier(x1, scale_2_a, scale_2_b); + int32_t x3 = x2 + output_zp; + int32_t x4 = std::min(std::max(output_min, x3), output_max); + GetTensorData(output_tensor)[i] = static_cast(x4); + } + } + + // Shift state. + { + for (int b = 0; b < n_batch; ++b) { + int16_t* state_ptr_batch = + GetTensorData(activation_state_tensor) + + b * n_memory * n_filter; + for (int f = 0; f < n_filter; ++f) { + // Shift the vector left: + int16_t* batch_ptr = state_ptr_batch; + int16_t* batch_start = state_ptr_batch + 1; + int16_t* batch_end = state_ptr_batch + n_memory; + while (batch_start != batch_end) { + *batch_ptr++ = *batch_start++; + } + state_ptr_batch[n_memory - 1] = 0; + state_ptr_batch += n_memory; + } + } + } +} + +} // namespace + +// Input tensors. +constexpr int kInputTensor = 0; +constexpr int kWeightsFeatureTensor = 1; +constexpr int kWeightsTimeTensor = 2; +constexpr int kBiasTensor = 3; +// This is a variable tensor, and will be modified by this op. +constexpr int kInputActivationStateTensor = 4; + +// Output tensor. +constexpr int kOutputTensor = 0; + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + return nullptr; +} + +void Free(TfLiteContext* context, void* buffer) {} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + const auto* params = reinterpret_cast(node->builtin_data); + + // Validate Tensor Inputs (dtype depends on quantization): + // [0] = Input, {2, batch_size, input_size} + // [1] = Weights Feature, {2, num_filters, input_size} + // [2] = Weights Time, {2, num_filters, memory_size} + // [3] = Bias (optional), {1, num_units} + // [4] = Activation State (variable), + // {2, batch_size, memory_size * num_filters} + + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + const TfLiteTensor* weights_feature = + GetInput(context, node, kWeightsFeatureTensor); + const TfLiteTensor* weights_time = + GetInput(context, node, kWeightsTimeTensor); + const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor); + const TfLiteTensor* activation_state = + GetInput(context, node, kInputActivationStateTensor); + + // Define input constants based on input tensor definition above: + const int rank = params->rank; + const int input_size = input->dims->data[1]; + const int batch_size = input->dims->data[0]; + const int num_filters = weights_feature->dims->data[0]; + TF_LITE_ENSURE_EQ(context, num_filters % rank, 0); + const int num_units = num_filters / rank; + const int memory_size = weights_time->dims->data[1]; + + const bool is_full_integer = input->type == kTfLiteInt8; + + // Validate Input Tensor: + TF_LITE_ENSURE(context, + input->type == kTfLiteFloat32 || input->type == kTfLiteInt8); + TF_LITE_ENSURE_EQ(context, NumDimensions(input), 2); + + // Validate Tensor Output: + // [0] = float/int8, {2, batch_size, num_units} + TF_LITE_ENSURE_EQ(context, node->outputs->size, 1); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + TF_LITE_ENSURE_EQ(context, NumDimensions(output), 2); + TF_LITE_ENSURE_EQ(context, output->dims->data[0], batch_size); + TF_LITE_ENSURE_EQ(context, output->dims->data[1], num_units); + + // Validate Weights Feature Input Tensor: + TF_LITE_ENSURE_EQ(context, NumDimensions(weights_feature), 2); + TF_LITE_ENSURE_EQ(context, weights_feature->dims->data[1], input_size); + + // Validate Weights Time Input Tensor: + TF_LITE_ENSURE_EQ(context, NumDimensions(weights_time), 2); + TF_LITE_ENSURE_EQ(context, weights_time->dims->data[0], num_filters); + TF_LITE_ENSURE_EQ(context, weights_time->dims->data[1], memory_size); + + // Validate Optional Bias Input Tensor: + if (bias) { + TF_LITE_ENSURE_EQ(context, bias->dims->data[0], num_units); + } + + // Validate Activation State Input Tensor: + TF_LITE_ENSURE_EQ(context, NumDimensions(activation_state), 2); + TF_LITE_ENSURE_EQ(context, activation_state->dims->data[0], batch_size); + TF_LITE_ENSURE_EQ(context, activation_state->dims->data[1], + memory_size * num_filters); + + if (is_full_integer) { + TF_LITE_ENSURE_EQ(context, node->inputs->size, 5); + + TF_LITE_ENSURE_EQ(context, weights_feature->type, kTfLiteInt8); + TF_LITE_ENSURE_EQ(context, weights_time->type, kTfLiteInt16); + + if (bias) { + TF_LITE_ENSURE_EQ(context, bias->type, kTfLiteInt32); + } + + TF_LITE_ENSURE_EQ(context, activation_state->type, kTfLiteInt16); + + // Validate Scratch Tensors: + // [0] = (shared - see float block below for usage) + // [1] = Output Temp, int8_t, {2, num_units, batch_size} + // (b/132070898): Scratch values are used as stack variables in + // EvalIntegerSVDF(). + + // Validate output tensor: + TF_LITE_ENSURE_EQ(context, output->type, kTfLiteInt8); + } else { + TF_LITE_ENSURE_EQ(context, node->inputs->size, 6); + + // Validate Input Tensor dtypes: + TF_LITE_ENSURE_EQ(context, weights_feature->type, kTfLiteFloat32); + TF_LITE_ENSURE_EQ(context, weights_time->type, kTfLiteFloat32); + TF_LITE_ENSURE_EQ(context, activation_state->type, kTfLiteFloat32); + + if (bias) { + TF_LITE_ENSURE_EQ(context, bias->type, kTfLiteFloat32); + } + + // Validate shared Scratch Tensor: + // [0] = Holds dot-product of time-forward calculations in + // ApplyTimeWeightsBiasAndActivation(): + // float/int32, {2, batch_size, num_filters} + // (b/132070898): Use input tensor as variable until scratch tensor + // allocation has been implemented (b/132070898) TfLiteTensor* + // scratch_tensor = GetTemporary(context, node, 0); + TfLiteTensor* scratch_tensor = &context->tensors[node->inputs->data[5]]; + TF_LITE_ENSURE_EQ(context, scratch_tensor->type, kTfLiteFloat32); + + TF_LITE_ENSURE_EQ(context, NumDimensions(scratch_tensor), 2); + TF_LITE_ENSURE_EQ(context, scratch_tensor->dims->data[0], batch_size); + TF_LITE_ENSURE_EQ(context, scratch_tensor->dims->data[1], num_filters); + + // Full-float SVDF only uses the one shared scratch tensor (see above for + // usage). + // (b/132070898): Use input tensor as variable until scratch tensor + // allocation has been implemented. + // TF_LITE_ENSURE_EQ(context, node->temporaries->size, 1); + TF_LITE_ENSURE_EQ(context, output->type, kTfLiteFloat32); + } + + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); + + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + const TfLiteTensor* weights_feature = + GetInput(context, node, kWeightsFeatureTensor); + const TfLiteTensor* weights_time = + GetInput(context, node, kWeightsTimeTensor); + const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor); + TfLiteTensor* activation_state = + GetVariableInput(context, node, kInputActivationStateTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + const bool is_full_integer = input->type == kTfLiteInt8; + + switch (weights_feature->type) { + case kTfLiteFloat32: { + // (b/132070898): Use input tensor as variable until scratch tensor + // allocation has been implemented. TfLiteTensor* scratch = + // GetTemporary(context, node, /*index=*/0); + TfLiteTensor* scratch = &context->tensors[node->inputs->data[5]]; + return EvalFloatSVDF(context, node, input, weights_feature, weights_time, + bias, params, scratch, activation_state, output); + break; + } + + case kTfLiteInt8: { + if (is_full_integer) { + // (b/132070898): Store these values in ::Prepare() instead of + // ::Eval(): + // Calculate effective scales. + OpData op_data; + auto* input_params = reinterpret_cast( + input->quantization.params); + auto* weights_feature_params = + reinterpret_cast( + weights_feature->quantization.params); + auto* state_params = reinterpret_cast( + activation_state->quantization.params); + auto* weight_time_params = reinterpret_cast( + weights_time->quantization.params); + auto* output_params = reinterpret_cast( + output->quantization.params); + const double effective_scale_1 = + static_cast(input_params->scale->data[0] * + weights_feature_params->scale->data[0] / + state_params->scale->data[0]); + const double effective_scale_2 = static_cast( + state_params->scale->data[0] * weight_time_params->scale->data[0] / + output_params->scale->data[0]); + QuantizeMultiplier(effective_scale_1, &op_data.effective_scale_1_a, + &op_data.effective_scale_1_b); + QuantizeMultiplier(effective_scale_2, &op_data.effective_scale_2_a, + &op_data.effective_scale_2_b); + + TF_LITE_ENSURE_EQ(context, params->activation, kTfLiteActRelu); + EvalIntegerSVDF( + context, node, input, weights_feature, weights_time, bias, params, + activation_state, output, op_data.effective_scale_1_a, + op_data.effective_scale_1_b, op_data.effective_scale_2_a, + op_data.effective_scale_2_b, input->params.zero_point, + output->params.zero_point); + return kTfLiteOk; + } + break; + } + + default: + TF_LITE_KERNEL_LOG(context, "Type %s not currently supported.", + TfLiteTypeGetName(weights_feature->type)); + return kTfLiteError; + } + return kTfLiteOk; +} + +} // namespace svdf + +TfLiteRegistration* Register_SVDF() { + static TfLiteRegistration r = {}; + r.init = svdf::Init; + r.free = svdf::Free; + r.prepare = svdf::Prepare; + r.invoke = svdf::Eval; + return &r; +} + +} // namespace micro +} // namespace ops +} // namespace tflite diff --git a/tensorflow/lite/micro/kernels/xtensa_hifi/xtensa_tf_micro_common.h b/tensorflow/lite/micro/kernels/xtensa_hifi/xtensa_tf_micro_common.h new file mode 100755 index 00000000000..cf741288d84 --- /dev/null +++ b/tensorflow/lite/micro/kernels/xtensa_hifi/xtensa_tf_micro_common.h @@ -0,0 +1,80 @@ +/****************************************************************************** + * Copyright (C) 2019 Cadence Design Systems, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files (the + * "Software"), to use this Software with Cadence processor cores only and + * not with any other processors and platforms, subject to + * the following conditions: + * + * The above copyright notice and this permission notice shall be included + * in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + ******************************************************************************/ + +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef __XTENSA_TF_MICRO_COMMON__ +#define __XTENSA_TF_MICRO_COMMON__ + +#include "xa_nnlib_api.h" +#include "xa_nnlib_standards.h" + +#define CHECK_ERR_HIFI_NNLIB_KER(ret, err_msg) \ + if (ret != 0) { \ + TF_LITE_KERNEL_LOG(context, err_msg); \ + return kTfLiteError; \ + } + +#ifndef XTENSA_NNLIB_MAX_SCRATCH_SIZE +#define XTENSA_NNLIB_MAX_SCRATCH_SIZE (70 * 1024) +#endif + +#define ALLOCATE_XTENSA_NNLIB_SCRATCH_MEM \ + uint8_t xtensa_nnlib_scratch_buf[XTENSA_NNLIB_MAX_SCRATCH_SIZE]; + +#define MIN(a, b) (a) < (b) ? (a) : (b); +#define MAX(a, b) (a) > (b) ? (a) : (b); + +#define ACTIVATION_MIN_MAX(data_type, out, inp, min, max) \ + { \ + data_type temp = MAX(inp, min); \ + out = MIN(temp, max); \ + } + +#define ACTIVATION_MIN_MAX_F32(out, inp, min, max) \ + { \ + float temp = MAX(inp, min); \ + out = MIN(temp, max); \ + } + +#define ACTIVATION_MIN_MAX_ASYM8(out, inp, min, max) \ + { \ + int32_t temp = MAX((int32_t)inp, min); \ + out = (uint8_t)MIN(temp, max); \ + } + +#define ALIGNED_SIZE(x, bytes) (((x) + (bytes - 1)) & (~(bytes - 1))) +#define ALIGN_PTR(x, bytes) ((((unsigned)(x)) + (bytes - 1)) & (~(bytes - 1))) + +#endif /* __XTENSA_TF_MICRO_COMMON__ */ diff --git a/tensorflow/lite/micro/kernels/xtensa_hifimini/quantize.cc b/tensorflow/lite/micro/kernels/xtensa_hifimini/quantize.cc index 78468984961..40102503f97 100644 --- a/tensorflow/lite/micro/kernels/xtensa_hifimini/quantize.cc +++ b/tensorflow/lite/micro/kernels/xtensa_hifimini/quantize.cc @@ -126,14 +126,15 @@ void Free(TfLiteContext* context, void* buffer) {} TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TfLiteTensor* output = GetOutput(context, node, 0); + const TfLiteTensor* input = GetInput(context, node, 0); // TODO(b/132070898): Use statically slotted OpData structures until a // scratch memory API is ready. OpData* op_data = &kStaticOpData[kStaticOpDataCounter++]; node->user_data = op_data; - op_data->scale_multiplier = - xtensa::hifimini::CreateQConstantForInt24(0, 1.f / output->params.scale); + op_data->scale_multiplier = xtensa::hifimini::CreateQConstantForInt24( + 0, input->params.scale / output->params.scale); return kTfLiteOk; } @@ -146,7 +147,6 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { tflite::QuantizationParams op_params; op_params.zero_point = output->params.zero_point; - op_params.scale = static_cast(output->params.scale); if (input->type != kTfLiteInt16 && output->type != kTfLiteInt8) { TF_LITE_KERNEL_LOG(context, "Input %s, output %s not supported.", diff --git a/tensorflow/lite/micro/kernels/xtensa_hifimini/softmax.cc b/tensorflow/lite/micro/kernels/xtensa_hifimini/softmax.cc index 4336dccbb46..58159b1eef4 100644 --- a/tensorflow/lite/micro/kernels/xtensa_hifimini/softmax.cc +++ b/tensorflow/lite/micro/kernels/xtensa_hifimini/softmax.cc @@ -117,21 +117,14 @@ inline void Softmax(const SoftmaxParams& params, namespace activations { namespace { -struct OpData { - int32_t input_multiplier = 0; - int input_left_shift = 0; - int32_t input_range_radius = 0; - int diff_min = 0; -}; - // This size will work for both the hotword (1) and ambient music (0): -static OpData kStaticOpData; +static SoftmaxParams kStaticOpData; TfLiteStatus CalculateSoftmaxOpData(TfLiteContext* context, const TfLiteTensor* input, TfLiteTensor* output, const TfLiteSoftmaxParams* params, - OpData* data) { + SoftmaxParams* op_data) { if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8) { if (input->type == kTfLiteUInt8) { TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0); @@ -148,12 +141,14 @@ TfLiteStatus CalculateSoftmaxOpData(TfLiteContext* context, static const int kScaledDiffIntegerBits = 5; + int input_left_shift; tflite::PreprocessSoftmaxScaling( - static_cast(params->beta), - static_cast(input->params.scale), kScaledDiffIntegerBits, - &data->input_multiplier, &data->input_left_shift); - data->diff_min = -1.0 * tflite::CalculateInputRadius( - kScaledDiffIntegerBits, data->input_left_shift); + params->beta, input->params.scale, kScaledDiffIntegerBits, + &op_data->input_multiplier, &input_left_shift); + op_data->input_left_shift = input_left_shift; + op_data->diff_min = + -1.0 * tflite::CalculateInputRadius(kScaledDiffIntegerBits, + op_data->input_left_shift); } return kTfLiteOk; } @@ -161,12 +156,7 @@ TfLiteStatus CalculateSoftmaxOpData(TfLiteContext* context, } // namespace void SoftmaxQuantized(const TfLiteTensor* input, TfLiteTensor* output, - TfLiteSoftmaxParams* params, OpData* data) { - SoftmaxParams op_params; - op_params.input_multiplier = data->input_multiplier; - op_params.input_left_shift = data->input_left_shift; - op_params.diff_min = data->diff_min; - + const SoftmaxParams& op_params) { if (output->type == kTfLiteInt16) { xtensa::hifimini::Softmax( op_params, GetTensorShape(input), GetTensorData(input), @@ -186,7 +176,7 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) { void Free(TfLiteContext* context, void* buffer) {} TfLiteStatus SoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) { - auto* params = reinterpret_cast(node->builtin_data); + auto* params = static_cast(node->builtin_data); TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); @@ -194,27 +184,26 @@ TfLiteStatus SoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) { TfLiteTensor* output = GetOutput(context, node, 0); TF_LITE_ENSURE(context, NumDimensions(input) >= 1); - // TODO(b/132070898): Use statically slotted OpData structures until a + // TODO(b/132070898): Use statically slotted SoftmaxParams structures until a // scratch memory API is ready. - OpData* op_data = &kStaticOpData; - node->user_data = op_data; + SoftmaxParams* op_params = &kStaticOpData; + node->user_data = op_params; TF_LITE_ENSURE_STATUS( - CalculateSoftmaxOpData(context, input, output, params, op_data)); + CalculateSoftmaxOpData(context, input, output, params, op_params)); return kTfLiteOk; } TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) { - auto* params = reinterpret_cast(node->builtin_data); - auto* op_data = reinterpret_cast(node->user_data); + auto* op_params = static_cast(node->user_data); const TfLiteTensor* input = GetInput(context, node, 0); TfLiteTensor* output = GetOutput(context, node, 0); switch (input->type) { case kTfLiteInt8: { - SoftmaxQuantized(input, output, params, op_data); + SoftmaxQuantized(input, output, *op_params); return kTfLiteOk; } default: diff --git a/tensorflow/lite/micro/memory_planner/BUILD b/tensorflow/lite/micro/memory_planner/BUILD index 33c9869afec..9e53fb1f874 100644 --- a/tensorflow/lite/micro/memory_planner/BUILD +++ b/tensorflow/lite/micro/memory_planner/BUILD @@ -5,6 +5,7 @@ load( load( "//tensorflow/lite/micro:build_def.bzl", "cc_library", + "micro_copts", ) package( @@ -18,11 +19,7 @@ cc_library( "memory_planner.h", ], build_for_embedded = True, - copts = [ - "-Werror", - "-Wdouble-promotion", - "-Wsign-compare", - ], + copts = micro_copts(), deps = [ "//tensorflow/lite/c:common", "//tensorflow/lite/core/api", @@ -37,11 +34,7 @@ cc_library( hdrs = [ "linear_memory_planner.h", ], - copts = [ - "-Werror", - "-Wdouble-promotion", - "-Wsign-compare", - ], + copts = micro_copts(), deps = [ ":memory_planner", "//tensorflow/lite/c:common", @@ -58,11 +51,7 @@ cc_library( "greedy_memory_planner.h", ], build_for_embedded = True, - copts = [ - "-Werror", - "-Wdouble-promotion", - "-Wsign-compare", - ], + copts = micro_copts(), deps = [ ":memory_planner", "//tensorflow/lite/c:common", diff --git a/tensorflow/lite/micro/memory_planner/greedy_memory_planner.cc b/tensorflow/lite/micro/memory_planner/greedy_memory_planner.cc index f96b249f1a8..faea73e9169 100644 --- a/tensorflow/lite/micro/memory_planner/greedy_memory_planner.cc +++ b/tensorflow/lite/micro/memory_planner/greedy_memory_planner.cc @@ -331,7 +331,7 @@ void GreedyMemoryPlanner::PrintMemoryPlan(ErrorReporter* error_reporter) { } } line[kLineWidth] = 0; - TF_LITE_REPORT_ERROR(error_reporter, "%s", line); + TF_LITE_REPORT_ERROR(error_reporter, "%s", (const char*)line); } } diff --git a/tensorflow/lite/micro/testing/BUILD b/tensorflow/lite/micro/testing/BUILD index 42f25f0e8b0..1d8c0745e4a 100644 --- a/tensorflow/lite/micro/testing/BUILD +++ b/tensorflow/lite/micro/testing/BUILD @@ -1,3 +1,8 @@ +load( + "//tensorflow/lite/micro/testing:micro_test.bzl", + "tflite_micro_cc_test", +) + package( default_visibility = ["//visibility:public"], licenses = ["notice"], # Apache 2.0 @@ -23,6 +28,16 @@ cc_library( ], ) +tflite_micro_cc_test( + name = "util_test", + srcs = [ + "util_test.cc", + ], + deps = [ + ":micro_test", + ], +) + cc_library( name = "micro_benchmark", hdrs = [ diff --git a/tensorflow/lite/micro/testing/micro_test.h b/tensorflow/lite/micro/testing/micro_test.h index f46bb0f4cfc..67fe86b0068 100644 --- a/tensorflow/lite/micro/testing/micro_test.h +++ b/tensorflow/lite/micro/testing/micro_test.h @@ -109,9 +109,11 @@ extern tflite::ErrorReporter* reporter; #define TF_LITE_MICRO_EXPECT_EQ(x, y) \ do { \ - if ((x) != (y)) { \ + auto vx = x; \ + auto vy = y; \ + if ((vx) != (vy)) { \ micro_test::reporter->Report(#x " == " #y " failed at %s:%d (%d vs %d)", \ - __FILE__, __LINE__, (x), (y)); \ + __FILE__, __LINE__, (vx), (vy)); \ micro_test::did_test_fail = true; \ } \ } while (false) @@ -142,15 +144,17 @@ extern tflite::ErrorReporter* reporter; } \ } while (false) -#define TF_LITE_MICRO_EXPECT_NEAR(x, y, epsilon) \ - do { \ - auto delta = ((x) > (y)) ? ((x) - (y)) : ((y) - (x)); \ - if (delta > epsilon) { \ - micro_test::reporter->Report( \ - #x " (%f) near " #y " (%f) failed at %s:%d", static_cast(x), \ - static_cast(y), __FILE__, __LINE__); \ - micro_test::did_test_fail = true; \ - } \ +#define TF_LITE_MICRO_EXPECT_NEAR(x, y, epsilon) \ + do { \ + auto vx = (x); \ + auto vy = (y); \ + auto delta = ((vx) > (vy)) ? ((vx) - (vy)) : ((vy) - (vx)); \ + if (delta > epsilon) { \ + micro_test::reporter->Report( \ + #x " (%f) near " #y " (%f) failed at %s:%d", static_cast(vx), \ + static_cast(vy), __FILE__, __LINE__); \ + micro_test::did_test_fail = true; \ + } \ } while (false) #define TF_LITE_MICRO_EXPECT_GT(x, y) \ diff --git a/tensorflow/lite/micro/testing/test_xtensa_hifi_binary.sh b/tensorflow/lite/micro/testing/test_xtensa_hifi_binary.sh new file mode 100755 index 00000000000..50415e7cf11 --- /dev/null +++ b/tensorflow/lite/micro/testing/test_xtensa_hifi_binary.sh @@ -0,0 +1,59 @@ +#!/bin/bash -e +# ============================================================================== +# Copyright (C) 2019 Cadence Design Systems, Inc. +# +# Permission is hereby granted, free of charge, to any person obtaining +# a copy of this software and associated documentation files (the +# "Software"), to use this Software with Cadence processor cores only and +# not with any other processors and platforms, subject to +# the following conditions: +# +# The above copyright notice and this permission notice shall be included +# in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +# CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +# ============================================================================== + +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +# Tests an Xtensa binary by parsing the log output. +# +# First argument is the binary location. +# Second argument is a regular expression that's required to be in the output +# logs for the test to pass. + +declare -r ROOT_DIR=`pwd` +declare -r TEST_TMPDIR=/tmp/test_xtensa_hifi_binary/ +declare -r MICRO_LOG_PATH=${TEST_TMPDIR}/$1 +declare -r MICRO_LOG_FILENAME=${MICRO_LOG_PATH}/logs.txt +mkdir -p ${MICRO_LOG_PATH} + +xt-run $1 2>&1 | tee ${MICRO_LOG_FILENAME} + +if grep -q "$2" ${MICRO_LOG_FILENAME} +then + echo "$1: PASS" + exit 0 +else + echo "$1: FAIL - '$2' not found in logs." + exit 1 +fi diff --git a/tensorflow/lite/micro/kernels/cmsis-nn/scratch_buffer.h b/tensorflow/lite/micro/testing/util_test.cc similarity index 52% rename from tensorflow/lite/micro/kernels/cmsis-nn/scratch_buffer.h rename to tensorflow/lite/micro/testing/util_test.cc index ba63cdfe90b..f4eb28e121a 100644 --- a/tensorflow/lite/micro/kernels/cmsis-nn/scratch_buffer.h +++ b/tensorflow/lite/micro/testing/util_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,14 +13,18 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_LITE_MICRO_KERNELS_CMSIS_NN_SCRATCH_BUFFER_H_ -#define TENSORFLOW_LITE_MICRO_KERNELS_CMSIS_NN_SCRATCH_BUFFER_H_ +#include "tensorflow/lite/micro/testing/micro_test.h" +#include "tensorflow/lite/micro/testing/test_utils.h" -#include "tensorflow/lite/c/common.h" +TF_LITE_MICRO_TESTS_BEGIN -// todo: remove this function once context->AllocateTemporaryTensor() is -// implemented. -TfLiteStatus get_cmsis_scratch_buffer(TfLiteContext* context, int16_t** buf, - int32_t buf_size); +TF_LITE_MICRO_TEST(ArgumentsExecutedOnlyOnce) { + float count = 0.; + // Make sure either argument is executed once after macro expansion. + TF_LITE_MICRO_EXPECT_NEAR(0, count++, 0.1); + TF_LITE_MICRO_EXPECT_NEAR(1, count++, 0.1); + TF_LITE_MICRO_EXPECT_NEAR(count++, 2, 0.1); + TF_LITE_MICRO_EXPECT_NEAR(count++, 3, 0.1); +} -#endif // TENSORFLOW_LITE_MICRO_KERNELS_CMSIS_NN_SCRATCH_BUFFER_H_ +TF_LITE_MICRO_TESTS_END diff --git a/tensorflow/lite/micro/tools/make/Makefile b/tensorflow/lite/micro/tools/make/Makefile index e78979032a2..4e8c2ae5758 100644 --- a/tensorflow/lite/micro/tools/make/Makefile +++ b/tensorflow/lite/micro/tools/make/Makefile @@ -2,8 +2,11 @@ ifneq (3.82,$(firstword $(sort $(MAKE_VERSION) 3.82))) $(error "Requires make version 3.82 or later (current is $(MAKE_VERSION))") endif +# root directory of tensorflow +TENSORFLOW_ROOT := MAKEFILE_DIR := tensorflow/lite/micro/tools/make + # Pull in some convenience functions. include $(MAKEFILE_DIR)/helper_functions.inc @@ -68,8 +71,10 @@ MICROLITE_LIBS := -lm CXXFLAGS := -std=c++11 -DTF_LITE_STATIC_MEMORY CCFLAGS := -std=c11 -DTF_LITE_STATIC_MEMORY ARFLAGS := -r + +# override these in the makefile.inc for specific compiler targets TARGET_TOOLCHAIN_PREFIX := -CC_PREFIX := +TARGET_TOOLCHAIN_ROOT := ifeq ($(BUILD_TYPE), debug) CXXFLAGS += -DDEBUG -g @@ -127,7 +132,7 @@ tensorflow/lite/core/api/error_reporter.h \ tensorflow/lite/core/api/flatbuffer_conversions.h \ tensorflow/lite/core/api/op_resolver.h \ tensorflow/lite/core/api/tensor_utils.h \ -tensorflow/lite/experimental/ruy/profiler/instrumentation.h \ +tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h \ tensorflow/lite/kernels/internal/common.h \ tensorflow/lite/kernels/internal/compatibility.h \ tensorflow/lite/kernels/internal/optimized/neon_check.h \ @@ -165,7 +170,7 @@ tensorflow/lite/kernels/internal/reference/sub.h \ tensorflow/lite/kernels/internal/reference/logistic.h \ tensorflow/lite/kernels/internal/reference/strided_slice.h \ tensorflow/lite/kernels/internal/reference/integer_ops/pooling.h \ -tensorflow/lite/kernels/internal/round.h \ +tensorflow/lite/kernels/internal/cppmath.h \ tensorflow/lite/kernels/internal/strided_slice_logic.h \ tensorflow/lite/kernels/internal/tensor.h \ tensorflow/lite/kernels/internal/tensor_ctypes.h \ @@ -190,8 +195,8 @@ third_party/flatbuffers/include/flatbuffers/flatbuffers.h \ third_party/flatbuffers/LICENSE.txt MAKE_PROJECT_FILES := \ - README_MAKE.md \ Makefile \ + README_MAKE.md \ .vscode/tasks.json MBED_PROJECT_FILES := \ @@ -249,9 +254,9 @@ PRJDIR := $(GENDIR)prj/ MICROLITE_LIB_PATH := $(LIBDIR)$(MICROLITE_LIB_NAME) -CXX := $(CC_PREFIX)${TARGET_TOOLCHAIN_PREFIX}${CXX_TOOL} -CC := $(CC_PREFIX)${TARGET_TOOLCHAIN_PREFIX}${CC_TOOL} -AR := $(CC_PREFIX)${TARGET_TOOLCHAIN_PREFIX}${AR_TOOL} +CXX := $(TARGET_TOOLCHAIN_ROOT)${TARGET_TOOLCHAIN_PREFIX}${CXX_TOOL} +CC := $(TARGET_TOOLCHAIN_ROOT)${TARGET_TOOLCHAIN_PREFIX}${CC_TOOL} +AR := $(TARGET_TOOLCHAIN_ROOT)${TARGET_TOOLCHAIN_PREFIX}${AR_TOOL} # Load the examples. include $(MICRO_LITE_EXAMPLE_TESTS) @@ -309,7 +314,7 @@ $(BINDIR)%.test_target: $(BINDIR)%_test # snease: Add %.bin rule here since BINDIR is now defined # These are microcontroller-specific rules for converting the ELF output # of the linker into a binary image that can be loaded directly. -OBJCOPY := $(TARGET_TOOLCHAIN_PREFIX)objcopy +OBJCOPY := ${TARGET_TOOLCHAIN_ROOT}$(TARGET_TOOLCHAIN_PREFIX)objcopy $(BINDIR)%.bin: $(BINDIR)% @mkdir -p $(dir $@) $(OBJCOPY) $< $@ -O binary diff --git a/tensorflow/lite/micro/tools/make/ext_libs/cmsis.inc b/tensorflow/lite/micro/tools/make/ext_libs/cmsis.inc index b4d6e505650..cfd87089a84 100644 --- a/tensorflow/lite/micro/tools/make/ext_libs/cmsis.inc +++ b/tensorflow/lite/micro/tools/make/ext_libs/cmsis.inc @@ -21,13 +21,6 @@ ifneq ($(filter cmsis-nn,$(ALL_TAGS)),) THIRD_PARTY_CC_HDRS += \ $(call recursive_find,$(CMSIS_PATH)/CMSIS/Core/Include,*.h) - # todo: remove the two lines below once context->AllocateTemporaryTensor() - # is implemented. - MICROLITE_CC_HDRS += \ - tensorflow/lite/micro/kernels/cmsis-nn/scratch_buffer.h - MICROLITE_CC_SRCS += \ - tensorflow/lite/micro/kernels/cmsis-nn/scratch_buffer.cc - INCLUDES += -I$(CMSIS_PATH)/CMSIS/Core/Include \ -I$(CMSIS_PATH)/CMSIS/NN/Include \ -I$(CMSIS_PATH)/CMSIS/DSP/Include diff --git a/tensorflow/lite/micro/tools/make/ext_libs/xtensa_hifi_nn_library.inc b/tensorflow/lite/micro/tools/make/ext_libs/xtensa_hifi_nn_library.inc new file mode 100644 index 00000000000..bd79d9cacca --- /dev/null +++ b/tensorflow/lite/micro/tools/make/ext_libs/xtensa_hifi_nn_library.inc @@ -0,0 +1,67 @@ +ifneq ($(filter xtensa_hifi, $(ALL_TAGS)),) + + XTENSA_PATH = $(MAKEFILE_DIR)/downloads + + ifneq (,$(filter hifi4%, $(TARGET_ARCH))) + + CCFLAGS += -DNNLIB_V2 \ + -DXTENSA_NNLIB_MAX_SCRATCH_SIZE=70*1024 + + CXXFLAGS += -DNNLIB_V2 \ + -DXTENSA_NNLIB_MAX_SCRATCH_SIZE=70*1024 + + MICROLITE_CC_SRCS += \ + $(XTENSA_PATH)/xa_nnlib/algo/kernels/activations/hifi4/xa_nn_activations_f32_f32.c \ + $(XTENSA_PATH)/xa_nnlib/algo/kernels/activations/hifi4/xa_nn_activations_asym8_asym8.c \ + $(XTENSA_PATH)/xa_nnlib/algo/kernels/activations/hifi4/xa_nn_activations_32_16.c \ + $(XTENSA_PATH)/xa_nnlib/algo/kernels/activations/hifi4/xa_nn_activations_32_8.c \ + $(XTENSA_PATH)/xa_nnlib/algo/kernels/activations/hifi4/xa_nn_softmax_asym8_asym8.c \ + $(XTENSA_PATH)/xa_nnlib/algo/kernels/basic/hifi4/xa_nn_floor_f32.c \ + $(XTENSA_PATH)/xa_nnlib/algo/kernels/cnn/hifi4/xa_nn_conv2d_std_circ_buf.c \ + $(XTENSA_PATH)/xa_nnlib/algo/kernels/cnn/hifi4/xa_nn_conv2d_std_asym8xasym8.c \ + $(XTENSA_PATH)/xa_nnlib/algo/kernels/cnn/hifi4/xa_nn_conv2d_std_f32.c \ + $(XTENSA_PATH)/xa_nnlib/algo/kernels/cnn/hifi4/xa_nn_matXvec_asym8xasym8_asym8_circ.c \ + $(XTENSA_PATH)/xa_nnlib/algo/kernels/cnn/hifi4/xa_nn_matXvec_f32_circ.c \ + $(XTENSA_PATH)/xa_nnlib/algo/kernels/cnn/hifi4/xa_nn_conv2d_depthwise.c \ + $(XTENSA_PATH)/xa_nnlib/algo/kernels/cnn/hifi4/xa_nn_conv2d_depthwise_f32.c \ + $(XTENSA_PATH)/xa_nnlib/algo/kernels/cnn/hifi4/xa_nn_conv2d_depthwise_asym8xasym8.c \ + $(XTENSA_PATH)/xa_nnlib/algo/kernels/cnn/hifi4/xa_nn_circ_buf.c \ + $(XTENSA_PATH)/xa_nnlib/algo/kernels/fc/hifi4/xa_nn_fully_connected.c \ + $(XTENSA_PATH)/xa_nnlib/algo/kernels/matXvec/hifi4/xa_nn_matXvec_f32.c \ + $(XTENSA_PATH)/xa_nnlib/algo/kernels/matXvec/hifi4/xa_nn_matXvec_16x16.c \ + $(XTENSA_PATH)/xa_nnlib/algo/kernels/matXvec/hifi4/xa_nn_matXvec_8x16.c \ + $(XTENSA_PATH)/xa_nnlib/algo/kernels/matXvec/hifi4/xa_nn_matXvec_8x8.c \ + $(XTENSA_PATH)/xa_nnlib/algo/kernels/matXvec/hifi4/xa_nn_matXvec_asym8xasym8.c \ + $(XTENSA_PATH)/xa_nnlib/algo/kernels/pool/hifi4/xa_nn_avgpool.c \ + $(XTENSA_PATH)/xa_nnlib/algo/kernels/pool/hifi4/xa_nn_avgpool_f32.c \ + $(XTENSA_PATH)/xa_nnlib/algo/kernels/pool/hifi4/xa_nn_avgpool_asym8.c \ + $(XTENSA_PATH)/xa_nnlib/algo/kernels/pool/hifi4/xa_nn_maxpool.c \ + $(XTENSA_PATH)/xa_nnlib/algo/kernels/pool/hifi4/xa_nn_maxpool_f32.c \ + $(XTENSA_PATH)/xa_nnlib/algo/kernels/pool/hifi4/xa_nn_maxpool_asym8.c \ + $(XTENSA_PATH)/xa_nnlib/algo/kernels/pool/hifi4/xa_nn_avgpool_f32_nhwc.c \ + $(XTENSA_PATH)/xa_nnlib/algo/kernels/pool/hifi4/xa_nn_avgpool_asym8_nhwc.c \ + $(XTENSA_PATH)/xa_nnlib/algo/kernels/pool/hifi4/xa_nn_maxpool_f32_nhwc.c \ + $(XTENSA_PATH)/xa_nnlib/algo/kernels/pool/hifi4/xa_nn_maxpool_asym8_nhwc.c \ + $(XTENSA_PATH)/xa_nnlib/algo/kernels/pool/hifi4/xa_nn_inv_256_tbl.c \ + $(XTENSA_PATH)/xa_nnlib/algo/ndsp/hifi4/src/vec_sigmoidf_hifi4.c \ + $(XTENSA_PATH)/xa_nnlib/algo/ndsp/hifi4/src/vec_tanhf_hifi4.c \ + $(XTENSA_PATH)/xa_nnlib/algo/ndsp/hifi4/src/vec_reluf_hifi4.c \ + $(XTENSA_PATH)/xa_nnlib/algo/ndsp/hifi4/src/vec_softmaxf_hifi4.c \ + $(XTENSA_PATH)/xa_nnlib/algo/ndsp/hifi4/src/vec_alognf_hifi4.c \ + $(XTENSA_PATH)/xa_nnlib/algo/ndsp/hifi4/src/scl_sigmoidf_hifi4.c \ + $(XTENSA_PATH)/xa_nnlib/algo/ndsp/hifi4/src/scl_tanhf_hifi4.c \ + $(XTENSA_PATH)/xa_nnlib/algo/ndsp/hifi4/src/expf_tbl.c \ + $(XTENSA_PATH)/xa_nnlib/algo/ndsp/hifi4/src/pow2f_tbl.c \ + $(XTENSA_PATH)/xa_nnlib/algo/ndsp/hifi4/src/inff_tbl.c \ + $(XTENSA_PATH)/xa_nnlib/algo/ndsp/hifi4/src/tanhf_tbl.c \ + $(XTENSA_PATH)/xa_nnlib/algo/ndsp/hifi4/src/nanf_tbl.c \ + + INCLUDES += -I$(XTENSA_PATH)/xa_nnlib/algo/kernels/ \ + -I$(XTENSA_PATH)/xa_nnlib/include/nnlib/ \ + -I$(XTENSA_PATH)/xa_nnlib/include/ \ + -I$(XTENSA_PATH)/xa_nnlib/algo/common/include/ \ + -I$(XTENSA_PATH)/xa_nnlib/algo/ndsp/hifi4/include/ \ + + endif + +endif diff --git a/tensorflow/lite/micro/tools/make/helper_functions.inc b/tensorflow/lite/micro/tools/make/helper_functions.inc index aae00bd119d..184a0293ad7 100644 --- a/tensorflow/lite/micro/tools/make/helper_functions.inc +++ b/tensorflow/lite/micro/tools/make/helper_functions.inc @@ -65,6 +65,8 @@ specialize = $(call specialize_on_tags,$(1),$(strip $(call reverse,$(ALL_TAGS))) # 6 - Linker flags required. # 7 - C++ compilation flags needed. # 8 - C compilation flags needed. +# 9 - Target Toolchian root directory +# 10 - Target Toolchain prefix # Calling eval on the output will create a _makefile target that you # can invoke to create the standalone project. define generate_project @@ -82,7 +84,9 @@ $(PRJDIR)$(3)/$(1)/%: tensorflow/lite/micro/tools/make/templates/%.tpl sed -E 's#\%\{EXECUTABLE\}\%#$(3)#g' | \ sed -E 's#\%\{LINKER_FLAGS\}\%#$(6)#g' | \ sed -E 's#\%\{CXX_FLAGS\}\%#$(7)#g' | \ - sed -E 's#\%\{CC_FLAGS\}\%#$(8)#g' > $$@ + sed -E 's#\%\{CC_FLAGS\}\%#$(8)#g' | \ + sed -E 's#\%\{TARGET_TOOLCHAIN_ROOT\}\%#$(9)#g' | \ + sed -E 's#\%\{TARGET_TOOLCHAIN_PREFIX\}\%#$(10)#g' > $$@ $(PRJDIR)$(3)/$(1)/keil_project.uvprojx: tensorflow/lite/micro/tools/make/templates/keil_project.uvprojx.tpl @mkdir -p $$(dir $$@) @@ -120,6 +124,7 @@ endef # 6 - Linker flags required. # 7 - C++ compilation flags needed. # 8 - C compilation flags needed. + # Calling eval on the output will create a _makefile target that you # can invoke to create the standalone project. define generate_arc_project @@ -134,6 +139,7 @@ $(PRJDIR)$(3)/$(1)/Makefile: tensorflow/lite/micro/tools/make/templates/Makefile sed -E 's#\%\{CXX_FLAGS\}\%#$(7)#g' | \ sed -E 's#\%\{CC_FLAGS\}\%#$(8)#g' > $$@ + # Special rule to copy TCF in case the local filesystem file name has been defined ifneq ($(TCF_FILE_NAME), ) $(PRJDIR)$(3)/$(1)/$(TCF_FILE_NAME): $(TCF_FILE) @@ -366,10 +372,10 @@ endef # Calling eval on the output will create targets that you can invoke to # generate the standalone project. define generate_microlite_projects -$(call generate_project,make,$(MAKE_PROJECT_FILES),$(1),$(MICROLITE_CC_SRCS) $(THIRD_PARTY_CC_SRCS) $(2),$(MICROLITE_CC_HDRS) $(THIRD_PARTY_CC_HDRS) $(MICROLITE_TEST_HDRS) $(3),$(LDFLAGS) $(MICROLITE_LIBS),$(CXXFLAGS) $(GENERATED_PROJECT_INCLUDES), $(CCFLAGS) $(GENERATED_PROJECT_INCLUDES)) +$(call generate_project,make,$(MAKE_PROJECT_FILES),$(1),$(MICROLITE_CC_SRCS) $(THIRD_PARTY_CC_SRCS) $(2),$(MICROLITE_CC_HDRS) $(THIRD_PARTY_CC_HDRS) $(MICROLITE_TEST_HDRS) $(3),$(LDFLAGS) $(MICROLITE_LIBS),$(CXXFLAGS) $(GENERATED_PROJECT_INCLUDES), $(CCFLAGS) $(GENERATED_PROJECT_INCLUDES),$(TARGET_TOOLCHAIN_ROOT),$(TARGET_TOOLCHAIN_PREFIX)) $(call generate_arc_project,make,$(MAKE_PROJECT_FILES),$(1),$(MICROLITE_CC_SRCS) $(THIRD_PARTY_CC_SRCS) $(2),$(MICROLITE_CC_HDRS) $(THIRD_PARTY_CC_HDRS) $(MICROLITE_TEST_HDRS) $(3),$(LDFLAGS) $(GENERATED_PROJECT_LIBS),$(CXXFLAGS) $(GENERATED_PROJECT_INCLUDES), $(CCFLAGS) $(GENERATED_PROJECT_INCLUDES)) -$(call generate_project,mbed,$(MBED_PROJECT_FILES),$(1),$(MICROLITE_CC_SRCS) $(THIRD_PARTY_CC_SRCS) $(2),$(MICROLITE_CC_HDRS) $(THIRD_PARTY_CC_HDRS) $(MICROLITE_TEST_HDRS) $(3),$(MICROLITE_LIBS),$(CXXFLAGS),$(CCFLAGS)) -$(call generate_project,keil,$(KEIL_PROJECT_FILES),$(1),$(MICROLITE_CC_SRCS) $(THIRD_PARTY_CC_SRCS) $(2),$(MICROLITE_CC_HDRS) $(THIRD_PARTY_CC_HDRS) $(MICROLITE_TEST_HDRS) $(3),$(MICROLITE_LIBS),$(CXXFLAGS),$(CCFLAGS)) +$(call generate_project,mbed,$(MBED_PROJECT_FILES),$(1),$(MICROLITE_CC_SRCS) $(THIRD_PARTY_CC_SRCS) $(2),$(MICROLITE_CC_HDRS) $(THIRD_PARTY_CC_HDRS) $(MICROLITE_TEST_HDRS) $(3),$(MICROLITE_LIBS),$(CXXFLAGS),$(CCFLAGS),$(TARGET_TOOLCHAIN_ROOT),$(TARGET_TOOLCHAIN_PREFIX)) +$(call generate_project,keil,$(KEIL_PROJECT_FILES),$(1),$(MICROLITE_CC_SRCS) $(THIRD_PARTY_CC_SRCS) $(2),$(MICROLITE_CC_HDRS) $(THIRD_PARTY_CC_HDRS) $(MICROLITE_TEST_HDRS) $(3),$(MICROLITE_LIBS),$(CXXFLAGS),$(CCFLAGS),$(TARGET_TOOLCHAIN_ROOT),$(TARGET_TOOLCHAIN_PREFIX)) $(call generate_arduino_project,$(ARDUINO_PROJECT_FILES),$(1),$(MICROLITE_CC_SRCS) $(THIRD_PARTY_CC_SRCS) $(2),$(MICROLITE_CC_HDRS) $(THIRD_PARTY_CC_HDRS) $(MICROLITE_TEST_HDRS) $(3),$(MICROLITE_LIBS),$(CXXFLAGS),$(CCFLAGS)) $(call generate_esp_project,$(ESP_PROJECT_FILES),$(1),$(MICROLITE_CC_SRCS) $(THIRD_PARTY_CC_SRCS),$(MICROLITE_CC_HDRS) $(THIRD_PARTY_CC_HDRS) $(MICROLITE_TEST_HDRS),$(2),$(3),$(MICROLITE_LIBS),$(CXXFLAGS),$(CCFLAGS),$(PROJECT_INCLUDES)) endef diff --git a/tensorflow/lite/micro/tools/make/targets/apollo3evb_makefile.inc b/tensorflow/lite/micro/tools/make/targets/apollo3evb_makefile.inc index 9494158cd50..fa20ad99125 100644 --- a/tensorflow/lite/micro/tools/make/targets/apollo3evb_makefile.inc +++ b/tensorflow/lite/micro/tools/make/targets/apollo3evb_makefile.inc @@ -61,8 +61,10 @@ $(MAKEFILE_DIR)/downloads/$(AM_SDK_DEST)/$(SF_BSPS_DEST): $(MAKEFILE_DIR)/downlo -Wno-unused-parameter \ -Wno-write-strings \ -fno-delete-null-pointer-checks \ + -fno-threadsafe-statics \ -fomit-frame-pointer \ -fpermissive \ + -fno-use-cxa-atexit \ -nostdlib \ -ggdb \ -O3 diff --git a/tensorflow/lite/micro/tools/make/targets/arc_makefile.inc b/tensorflow/lite/micro/tools/make/targets/arc_makefile.inc index 3c397a0ab80..0f56e5f4641 100644 --- a/tensorflow/lite/micro/tools/make/targets/arc_makefile.inc +++ b/tensorflow/lite/micro/tools/make/targets/arc_makefile.inc @@ -9,7 +9,7 @@ ifneq ($(TCF_FILE), ) TARGET = $(basename $(notdir $(TCF_FILE))) else TARGET = em7d_voice_audio - TCF_FILE = em7d_voice_audio + TCF_FILE = em7d_voice_audio endif # The variable TCF_FILE_NAME stores the TCF file name (including .tcf extension), this variable is used later to add the option to the linker/compiler flags. diff --git a/tensorflow/lite/micro/tools/make/targets/bluepill_makefile.inc b/tensorflow/lite/micro/tools/make/targets/bluepill_makefile.inc index 29a49288081..ac066408d9a 100644 --- a/tensorflow/lite/micro/tools/make/targets/bluepill_makefile.inc +++ b/tensorflow/lite/micro/tools/make/targets/bluepill_makefile.inc @@ -41,6 +41,7 @@ ifeq ($(TARGET), bluepill) -fno-threadsafe-statics \ -fomit-frame-pointer \ -fpermissive \ + -fno-use-cxa-atexit \ -nostdlib \ -g \ -Os diff --git a/tensorflow/lite/micro/tools/make/targets/ecm3531_makefile.inc b/tensorflow/lite/micro/tools/make/targets/ecm3531_makefile.inc index 8b24f5beb92..bb6a9f3daf5 100644 --- a/tensorflow/lite/micro/tools/make/targets/ecm3531_makefile.inc +++ b/tensorflow/lite/micro/tools/make/targets/ecm3531_makefile.inc @@ -48,8 +48,10 @@ ifeq ($(TARGET), ecm3531) -Wno-unused-parameter \ -Wno-write-strings \ -fno-delete-null-pointer-checks \ + -fno-threadsafe-statics \ -fomit-frame-pointer \ -fpermissive \ + -fno-use-cxa-atexit \ -nostdlib \ -g \ -Os diff --git a/tensorflow/lite/micro/tools/make/targets/mcu_riscv_makefile.inc b/tensorflow/lite/micro/tools/make/targets/mcu_riscv_makefile.inc index 7336c520b11..ddd06718bed 100644 --- a/tensorflow/lite/micro/tools/make/targets/mcu_riscv_makefile.inc +++ b/tensorflow/lite/micro/tools/make/targets/mcu_riscv_makefile.inc @@ -14,7 +14,7 @@ ifeq ($(TARGET), riscv32_mcu) -fno-builtin-printf \ -fno-exceptions \ -DTF_LITE_MCU_DEBUG_LOG \ - -DTF_LITE_USE_GLOBAL_ROUND \ + -DTF_LITE_USE_GLOBAL_CMATH_FUNCTIONS \ -fno-unwind-tables \ -fno-builtin \ -ffunction-sections \ @@ -31,7 +31,9 @@ ifeq ($(TARGET), riscv32_mcu) -Wno-unused-parameter \ -Wno-write-strings \ -fno-delete-null-pointer-checks \ + -fno-threadsafe-statics \ -fomit-frame-pointer \ + -fno-use-cxa-atexit \ -Os CXXFLAGS += $(PLATFORM_FLAGS) \ @@ -79,4 +81,4 @@ ifeq ($(TARGET), riscv32_mcu) $(BINDIR)/%.bin: $(BINDIR)/% @mkdir -p $(dir $@) $(OBJCOPY) $< $@ -O binary -endif \ No newline at end of file +endif diff --git a/tensorflow/lite/micro/tools/make/targets/stm32f4_makefile.inc b/tensorflow/lite/micro/tools/make/targets/stm32f4_makefile.inc index 539f4528d06..2bad89e423e 100644 --- a/tensorflow/lite/micro/tools/make/targets/stm32f4_makefile.inc +++ b/tensorflow/lite/micro/tools/make/targets/stm32f4_makefile.inc @@ -34,6 +34,7 @@ ifeq ($(TARGET), stm32f4) -fno-delete-null-pointer-checks \ -fomit-frame-pointer \ -fpermissive \ + -fno-use-cxa-atexit \ -g \ -Os CXXFLAGS += $(PLATFORM_FLAGS) @@ -76,6 +77,8 @@ ifeq ($(TARGET), stm32f4) tensorflow/lite/micro/kernels/dequantize_test.cc \ tensorflow/lite/micro/kernels/unpack_test.cc \ tensorflow/lite/micro/kernels/split_test.cc \ + tensorflow/lite/micro/kernels/conv_test.cc \ + tensorflow/lite/micro/kernels/depthwise_conv_test.cc \ tensorflow/lite/micro/simple_tensor_allocator_test.cc MICROLITE_TEST_SRCS := $(filter-out $(EXCLUDED_TESTS), $(MICROLITE_TEST_SRCS)) EXCLUDED_EXAMPLE_TESTS := \ diff --git a/tensorflow/lite/micro/tools/make/targets/xtensa_hifi/README.md b/tensorflow/lite/micro/tools/make/targets/xtensa_hifi/README.md new file mode 100644 index 00000000000..fd606a7f96b --- /dev/null +++ b/tensorflow/lite/micro/tools/make/targets/xtensa_hifi/README.md @@ -0,0 +1,35 @@ +# Building TensorFlow Lite for Microcontrollers for Cadence Tensilica HiFi DSPs + +This document describes the steps to build and run the Tensorflow Lite Micro on +the Cadence HiFi DSPs. + +## Pre-requisites + +The Xtensa development tools and the target processor configurations should be +installed on the system. Please check [https://tensilicatools.com] for more +information about downloading and installing the required tools. + +The PATH variable should be set to include the /bin +directory. The XTENSA_SYSTEM and XTENSA_CORE environment variables should be set +to the required tools version and the required processor configuration. + +## Building for HiFi Processors + +To build the code using Xtensa tools for the processor configuration selected by +XTENSA_CORE , set TARGET=xtensa_hifi. Additionally TARGET_ARCH can be used to +select optimized HiFi NN kernels specific to the processor configuration. +Currently the HiFi4 NN kernels are provided which can be enabled as follows: + +make -f tensorflow/lite/micro/tools/make/Makefile test_micro_speech_test +TARGET=xtensa_hifi TARGET_ARCH=hifi4 + +Xtensa specific TF Lite Micro kernels are implemented in this folder: +tensorflow/lite/micro/kernels/xtensa_hifi/ + +A scratch memory allocation is needed for the HiFi optimized kernels. This +allocation is currently done on stack and it's size can be controlled by +defining 'XTENSA_NNLIB_MAX_SCRATCH_SIZE' approproately in the file +'tensorflow/lite/micro/tools/make/ext_libs/xtensa_hifi_nn_library.inc + +The files containing the HiFi optimized NN kernels are present in this folder: +tensorflow/lite/micro/kernels/xtensa_hifi/xa_nnlib/ diff --git a/tensorflow/lite/micro/tools/make/targets/xtensa_hifi_makefile.inc b/tensorflow/lite/micro/tools/make/targets/xtensa_hifi_makefile.inc new file mode 100644 index 00000000000..aa7d8cfb1c3 --- /dev/null +++ b/tensorflow/lite/micro/tools/make/targets/xtensa_hifi_makefile.inc @@ -0,0 +1,44 @@ +# Settings for Xtensa toolchain. +# Derived from xtensa_xpg_makefile.inc +# The Xtensa environment variables should be configured externally (XTENSA_CORE, XTENSA_SYSTEM) + +ifeq ($(TARGET), xtensa_hifi) + TARGET_ARCH := hifi3_bd5 + +$(eval $(call add_third_party_download,$(XTENSA_HIFI4_URL),$(XTENSA_HIFI4_MD5),xa_nnlib,)) + + PLATFORM_ARGS = \ + -mno-mul16 \ + -mno-mul32 \ + -mno-div32 \ + -fsigned-char \ + -fno-exceptions \ + -mlongcalls \ + -INLINE:requested \ + -mcoproc \ + -fno-zero-initialized-in-bss \ + -mtext-section-literals \ + -fno-unsafe-math-optimizations \ + + TF_LITE_MICRO_FLAGS = \ + -DTF_LITE_STATIC_MEMORY\ + + TARGET_TOOLCHAIN_PREFIX := xt- + CXX_TOOL := clang++ + CC_TOOL := clang + + CXXFLAGS = -O0 $(PLATFORM_ARGS) -std=c++11 $(TF_LITE_MICRO_FLAGS) + #TODO: Use -std=c11 ? + CCFLAGS = -O3 $(PLATFORM_ARGS) $(TF_LITE_MICRO_FLAGS) + + TEST_SCRIPT := tensorflow/lite/micro/testing/test_xtensa_hifi_binary.sh + + # These are microcontroller-specific rules for converting the ELF output + # of the linker into a binary image that can be loaded directly. + OBJCOPY := $(TARGET_TOOLCHAIN_PREFIX)objcopy + + $(BINDIR)/%.bin: $(BINDIR)/% + echo "here" + @mkdir -p $(dir $@) + $(OBJCOPY) $< $@ -O binary +endif diff --git a/tensorflow/lite/micro/tools/make/targets/xtensa_xpg_makefile.inc b/tensorflow/lite/micro/tools/make/targets/xtensa_xpg_makefile.inc index 5836aea417d..5ed601f8dd1 100644 --- a/tensorflow/lite/micro/tools/make/targets/xtensa_xpg_makefile.inc +++ b/tensorflow/lite/micro/tools/make/targets/xtensa_xpg_makefile.inc @@ -16,7 +16,9 @@ ifeq ($(TARGET), xtensa-xpg) -ffunction-sections \ -fno-exceptions \ -fno-unwind-tables \ - -fmessage-length=0 + -fno-use-cxa-atexit \ + -fmessage-length=0 \ + -fno-threadsafe-statics TARGET_TOOLCHAIN_PREFIX := xt- CXX_TOOL := clang++ diff --git a/tensorflow/lite/micro/tools/make/templates/Makefile.tpl b/tensorflow/lite/micro/tools/make/templates/Makefile.tpl index f72658f4aa0..6078b927fa0 100644 --- a/tensorflow/lite/micro/tools/make/templates/Makefile.tpl +++ b/tensorflow/lite/micro/tools/make/templates/Makefile.tpl @@ -1,16 +1,37 @@ -RM = rm -f +TARGET_TOOLCHAIN_ROOT := %{TARGET_TOOLCHAIN_ROOT}% +TARGET_TOOLCHAIN_PREFIX := %{TARGET_TOOLCHAIN_PREFIX}% +# These are microcontroller-specific rules for converting the ELF output +# of the linker into a binary image that can be loaded directly. +CXX := '$(TARGET_TOOLCHAIN_ROOT)$(TARGET_TOOLCHAIN_PREFIX)g++' +CC := '$(TARGET_TOOLCHAIN_ROOT)$(TARGET_TOOLCHAIN_PREFIX)gcc' +AS := '$(TARGET_TOOLCHAIN_ROOT)$(TARGET_TOOLCHAIN_PREFIX)as' +AR := '$(TARGET_TOOLCHAIN_ROOT)$(TARGET_TOOLCHAIN_PREFIX)ar' +LD := '$(TARGET_TOOLCHAIN_ROOT)$(TARGET_TOOLCHAIN_PREFIX)ld' +NM := '$(TARGET_TOOLCHAIN_ROOT)$(TARGET_TOOLCHAIN_PREFIX)nm' +OBJDUMP := '$(TARGET_TOOLCHAIN_ROOT)$(TARGET_TOOLCHAIN_PREFIX)objdump' +OBJCOPY := '$(TARGET_TOOLCHAIN_ROOT)$(TARGET_TOOLCHAIN_PREFIX)objcopy' +SIZE := '$(TARGET_TOOLCHAIN_ROOT)$(TARGET_TOOLCHAIN_PREFIX)size' + +RM = rm -f +ARFLAGS := -csr SRCS := \ %{SRCS}% OBJS := \ $(patsubst %.cc,%.o,$(patsubst %.c,%.o,$(SRCS))) +LIBRARY_OBJS := $(filter-out tensorflow/lite/micro/examples/%, $(OBJS)) + CXXFLAGS += %{CXX_FLAGS}% CCFLAGS += %{CC_FLAGS}% LDFLAGS += %{LINKER_FLAGS}% + +# library to be generated +MICROLITE_LIB = libtensorflow-microlite.a + %.o: %.cc $(CXX) $(CXXFLAGS) $(INCLUDES) -c $< -o $@ @@ -20,8 +41,17 @@ LDFLAGS += %{LINKER_FLAGS}% %{EXECUTABLE}% : $(OBJS) $(CXX) $(CXXFLAGS) -o $@ $(OBJS) $(LDFLAGS) + +# Creates a tensorflow-litemicro.a which excludes any example code. +$(MICROLITE_LIB): tensorflow/lite/schema/schema_generated.h $(LIBRARY_OBJS) + @mkdir -p $(dir $@) + $(AR) $(ARFLAGS) $(MICROLITE_LIB) $(LIBRARY_OBJS) + all: %{EXECUTABLE}% +lib: $(MICROLITE_LIB) + clean: -$(RM) $(OBJS) -$(RM) %{EXECUTABLE}% + -$(RM) ${MICROLITE_LIB} diff --git a/tensorflow/lite/micro/tools/make/third_party_downloads.inc b/tensorflow/lite/micro/tools/make/third_party_downloads.inc index ca544d1371e..c4ff652a0ff 100644 --- a/tensorflow/lite/micro/tools/make/third_party_downloads.inc +++ b/tensorflow/lite/micro/tools/make/third_party_downloads.inc @@ -59,3 +59,7 @@ EMBARC_OSP_MD5 := "9eaf7b3a1ed05872a03da9796672a776" EMBARC_MLI_URL := "https://github.com/foss-for-synopsys-dwc-arc-processors/embarc_mli/archive/6316034d421cbbb59756239908d7c9a99075a3bb.zip" EMBARC_MLI_MD5 := "db0910cf0e07e43f74ae7a31de485d56" + +XTENSA_HIFI4_URL :="https://github.com/foss-xtensa/nnlib-hifi4/raw/master/archive/xa_nnlib.zip" +XTENSA_HIFI4_MD5 :="a517b653a75b96d0271e1b99ee2a8c14" + diff --git a/tensorflow/lite/model.h b/tensorflow/lite/model.h index 5819142ee25..84dc00f145b 100644 --- a/tensorflow/lite/model.h +++ b/tensorflow/lite/model.h @@ -31,6 +31,7 @@ namespace tflite { #if TFLITE_EXPERIMENTAL_RUNTIME_EAGER using InterpreterBuilder = tflrt::EagerTfLiteInterpreterBuilderAPI; +using Interpreter = tflrt::EagerInterpreter; #else using InterpreterBuilder = impl::InterpreterBuilder; #endif diff --git a/tensorflow/lite/model_xnnpack_test.cc b/tensorflow/lite/model_xnnpack_test.cc new file mode 100644 index 00000000000..9c06147f602 --- /dev/null +++ b/tensorflow/lite/model_xnnpack_test.cc @@ -0,0 +1,59 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/model.h" + +#include + +#include +#include "tensorflow/lite/core/macros.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/testing/util.h" +#include "tensorflow/lite/util.h" + +namespace tflite { + +TEST(FloatModel, WithXnnpackDelegate) { + // Note: this graph will be fully delegated by the XNNPACK delegate. + auto model = FlatBufferModel::BuildFromFile( + "tensorflow/lite/testdata/multi_add.bin"); + ASSERT_TRUE(model); + + std::unique_ptr interpreter; + ASSERT_EQ(InterpreterBuilder(*model, + ops::builtin::BuiltinOpResolver{})(&interpreter), + kTfLiteOk); + ASSERT_TRUE(interpreter); + + ASSERT_EQ(interpreter->AllocateTensors(), kTfLiteOk); + +#if TFLITE_HAS_ATTRIBUTE_WEAK + // As the graph is fully delegated by XNNPACK delegate, we will expect the + // following: + EXPECT_EQ(1, interpreter->execution_plan().size()); + int first_node_id = interpreter->execution_plan()[0]; + const auto& first_node_reg = + interpreter->node_and_registration(first_node_id)->second; + const std::string op_name = GetOpNameByRegistration(first_node_reg); + EXPECT_EQ("DELEGATE TfLiteXNNPackDelegate", op_name); +#endif +} + +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/lite/nnapi/BUILD b/tensorflow/lite/nnapi/BUILD index 4cd5ab3922f..82d775dd94b 100644 --- a/tensorflow/lite/nnapi/BUILD +++ b/tensorflow/lite/nnapi/BUILD @@ -1,3 +1,5 @@ +load("//tensorflow/lite:special_rules.bzl", "if_nnapi") + package( default_visibility = [ "//visibility:public", @@ -11,51 +13,22 @@ cc_library( "NeuralNetworksShim.h", "NeuralNetworksTypes.h", ], - linkopts = select({ - "//tensorflow:emscripten": [], - "//tensorflow:ios": [], - "//tensorflow:macos": [], - "//tensorflow:windows": [], - "//conditions:default": ["-ldl"], - }), + linkopts = if_nnapi(["-ldl"]), ) cc_library( name = "nnapi_implementation", - srcs = select({ - "//tensorflow:emscripten": [ - "nnapi_implementation_disabled.cc", - ], - "//tensorflow:ios": [ - "nnapi_implementation_disabled.cc", - ], - "//tensorflow:macos": [ - "nnapi_implementation_disabled.cc", - ], - "//tensorflow:windows": [ - "nnapi_implementation_disabled.cc", - ], - "//conditions:default": [ - "nnapi_implementation.cc", - ], - }), + srcs = if_nnapi( + not_supported = ["nnapi_implementation_disabled.cc"], + supported = ["nnapi_implementation.cc"], + ), hdrs = [ "nnapi_implementation.h", ], - linkopts = select({ - "//tensorflow:emscripten": [], - "//tensorflow:ios": [], - "//tensorflow:macos": [], - "//tensorflow:windows": [], - "//conditions:default": ["-ldl"], - }) + select({ - "//tensorflow:android": [], - "//tensorflow:emscripten": [], - "//tensorflow:ios": [], - "//tensorflow:macos": [], - "//tensorflow:windows": [], - "//conditions:default": ["-lrt"], - }), + linkopts = if_nnapi(["-ldl"]) + if_nnapi( + supported = ["-lrt"], + supported_android = [], + ), deps = [ ":nnapi_lib", ], @@ -84,16 +57,8 @@ cc_test( # Cannot inject NNAPI instance on ios and windows cc_library( name = "nnapi_handler", - srcs = select({ - "//tensorflow:ios": [], - "//tensorflow:windows": [], - "//conditions:default": ["nnapi_handler.cc"], - }), - hdrs = select({ - "//tensorflow:ios": [], - "//tensorflow:windows": [], - "//conditions:default": ["nnapi_handler.h"], - }), + srcs = if_nnapi(["nnapi_handler.cc"]), + hdrs = if_nnapi(["nnapi_handler.h"]), deps = [ ":nnapi_implementation", ":nnapi_lib", diff --git a/tensorflow/lite/python/BUILD b/tensorflow/lite/python/BUILD index 7248792523e..8333fa418c2 100644 --- a/tensorflow/lite/python/BUILD +++ b/tensorflow/lite/python/BUILD @@ -143,6 +143,7 @@ py_test( shard_count = 4, srcs_version = "PY2AND3", tags = [ + "no_rocm", "no_windows", ], deps = [ @@ -159,6 +160,7 @@ py_test( python_version = "PY3", srcs_version = "PY2AND3", tags = [ + "no_rocm", "no_windows", ], deps = [ diff --git a/tensorflow/lite/python/interpreter.py b/tensorflow/lite/python/interpreter.py index e3d7d04be14..39f303b3a68 100644 --- a/tensorflow/lite/python/interpreter.py +++ b/tensorflow/lite/python/interpreter.py @@ -43,7 +43,7 @@ if not __file__.endswith('tflite_runtime/interpreter.py'): del LazyLoader else: # This file is part of tflite_runtime package. - from tflite_runtime import interpreter_wrapper as _interpreter_wrapper + from tflite_runtime import _pywrap_tensorflow_interpreter_wrapper as _interpreter_wrapper def _tf_export(*x, **kwargs): del x, kwargs @@ -77,6 +77,7 @@ class Delegate(object): keys and values in the dictionary should be serializable. Consult the documentation of the specific delegate for required and legal options. (default None) + Raises: RuntimeError: This is raised if the Python implementation is not CPython. """ @@ -191,7 +192,7 @@ class Interpreter(object): model_content: Content of model. experimental_delegates: Experimental. Subject to change. List of [TfLiteDelegate](https://www.tensorflow.org/lite/performance/delegates) - objects returned by lite.load_delegate(). + objects returned by lite.load_delegate(). Raises: ValueError: If the interpreter was unable to create. @@ -388,14 +389,16 @@ class Interpreter(object): ] def set_tensor(self, tensor_index, value): - """Sets the value of the input tensor. Note this copies data in `value`. + """Sets the value of the input tensor. + + Note this copies data in `value`. If you want to avoid copying, you can use the `tensor()` function to get a numpy buffer pointing to the input buffer in the tflite interpreter. Args: tensor_index: Tensor index of tensor to set. This value can be gotten from - the 'index' field in get_input_details. + the 'index' field in get_input_details. value: Value of tensor to set. Raises: @@ -408,7 +411,7 @@ class Interpreter(object): Args: input_index: Tensor index of input to set. This value can be gotten from - the 'index' field in get_input_details. + the 'index' field in get_input_details. tensor_size: The tensor_shape to resize the input to. Raises: @@ -438,7 +441,7 @@ class Interpreter(object): Args: tensor_index: Tensor index of tensor to get. This value can be gotten from - the 'index' field in get_output_details. + the 'index' field in get_output_details. Returns: a numpy array. @@ -486,7 +489,7 @@ class Interpreter(object): Args: tensor_index: Tensor index of tensor to get. This value can be gotten from - the 'index' field in get_output_details. + the 'index' field in get_output_details. Returns: A function that can return a new numpy array pointing to the internal diff --git a/tensorflow/lite/python/lite.py b/tensorflow/lite/python/lite.py index 151aecf02cb..ba9e6e0bd39 100644 --- a/tensorflow/lite/python/lite.py +++ b/tensorflow/lite/python/lite.py @@ -280,8 +280,8 @@ class TFLiteConverterBase(object): self.optimizations = [] self.representative_dataset = None self.experimental_new_converter = _USE_EXPERIMENTAL_NEW_CONVERTER - self.experimental_new_quantizer = False - self.experimental_calibrate_only = False + self._experimental_new_quantizer = False + self._experimental_calibrate_only = False # The 'GraphDebugInfo' contains the stack traces of all the original nodes # in the `GraphDef` to the converter. self._debug_info = None @@ -314,7 +314,7 @@ class TFLiteConverterBase(object): self.representative_dataset = RepresentativeDataset( self.representative_dataset) calibrate_quantize = _calibrator.Calibrator(result) - if self.experimental_calibrate_only: + if self._experimental_calibrate_only: return calibrate_quantize.calibrate(self.representative_dataset.input_gen) else: return calibrate_quantize.calibrate_and_quantize( @@ -370,11 +370,6 @@ class TFLiteConverterV2(TFLiteConverterBase): target ops. experimental_new_converter: Experimental flag, subject to change. Enables MLIR-based conversion instead of TOCO conversion. - experimental_new_quantizer: Experimental flag, subject to change. - Enables MLIR-based post-training quantization. - experimental_calibrate_only: Experimental flag, subject to change. - Calibrates the converted model with representative dataset, but not - quantize it. Example usage: ```python @@ -698,11 +693,6 @@ class TFLiteConverter(TFLiteConverterBase): the dataset to evaluate different optimizations. experimental_new_converter: Experimental flag, subject to change. Enables MLIR-based conversion instead of TOCO conversion. - experimental_new_quantizer: Experimental flag, subject to change. - Enables MLIR-based post-training quantization. - experimental_calibrate_only: Experimental flag, subject to change. - Calibrates the converted model with representative dataset, but not - quantize it. Example usage: ```python diff --git a/tensorflow/lite/python/lite_v2_test.py b/tensorflow/lite/python/lite_v2_test.py index e0595893531..508c4fc2053 100644 --- a/tensorflow/lite/python/lite_v2_test.py +++ b/tensorflow/lite/python/lite_v2_test.py @@ -263,11 +263,11 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest): quantized_converter.representative_dataset = calibration_gen # default quantizer - quantized_converter.experimental_new_quantizer = False + quantized_converter._experimental_new_quantizer = False old_tflite = quantized_converter.convert() # new quantizer - quantized_converter.experimental_new_quantizer = True + quantized_converter._experimental_new_quantizer = True new_tflite = quantized_converter.convert() for _ in range(5): diff --git a/tensorflow/lite/python/optimize/calibration_wrapper.cc b/tensorflow/lite/python/optimize/calibration_wrapper.cc index 9a5d1e9aa2f..cadd5538f5b 100644 --- a/tensorflow/lite/python/optimize/calibration_wrapper.cc +++ b/tensorflow/lite/python/optimize/calibration_wrapper.cc @@ -114,6 +114,44 @@ PyObject* CalibrationWrapper::Prepare() { Py_RETURN_NONE; } +PyObject* CalibrationWrapper::Prepare(PyObject* input_shapes) { + TFLITE_PY_ENSURE_VALID_INTERPRETER(); + if (!PyList_Check(input_shapes)) { + PyErr_Format(PyExc_ValueError, + "Invalid input shapes: expected shapes to be a list."); + return nullptr; + } + + const size_t inputs_size = PyList_Size(input_shapes); + if (inputs_size != interpreter_->inputs().size()) { + PyErr_Format(PyExc_ValueError, + "Invalid input shapes: expected %ld items got %ld items.", + interpreter_->inputs().size(), inputs_size); + return nullptr; + } + + for (size_t i = 0; i < inputs_size; i++) { + PyObject* shape = PyList_GetItem(input_shapes, i); + if (!shape || !PyList_Check(shape)) { + PyErr_Format(PyExc_ValueError, + "Invalid %ld input shape: expected to be a list.", i); + return nullptr; + } + std::vector dims; + for (size_t dim_index = 0; dim_index < PyList_Size(shape); ++dim_index) { + PyObject* dim = PyList_GetItem(shape, dim_index); + dims.push_back(PyLong_AsLong(dim)); + } + int input_tensor_idx = interpreter_->inputs()[i]; + if (interpreter_->ResizeInputTensor(input_tensor_idx, dims) != kTfLiteOk) { + PyErr_Format(PyExc_ValueError, "Failed to resize %ld input tensor.", i); + return nullptr; + } + } + + return Prepare(); +} + PyObject* CalibrationWrapper::FeedTensor(PyObject* input_value) { TFLITE_PY_ENSURE_VALID_INTERPRETER(); if (!PyList_Check(input_value)) { diff --git a/tensorflow/lite/python/optimize/calibration_wrapper.h b/tensorflow/lite/python/optimize/calibration_wrapper.h index 449f8ee6b83..fc8d6c1c890 100644 --- a/tensorflow/lite/python/optimize/calibration_wrapper.h +++ b/tensorflow/lite/python/optimize/calibration_wrapper.h @@ -57,6 +57,7 @@ class CalibrationWrapper { ~CalibrationWrapper(); PyObject* Prepare(); + PyObject* Prepare(PyObject* input_shapes); PyObject* FeedTensor(PyObject* input_value); diff --git a/tensorflow/lite/python/optimize/calibration_wrapper_pybind11.cc b/tensorflow/lite/python/optimize/calibration_wrapper_pybind11.cc index dcecd880a5e..3d75eccd505 100644 --- a/tensorflow/lite/python/optimize/calibration_wrapper_pybind11.cc +++ b/tensorflow/lite/python/optimize/calibration_wrapper_pybind11.cc @@ -29,6 +29,10 @@ PYBIND11_MODULE(_pywrap_tensorflow_lite_calibration_wrapper, m) { .def(py::init([](py::handle& data) { return ::CalibrationWrapper::CreateWrapperCPPFromBuffer(data.ptr()); })) + .def("Prepare", + [](CalibrationWrapper& self, py::handle& input_shapes) { + return tensorflow::pyo_or_throw(self.Prepare(input_shapes.ptr())); + }) .def("Prepare", [](CalibrationWrapper& self) { return tensorflow::pyo_or_throw(self.Prepare()); @@ -50,10 +54,14 @@ PYBIND11_MODULE(_pywrap_tensorflow_lite_calibration_wrapper, m) { return tensorflow::pyo_or_throw(self.QuantizeModel( input_py_type, output_py_type, allow_float)); }) - .def("QuantizeModel", [](CalibrationWrapper& self, int input_py_type, - int output_py_type, bool allow_float, - const char* operator_output_name) { - return tensorflow::pyo_or_throw(self.QuantizeModel( - input_py_type, output_py_type, allow_float, operator_output_name)); + .def("QuantizeModel", + [](CalibrationWrapper& self, int input_py_type, int output_py_type, + bool allow_float, const char* operator_output_name) { + return tensorflow::pyo_or_throw( + self.QuantizeModel(input_py_type, output_py_type, allow_float, + operator_output_name)); + }) + .def("Calibrate", [](CalibrationWrapper& self) { + return tensorflow::pyo_or_throw(self.Calibrate()); }); } diff --git a/tensorflow/lite/python/optimize/calibrator.py b/tensorflow/lite/python/optimize/calibrator.py index fb3b87fdaa7..8f5fab64ffc 100644 --- a/tensorflow/lite/python/optimize/calibrator.py +++ b/tensorflow/lite/python/optimize/calibrator.py @@ -54,10 +54,17 @@ class Calibrator(object): if not self._calibrator: raise ValueError("Failed to parse the model.") - def calibrate_and_quantize(self, dataset_gen, input_type, output_type, - allow_float): + def calibrate_and_quantize(self, + dataset_gen, + input_type, + output_type, + allow_float, + resize_input=True): """Calibrates the model with specified generator and then quantizes it. + The input shapes of the calibrator are resized with the calibration data if + `resize_input` is set. + Returns: A quantized model. @@ -66,22 +73,36 @@ class Calibrator(object): input_type: A tf.dtype representing the desired real-value input type. output_type: A tf.dtype representing the desired real-value output type. allow_float: A boolean. False if the resulting model cannot perform float - computation, useful when targeting an integer-only backend. - If False, an error will be thrown if an operation cannot be - quantized, otherwise the model will fallback to float ops. + computation, useful when targeting an integer-only backend. If False, an + error will be thrown if an operation cannot be quantized, otherwise the + model will fallback to float ops. + resize_input: A boolean. True if the shape of the sample data is different + from the input. """ - self._calibrator.Prepare() - for calibration_sample in dataset_gen(): - self._calibrator.FeedTensor(calibration_sample) + initialized = False + for sample in dataset_gen(): + if not initialized: + initialized = True + if resize_input: + self._calibrator.Prepare([list(s.shape) for s in sample]) + else: + self._calibrator.Prepare() + self._calibrator.FeedTensor(sample) return self._calibrator.QuantizeModel( np.dtype(input_type.as_numpy_dtype()).num, np.dtype(output_type.as_numpy_dtype()).num, allow_float) - def calibrate_and_quantize_single(self, dataset_gen, input_type, output_type, - allow_float, op_output_name): + def calibrate_and_quantize_single(self, + dataset_gen, + input_type, + output_type, + allow_float, + op_output_name, + resize_input=True): """Calibrates the model with specified generator and then quantizes it. Only the single op with output op_output_name will be quantized. + The input shapes of the calibrator are resized with the calibration data. Returns: A quantized model. @@ -95,10 +116,18 @@ class Calibrator(object): error will be thrown if an operation cannot be quantized, otherwise the model will fallback to float ops. op_output_name: A string, only this op will be quantized. + resize_input: A boolean. True if the shape of the sample data is different + from the input. """ - self._calibrator.Prepare() - for calibration_sample in dataset_gen(): - self._calibrator.FeedTensor(calibration_sample) + initialized = False + for sample in dataset_gen(): + if not initialized: + initialized = True + if resize_input: + self._calibrator.Prepare([list(s.shape) for s in sample]) + else: + self._calibrator.Prepare() + self._calibrator.FeedTensor(sample) return self._calibrator.QuantizeModel( np.dtype(input_type.as_numpy_dtype()).num, np.dtype(output_type.as_numpy_dtype()).num, allow_float, op_output_name) @@ -112,7 +141,10 @@ class Calibrator(object): Args: dataset_gen: A generator that generates calibration samples. """ - self._calibrator.Prepare() - for calibration_sample in dataset_gen(): - self._calibrator.FeedTensor(calibration_sample) - return self._calibrator.calibrate() + initialized = False + for sample in dataset_gen(): + if not initialized: + initialized = True + self._calibrator.Prepare([list(s.shape) for s in sample]) + self._calibrator.FeedTensor(sample) + return self._calibrator.Calibrate() diff --git a/tensorflow/lite/python/optimize/calibrator_test.py b/tensorflow/lite/python/optimize/calibrator_test.py index 34e93543f82..ff7e7009c7b 100644 --- a/tensorflow/lite/python/optimize/calibrator_test.py +++ b/tensorflow/lite/python/optimize/calibrator_test.py @@ -130,7 +130,7 @@ class CalibratorTest(test_util.TensorFlowTestCase, parameterized.TestCase): with self.assertRaisesRegex(ValueError, 'Size mismatch'): quantizer.calibrate_and_quantize(input_gen, constants.FLOAT, - constants.FLOAT, False) + constants.FLOAT, False, False) def test_invalid_type_calibrator_gen(self): model_path = resource_loader.get_path_to_datafile( @@ -138,15 +138,28 @@ class CalibratorTest(test_util.TensorFlowTestCase, parameterized.TestCase): float_model = open(model_path, 'rb').read() quantizer = _calibrator.Calibrator(float_model) - # Input generator with incorrect shape. + # Input generator with incorrect type. def input_gen(): for _ in range(10): - yield np.ones(shape=(1, 5, 5, 3), dtype=np.int32) + yield [np.ones(shape=(1, 5, 5, 3), dtype=np.int32)] with self.assertRaises(ValueError): quantizer.calibrate_and_quantize(input_gen, constants.FLOAT, constants.FLOAT, False) + def test_calibration(self): + model_path = resource_loader.get_path_to_datafile( + 'test_data/mobilenet_like_model.bin') + float_model = open(model_path, 'rb').read() + quantizer = _calibrator.Calibrator(float_model) + + # Input generator for the model. + def input_gen(): + for _ in range(10): + yield [np.ones(shape=(1, 5, 5, 3), dtype=np.float32)] + + quantized_model = quantizer.calibrate(input_gen) + self.assertIsNotNone(quantized_model) if __name__ == '__main__': test.main() diff --git a/tensorflow/lite/simple_memory_arena.h b/tensorflow/lite/simple_memory_arena.h index bb9407b7bfd..803b8be174f 100644 --- a/tensorflow/lite/simple_memory_arena.h +++ b/tensorflow/lite/simple_memory_arena.h @@ -16,8 +16,8 @@ limitations under the License. #define TENSORFLOW_LITE_SIMPLE_MEMORY_ARENA_H_ #include -#include #include +#include #include "tensorflow/lite/c/common.h" @@ -112,7 +112,7 @@ class SimpleMemoryArena { std::unique_ptr underlying_buffer_; size_t underlying_buffer_size_; char* underlying_buffer_aligned_ptr_; - std::list ordered_allocs_; + std::vector ordered_allocs_; }; } // namespace tflite diff --git a/tensorflow/lite/special_rules.bzl b/tensorflow/lite/special_rules.bzl index b0ece0e2d25..eefbe1fb778 100644 --- a/tensorflow/lite/special_rules.bzl +++ b/tensorflow/lite/special_rules.bzl @@ -1,5 +1,10 @@ """External versions of build rules that differ outside of Google.""" +load( + "//tensorflow:tensorflow.bzl", + "clean_dep", +) + def tflite_portable_test_suite(**kwargs): """This is a no-op outside of Google.""" _ignore = [kwargs] @@ -26,3 +31,17 @@ def tflite_extra_gles_deps(): def tflite_ios_lab_runner(version): """This is a no-op outside of Google.""" return None + +def if_nnapi(supported, not_supported = [], supported_android = None): + if supported_android == None: + supported_android = supported + + # We use a blacklist rather than a whitelist for known unsupported platforms. + return select({ + clean_dep("//tensorflow:emscripten"): not_supported, + clean_dep("//tensorflow:ios"): not_supported, + clean_dep("//tensorflow:macos"): not_supported, + clean_dep("//tensorflow:windows"): not_supported, + clean_dep("//tensorflow:android"): supported_android, + "//conditions:default": supported, + }) diff --git a/tensorflow/lite/tflite_with_xnnpack.cc b/tensorflow/lite/tflite_with_xnnpack.cc new file mode 100644 index 00000000000..c8c2c2e02c1 --- /dev/null +++ b/tensorflow/lite/tflite_with_xnnpack.cc @@ -0,0 +1,30 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include + +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" + +namespace tflite { +// Corresponding weak declaration found in lite/model.cc. +std::unique_ptr +AcquireXNNPACKDelegate(int num_threads) { + auto opts = TfLiteXNNPackDelegateOptionsDefault(); + // Note that we don't want to use the thread pool for num_threads == 1. + opts.num_threads = num_threads > 1 ? num_threads : 0; + return std::unique_ptr( + TfLiteXNNPackDelegateCreate(&opts), TfLiteXNNPackDelegateDelete); +} +} // namespace tflite diff --git a/tensorflow/lite/tools/BUILD b/tensorflow/lite/tools/BUILD index 9eacc19ed28..41ccb3df36e 100644 --- a/tensorflow/lite/tools/BUILD +++ b/tensorflow/lite/tools/BUILD @@ -16,10 +16,7 @@ py_binary( srcs = ["visualize.py"], python_version = "PY3", srcs_version = "PY2AND3", - deps = [ - "//tensorflow/lite/python:schema_py", - "//tensorflow/python:platform", - ], + deps = ["//tensorflow/lite/python:schema_py"], ) py_test( @@ -31,7 +28,7 @@ py_test( "no_mac", # TODO(b/148247402): flatbuffers import broken on Mac OS. ], deps = [ - ":test_utilities", + ":test_utils", ":visualize", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_test_lib", @@ -44,11 +41,12 @@ py_binary( python_version = "PY3", srcs_version = "PY2AND3", deps = [ - "//tensorflow:tensorflow_py_no_contrib", + "//tensorflow/python:errors", + "//tensorflow/python:framework_ops", + "//tensorflow/python:image_ops", + "//tensorflow/python:io_ops", "//tensorflow/python:platform", - "//tensorflow/python/keras", - "//third_party/py/numpy", - "@six_archive//:six", + "//tensorflow/python:session", ], ) @@ -61,8 +59,10 @@ py_test( deps = [ ":convert_image_to_csv", "//tensorflow/python:client_testlib", + "//tensorflow/python:errors", "//tensorflow/python:framework_test_lib", "//tensorflow/python:platform", + "//third_party/py/numpy", ], ) @@ -72,32 +72,51 @@ py_binary( python_version = "PY3", srcs_version = "PY2AND3", deps = [ - "//tensorflow/lite/python:schema_py", + ":flatbuffer_utils", "//tensorflow/python:platform", + ], +) + +py_binary( + name = "randomize_weights", + srcs = ["randomize_weights.py"], + python_version = "PY3", + srcs_version = "PY2AND3", + deps = [ + ":flatbuffer_utils", + "//tensorflow/python:platform", + ], +) + +py_library( + name = "flatbuffer_utils", + srcs = ["flatbuffer_utils.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/lite/python:schema_py", "@flatbuffers//:runtime_py", ], ) py_test( - name = "strip_strings_test", - srcs = ["strip_strings_test.py"], + name = "flatbuffer_utils_test", + srcs = ["flatbuffer_utils_test.py"], python_version = "PY3", srcs_version = "PY2AND3", tags = [ "no_mac", # TODO(b/148247402): flatbuffers import broken on Mac OS. ], deps = [ - ":strip_strings", - ":test_utilities", - "//tensorflow/lite/python:schema_py", + ":flatbuffer_utils", + ":test_utils", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_test_lib", ], ) py_library( - name = "test_utilities", - srcs = ["test_utilities.py"], + name = "test_utils", + srcs = ["test_utils.py"], srcs_version = "PY2AND3", deps = [ "//tensorflow/lite/python:schema_py", @@ -109,9 +128,9 @@ tf_cc_binary( name = "generate_op_registrations", srcs = ["gen_op_registration_main.cc"], deps = [ + ":gen_op_registration", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", - "//tensorflow/lite/tools:gen_op_registration", "@com_google_absl//absl/strings", ], ) @@ -211,7 +230,10 @@ py_binary( srcs = ["zip_files.py"], python_version = "PY3", visibility = ["//visibility:public"], - deps = ["@absl_py//absl:app"], + deps = [ + "@absl_py//absl:app", + "@absl_py//absl/flags", + ], ) tflite_portable_test_suite() diff --git a/tensorflow/lite/tools/accuracy/ilsvrc/README.md b/tensorflow/lite/tools/accuracy/ilsvrc/README.md index f6c96591c1e..d702c2669b6 100644 --- a/tensorflow/lite/tools/accuracy/ilsvrc/README.md +++ b/tensorflow/lite/tools/accuracy/ilsvrc/README.md @@ -99,7 +99,7 @@ python generate_validation_labels.py \ ``` bazel build -c opt \ - --config=android_arm \ + --config=android_arm64 \ //tensorflow/lite/tools/accuracy/ilsvrc:imagenet_accuracy_eval ``` diff --git a/tensorflow/lite/tools/benchmark/BUILD b/tensorflow/lite/tools/benchmark/BUILD index 6d946b9702c..1dd7e928c20 100644 --- a/tensorflow/lite/tools/benchmark/BUILD +++ b/tensorflow/lite/tools/benchmark/BUILD @@ -154,7 +154,7 @@ cc_library( "@com_google_absl//absl/strings", "//tensorflow/lite:framework", "//tensorflow/lite:string_util", - "//tensorflow/lite/experimental/ruy/profiler", + "//tensorflow/lite/experimental/ruy/ruy/profiler", "//tensorflow/lite/kernels:builtin_ops", "//tensorflow/lite/profiling:platform_profiler", "//tensorflow/lite/profiling:profiler", diff --git a/tensorflow/lite/tools/benchmark/README.md b/tensorflow/lite/tools/benchmark/README.md index 8790d8d4484..01034fe46ce 100644 --- a/tensorflow/lite/tools/benchmark/README.md +++ b/tensorflow/lite/tools/benchmark/README.md @@ -112,7 +112,7 @@ and the following optional parameters: ``` bazel build -c opt \ - --config=android_arm \ + --config=android_arm64 \ tensorflow/lite/tools/benchmark:benchmark_model ``` @@ -140,7 +140,7 @@ adb push mobilenet_quant_v1_224.tflite /data/local/tmp That step is only needed when using the Hexagon delegate. ``` -bazel build --config=android_arm \ +bazel build --config=android_arm64 \ tensorflow/lite/experimental/delegates/hexagon/hexagon_nn:libhexagon_interface.so adb push bazel-bin/tensorflow/lite/experimental/delegates/hexagon/hexagon_nn/libhexagon_interface.so /data/local/tmp adb push libhexagon_nn_skel*.so /data/local/tmp diff --git a/tensorflow/lite/tools/benchmark/android/README.md b/tensorflow/lite/tools/benchmark/android/README.md index 3e66b7f13f1..00092c4a44f 100644 --- a/tensorflow/lite/tools/benchmark/android/README.md +++ b/tensorflow/lite/tools/benchmark/android/README.md @@ -36,8 +36,9 @@ bazel build -c opt \ ``` adb install -r -d -g bazel-bin/tensorflow/lite/tools/benchmark/android/benchmark_model.apk ``` + Note: Make sure to install with "-g" option to grant the permission for reading -extenal storage. +external storage. (3) Push the compute graph that you need to test. @@ -113,12 +114,12 @@ the system dismisses the notification and displays a third notification "Trace saved", confirming that your trace has been saved and that you're ready to share the system trace. -(9) [Share](https://developer.android.com/topic/performance/tracing/on-device#share-trace) +(9) +[Share](https://developer.android.com/topic/performance/tracing/on-device#share-trace) a trace file, [convert](https://developer.android.com/topic/performance/tracing/on-device#converting_between_trace_formats) between tracing formats and [create](https://developer.android.com/topic/performance/tracing/on-device#create-html-report) -an HTML report. -Note that, the catured tracing file format is either in Perfetto format or in -Systrace format depending on the Android version of your device. Select the -appropriate method to handle the generated file. +an HTML report. Note that, the captured tracing file format is either in +Perfetto format or in Systrace format depending on the Android version of your +device. Select the appropriate method to handle the generated file. diff --git a/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc b/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc index 47ec9f4af0b..617976991e1 100644 --- a/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc +++ b/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc @@ -28,7 +28,7 @@ limitations under the License. #include "absl/base/attributes.h" #include "absl/strings/numbers.h" -#include "tensorflow/lite/experimental/ruy/profiler/profiler.h" +#include "tensorflow/lite/experimental/ruy/ruy/profiler/profiler.h" #include "tensorflow/lite/kernels/register.h" #include "tensorflow/lite/model.h" #include "tensorflow/lite/op_resolver.h" diff --git a/tensorflow/lite/tools/evaluation/stages/inference_profiler_stage.cc b/tensorflow/lite/tools/evaluation/stages/inference_profiler_stage.cc index e8f48d5b407..cfafc1e9214 100644 --- a/tensorflow/lite/tools/evaluation/stages/inference_profiler_stage.cc +++ b/tensorflow/lite/tools/evaluation/stages/inference_profiler_stage.cc @@ -73,6 +73,7 @@ TfLiteStatus InferenceProfilerStage::Init() { // TfliteInferenceParams. test_stage_.reset(new TfliteInferenceStage(config_)); if (test_stage_->Init() != kTfLiteOk) return kTfLiteError; + LOG(INFO) << "Test interpreter has been initialized."; // Initialize a reference TfliteInferenceStage that uses the given model & // num_runs, but maintains the rest of TfliteInferenceParams to default. @@ -86,6 +87,7 @@ TfLiteStatus InferenceProfilerStage::Init() { config_.specification().tflite_inference_params().invocations_per_run()); reference_stage_.reset(new TfliteInferenceStage(reference_config)); if (reference_stage_->Init() != kTfLiteOk) return kTfLiteError; + LOG(INFO) << "Reference interpreter (1 thread on CPU) has been initialized."; model_info_ = reference_stage_->GetModelInfo(); diff --git a/tensorflow/lite/tools/evaluation/stages/tflite_inference_stage.cc b/tensorflow/lite/tools/evaluation/stages/tflite_inference_stage.cc index 222e44c7168..cbf41de0e03 100644 --- a/tensorflow/lite/tools/evaluation/stages/tflite_inference_stage.cc +++ b/tensorflow/lite/tools/evaluation/stages/tflite_inference_stage.cc @@ -100,6 +100,8 @@ TfLiteStatus TfliteInferenceStage::Init() { auto delegate = CreateTfLiteDelegate(params, &error_message); if (delegate) { delegates_.push_back(std::move(delegate)); + LOG(INFO) << "Successfully created " + << params.Delegate_Name(params.delegate()) << " delegate."; } else { LOG(WARNING) << error_message; } diff --git a/tensorflow/lite/tools/flatbuffer_utils.py b/tensorflow/lite/tools/flatbuffer_utils.py new file mode 100644 index 00000000000..5b513bbfef2 --- /dev/null +++ b/tensorflow/lite/tools/flatbuffer_utils.py @@ -0,0 +1,122 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utility functions for FlatBuffers. + +All functions that are commonly used to work with FlatBuffers. + +Refer to the tensorflow lite flatbuffer schema here: +tensorflow/lite/schema/schema.fbs + +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import random + +from flatbuffers.python import flatbuffers +from tensorflow.lite.python import schema_py_generated as schema_fb + + +def read_model(input_tflite_file): + """Reads and parses a tflite model. + + Args: + input_tflite_file: Full path name to the input tflite file + + Raises: + RuntimeError: If input_tflite_file is not found. + IOError: If input_tflite_file cannot be opened. + + Returns: + A python flatbuffer object corresponding to the input tflite file. + """ + if not os.path.exists(input_tflite_file): + raise RuntimeError('Input file not found at %r\n' % input_tflite_file) + with open(input_tflite_file, 'rb') as file_handle: + file_data = bytearray(file_handle.read()) + model_obj = schema_fb.Model.GetRootAsModel(file_data, 0) + return schema_fb.ModelT.InitFromObj(model_obj) + + +def write_model(model, output_tflite_file): + """Writes the model, a python flatbuffer object, into the output tflite file. + + Args: + model: tflite model + output_tflite_file: Full path name to the output tflite file. + + Raises: + IOError: If output_tflite_file cannot be opened. + """ + # Initial size of the buffer, which will grow automatically if needed + builder = flatbuffers.Builder(1024) + model_offset = model.Pack(builder) + builder.Finish(model_offset) + model_data = builder.Output() + with open(output_tflite_file, 'wb') as out_file: + out_file.write(model_data) + + +def strip_strings(model): + """Strips all nonessential strings from the model to reduce model size. + + We remove the following strings: + (find strings by searching ":string" in the tensorflow lite flatbuffer schema) + 1. Model description + 2. SubGraph name + 3. Tensor names + We retain OperatorCode custom_code and Metadata name. + + Args: + model: The model from which to remove nonessential strings. + + """ + + model.description = '' + for subgraph in model.subgraphs: + subgraph.name = '' + for tensor in subgraph.tensors: + tensor.name = '' + + +def randomize_weights(model, random_seed=0): + """Randomize weights in a model. + + Args: + model: The model in which to randomize weights. + random_seed: The input to the random number generator (default value is 0). + + """ + + # The input to the random seed generator. The default value is 0. + random.seed(random_seed) + + # Parse model buffers which store the model weights + buffers = model.buffers + for i in range(1, len(buffers)): # ignore index 0 as it's always None + buffer_i_data = buffers[i].data + buffer_i_size = 0 if buffer_i_data is None else buffer_i_data.size + + # Raw data buffers are of type ubyte (or uint8) whose values lie in the + # range [0, 255]. Those ubytes (or unint8s) are the underlying + # representation of each datatype. For example, a bias tensor of type + # int32 appears as a buffer 4 times it's length of type ubyte (or uint8). + # TODO(b/152324470): This does not work for float as randomized weights may + # end up as denormalized or NaN/Inf floating point numbers. + for j in range(buffer_i_size): + buffer_i_data[j] = random.randint(0, 255) diff --git a/tensorflow/lite/tools/flatbuffer_utils_test.py b/tensorflow/lite/tools/flatbuffer_utils_test.py new file mode 100644 index 00000000000..d2e4fe6daea --- /dev/null +++ b/tensorflow/lite/tools/flatbuffer_utils_test.py @@ -0,0 +1,163 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for flatbuffer_utils.py.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import copy +import os + +from tensorflow.lite.tools import flatbuffer_utils +from tensorflow.lite.tools import test_utils +from tensorflow.python.framework import test_util +from tensorflow.python.platform import test + + +class WriteReadModelTest(test_util.TensorFlowTestCase): + + def testWriteReadModel(self): + # 1. SETUP + # Define the initial model + initial_model = test_utils.build_mock_model_python_object() + # Define temporary files + tmp_dir = self.get_temp_dir() + model_filename = os.path.join(tmp_dir, 'model.tflite') + + # 2. INVOKE + # Invoke the write_model and read_model functions + flatbuffer_utils.write_model(initial_model, model_filename) + final_model = flatbuffer_utils.read_model(model_filename) + + # 3. VALIDATE + # Validate that the initial and final models are the same + # Validate the description + self.assertEqual(initial_model.description, final_model.description) + # Validate the main subgraph's name, inputs, outputs, operators and tensors + initial_subgraph = initial_model.subgraphs[0] + final_subgraph = final_model.subgraphs[0] + self.assertEqual(initial_subgraph.name, final_subgraph.name) + for i in range(len(initial_subgraph.inputs)): + self.assertEqual(initial_subgraph.inputs[i], final_subgraph.inputs[i]) + for i in range(len(initial_subgraph.outputs)): + self.assertEqual(initial_subgraph.outputs[i], final_subgraph.outputs[i]) + for i in range(len(initial_subgraph.operators)): + self.assertEqual(initial_subgraph.operators[i].opcodeIndex, + final_subgraph.operators[i].opcodeIndex) + initial_tensors = initial_subgraph.tensors + final_tensors = final_subgraph.tensors + for i in range(len(initial_tensors)): + self.assertEqual(initial_tensors[i].name, final_tensors[i].name) + self.assertEqual(initial_tensors[i].type, final_tensors[i].type) + self.assertEqual(initial_tensors[i].buffer, final_tensors[i].buffer) + for j in range(len(initial_tensors[i].shape)): + self.assertEqual(initial_tensors[i].shape[j], final_tensors[i].shape[j]) + # Validate the first valid buffer (index 0 is always None) + initial_buffer = initial_model.buffers[1].data + final_buffer = final_model.buffers[1].data + for i in range(initial_buffer.size): + self.assertEqual(initial_buffer.data[i], final_buffer.data[i]) + + +class StripStringsTest(test_util.TensorFlowTestCase): + + def testStripStrings(self): + # 1. SETUP + # Define the initial model + initial_model = test_utils.build_mock_model_python_object() + final_model = copy.deepcopy(initial_model) + + # 2. INVOKE + # Invoke the strip_strings function + flatbuffer_utils.strip_strings(final_model) + + # 3. VALIDATE + # Validate that the initial and final models are the same except strings + # Validate the description + self.assertNotEqual('', initial_model.description) + self.assertEqual('', final_model.description) + # Validate the main subgraph's name, inputs, outputs, operators and tensors + initial_subgraph = initial_model.subgraphs[0] + final_subgraph = final_model.subgraphs[0] + self.assertNotEqual('', initial_model.subgraphs[0].name) + self.assertEqual('', final_model.subgraphs[0].name) + for i in range(len(initial_subgraph.inputs)): + self.assertEqual(initial_subgraph.inputs[i], final_subgraph.inputs[i]) + for i in range(len(initial_subgraph.outputs)): + self.assertEqual(initial_subgraph.outputs[i], final_subgraph.outputs[i]) + for i in range(len(initial_subgraph.operators)): + self.assertEqual(initial_subgraph.operators[i].opcodeIndex, + final_subgraph.operators[i].opcodeIndex) + initial_tensors = initial_subgraph.tensors + final_tensors = final_subgraph.tensors + for i in range(len(initial_tensors)): + self.assertNotEqual('', initial_tensors[i].name) + self.assertEqual('', final_tensors[i].name) + self.assertEqual(initial_tensors[i].type, final_tensors[i].type) + self.assertEqual(initial_tensors[i].buffer, final_tensors[i].buffer) + for j in range(len(initial_tensors[i].shape)): + self.assertEqual(initial_tensors[i].shape[j], final_tensors[i].shape[j]) + # Validate the first valid buffer (index 0 is always None) + initial_buffer = initial_model.buffers[1].data + final_buffer = final_model.buffers[1].data + for i in range(initial_buffer.size): + self.assertEqual(initial_buffer.data[i], final_buffer.data[i]) + + +class RandomizeWeightsTest(test_util.TensorFlowTestCase): + + def testRandomizeWeights(self): + # 1. SETUP + # Define the initial model + initial_model = test_utils.build_mock_model_python_object() + final_model = copy.deepcopy(initial_model) + + # 2. INVOKE + # Invoke the randomize_weights function + flatbuffer_utils.randomize_weights(final_model) + + # 3. VALIDATE + # Validate that the initial and final models are the same, except that + # the weights in the model buffer have been modified (i.e, randomized) + # Validate the description + self.assertEqual(initial_model.description, final_model.description) + # Validate the main subgraph's name, inputs, outputs, operators and tensors + initial_subgraph = initial_model.subgraphs[0] + final_subgraph = final_model.subgraphs[0] + self.assertEqual(initial_subgraph.name, final_subgraph.name) + for i in range(len(initial_subgraph.inputs)): + self.assertEqual(initial_subgraph.inputs[i], final_subgraph.inputs[i]) + for i in range(len(initial_subgraph.outputs)): + self.assertEqual(initial_subgraph.outputs[i], final_subgraph.outputs[i]) + for i in range(len(initial_subgraph.operators)): + self.assertEqual(initial_subgraph.operators[i].opcodeIndex, + final_subgraph.operators[i].opcodeIndex) + initial_tensors = initial_subgraph.tensors + final_tensors = final_subgraph.tensors + for i in range(len(initial_tensors)): + self.assertEqual(initial_tensors[i].name, final_tensors[i].name) + self.assertEqual(initial_tensors[i].type, final_tensors[i].type) + self.assertEqual(initial_tensors[i].buffer, final_tensors[i].buffer) + for j in range(len(initial_tensors[i].shape)): + self.assertEqual(initial_tensors[i].shape[j], final_tensors[i].shape[j]) + # Validate the first valid buffer (index 0 is always None) + initial_buffer = initial_model.buffers[1].data + final_buffer = final_model.buffers[1].data + for j in range(initial_buffer.size): + self.assertNotEqual(initial_buffer.data[j], final_buffer.data[j]) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/lite/tools/make/Makefile b/tensorflow/lite/tools/make/Makefile index ef265ccf719..9043d494235 100644 --- a/tensorflow/lite/tools/make/Makefile +++ b/tensorflow/lite/tools/make/Makefile @@ -119,7 +119,7 @@ $(wildcard tensorflow/lite/c/*.c) \ $(wildcard tensorflow/lite/core/*.cc) \ $(wildcard tensorflow/lite/core/api/*.cc) \ $(wildcard tensorflow/lite/experimental/resource/*.cc) \ -$(wildcard tensorflow/lite/experimental/ruy/*.cc) +$(wildcard tensorflow/lite/experimental/ruy/ruy/*.cc) ifneq ($(BUILD_TYPE),micro) CORE_CC_ALL_SRCS += \ $(wildcard tensorflow/lite/kernels/*.cc) \ @@ -146,10 +146,13 @@ $(wildcard tensorflow/lite/*/*/example*.cc) \ $(wildcard tensorflow/lite/*/*/test*.cc) \ $(wildcard tensorflow/lite/*/*/*test.cc) \ $(wildcard tensorflow/lite/*/*/*tool.cc) \ +$(wildcard tensorflow/lite/*/*/*/benchmark.cc) \ +$(wildcard tensorflow/lite/*/*/*/example*.cc) \ +$(wildcard tensorflow/lite/*/*/*/test*.cc) \ $(wildcard tensorflow/lite/*/*/*/*test.cc) \ +$(wildcard tensorflow/lite/*/*/*/*tool.cc) \ $(wildcard tensorflow/lite/kernels/*test_main.cc) \ $(wildcard tensorflow/lite/kernels/*test_util*.cc) \ -tensorflow/lite/experimental/ruy/tune_tool.cc \ tensorflow/lite/tflite_with_xnnpack.cc \ $(MINIMAL_SRCS) diff --git a/tensorflow/lite/tools/optimize/BUILD b/tensorflow/lite/tools/optimize/BUILD index c3318f1ab26..3011c01cdeb 100644 --- a/tensorflow/lite/tools/optimize/BUILD +++ b/tensorflow/lite/tools/optimize/BUILD @@ -108,8 +108,8 @@ cc_library( "//tensorflow/lite:minimal_logging", "//tensorflow/lite/c:common", "//tensorflow/lite/core/api", + "//tensorflow/lite/kernels/internal:cppmath", "//tensorflow/lite/kernels/internal:quantization_util", - "//tensorflow/lite/kernels/internal:round", "//tensorflow/lite/kernels/internal:tensor_utils", "//tensorflow/lite/kernels/internal:types", "//tensorflow/lite/schema:schema_fbs", diff --git a/tensorflow/lite/tools/optimize/quantization_utils.cc b/tensorflow/lite/tools/optimize/quantization_utils.cc index 539711cf3b5..33bc4f44596 100644 --- a/tensorflow/lite/tools/optimize/quantization_utils.cc +++ b/tensorflow/lite/tools/optimize/quantization_utils.cc @@ -23,8 +23,8 @@ limitations under the License. #include "third_party/eigen3/Eigen/Core" #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/core/api/error_reporter.h" +#include "tensorflow/lite/kernels/internal/cppmath.h" #include "tensorflow/lite/kernels/internal/quantization_util.h" -#include "tensorflow/lite/kernels/internal/round.h" #include "tensorflow/lite/kernels/internal/tensor_utils.h" #include "tensorflow/lite/kernels/internal/types.h" #include "tensorflow/lite/minimal_logging.h" diff --git a/tensorflow/lite/tools/pip_package/Dockerfile b/tensorflow/lite/tools/pip_package/Dockerfile index e4d47d0fa0d..3b92596758b 100644 --- a/tensorflow/lite/tools/pip_package/Dockerfile +++ b/tensorflow/lite/tools/pip_package/Dockerfile @@ -14,6 +14,7 @@ RUN apt-get update && \ python-setuptools \ python-wheel \ python-numpy \ + python-pip \ libpython-dev \ libpython-dev:armhf \ libpython-dev:arm64 \ @@ -21,6 +22,7 @@ RUN apt-get update && \ python3-setuptools \ python3-wheel \ python3-numpy \ + python3-pip \ libpython3-dev \ libpython3-dev:armhf \ libpython3-dev:arm64 \ @@ -29,8 +31,11 @@ RUN apt-get update && \ zlib1g-dev \ zlib1g-dev:armhf \ zlib1g-dev:arm64 \ - swig \ curl \ unzip \ git && \ apt-get clean +RUN pip install pip --upgrade +RUN pip install pybind11 +RUN pip3 install pip --upgrade +RUN pip3 install pybind11 diff --git a/tensorflow/lite/tools/pip_package/setup.py b/tensorflow/lite/tools/pip_package/setup.py index f99a5b043dc..2f2515145c4 100644 --- a/tensorflow/lite/tools/pip_package/setup.py +++ b/tensorflow/lite/tools/pip_package/setup.py @@ -24,9 +24,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import glob import multiprocessing import os import subprocess +import sys +import sysconfig from distutils.command.build_ext import build_ext import numpy @@ -65,7 +68,7 @@ for name in ['TARGET', 'TARGET_ARCH', 'CC_PREFIX', 'EXTRA_CXXFLAGS']: # with more than 4GB, use all the CPUs, otherwise only 1. def get_build_cpus(): physical_bytes = os.sysconf('SC_PAGESIZE') * os.sysconf('SC_PHYS_PAGES') - if physical_bytes < (1<<30) * 4: + if physical_bytes < (1 << 30) * 4: return 1 else: return multiprocessing.cpu_count() @@ -73,9 +76,9 @@ def get_build_cpus(): def make_args(target='', quiet=True): """Construct make command line.""" - args = (['make', 'SHELL=/bin/bash', - 'BUILD_WITH_NNAPI=false', '-C', TENSORFLOW_DIR] - + MAKE_CROSS_OPTIONS + + args = ([ + 'make', 'SHELL=/bin/bash', 'BUILD_WITH_NNAPI=false', '-C', TENSORFLOW_DIR + ] + MAKE_CROSS_OPTIONS + ['-f', RELATIVE_MAKEFILE_PATH, '-j', str(get_build_cpus())]) if quiet: @@ -128,28 +131,55 @@ class CustomBuildPy(build_py, object): return super(CustomBuildPy, self).run() +def get_pybind_include(): + """pybind11 include directory is not correctly resolved. + + This fixes include directory to /usr/local/pythonX.X + + Returns: + include directories to find pybind11 + """ + if sys.version_info[0] == 3: + include_dirs = glob.glob('/usr/local/include/python3*') + else: + include_dirs = glob.glob('/usr/local/include/python2*') + include_dirs.append(sysconfig.get_path('include')) + tmp_include_dirs = [] + pip_dir = os.path.join(TENSORFLOW_DIR, 'tensorflow', 'lite', 'tools', + 'pip_package', 'gen') + for include_dir in include_dirs: + tmp_include_dir = os.path.join(pip_dir, include_dir[1:]) + tmp_include_dirs.append(tmp_include_dir) + try: + os.makedirs(tmp_include_dir) + os.symlink(include_dir, os.path.join(tmp_include_dir, 'include')) + except IOError: # file already exists. + pass + return tmp_include_dirs + + LIB_TFLITE = 'tensorflow-lite' LIB_TFLITE_DIR = make_output('libdir') ext = Extension( - name='%s._interpreter_wrapper' % PACKAGE_NAME, + name='%s._pywrap_tensorflow_interpreter_wrapper' % PACKAGE_NAME, language='c++', - sources=['interpreter_wrapper/interpreter_wrapper.i', - 'interpreter_wrapper/interpreter_wrapper.cc', - 'interpreter_wrapper/numpy.cc', - 'interpreter_wrapper/python_error_reporter.cc', - 'interpreter_wrapper/python_utils.cc'], + sources=[ + 'interpreter_wrapper/interpreter_wrapper.cc', + 'interpreter_wrapper/interpreter_wrapper_pybind11.cc', + 'interpreter_wrapper/numpy.cc', + 'interpreter_wrapper/python_error_reporter.cc', + 'interpreter_wrapper/python_utils.cc' + ], extra_compile_args=['--std=c++11'], - swig_opts=['-c++', - '-I%s' % TENSORFLOW_DIR, - '-module', 'interpreter_wrapper', - '-outdir', PACKAGE_NAME], - include_dirs=[TENSORFLOW_DIR, - os.path.join(TENSORFLOW_DIR, 'tensorflow', 'lite', 'tools', - 'pip_package'), - numpy.get_include(), - os.path.join(DOWNLOADS_DIR, 'flatbuffers', 'include'), - os.path.join(DOWNLOADS_DIR, 'absl')], + include_dirs=[ + TENSORFLOW_DIR, + os.path.join(TENSORFLOW_DIR, 'tensorflow', 'lite', 'tools', + 'pip_package'), + numpy.get_include(), + os.path.join(DOWNLOADS_DIR, 'flatbuffers', 'include'), + os.path.join(DOWNLOADS_DIR, 'absl') + ] + get_pybind_include(), libraries=[LIB_TFLITE], library_dirs=[LIB_TFLITE_DIR]) @@ -186,9 +216,9 @@ setup( ext_modules=[ext], install_requires=[ 'numpy >= 1.16.0', + 'pybind11 >= 2.4.3', ], cmdclass={ 'build_ext': CustomBuildExt, 'build_py': CustomBuildPy, - } -) + }) diff --git a/tensorflow/lite/tools/randomize_weights.py b/tensorflow/lite/tools/randomize_weights.py new file mode 100644 index 00000000000..84bbe3955a7 --- /dev/null +++ b/tensorflow/lite/tools/randomize_weights.py @@ -0,0 +1,62 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Randomize all weights in a tflite file. + +Example usage: +python randomize_weights.py foo.tflite foo_randomized.tflite +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import sys + +from tensorflow.lite.tools import flatbuffer_utils +from tensorflow.python.platform import app + + +def main(_): + parser = argparse.ArgumentParser( + description='Randomize weights in a tflite file.') + parser.add_argument( + '--input_tflite_file', + type=str, + required=True, + help='Full path name to the input tflite file.') + parser.add_argument( + '--output_tflite_file', + type=str, + required=True, + help='Full path name to the output randomized tflite file.') + parser.add_argument( + '--random_seed', + type=str, + required=False, + default=0, + help='Input to the random number generator. The default value is 0.') + args = parser.parse_args() + + # Read the model + model = flatbuffer_utils.read_model(args.input_tflite_file) + # Invoke the randomize weights function + flatbuffer_utils.randomize_weights(model, args.random_seed) + # Write the model + flatbuffer_utils.write_model(model, args.output_tflite_file) + + +if __name__ == '__main__': + app.run(main=main, argv=sys.argv[:1]) diff --git a/tensorflow/lite/tools/strip_strings.py b/tensorflow/lite/tools/strip_strings.py index be9e726835a..cc88562caf1 100644 --- a/tensorflow/lite/tools/strip_strings.py +++ b/tensorflow/lite/tools/strip_strings.py @@ -12,17 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""This tool strips all nonessential strings from a tflite file. - -Refer to the schema here: //third_party/tensorflow/lite/schema/schema.fbs -We remove the following strings: (search for ":string" in this schema) -1. Tensor names -2. SubGraph name -3. Model description -We retain OperatorCode custom_code and Metadata name. +"""Strips all nonessential strings from a tflite file. Example usage: - python strip_strings.py foo.tflite foo_stripped.tflite """ @@ -31,47 +23,12 @@ from __future__ import division from __future__ import print_function import argparse -import os import sys -from flatbuffers.python import flatbuffers -from tensorflow.lite.python import schema_py_generated as schema_fb +from tensorflow.lite.tools import flatbuffer_utils from tensorflow.python.platform import app -def StripTfliteFile(input_tflite_file, output_tflite_file): - """Strips all nonessential strings from the model to reduce model size. - - Args: - input_tflite_file: Full path name to the input tflite file - output_tflite_file: Full path name to the stripped output tflite file. - - Raises: - RuntimeError: If input_tflite_file is not found. - IOError: If input_tflite_file or output_tflite_file cannot be opened. - - """ - - if not os.path.exists(input_tflite_file): - raise RuntimeError('Input file not found at %r\n' % input_tflite_file) - with open(input_tflite_file, 'rb') as file_handle: - file_data = bytearray(file_handle.read()) - model_obj = schema_fb.Model.GetRootAsModel(file_data, 0) - model = schema_fb.ModelT.InitFromObj(model_obj) - model.description = '' - for subgraph in model.subgraphs: - subgraph.name = '' - for tensor in subgraph.tensors: - tensor.name = '' - builder = flatbuffers.Builder(1024) # Initial size of the buffer, which - # will grow automatically if needed - model_offset = model.Pack(builder) - builder.Finish(model_offset) - model_data = builder.Output() - with open(output_tflite_file, 'wb') as out_file: - out_file.write(model_data) - - def main(_): """Application run loop.""" parser = argparse.ArgumentParser( @@ -88,8 +45,12 @@ def main(_): help='Full path name to the stripped output tflite file.') args = parser.parse_args() + # Read the model + model = flatbuffer_utils.read_model(args.input_tflite_file) # Invoke the strip tflite file function - StripTfliteFile(args.input_tflite_file, args.output_tflite_file) + flatbuffer_utils.strip_strings(model) + # Write the model + flatbuffer_utils.write_model(model, args.output_tflite_file) if __name__ == '__main__': diff --git a/tensorflow/lite/tools/strip_strings_test.py b/tensorflow/lite/tools/strip_strings_test.py deleted file mode 100644 index 9ba2c991444..00000000000 --- a/tensorflow/lite/tools/strip_strings_test.py +++ /dev/null @@ -1,117 +0,0 @@ -# Copyright 2020 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for strip_strings.py.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import os - -from tensorflow.lite.python import schema_py_generated as schema_fb -from tensorflow.lite.tools import strip_strings -from tensorflow.lite.tools import test_utilities -from tensorflow.python.framework import test_util -from tensorflow.python.platform import test - - -class StripTensorNamesTest(test_util.TensorFlowTestCase): - - def testStripTensorNames(self): - # Define mock model - model_mock = test_utilities.BuildMockModel() - - # Define temporary files - tmp_dir = self.get_temp_dir() - model_filename = os.path.join(tmp_dir, 'model.tflite') - model_stripped_filename = os.path.join(tmp_dir, 'model_stripped.tflite') - - # Validate the mock model - model = schema_fb.Model.GetRootAsModel(model_mock, 0) - model_tensors = model.Subgraphs(0).Tensors - self.assertEqual(b'input_tensor', model_tensors(0).Name()) - self.assertEqual(b'constant_tensor', model_tensors(1).Name()) - self.assertEqual(b'output_tensor', model_tensors(2).Name()) - - # Store the model locally in model_filename - with open(model_filename, 'wb') as model_file: - model_file.write(model_mock) - # Invoke the StripTfliteFile function to remove string names - strip_strings.StripTfliteFile(model_filename, model_stripped_filename) - # Read the locally stored model in model_stripped_filename - with open(model_stripped_filename, 'rb') as model_file: - model_stripped = model_file.read() - - # Validate the model stripped of tensor names - model_stripped = schema_fb.Model.GetRootAsModel(model_stripped, 0) - model_stripped_tensors = model_stripped.Subgraphs(0).Tensors - self.assertEqual(b'', model_stripped_tensors(0).Name()) - self.assertEqual(b'', model_stripped_tensors(1).Name()) - self.assertEqual(b'', model_stripped_tensors(2).Name()) - - def testStripSubGraphNames(self): - # Define mock model - model_mock = test_utilities.BuildMockModel() - - # Define temporary files - tmp_dir = self.get_temp_dir() - model_filename = os.path.join(tmp_dir, 'model.tflite') - model_stripped_filename = os.path.join(tmp_dir, 'model_stripped.tflite') - - # Validate the mock model - model = schema_fb.Model.GetRootAsModel(model_mock, 0) - self.assertEqual(b'subgraph_name', model.Subgraphs(0).Name()) - - # Store the model locally in model_filename - with open(model_filename, 'wb') as model_file: - model_file.write(model_mock) - # Invoke the StripTfliteFile function to remove string names - strip_strings.StripTfliteFile(model_filename, model_stripped_filename) - # Read the locally stored model in model_stripped_filename - with open(model_stripped_filename, 'rb') as model_file: - model_stripped = model_file.read() - - # Validate the model stripped of subgraph names - model_stripped = schema_fb.Model.GetRootAsModel(model_stripped, 0) - self.assertEqual(b'', model_stripped.Subgraphs(0).Name()) - - def testStripModelDescription(self): - # Define mock model - model_mock = test_utilities.BuildMockModel() - - # Define temporary files - tmp_dir = self.get_temp_dir() - model_filename = os.path.join(tmp_dir, 'model.tflite') - model_stripped_filename = os.path.join(tmp_dir, 'model_stripped.tflite') - - # Validate the mock model - model = schema_fb.Model.GetRootAsModel(model_mock, 0) - self.assertEqual(b'model_description', model.Description()) - - # Store the model locally in model_filename - with open(model_filename, 'wb') as model_file: - model_file.write(model_mock) - # Invoke the StripTfliteFile function to remove string names - strip_strings.StripTfliteFile(model_filename, model_stripped_filename) - # Read the locally stored model in model_stripped_filename - with open(model_stripped_filename, 'rb') as model_file: - model_stripped = model_file.read() - - # Validate the model stripped of model description - model_stripped = schema_fb.Model.GetRootAsModel(model_stripped, 0) - self.assertEqual(b'', model_stripped.Description()) - - -if __name__ == '__main__': - test.main() diff --git a/tensorflow/lite/tools/test_utilities.py b/tensorflow/lite/tools/test_utils.py similarity index 94% rename from tensorflow/lite/tools/test_utilities.py rename to tensorflow/lite/tools/test_utils.py index a6a994898d7..75a649b729a 100644 --- a/tensorflow/lite/tools/test_utilities.py +++ b/tensorflow/lite/tools/test_utils.py @@ -25,24 +25,24 @@ from flatbuffers.python import flatbuffers from tensorflow.lite.python import schema_py_generated as schema_fb -def BuildMockModel(): - """Creates a flatbuffer object containing an example model.""" +def build_mock_model(): + """Creates a flatbuffer containing an example model.""" builder = flatbuffers.Builder(1024) schema_fb.BufferStart(builder) buffer0_offset = schema_fb.BufferEnd(builder) schema_fb.BufferStartDataVector(builder, 10) - builder.PrependUint8(0) - builder.PrependUint8(1) - builder.PrependUint8(2) - builder.PrependUint8(3) - builder.PrependUint8(4) - builder.PrependUint8(5) - builder.PrependUint8(6) - builder.PrependUint8(7) - builder.PrependUint8(8) builder.PrependUint8(9) + builder.PrependUint8(8) + builder.PrependUint8(7) + builder.PrependUint8(6) + builder.PrependUint8(5) + builder.PrependUint8(4) + builder.PrependUint8(3) + builder.PrependUint8(2) + builder.PrependUint8(1) + builder.PrependUint8(0) buffer1_data_offset = builder.EndVector(10) schema_fb.BufferStart(builder) schema_fb.BufferAddData(builder, buffer1_data_offset) @@ -200,6 +200,15 @@ def BuildMockModel(): schema_fb.ModelAddBuffers(builder, buffers_offset) model_offset = schema_fb.ModelEnd(builder) builder.Finish(model_offset) - model_data = builder.Output() + model = builder.Output() - return model_data + return model + + +def build_mock_model_python_object(): + """Creates a python flatbuffer object containing an example model.""" + model_mock = build_mock_model() + model_obj = schema_fb.Model.GetRootAsModel(model_mock, 0) + model = schema_fb.ModelT.InitFromObj(model_obj) + + return model diff --git a/tensorflow/lite/tools/versioning/op_version.cc b/tensorflow/lite/tools/versioning/op_version.cc index 8598c9c1bb2..031b2e17583 100644 --- a/tensorflow/lite/tools/versioning/op_version.cc +++ b/tensorflow/lite/tools/versioning/op_version.cc @@ -216,10 +216,11 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) { case BuiltinOperator_SPLIT: // If the op take int8 input, it is version 2, for int32 it's version 3. - if (op_sig.input_types.at(0) == TensorType_INT32) { + // The input tensor is at index 1 not 0, 0 is the axis. + if (op_sig.input_types.at(1) == TensorType_INT32) { return 3; } - if (op_sig.input_types.at(0) == TensorType_INT8) { + if (op_sig.input_types.at(1) == TensorType_INT8) { return 2; } return 1; diff --git a/tensorflow/lite/tools/visualize_test.py b/tensorflow/lite/tools/visualize_test.py index 7368f34a37b..8beb8f801da 100644 --- a/tensorflow/lite/tools/visualize_test.py +++ b/tensorflow/lite/tools/visualize_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import os import re -from tensorflow.lite.tools import test_utilities +from tensorflow.lite.tools import test_utils from tensorflow.lite.tools import visualize from tensorflow.python.framework import test_util from tensorflow.python.platform import test @@ -35,7 +35,7 @@ class VisualizeTest(test_util.TensorFlowTestCase): self.assertEqual('HASHTABLE_LOOKUP', visualize.BuiltinCodeToName(10)) def testFlatbufferToDict(self): - model_data = test_utilities.BuildMockModel() + model_data = test_utils.build_mock_model() model_dict = visualize.CreateDictFromFlatbuffer(model_data) self.assertEqual(0, model_dict['version']) self.assertEqual(1, len(model_dict['subgraphs'])) @@ -45,7 +45,7 @@ class VisualizeTest(test_util.TensorFlowTestCase): self.assertEqual(0, model_dict['subgraphs'][0]['tensors'][0]['buffer']) def testVisualize(self): - model_data = test_utilities.BuildMockModel() + model_data = test_utils.build_mock_model() tmp_dir = self.get_temp_dir() model_filename = os.path.join(tmp_dir, 'model.tflite') diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 2f65cb57eca..c402dcf947d 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -5,12 +5,26 @@ load("//tensorflow:tensorflow.bzl", "py_strict_library") load("//tensorflow:tensorflow.bzl", "cc_header_only_library", "if_mlir", "if_not_windows", "if_xla_available", "py_test", "py_tests", "tf_cc_shared_object", "tf_cuda_library", "tf_gen_op_wrapper_py", "tf_py_build_info_genrule", "tf_py_test") + +# buildifier: disable=same-origin-load load("//tensorflow:tensorflow.bzl", "tf_python_pybind_extension") + +# buildifier: disable=same-origin-load load("//tensorflow:tensorflow.bzl", "pybind_extension") + +# buildifier: disable=same-origin-load load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc") + +# buildifier: disable=same-origin-load load("//tensorflow:tensorflow.bzl", "cuda_py_test") + +# buildifier: disable=same-origin-load load("//tensorflow:tensorflow.bzl", "cuda_py_tests") + +# buildifier: disable=same-origin-load load("//tensorflow:tensorflow.bzl", "tf_external_workspace_visible") + +# buildifier: disable=same-origin-load load("//tensorflow:tensorflow.bzl", "tf_pybind_cc_library_wrapper") load("//tensorflow/core/platform:build_config.bzl", "pyx_library", "tf_additional_all_protos", "tf_additional_lib_deps", "tf_proto_library", "tf_proto_library_py", "tf_protos_grappler") # @unused load("//tensorflow/core/platform:build_config_root.bzl", "if_static", "tf_additional_plugin_deps", "tf_additional_xla_deps_py") @@ -1864,6 +1878,7 @@ py_library( ":tensor_shape", ":util", "//tensorflow/core:protos_all_py", + "//tensorflow/python/types", ], ) @@ -3178,6 +3193,7 @@ tf_py_test( size = "small", srcs = ["ops/collective_ops_test.py"], python_version = "PY3", + tags = ["no_rocm"], deps = [ ":client_testlib", ":collective_ops", @@ -5204,7 +5220,7 @@ py_library( "//tensorflow/python/distribute:reduce_util", "//tensorflow/python/eager:backprop", "//tensorflow/python/eager:context", - "//tensorflow/python/keras/optimizer_v2:learning_rate_schedule", + "//tensorflow/python/keras/optimizer_v2:legacy_learning_rate_decay", "//tensorflow/python/ops/losses", "//third_party/py/numpy", "@six_archive//:six", @@ -5843,6 +5859,7 @@ tf_py_wrap_cc( "//tensorflow/c:tf_status_helper", "//tensorflow/c/eager:c_api", "//tensorflow/c/eager:c_api_experimental", + "//tensorflow/core/data/service:server_lib", "//tensorflow/core/distributed_runtime/rpc:grpc_rpc_factory_registration", "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib", "//tensorflow/core/distributed_runtime/rpc:grpc_session", @@ -5898,6 +5915,7 @@ filegroup( "//tensorflow/compiler/mlir/python:mlir", # mlir "//tensorflow/core:core_cpu_base_no_ops", # tf_session "//tensorflow/core:core_cpu_impl", # device_lib + "//tensorflow/core/data/service:server_lib", # server_lib "//tensorflow/core:framework_internal_impl", # op_def_registry "//tensorflow/core:lib_internal_impl", # device_lib "//tensorflow/core:op_gen_lib", # tf_session @@ -6409,6 +6427,9 @@ cuda_py_test( size = "small", srcs = ["client/timeline_test.py"], python_version = "PY3", + tags = [ + "gpu_cupti", + ], xla_enable_strict_auto_jit = False, # Graph structure is different with autojit deps = [ ":client", @@ -6472,6 +6493,7 @@ tf_py_test( size = "small", srcs = ["framework/convert_to_constants_test.py"], python_version = "PY3", + tags = ["no_rocm"], deps = [ "client_testlib", "framework_test_lib", @@ -6496,7 +6518,10 @@ tf_py_test( size = "small", srcs = ["lib/io/file_io_test.py"], python_version = "PY3", - tags = ["no_windows"], + tags = [ + "no_rocm", + "no_windows", + ], deps = [ ":client_testlib", ":errors", @@ -6591,7 +6616,6 @@ cuda_py_tests( "training/device_setter_test.py", "training/ftrl_test.py", "training/gradient_descent_test.py", - "training/learning_rate_decay_test.py", "training/momentum_test.py", "training/optimizer_test.py", "training/proximal_adagrad_test.py", @@ -7598,35 +7622,6 @@ tf_py_test( ], ) -py_library( - name = "graph_placer", - srcs = [ - "grappler/controller.py", - "grappler/graph_placer.py", - "grappler/hierarchical_controller.py", - ], - deps = [ - ":python", - "//third_party/py/numpy", - ], -) - -tf_py_test( - name = "graph_placer_test", - size = "large", - srcs = ["grappler/graph_placer_test.py"], - python_version = "PY3", - tags = [ - "grappler", - "no_pip", # graph_placer is not available in pip. - ], - deps = [ - ":client_testlib", - ":graph_placer", - ":math_ops", - ], -) - tf_py_test( name = "memory_optimizer_test", size = "medium", @@ -7825,10 +7820,7 @@ cuda_py_test( "grappler/auto_mixed_precision_test.py", ], python_version = "PY3", - tags = [ - "grappler", - "no_rocm", - ], + tags = ["grappler"], # This test analyzes the graph, but XLA changes the names of nodes. xla_enable_strict_auto_jit = False, deps = [ diff --git a/tensorflow/python/autograph/core/BUILD b/tensorflow/python/autograph/core/BUILD index d81723cf04c..655dc118a37 100644 --- a/tensorflow/python/autograph/core/BUILD +++ b/tensorflow/python/autograph/core/BUILD @@ -24,7 +24,6 @@ py_library( "config_lib.py", "converter.py", "function_wrappers.py", - "naming.py", "unsupported_features_checker.py", ], srcs_version = "PY2AND3", @@ -79,14 +78,3 @@ py_test( "//tensorflow/python:client_testlib", ], ) - -py_test( - name = "naming_test", - srcs = ["naming_test.py"], - python_version = "PY3", - srcs_version = "PY2AND3", - deps = [ - ":core", - "//tensorflow/python:client_testlib", - ], -) diff --git a/tensorflow/python/autograph/core/converter_testing.py b/tensorflow/python/autograph/core/converter_testing.py index 4ea1187f8ed..4b170159b8b 100644 --- a/tensorflow/python/autograph/core/converter_testing.py +++ b/tensorflow/python/autograph/core/converter_testing.py @@ -30,9 +30,9 @@ from tensorflow.python.autograph import utils from tensorflow.python.autograph.core import config from tensorflow.python.autograph.core import converter from tensorflow.python.autograph.core import function_wrappers -from tensorflow.python.autograph.core import naming from tensorflow.python.autograph.lang import special_functions from tensorflow.python.autograph.pyct import loader +from tensorflow.python.autograph.pyct import naming from tensorflow.python.autograph.pyct import origin_info from tensorflow.python.autograph.pyct import parser from tensorflow.python.autograph.pyct import pretty_printer diff --git a/tensorflow/python/autograph/core/naming.py b/tensorflow/python/autograph/core/naming.py deleted file mode 100644 index 67a565a9270..00000000000 --- a/tensorflow/python/autograph/core/naming.py +++ /dev/null @@ -1,129 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Symbol naming utilities.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import enum - -from tensorflow.python.autograph.pyct import qual_names -from tensorflow.python.autograph.utils import misc - - -class _NamingStyle(enum.Enum): - SNAKE = 1 - CAMEL = 2 - - -class Namer(object): - """Symbol name generator.""" - - def __init__(self, global_namespace): - self.global_namespace = global_namespace - self.generated_names = set() - - def _as_symbol_name(self, fqn, style=_NamingStyle.SNAKE): - """Returns a symbol name that matches a fully-qualified name. - - The returned name is safe to use for Python symbols. Any special characters - present in fqn are replaced according to the style argument. - - Examples: - - self._as_symbol_name('foo.bar', style=_NamingStyle.CAMEL) == 'FooBar' - self._as_symbol_name('foo.bar', style=_NamingStyle.SNAKE) == 'foo_bar' - - See the unit tests for more examples. - - Args: - fqn: Union[Text, Tuple[Text]] a fully-qualified symbol name. The qualifier - may include module, class names, attributes, etc. - style: _NamingStyle - Returns: - Text - """ - assert style in _NamingStyle - - if isinstance(fqn, tuple): - cn = '.'.join(fqn) - else: - cn = fqn - - # Until we clean up the whole FQN mechanism, `fqn` may not be - # canonical, that is, in can appear as ('foo.bar', 'baz') - # This replaces any characters that might remain because of that. - pieces = cn.split('.') - - if style == _NamingStyle.CAMEL: - pieces = tuple(misc.capitalize_initial(p) for p in pieces) - return ''.join(pieces) - elif style == _NamingStyle.SNAKE: - return '_'.join(pieces) - - def class_name(self, original_fqn): - """Returns the name of a converted class.""" - canonical_name = self._as_symbol_name( - original_fqn, style=_NamingStyle.CAMEL) - new_name_root = 'Tf%s' % canonical_name - new_name = new_name_root - n = 0 - while new_name in self.global_namespace: - n += 1 - new_name = '%s_%d' % (new_name_root, n) - self.generated_names.add(new_name) - return new_name - - def function_name(self, original_fqn): - """Returns the name of a converted function.""" - canonical_name = self._as_symbol_name( - original_fqn, style=_NamingStyle.SNAKE) - new_name_root = 'tf__%s' % canonical_name - new_name = new_name_root - n = 0 - while new_name in self.global_namespace: - n += 1 - new_name = '%s_%d' % (new_name_root, n) - self.generated_names.add(new_name) - return new_name - - def new_symbol(self, name_root, reserved_locals): - """See control_flow.SymbolNamer.new_symbol.""" - # reserved_locals may contain QNs. - all_reserved_locals = set() - for s in reserved_locals: - if isinstance(s, qual_names.QN): - all_reserved_locals.update(s.qn) - elif isinstance(s, str): - all_reserved_locals.add(s) - else: - raise ValueError('Unexpected symbol type "%s"' % type(s)) - - pieces = name_root.split('_') - if pieces[-1].isdigit(): - name_root = '_'.join(pieces[:-1]) - n = int(pieces[-1]) - else: - n = 0 - new_name = name_root - - while (new_name in self.global_namespace or - new_name in all_reserved_locals or new_name in self.generated_names): - n += 1 - new_name = '%s_%d' % (name_root, n) - - self.generated_names.add(new_name) - return new_name diff --git a/tensorflow/python/autograph/impl/api.py b/tensorflow/python/autograph/impl/api.py index 616a74e4f2a..146d4b6ec2c 100644 --- a/tensorflow/python/autograph/impl/api.py +++ b/tensorflow/python/autograph/impl/api.py @@ -375,6 +375,44 @@ def _is_known_loaded_type(f, module_name, entity_name): return False +def _fall_back_unconverted(f, args, kwargs, options, exc): + """Falls back to calling the function unconverted, in case of error.""" + # TODO(mdan): Consider adding an internal metric. + warning_template = ( + 'AutoGraph could not transform %s and will run it as-is.\n' + '%s' + 'Cause: %s\n' + 'To silence this warning, decorate the function with' + ' @tf.autograph.experimental.do_not_convert') + if isinstance(exc, errors.UnsupportedLanguageElementError): + if not conversion.is_in_whitelist_cache(f, options): + logging.warn(warning_template, f, '', exc) + else: + file_bug_message = ( + 'Please report this to the TensorFlow team. When filing the bug, set' + ' the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and' + ' attach the full output.\n') + logging.warn(warning_template, f, file_bug_message, exc) + + return _call_unconverted(f, args, kwargs, options) + + +def _log_callargs(f, args, kwargs): + """Logging helper.""" + logging.log(2, 'Defaults of %s : %s', f, f.__defaults__) + if not six.PY2: + logging.log(2, 'KW defaults of %s : %s', f, f.__kwdefaults__) + + if kwargs is not None: + callargs = tf_inspect.getcallargs(f, *args, **kwargs) + else: + callargs = tf_inspect.getcallargs(f, *args) + + formatted_callargs = '\n'.join( + ' {}: {}'.format(k, v) for k, v in callargs.items()) + logging.log(2, 'Calling %s with\n%s\n', f, formatted_callargs) + + def converted_call(f, args, kwargs, @@ -498,9 +536,7 @@ def converted_call(f, if not options.internal_convert_user_code: return _call_unconverted(f, args, kwargs, options) - # TODO(mdan): Move this entire block inside to_graph. - try: # Begin of transformation error guards - + try: if inspect.ismethod(f) or inspect.isfunction(f): target_entity = f effective_args = args @@ -514,6 +550,8 @@ def converted_call(f, elif hasattr(f, '__class__') and hasattr(f.__class__, '__call__'): # Callable objects. Dunder methods have special lookup rules, see: # https://docs.python.org/3/reference/datamodel.html#specialnames + # TODO(mdan): Recurse into converted_call to simplify other verifications. + # This should be handled in the same way as partials. target_entity = f.__class__.__call__ effective_args = (f,) + args @@ -521,63 +559,34 @@ def converted_call(f, target_entity = f raise NotImplementedError('unknown callable type "%s"' % type(f)) - if not inspect.isclass(target_entity): - if not hasattr(target_entity, '__code__'): - logging.log(2, 'Permanently whitelisted: %s: native binding', - target_entity) - return _call_unconverted(f, args, kwargs, options) - elif (hasattr(target_entity.__code__, 'co_filename') and - target_entity.__code__.co_filename == ''): - # TODO(mdan): __globals__['txt'] might work in Py3. - logging.log(2, 'Permanently whitelisted: %s: dynamic code (exec?)', - target_entity) - return _call_unconverted(f, args, kwargs, options) - - program_ctx = converter.ProgramContext( - options=options, autograph_module=tf_inspect.getmodule(converted_call)) - converted_f = conversion.convert(target_entity, program_ctx) - - if logging.has_verbosity(2): - logging.log(2, 'Defaults of %s : %s', converted_f, - converted_f.__defaults__) - if not six.PY2: - logging.log(2, 'KW defaults of %s : %s', - converted_f, converted_f.__kwdefaults__) - - if kwargs is not None: - callargs = tf_inspect.getcallargs(converted_f, *effective_args, - **kwargs) - else: - callargs = tf_inspect.getcallargs(converted_f, *effective_args) - - formatted_callargs = '\n'.join( - ' {}: {}'.format(k, v) for k, v in callargs.items()) - logging.log(2, 'Calling %s with\n%s\n', converted_f, formatted_callargs) - except Exception as e: # pylint:disable=broad-except logging.log(1, 'Error transforming entity %s', target_entity, exc_info=True) if is_autograph_strict_conversion_mode(): raise + return _fall_back_unconverted(f, args, kwargs, options, e) - warning_template = ( - 'AutoGraph could not transform %s and will run it as-is.\n' - '%s' - 'Cause: %s\n' - 'To silence this warning, decorate the function with' - ' @tf.autograph.experimental.do_not_convert') - if isinstance(e, errors.UnsupportedLanguageElementError): - # Repeating the check made upon function entry because the state might - # have updated in the meantime. - if not conversion.is_in_whitelist_cache(f, options): - logging.warn(warning_template, target_entity, '', e) - else: - file_bug_message = ( - 'Please report this to the TensorFlow team. When filing the bug, set' - ' the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and' - ' attach the full output.\n') - logging.warn(warning_template, target_entity, file_bug_message, e) - + if not hasattr(target_entity, '__code__'): + logging.log(2, 'Permanently whitelisted: %s: native binding', + target_entity) return _call_unconverted(f, args, kwargs, options) + elif (hasattr(target_entity.__code__, 'co_filename') and + target_entity.__code__.co_filename == ''): + # TODO(mdan): __globals__['txt'] might work in Py3. + logging.log(2, 'Permanently whitelisted: %s: dynamic code (exec?)', + target_entity) + return _call_unconverted(f, args, kwargs, options) + + try: + program_ctx = converter.ProgramContext( + options=options, autograph_module=tf_inspect.getmodule(converted_call)) + converted_f = conversion.convert(target_entity, program_ctx) + if logging.has_verbosity(2): + _log_callargs(converted_f, effective_args, kwargs) + except Exception as e: # pylint:disable=broad-except + logging.log(1, 'Error transforming entity %s', target_entity, exc_info=True) + if is_autograph_strict_conversion_mode(): + raise + return _fall_back_unconverted(f, args, kwargs, options, e) with StackTraceMapper(converted_f), tf_stack.CurrentModuleFilter(): try: diff --git a/tensorflow/python/autograph/impl/api_test.py b/tensorflow/python/autograph/impl/api_test.py index 4365edaaa8e..d8f73f20674 100644 --- a/tensorflow/python/autograph/impl/api_test.py +++ b/tensorflow/python/autograph/impl/api_test.py @@ -753,6 +753,27 @@ class ApiTest(test.TestCase): self.assertAllEqual(1, self.evaluate(x)) + def test_converted_call_native_binding(self): + x = api.converted_call(np.power, (2, 2), None, options=DEFAULT_RECURSIVE) + self.assertAllEqual(x, 4) + + def test_converted_call_native_binding_errorneous(self): + + class FaultyBinding(object): + + def __array__(self): + raise ValueError('fault') + + bad_obj = FaultyBinding() + + def fail_if_warning(*_): + self.fail('No warning should be issued') + + with test.mock.patch.object(ag_logging, 'warn', fail_if_warning): + with self.assertRaisesRegex(ValueError, 'fault'): + api.converted_call( + np.power, (bad_obj, 2), None, options=DEFAULT_RECURSIVE) + def test_converted_call_through_tf_dataset(self): def other_fn(x): diff --git a/tensorflow/python/autograph/impl/conversion.py b/tensorflow/python/autograph/impl/conversion.py index e14c8e2bfcf..7134c2c0b69 100644 --- a/tensorflow/python/autograph/impl/conversion.py +++ b/tensorflow/python/autograph/impl/conversion.py @@ -48,12 +48,12 @@ from tensorflow.python.autograph.converters import slices from tensorflow.python.autograph.core import config from tensorflow.python.autograph.core import converter from tensorflow.python.autograph.core import function_wrappers -from tensorflow.python.autograph.core import naming from tensorflow.python.autograph.core import unsupported_features_checker from tensorflow.python.autograph.lang import special_functions from tensorflow.python.autograph.pyct import ast_util from tensorflow.python.autograph.pyct import inspect_utils from tensorflow.python.autograph.pyct import loader +from tensorflow.python.autograph.pyct import naming from tensorflow.python.autograph.pyct import origin_info from tensorflow.python.autograph.pyct import parser from tensorflow.python.autograph.pyct import pretty_printer @@ -572,7 +572,7 @@ def convert_func_to_ast(f, program_ctx, do_rename=True): if isinstance(node, gast.Lambda): new_name = namer.new_symbol('tf__lambda', ()) elif do_rename: - new_name = namer.function_name(f.__name__) + new_name = namer.new_symbol('tf__' + f.__name__, ()) else: new_name = f.__name__ diff --git a/tensorflow/python/autograph/operators/control_flow.py b/tensorflow/python/autograph/operators/control_flow.py index a1ef5eeedab..40493f07a2d 100644 --- a/tensorflow/python/autograph/operators/control_flow.py +++ b/tensorflow/python/autograph/operators/control_flow.py @@ -340,8 +340,11 @@ def for_stmt(iter_, extra_test, body, get_state, set_state, symbol_names, opts): """ if tensor_util.is_tensor(iter_): if tensors.is_range_tensor(iter_): - _tf_range_for_stmt( - iter_, extra_test, body, get_state, set_state, symbol_names, opts) + _tf_range_for_stmt(iter_, extra_test, body, get_state, set_state, + symbol_names, opts) + elif isinstance(iter_, ragged_tensor.RaggedTensor): + _tf_ragged_for_stmt(iter_, extra_test, body, get_state, set_state, + symbol_names, opts) else: _known_len_tf_for_stmt( iter_, extra_test, body, get_state, set_state, symbol_names, opts) diff --git a/tensorflow/python/autograph/pyct/BUILD b/tensorflow/python/autograph/pyct/BUILD index 5311392263c..7881b17f88b 100644 --- a/tensorflow/python/autograph/pyct/BUILD +++ b/tensorflow/python/autograph/pyct/BUILD @@ -31,6 +31,7 @@ py_library( "inspect_utils.py", "loader.py", "loader_deprecated_py2.py", + "naming.py", "origin_info.py", "parser.py", "pretty_printer.py", @@ -133,6 +134,17 @@ sh_test( tags = ["no_oss"], ) +py_test( + name = "naming_test", + srcs = ["naming_test.py"], + python_version = "PY3", + srcs_version = "PY2AND3", + deps = [ + ":pyct", + "//tensorflow/python:client_testlib", + ], +) + py_test( name = "origin_info_test", srcs = ["origin_info_test.py"], diff --git a/tensorflow/python/autograph/pyct/naming.py b/tensorflow/python/autograph/pyct/naming.py new file mode 100644 index 00000000000..c7d239bd7e6 --- /dev/null +++ b/tensorflow/python/autograph/pyct/naming.py @@ -0,0 +1,57 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Symbol naming utilities.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.autograph.pyct import qual_names + + +class Namer(object): + """Symbol name generator.""" + + def __init__(self, global_namespace): + self.global_namespace = global_namespace + self.generated_names = set() + + def new_symbol(self, name_root, reserved_locals): + """See control_flow.SymbolNamer.new_symbol.""" + # reserved_locals may contain QNs. + all_reserved_locals = set() + for s in reserved_locals: + if isinstance(s, qual_names.QN): + all_reserved_locals.update(s.qn) + elif isinstance(s, str): + all_reserved_locals.add(s) + else: + raise ValueError('Unexpected symbol type "%s"' % type(s)) + + pieces = name_root.split('_') + if pieces[-1].isdigit(): + name_root = '_'.join(pieces[:-1]) + n = int(pieces[-1]) + else: + n = 0 + new_name = name_root + + while (new_name in self.global_namespace or + new_name in all_reserved_locals or new_name in self.generated_names): + n += 1 + new_name = '%s_%d' % (name_root, n) + + self.generated_names.add(new_name) + return new_name diff --git a/tensorflow/python/autograph/core/naming_test.py b/tensorflow/python/autograph/pyct/naming_test.py similarity index 60% rename from tensorflow/python/autograph/core/naming_test.py rename to tensorflow/python/autograph/pyct/naming_test.py index 49526ed77f3..61fe22068e4 100644 --- a/tensorflow/python/autograph/core/naming_test.py +++ b/tensorflow/python/autograph/pyct/naming_test.py @@ -18,40 +18,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.autograph.core import naming +from tensorflow.python.autograph.pyct import naming from tensorflow.python.platform import test class NamerTest(test.TestCase): - def test_function_name_tracks_names(self): - namer = naming.Namer({}) - self.assertEqual('tf__foo', namer.function_name('foo')) - self.assertEqual('tf__bar', namer.function_name('bar')) - self.assertItemsEqual(('tf__bar', 'tf__foo'), namer.generated_names) - - def test_function_name_consistent(self): - namer = naming.Namer({}) - self.assertEqual('tf__foo', namer.function_name('foo')) - self.assertEqual('tf__foo', namer.function_name('foo')) - - def test_function_name_unsanitized_fqn(self): - namer = naming.Namer({}) - self.assertEqual('tf__foo_bar', namer.function_name('foo.bar')) - self.assertEqual('tf__foo_bar_baz', namer.function_name(('foo.bar', 'baz'))) - - def test_class_name_basic(self): - namer = naming.Namer({}) - self.assertEqual('TfFooBar', namer.class_name(('foo', 'Bar'))) - - def test_class_name_unsanitized_fqn(self): - namer = naming.Namer({}) - self.assertEqual('TfFooBarBaz', namer.class_name(('foo.bar', 'Baz'))) - - def test_function_name_avoids_global_conflicts(self): - namer = naming.Namer({'tf__foo': 1}) - self.assertEqual('tf__foo_1', namer.function_name('foo')) - def test_new_symbol_tracks_names(self): namer = naming.Namer({}) self.assertEqual('temp', namer.new_symbol('temp', set())) diff --git a/tensorflow/python/autograph/pyct/transformer.py b/tensorflow/python/autograph/pyct/transformer.py index 28cd9427bd1..07dcde7fdc7 100644 --- a/tensorflow/python/autograph/pyct/transformer.py +++ b/tensorflow/python/autograph/pyct/transformer.py @@ -181,20 +181,19 @@ class _State(object): return self._value[key] -class Base(gast.NodeTransformer): - """Base class for general-purpose code transformers transformers. +class NodeStateTracker(object): + """Base class for general-purpose Python code transformation. - This is an extension of ast.NodeTransformer that provides a few additional - functions, like state tracking within the scope of arbitrary node, helpers - for processing code blocks, debugging, mapping of transformed code to - original code, and others. + This abstract class provides helpful functions, like state tracking within + the scope of arbitrary node, helpers for processing code blocks, debugging, + mapping of transformed code to original code, and others. Scope-local state tracking: to keep state across nodes, at the level of (possibly nested) scopes, use enter/exit_local_scope and set/get_local. You must call enter/exit_local_scope manually, but the transformer detects when they are not properly paired. - The transformer allows keeping state across calls to `visit_*` that is local + The transformer allows keeping state across calls that is local to arbitrary nodes and their descendants, using the self.state attribute. Multiple independent scopes are allowed and automatically constructed. @@ -207,7 +206,7 @@ class Base(gast.NodeTransformer): def __init__(self): self.foo_property = None - class DummyTransformer(Base): + class DummyTransformer(NodeStateTracker, ast.NodeTransformer): def visit_If(self, node): self.state[FooType].enter() @@ -264,12 +263,6 @@ class Base(gast.NodeTransformer): print(loader.load_ast(node)) return node - def create_assignment(self, target, expression): - template = """ - target = expression - """ - return templates.replace(template, target=target, expression=expression) - def visit_block(self, nodes, before_visit=None, after_visit=None): """A more powerful version of generic_visit for statement blocks. @@ -346,6 +339,32 @@ class Base(gast.NodeTransformer): node_destination = new_destination return results + def _get_source(self, node): + try: + source, _ = loader.load_ast(node) + return source + # pylint: disable=broad-except + # This function is used for error reporting. If an exception occurs here, + # it should be suppressed, in favor of emitting as informative a message + # about the original error as possible. + except Exception: + return '' + + +# TODO(mdan): Rename to PythonCodeTransformer. +class Base(NodeStateTracker, gast.NodeTransformer): + """Base class for general-purpose Python-to-Python code transformation. + + This is an extension of ast.NodeTransformer that provides the additional + functions offered by NodeStateTracker. + """ + + def create_assignment(self, target, expression): + template = """ + target = expression + """ + return templates.replace(template, target=target, expression=expression) + # TODO(mdan): Remove. def apply_to_single_assignments(self, targets, values, apply_fn): """Applies a function to each individual assignment. @@ -394,17 +413,6 @@ class Base(gast.NodeTransformer): # TODO(mdan): Look into allowing to rewrite the AST here. apply_fn(target, values) - def _get_source(self, node): - try: - source, _ = loader.load_ast(node) - return source - # pylint: disable=broad-except - # This function is used for error reporting. If an exception occurs here, - # it should be suppressed, in favor of emitting as informative a message - # about the original error as possible. - except Exception: - return '' - def visit(self, node): if not isinstance(node, gast.AST): # This is not that uncommon a mistake: various node bodies are lists, for @@ -460,3 +468,69 @@ class Base(gast.NodeTransformer): self.ctx.current_origin = parent_origin return result + + +class CodeGenerator(NodeStateTracker, gast.NodeVisitor): + """Base class for general-purpose Python-to-string code transformation. + + Similar to Base, but outputs arbitrary strings instead of a Python AST. + + This uses the same visitor mechanism that the standard NodeVisitor uses, + meaning that subclasses write handlers for the different kinds of nodes. + New code is generated using the emit method, which appends to a code buffer + that can be afterwards obtained from code_buffer. + + Example: + + class SimpleCodeGen(CodeGenerator): + + def visitIf(self, node): + self.emit('if ') + self.visit(node.test) + self.emit(' { ') + self.visit(node.body) + self.emit(' } else { ') + self.visit(node.orelse) + self.emit(' } ') + + node = ast.parse(...) + gen = SimpleCodeGen() + gen.visit(node) + # gen.code_buffer contains the resulting code + """ + + def __init__(self, ctx): + super(CodeGenerator, self).__init__(ctx) + + self._output_code = '' + self.source_map = {} + + def emit(self, code): + self._output_code += code + + @property + def code_buffer(self): + return self._output_code + + def visit(self, node): + if anno.hasanno(node, anno.Basic.SKIP_PROCESSING): + return + + parent_origin = self.ctx.current_origin + eof_before = len(self._output_code) + if anno.hasanno(node, anno.Basic.ORIGIN): + self.ctx.current_origin = anno.getanno(node, anno.Basic.ORIGIN) + + try: + super(CodeGenerator, self).visit(node) + + # By default, all replacements receive the origin info of the replaced + # node. + eof_after = len(self._output_code) + if eof_before - eof_after: + inherited_origin = anno.getanno( + node, anno.Basic.ORIGIN, default=parent_origin) + if inherited_origin is not None: + self.source_map[(eof_before, eof_after)] = inherited_origin + finally: + self.ctx.current_origin = parent_origin diff --git a/tensorflow/python/autograph/pyct/transformer_test.py b/tensorflow/python/autograph/pyct/transformer_test.py index 05bae8e8f31..4408395f813 100644 --- a/tensorflow/python/autograph/pyct/transformer_test.py +++ b/tensorflow/python/autograph/pyct/transformer_test.py @@ -295,5 +295,66 @@ class TransformerTest(test.TestCase): anno.getanno(aug_assign_node, anno.Basic.ORIGIN).loc.lineno, 104) +class CodeGeneratorTest(test.TestCase): + + def _simple_context(self): + entity_info = transformer.EntityInfo( + source_code=None, source_file=None, future_features=(), namespace=None) + return transformer.Context(entity_info) + + def test_basic_codegen(self): + + class TestCodegen(transformer.CodeGenerator): + + def visit_Assign(self, node): + self.emit(parser.unparse(node, include_encoding_marker=False)) + self.emit('\n') + + def visit_Return(self, node): + self.emit(parser.unparse(node, include_encoding_marker=False)) + self.emit('\n') + + def visit_If(self, node): + self.emit('if ') + # This is just for simplifity. A real generator will walk the tree and + # emit proper code. + self.emit(parser.unparse(node.test, include_encoding_marker=False)) + self.emit(' {\n') + self.visit_block(node.body) + self.emit('} else {\n') + self.visit_block(node.orelse) + self.emit('}\n') + + tg = TestCodegen(self._simple_context()) + + def test_fn(): + x = 1 + if x > 0: + x = 2 + if x > 1: + x = 3 + return x + + node, source = parser.parse_entity(test_fn, future_features=()) + origin_info.resolve(node, source, 'test_file', 100, 0) + tg.visit(node) + + self.assertEqual( + tg.code_buffer, '\n'.join([ + 'x = 1', + 'if (x > 0) {', + 'x = 2', + 'if (x > 1) {', + 'x = 3', + '} else {', + '}', + '} else {', + '}', + 'return x', + '', + ])) + # TODO(mdan): Test the source map. + + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index 5d7ee54a469..d356e72bcba 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -31,7 +31,7 @@ from tensorflow.python.util.tf_export import tf_export # This value changes every day with an automatic CL. It can be modified in code # via `forward_compatibility_horizon()` or with the environment variable # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2020, 3, 24) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2020, 3, 31) _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DATE_NUMBER = None diff --git a/tensorflow/python/data/experimental/kernel_tests/bucket_by_sequence_length_test.py b/tensorflow/python/data/experimental/kernel_tests/bucket_by_sequence_length_test.py index d23bbbe615a..0dd7ae1f083 100644 --- a/tensorflow/python/data/experimental/kernel_tests/bucket_by_sequence_length_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/bucket_by_sequence_length_test.py @@ -48,7 +48,7 @@ def _format_record(array, sparse): return { "values": array, "indices": [[i] for i in range(len(array))], - "dense_shape": [len(array),] + "dense_shape": (len(array),) } return array @@ -402,16 +402,13 @@ class BucketBySequenceLengthTest(test_base.DatasetTestBase, bucket_size = 10 def _build_dataset(): - input_data = [list(range(i + 1)) for i in range(min_len, max_len)] - + input_data = [range(i+1) for i in range(min_len, max_len)] def generator_fn(): for record in input_data: yield _format_record(record, sparse=True) - dataset = dataset_ops.Dataset.from_generator( generator=generator_fn, output_types=_get_record_type(sparse=True)) - dataset = dataset.map(_to_sparse_tensor) return dataset diff --git a/tensorflow/python/data/experimental/kernel_tests/dense_to_ragged_batch_test.py b/tensorflow/python/data/experimental/kernel_tests/dense_to_ragged_batch_test.py index 60b9dcaa4c5..e66f401ed8e 100644 --- a/tensorflow/python/data/experimental/kernel_tests/dense_to_ragged_batch_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/dense_to_ragged_batch_test.py @@ -23,10 +23,13 @@ import numpy as np from tensorflow.python.data.experimental.ops import batching from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.util import nest from tensorflow.python.framework import combinations from tensorflow.python.framework import errors +from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops from tensorflow.python.ops.ragged import ragged_concat_ops from tensorflow.python.ops.ragged import ragged_factory_ops from tensorflow.python.platform import test @@ -42,11 +45,27 @@ def _make_vector_ds(nrows): return _make_scalar_ds(nrows).map(lambda x: array_ops.fill([x], x)) -def _make_matrix_ds(nrows): +def _make_matrix_ds1(nrows): """Create a test dataset with matrix elements (of varying size).""" return _make_scalar_ds(nrows).map(lambda x: array_ops.fill([x, 2], x)) +def _make_matrix_ds2(nrows): + """Create a test dataset with matrix elements (of varying size).""" + return _make_scalar_ds(nrows).map(lambda x: array_ops.fill([2, x], x)) + + +def _make_matrix_ds_fully_defined(nrows): + """Create a test dataset with matrix elements (of varying size).""" + return _make_scalar_ds(nrows).map(lambda x: array_ops.fill([2, 3], x)) + + +def _make_5dtensor_ds(nrows): + """Create a test dataset with matrix elements (of varying size).""" + return _make_scalar_ds(nrows).map( + lambda x: array_ops.fill([2, x, 3, 2*x, 4], x)) + + def _make_ragged_ds(nrows): """Create a test dataset with RaggedTensor elements (of varying size).""" values = [[[i] * (i % 3) for i in range(j)] * (j % 3) for j in range(nrows)] @@ -54,6 +73,28 @@ def _make_ragged_ds(nrows): return dataset_ops.Dataset.from_tensor_slices(rt) +def _make_dict_ds(nrows): + """Create a test set with various element shapes.""" + def transform(x): + return { + 'shape=[]': ops.convert_to_tensor(x), + 'shape=[x]': math_ops.range(x), + 'shape=[x, 2]': array_ops.fill([x, 2], x), + 'shape=[2, x]': array_ops.fill([2, x], x), + 'shape=[2, x, 3, 2x, 4]': array_ops.fill([2, x, 3, 2*x, 4], x) + } + return _make_scalar_ds(nrows).map(transform) + + +def _make_tuple_ds(nrows): + """Create a test set with various element shapes.""" + def transform(x): + return (ops.convert_to_tensor(x), + math_ops.range(x), + array_ops.fill([x, 2], x)) + return _make_scalar_ds(nrows).map(transform) + + def _to_list(v): return v.to_list() if hasattr(v, 'to_list') else v.tolist() @@ -65,8 +106,9 @@ class RaggedBatchTest(test_base.DatasetTestBase, parameterized.TestCase): test_base.default_test_combinations(), combinations.combine( make_dataset=[ - _make_scalar_ds, _make_vector_ds, _make_matrix_ds, - _make_ragged_ds + _make_scalar_ds, _make_vector_ds, _make_matrix_ds1, + _make_matrix_ds2, _make_ragged_ds, _make_5dtensor_ds, + _make_dict_ds, _make_tuple_ds, _make_matrix_ds_fully_defined, ], nrows=[0, 20, 23], batch_size=[4], @@ -77,7 +119,8 @@ class RaggedBatchTest(test_base.DatasetTestBase, parameterized.TestCase): # Get the unbatched rows (so we can check expected values). get_next = self.getNext(dataset) - rows = [_to_list(self.evaluate(get_next())) for _ in range(nrows)] + rows = [nest.map_structure(_to_list, self.evaluate(get_next())) + for _ in range(nrows)] # Batch the dataset, and check that batches match slices from `rows`. batched_dataset = dataset.apply( @@ -90,7 +133,11 @@ class RaggedBatchTest(test_base.DatasetTestBase, parameterized.TestCase): end_row = min(end_row, nrows) result = self.evaluate(get_next()) - self.assertAllEqual(result, rows[start_row:end_row]) + # Use nest for potentially nested datasets. + nest.map_structure_up_to( + result, lambda a, *b: self.assertAllEqual(a, list(b)), + result, *rows[start_row:end_row]) + with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) diff --git a/tensorflow/python/data/experimental/kernel_tests/replicate_test.py b/tensorflow/python/data/experimental/kernel_tests/replicate_test.py index df59ebb996e..5be6cca9332 100644 --- a/tensorflow/python/data/experimental/kernel_tests/replicate_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/replicate_test.py @@ -262,10 +262,9 @@ class RemoteReplicateTest(test_base.DatasetTestBase, parameterized.TestCase): with ops.device(self._device2): self.assertDatasetProduces(dataset2, range(0, 200, 2)) - # TODO(b/150821179): Re-enable this test. @combinations.generate( combinations.combine(tf_api_version=[2], mode=["eager"])) - def _testVariableInput(self): + def testVariableInput(self): with ops.device(self._device0): counter_var = variable_scope.get_variable( "counter", (), dtypes.int32, use_resource=True) diff --git a/tensorflow/python/data/experimental/ops/batching.py b/tensorflow/python/data/experimental/ops/batching.py index d976977e305..398ec98a7cb 100644 --- a/tensorflow/python/data/experimental/ops/batching.py +++ b/tensorflow/python/data/experimental/ops/batching.py @@ -50,11 +50,23 @@ def dense_to_ragged_batch(batch_size, batch from being produced. Unlike `tf.data.Dataset.batch`, the input elements to be batched may have - different shapes, and each batch will be encoded as a `tf.RaggedTensor`. + different shapes: + + * If an input element is a `tf.Tensor` whose static `tf.TensorShape` is + fully defined, then it is batched as normal. + * If an input element is a `tf.Tensor` whose static `tf.TensorShape` contains + one or more axes with unknown size (i.e., `shape[i]=None`), then the output + will contain a `tf.RaggedTensor` that is ragged up to any of such + dimensions. + * If an input element is a `tf.RaggedTensor` or any other type, then it is + batched as normal. + Example: >>> dataset = tf.data.Dataset.from_tensor_slices(np.arange(6)) >>> dataset = dataset.map(lambda x: tf.range(x)) + >>> dataset.element_spec.shape + TensorShape([None]) >>> dataset = dataset.apply( ... tf.data.experimental.dense_to_ragged_batch(batch_size=2)) >>> for batch in dataset: @@ -385,35 +397,52 @@ class _DenseToRaggedDataset(dataset_ops.UnaryDataset): any new ragged tensors. Existing `tf.RaggedTensor` elements do *not* have their row_splits dtype changed. """ - # Replace each TensorSpec in the input dataset's structure with a # corresponding RaggedTensorSpec. def to_ragged_spec(spec): - if isinstance(spec, tensor_spec.TensorSpec) and spec.shape.ndims != 0: + """Returns the new spec based on RaggedTensors.""" + if (not isinstance(spec, tensor_spec.TensorSpec) or + spec.shape.rank is None or + spec.shape.is_fully_defined()): + return spec + else: + ragged_rank = max([ + axis for (axis, size) in enumerate(spec.shape.as_list()) + if size is None + ]) return ragged_tensor.RaggedTensorSpec( shape=spec.shape, dtype=spec.dtype, - ragged_rank=0, + ragged_rank=ragged_rank, row_splits_dtype=row_splits_dtype) - else: - return spec self._structure = nest.map_structure(to_ragged_spec, input_dataset.element_spec) # Replace each tf.Tensor value in the input dataset with a variant-encoded - # RaggedTensor. Since we're updating the corresponding structure to be + # RaggedTensor. Since we're updating the corresponding structure to be # a RaggedTensorSpec, this variant-encoded tensor will be decoded with # RaggedTensorSpec._from_tensor_list. def to_ragged_variant(value): - if isinstance(value, ops.Tensor) and value.shape.ndims != 0: - spec = to_ragged_spec(tensor_spec.TensorSpec.from_tensor(value)) - return spec._to_tensor_list(value)[0] # pylint: disable=protected-access - else: + """Re-encode Tensors as RaggedTensors.""" + if (not isinstance(value, ops.Tensor) or + value.shape.rank is None or + value.shape.is_fully_defined()): return value + else: + spec = to_ragged_spec(tensor_spec.TensorSpec.from_tensor(value)) + if spec._ragged_rank > 0: # pylint: disable=protected-access + value = ragged_tensor.RaggedTensor.from_tensor( + value, ragged_rank=spec._ragged_rank) # pylint: disable=protected-access + return spec._to_tensor_list(value)[0] # pylint: disable=protected-access - self._mapped_dataset = input_dataset.map( - lambda value: nest.map_structure(to_ragged_variant, value)) + # Tuples are automatically unpacked by `dataset.map` so we repack them. + if dataset_ops._should_unpack_args(input_dataset.element_spec): # pylint: disable=protected-access + map_fn = lambda *value: nest.map_structure(to_ragged_variant, value) + else: + map_fn = lambda value: nest.map_structure(to_ragged_variant, value) + + self._mapped_dataset = input_dataset.map(map_fn) variant = self._mapped_dataset._variant_tensor # pylint: disable=protected-access super(_DenseToRaggedDataset, self).__init__(input_dataset, variant) diff --git a/tensorflow/python/data/kernel_tests/from_generator_test.py b/tensorflow/python/data/kernel_tests/from_generator_test.py index 288d0e694f2..ecc022b58a5 100644 --- a/tensorflow/python/data/kernel_tests/from_generator_test.py +++ b/tensorflow/python/data/kernel_tests/from_generator_test.py @@ -28,12 +28,7 @@ from tensorflow.python.framework import combinations from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors -from tensorflow.python.framework import sparse_tensor -from tensorflow.python.framework import tensor_spec from tensorflow.python.ops import script_ops -from tensorflow.python.ops import sparse_ops -from tensorflow.python.ops.ragged import ragged_factory_ops -from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.platform import test @@ -231,6 +226,23 @@ class FromGeneratorTest(test_base.DatasetTestBase, parameterized.TestCase): self.assertDatasetProduces( dataset, expected_output=[b"foo", b"bar", b"baz"]) + @combinations.generate(test_base.default_test_combinations()) + def testFromGeneratorDict(self): + def generator(): + yield {"a": "foo", "b": [1, 2]} + yield {"a": "bar", "b": [3, 4]} + yield {"a": "baz", "b": [5, 6]} + + dataset = dataset_ops.Dataset.from_generator( + generator, + output_types={"a": dtypes.string, "b": dtypes.int32}, + output_shapes={"a": [], "b": [None]}) + self.assertDatasetProduces( + dataset, + expected_output=[{"a": b"foo", "b": [1, 2]}, + {"a": b"bar", "b": [3, 4]}, + {"a": b"baz", "b": [5, 6]}]) + @combinations.generate(test_base.default_test_combinations()) def testFromGeneratorTypeError(self): def generator(): @@ -246,7 +258,7 @@ class FromGeneratorTest(test_base.DatasetTestBase, parameterized.TestCase): self.assertAllEqual([1, 2, 3], self.evaluate(get_next())) self.assertAllEqual([4, 5, 6], self.evaluate(get_next())) - with self.assertRaises(errors.InvalidArgumentError): + with self.assertRaisesOpError("The expected type was int64"): self.evaluate(get_next()) self.assertAllEqual([7, 8, 9], self.evaluate(get_next())) with self.assertRaises(errors.OutOfRangeError): @@ -266,7 +278,7 @@ class FromGeneratorTest(test_base.DatasetTestBase, parameterized.TestCase): self.assertAllEqual([1, 2, 3], self.evaluate(get_next())) self.assertAllEqual([4, 5, 6], self.evaluate(get_next())) - with self.assertRaises(errors.InvalidArgumentError): + with self.assertRaisesOpError(r"element of shape \(3,\) was expected"): self.evaluate(get_next()) self.assertAllEqual([11, 12, 13], self.evaluate(get_next())) with self.assertRaises(errors.OutOfRangeError): @@ -287,9 +299,11 @@ class FromGeneratorTest(test_base.DatasetTestBase, parameterized.TestCase): self.assertEqual((1, 2), self.evaluate(get_next())) self.assertEqual((3, 4), self.evaluate(get_next())) - with self.assertRaises(errors.InvalidArgumentError): + with self.assertRaisesOpError( + r"The expected structure was \(tf\.int64, tf\.int64\)"): self.evaluate(get_next()) - with self.assertRaises(errors.InvalidArgumentError): + with self.assertRaisesOpError( + r"The expected structure was \(tf\.int64, tf\.int64\)"): self.evaluate(get_next()) self.assertEqual((9, 10), self.evaluate(get_next())) with self.assertRaises(errors.OutOfRangeError): @@ -408,12 +422,8 @@ class FromGeneratorTest(test_base.DatasetTestBase, parameterized.TestCase): stateful=True) dummy = constant_op.constant(37) - - dataset = dataset_ops._GeneratorDataset( - dummy, lambda x: x, lambda x: x, finalize_fn, - tensor_spec.TensorSpec((), dtypes.int32)) - - dataset = dataset.take(2) + dataset = dataset_ops._GeneratorDataset(dummy, lambda x: x, lambda x: x, + finalize_fn).take(2) get_next = self.getNext(dataset) self.assertAllEqual(37, self.evaluate(get_next())) @@ -435,46 +445,6 @@ class FromGeneratorTest(test_base.DatasetTestBase, parameterized.TestCase): self.assertAllEqual([20], self.evaluate(get_next())) - @combinations.generate(test_base.default_test_combinations()) - def testFromGeneratorRaggedTensor(self): - - def generator(): - yield ragged_factory_ops.constant([[1, 2], [3]], - dtype=dtypes.int64, - ragged_rank=1) - - dataset = dataset_ops.Dataset.from_generator( - generator, - output_signature=ragged_tensor.RaggedTensorSpec( - shape=(2, None), dtype=dtypes.int64)) - get_next = self.getNext(dataset) - - ret = get_next() - - self.assertIsInstance(ret, ragged_tensor.RaggedTensor) - self.assertAllEqual([1, 2, 3], ret.values) - - @combinations.generate(test_base.default_test_combinations()) - def testFromGeneratorSparseTensor(self): - - def generator(): - yield sparse_tensor.SparseTensor( - indices=[[0, 0], [1, 2]], - values=constant_op.constant([1, 2], dtype=dtypes.int64), - dense_shape=[3, 4]) - - dataset = dataset_ops.Dataset.from_generator( - generator, - output_signature=sparse_tensor.SparseTensorSpec([3, 4], dtypes.int64)) - - get_next = self.getNext(dataset) - - ret = get_next() - - self.assertIsInstance(ret, sparse_tensor.SparseTensor) - self.assertAllEqual([[1, 0, 0, 0], [0, 0, 2, 0], [0, 0, 0, 0]], - sparse_ops.sparse_tensor_to_dense(ret)) - if __name__ == "__main__": test.main() diff --git a/tensorflow/python/data/kernel_tests/iterator_test.py b/tensorflow/python/data/kernel_tests/iterator_test.py index 94b50a7864d..36689ed75fb 100644 --- a/tensorflow/python/data/kernel_tests/iterator_test.py +++ b/tensorflow/python/data/kernel_tests/iterator_test.py @@ -946,9 +946,7 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase): @def_function.function def fn(): - output_spec = tensor_spec.TensorSpec((), dtypes.int64) - dataset = dataset_ops._GeneratorDataset(1, init_fn, next_fn, finalize_fn, - output_spec) + dataset = dataset_ops._GeneratorDataset(1, init_fn, next_fn, finalize_fn) iterator = iter(dataset) next(iterator) diff --git a/tensorflow/python/data/kernel_tests/memory_cleanup_test.py b/tensorflow/python/data/kernel_tests/memory_cleanup_test.py index 4917d6ec163..8ba9d4c925a 100644 --- a/tensorflow/python/data/kernel_tests/memory_cleanup_test.py +++ b/tensorflow/python/data/kernel_tests/memory_cleanup_test.py @@ -116,7 +116,7 @@ class MemoryCleanupTest(test_base.DatasetTestBase, parameterized.TestCase): gc.collect() tensors = [ - o for o in gc.get_objects() if isinstance(o, tensor_like._TensorLike) + o for o in gc.get_objects() if isinstance(o, tensor_like.TensorLike) ] self.assertEmpty(tensors, "%d Tensors are still alive." % len(tensors)) diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py index 359b55ca78c..d2c247678a2 100644 --- a/tensorflow/python/data/ops/dataset_ops.py +++ b/tensorflow/python/data/ops/dataset_ops.py @@ -408,7 +408,8 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor): def element_spec(self): """The type specification of an element of this dataset. - >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3]).element_spec + >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3]) + >>> dataset.element_spec TensorSpec(shape=(), dtype=tf.int32, name=None) Returns: @@ -674,48 +675,27 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor): del self._iterators[iterator_id] @staticmethod - @deprecation.deprecated_args(None, "Use output_signature instead", - "output_types", "output_shapes") - def from_generator(generator, - output_types=None, - output_shapes=None, - args=None, - output_signature=None): + def from_generator(generator, output_types, output_shapes=None, args=None): """Creates a `Dataset` whose elements are generated by `generator`. The `generator` argument must be a callable object that returns an object that supports the `iter()` protocol (e.g. a generator function). + The elements generated by `generator` must be compatible with the given + `output_types` and (optional) `output_shapes` arguments. - The elements generated by `generator` must be compatible with either the - given `output_signature` argument or with the given `output_types` and - (optionally) `output_shapes` arguments whichiver was specified. - - The recommended way to call `from_generator` is to use the - `output_signature` argument. In this case the output will be assumed to - consist of objects with the classes, shapes and types defined by - `tf.TypeSpec` objects from `output_signature` argument: - + >>> import itertools + >>> >>> def gen(): - ... ragged_tensor = tf.ragged.constant([[1, 2], [3]], - ... ragged_rank=1, - ... dtype=tf.int64) - ... yield 42, ragged_tensor + ... for i in itertools.count(1): + ... yield (i, [1] * i) >>> >>> dataset = tf.data.Dataset.from_generator( ... gen, - ... output_signature=( - ... tf.TensorSpec(shape=(), dtype=tf.int64), - ... tf.RaggedTensorSpec(shape=(2, None), dtype=tf.int64))) + ... (tf.int64, tf.int64), + ... (tf.TensorShape([]), tf.TensorShape([None]))) >>> - >>> list(dataset.take(1)) - [(, - )] - - There is also a deprecated way to call `from_generator` by either with - `output_types` argument alone or together with `output_shapes` argument. - In this case the output of the function will be assumed to consist of - `tf.Tensor` objects with with the types defined by `output_types` and with - the shapes which are either unknown or defined by `output_shapes`. + >>> list(dataset.take(3).as_numpy_iterator()) + [(1, array([1])), (2, array([1, 1])), (3, array([1, 1, 1]))] Note: The current implementation of `Dataset.from_generator()` uses `tf.numpy_function` and inherits the same constraints. In particular, it @@ -739,56 +719,31 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor): `iter()` protocol. If `args` is not specified, `generator` must take no arguments; otherwise it must take as many arguments as there are values in `args`. - output_types: (Optional.) A nested structure of `tf.DType` objects - corresponding to each component of an element yielded by `generator`. + output_types: A nested structure of `tf.DType` objects corresponding to + each component of an element yielded by `generator`. output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects corresponding to each component of an element yielded by `generator`. args: (Optional.) A tuple of `tf.Tensor` objects that will be evaluated and passed to `generator` as NumPy-array arguments. - output_signature: (Optional.) A nested structure of `tf.TypeSpec` objects - corresponding to each component of an element yielded by `generator`. Returns: Dataset: A `Dataset`. """ if not callable(generator): raise TypeError("`generator` must be callable.") - - if output_signature is not None: - if output_types is not None: - raise TypeError("`output_types` can not be used together with " - "`output_signature`") - if output_shapes is not None: - raise TypeError("`output_shapes` can not be used together with " - "`output_signature`") - if not all( - isinstance(_, type_spec.TypeSpec) - for _ in nest.flatten(output_signature)): - raise TypeError("All the elements of `output_siganture` must be " - "a `tf.TypeSpec` objects.") + if output_shapes is None: + output_shapes = nest.map_structure( + lambda _: tensor_shape.TensorShape(None), output_types) else: - if output_types is None and output_shapes is not None: - raise TypeError("`output_shapes` can not be used alone without " - "`output_types`") - - if output_signature is None: - if output_shapes is None: - output_shapes = nest.map_structure( - lambda _: tensor_shape.TensorShape(None), output_types) - else: - output_shapes = nest.map_structure_up_to(output_types, - tensor_shape.as_shape, - output_shapes) - output_signature = nest.map_structure_up_to(output_types, - tensor_spec.TensorSpec, - output_shapes, output_types) - + output_shapes = nest.map_structure_up_to( + output_types, tensor_shape.as_shape, output_shapes) if args is None: args = () else: args = tuple(ops.convert_n_to_tensor(args, name="args")) - flat_output_types = structure.get_flat_tensor_types(output_signature) + flattened_types = [dtypes.as_dtype(dt) for dt in nest.flatten(output_types)] + flattened_shapes = nest.flatten(output_shapes) generator_state = DatasetV2._GeneratorState(generator) @@ -826,41 +781,56 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor): """A `py_func` that will be called to invoke the iterator.""" # `next()` raises `StopIteration` when there are no more # elements remaining to be generated. - values = next(generator_state.get_iterator(iterator_id.numpy())) - - def serialize_structure(s): - return nest.map_structure(lambda ts: ts._serialize(), s) # pylint: disable=protected-access + values = next(generator_state.get_iterator(iterator_id)) + # Use the same _convert function from the py_func() implementation to + # convert the returned values to arrays early, so that we can inspect + # their values. try: - output_dtypes = nest.map_structure(lambda t: t.dtype, - output_signature) - values = structure.normalize_element(values, dtypes=output_dtypes) + flattened_values = nest.flatten_up_to(output_types, values) except (TypeError, ValueError): - six.reraise( - TypeError, - TypeError( - "`generator` yielded an element that did not match the " - "expected structure. The expected structure was %s, but the " - "yielded element was %s." % - (serialize_structure(output_signature), values)), - sys.exc_info()[2]) + six.reraise(TypeError, TypeError( + "`generator` yielded an element that did not match the expected " + "structure. The expected structure was %s, but the yielded " + "element was %s." % (output_types, values)), sys.exc_info()[2]) + ret_arrays = [] + for ret, dtype in zip(flattened_values, flattened_types): + try: + ret_arrays.append(script_ops.FuncRegistry._convert( # pylint: disable=protected-access + ret, dtype=dtype.as_numpy_dtype)) + except (TypeError, ValueError): + six.reraise(TypeError, TypeError( + "`generator` yielded an element that could not be converted to " + "the expected type. The expected type was %s, but the yielded " + "element was %s." % (dtype.name, ret)), sys.exc_info()[2]) - values_spec = structure.type_spec_from_value(values) + # Additional type and shape checking to ensure that the components + # of the generated element match the `output_types` and `output_shapes` + # arguments. + for (ret_array, expected_dtype, expected_shape) in zip( + ret_arrays, flattened_types, flattened_shapes): + if ret_array.dtype != expected_dtype.as_numpy_dtype: + raise TypeError( + "`generator` yielded an element of type %s where an element " + "of type %s was expected." % (ret_array.dtype, + expected_dtype.as_numpy_dtype)) + if not expected_shape.is_compatible_with(ret_array.shape): + raise ValueError( + "`generator` yielded an element of shape %s where an element " + "of shape %s was expected." % (ret_array.shape, expected_shape)) - if not structure.are_compatible(values_spec, output_signature): - raise TypeError( - "`generator` yielded an element of TypeSpec%s where an element " - "of TypeSpec%s was expected." % - (serialize_structure(values_spec), - serialize_structure(output_signature))) + return ret_arrays - return structure.to_tensor_list(output_signature, values) + flat_values = script_ops.numpy_function(generator_py_func, + [iterator_id_t], flattened_types) - return script_ops._eager_py_func( # pylint: disable=protected-access - generator_py_func, - inp=[iterator_id_t], - Tout=flat_output_types, - use_tape_cache=False) + # The `py_func()` op drops the inferred shapes, so we add them back in + # here. + if output_shapes is not None: + for ret_t, shape in zip(flat_values, flattened_shapes): + ret_t.set_shape(shape) + + return nest.pack_sequence_as(output_types, flat_values) def finalize_fn(iterator_id_t): """Releases host-side state for the iterator with ID `iterator_id_t`.""" @@ -886,7 +856,7 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor): # given ID, and raises StopIteration when that iterator contains no # more elements. return _GeneratorDataset(dummy_arg, get_iterator_id_fn, generator_next_fn, - finalize_fn, output_signature) + finalize_fn) # A single-element dataset that, each time it is evaluated, contains a # freshly-generated and unique (for the returned dataset) int64 @@ -2308,14 +2278,9 @@ class DatasetV1(DatasetV2): @staticmethod @functools.wraps(DatasetV2.from_generator) - def from_generator(generator, - output_types=None, - output_shapes=None, - args=None, - output_signature=None): - return DatasetV1Adapter( - DatasetV2.from_generator(generator, output_types, output_shapes, args, - output_signature)) + def from_generator(generator, output_types, output_shapes=None, args=None): + return DatasetV1Adapter(DatasetV2.from_generator( + generator, output_types, output_shapes, args)) @staticmethod @functools.wraps(DatasetV2.range) @@ -3296,8 +3261,7 @@ class StructuredFunctionWrapper(object): class _GeneratorDataset(DatasetSource): """A `Dataset` that generates elements by invoking a function.""" - def __init__(self, init_args, init_func, next_func, finalize_func, - output_signature): + def __init__(self, init_args, init_func, next_func, finalize_func): """Constructs a `_GeneratorDataset`. Args: @@ -3311,8 +3275,6 @@ class _GeneratorDataset(DatasetSource): finalize_func: A TensorFlow function that will be called on the result of `init_func` immediately before a C++ iterator over this dataset is destroyed. The return value is ignored. - output_signature: A nested structure of `tf.TypeSpec` objects describing - the output of `next_func`. """ self._init_args = init_args @@ -3332,9 +3294,6 @@ class _GeneratorDataset(DatasetSource): finalize_func, self._transformation_name(), input_structure=self._init_func.output_structure) - - self._output_signature = output_signature - variant_tensor = gen_dataset_ops.generator_dataset( structure.to_tensor_list(self._init_structure, self._init_args) + self._init_func.function.captured_inputs, @@ -3348,7 +3307,7 @@ class _GeneratorDataset(DatasetSource): @property def element_spec(self): - return self._output_signature + return self._next_func.output_structure def _transformation_name(self): return "Dataset.from_generator()" diff --git a/tensorflow/python/data/util/structure.py b/tensorflow/python/data/util/structure.py index ee6151742f6..87825005069 100644 --- a/tensorflow/python/data/util/structure.py +++ b/tensorflow/python/data/util/structure.py @@ -67,7 +67,7 @@ def _RaggedTensorStructure(dtype, shape, ragged_rank): # TODO(jsimsa): Remove the special-case for `TensorArray` pass-through once # it is a subclass of `CompositeTensor`. -def normalize_element(element, dtypes=None): +def normalize_element(element): """Normalizes a nested structure of element components. * Components matching `SparseTensorSpec` are converted to `SparseTensor`. @@ -78,10 +78,6 @@ def normalize_element(element, dtypes=None): Args: element: A nested structure of individual components. - dtypes: (Optional.) A nested structure of `tf.DType` objects corresponding - to each component of `element`. If specified, it will be used to set the - exact type of output tensor when converting input components which - are not tensors themselves (e.g. numpy arrays, native python types, etc.) Returns: A nested structure of `Tensor`, `Dataset`, `SparseTensor`, `RaggedTensor`, @@ -89,21 +85,17 @@ def normalize_element(element, dtypes=None): """ components = nest.flatten(element) normalized_components = [] - if dtypes is None: - flattened_dtypes = [None] * len(components) - else: - flattened_dtypes = nest.flatten(dtypes) with ops.name_scope("normalize_element"): # Imported here to avoid circular dependency. from tensorflow.python.data.ops import dataset_ops # pylint: disable=g-import-not-at-top - for i, (t, dtype) in enumerate(zip(components, flattened_dtypes)): + for i, t in enumerate(components): try: spec = type_spec_from_value(t, use_fallback=False) except TypeError: # TypeError indicates it was not possible to compute a `TypeSpec` for # the value. As a fallback try converting the value to a tensor. normalized_components.append( - ops.convert_to_tensor(t, name="component_%d" % i, dtype=dtype)) + ops.convert_to_tensor(t, name="component_%d" % i)) else: if isinstance(spec, sparse_tensor.SparseTensorSpec): normalized_components.append(sparse_tensor.SparseTensor.from_value(t)) @@ -120,7 +112,7 @@ def normalize_element(element, dtypes=None): normalized_components.append(t) else: normalized_components.append( - ops.convert_to_tensor(t, name="component_%d" % i, dtype=dtype)) + ops.convert_to_tensor(t, name="component_%d" % i)) return nest.pack_sequence_as(element, normalized_components) diff --git a/tensorflow/python/debug/BUILD b/tensorflow/python/debug/BUILD index d34e34d2307..a6178fce24b 100644 --- a/tensorflow/python/debug/BUILD +++ b/tensorflow/python/debug/BUILD @@ -929,6 +929,7 @@ py_test( srcs = ["wrappers/framework_test.py"], python_version = "PY3", srcs_version = "PY2AND3", + tags = ["no_rocm"], deps = [ ":debug_data", ":framework", @@ -1149,6 +1150,7 @@ py_test( srcs = ["cli/debugger_cli_common_test.py"], python_version = "PY3", srcs_version = "PY2AND3", + tags = ["no_rocm"], deps = [ ":debugger_cli_common", "//tensorflow/python:framework_test_lib", diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD index 46f66f42b0a..b0f982e3d5c 100644 --- a/tensorflow/python/distribute/BUILD +++ b/tensorflow/python/distribute/BUILD @@ -176,6 +176,7 @@ py_test( srcs = ["distribute_lib_test.py"], python_version = "PY3", srcs_version = "PY2AND3", + tags = ["no_rocm"], deps = [ ":combinations", ":distribute_lib", @@ -267,19 +268,25 @@ py_library( ":shared_variable_creator", ":values", "//tensorflow/python:array_ops", + "//tensorflow/python:config", "//tensorflow/python:constant_op", "//tensorflow/python:control_flow_ops", "//tensorflow/python:device", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", + "//tensorflow/python:platform", "//tensorflow/python:pywrap_tfe", "//tensorflow/python:summary_ops_v2", "//tensorflow/python:tensor_util", + "//tensorflow/python:tf_export", "//tensorflow/python:training", "//tensorflow/python:util", "//tensorflow/python:variable_scope", + "//tensorflow/python/autograph/core", + "//tensorflow/python/autograph/impl", "//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib", "//tensorflow/python/eager:context", + "//tensorflow/python/eager:def_function", "//tensorflow/python/eager:tape", ], ) @@ -1214,10 +1221,12 @@ cuda_py_test( ], deps = [ ":combinations", + ":distribute_lib", ":mirrored_strategy", ":multi_worker_test_base", ":strategy_combinations", ":strategy_test_lib", + ":values", "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", "//tensorflow/python:constant_op", @@ -1227,8 +1236,6 @@ cuda_py_test( "//tensorflow/python:tensor_util", "//tensorflow/python:variable_scope", "//tensorflow/python/autograph/core:test_lib", - "//tensorflow/python/distribute:distribute_lib", - "//tensorflow/python/distribute:values", "//tensorflow/python/eager:context", "//tensorflow/python/eager:test", "//tensorflow/python/keras/layers", @@ -1351,6 +1358,7 @@ distribute_py_test( shard_count = 5, tags = [ "multi_and_single_gpu", + "no_rocm", ], deps = [ ":saved_model_test_base", @@ -1383,6 +1391,7 @@ distribute_py_test( shard_count = 5, tags = [ "multi_and_single_gpu", + "no_rocm", ], deps = [ ":saved_model_test_base", diff --git a/tensorflow/python/distribute/mirrored_strategy.py b/tensorflow/python/distribute/mirrored_strategy.py index 31c1c6665fa..698bf2c2ce6 100644 --- a/tensorflow/python/distribute/mirrored_strategy.py +++ b/tensorflow/python/distribute/mirrored_strategy.py @@ -1057,26 +1057,48 @@ class MirroredReplicaContext(distribute_lib.ReplicaContext): t.captured_var_scope = variable_scope.get_variable_scope() t.captured_control_deps = t.graph._current_control_dependencies() # pylint: disable=protected-access - # NOTE(priyag): Throw an error if there is a merge call in the middle of a - # `fn` passed to call_for_each_replica which changes the graph being used - # while calling `fn`. This can happen when the `fn` is decorated with - # `tf.function` and there is a merge_call in `fn`. This breaks because each - # thread tries to create a distinct tf.function. Each tf.function creation - # takes a lock, and so if there is a merge call in the middle, the lock is - # never released and subsequent replica threads cannot proceed to define - # their own functions. Checking for the graph being the same is one way for - # us to check this didn't happen. + # It is problematic if `merge_call` is called under a different graph other + # than the one that `_call_for_each_replica` is called under, there are + # 3 cases this can happen: + # + # 1. The `fn` passed to `_call_for_each_replica` is decorated with + # `tf.function` and there is a `merge_call` in `fn`. Since + # MirroredStrategy traces a separate function per thread (per device), + # and each trace takes a shared lock, the lock is never released by the + # first thread and subsequent replica threads cannot proceed to trace + # their own functions. This issue is addressed by always converting + # `_call_for_each_replica(tf.function(f))` to + # ``tf.function(_call_for_each_replica(f))`.` in + # `MirroredStrategy._call_for_each_replica`. + # + # 2. The `fn` passed to `_call_for_each_replica` contains a nested + # `tf.function`, and there is a `merge_call` in the nested `tf.function`. + # In this case each thread can successfully trace its own function, but + # since the `merge_fn` passed to `merge_call` is executed in the main + # thread (where `_call_for_each_replica` is executed), it can't access + # the tensors that come from different graphs. + # + # 3. The `fn` passed to `_call_for_each_replica` contains a control-flow + # statement, and there is a `merge_call` inside the control-flow body, + # `fn` or `_call_for_each_replica` is decorated with `tf.function`. + # Control flow statement creates a separate graph for its body, similar + # to #2, `merge_fn` executed in the main thread can't access the + # tensors that come from different graphs. + # + # We raise an error for #2 and #3. if ops.get_default_graph() != t.graph: raise RuntimeError( - "`merge_call` called while defining a new graph or a tf.function. " - "This can often happen if the function `fn` passed to " - "`strategy.experimental_run()` is decorated with " - "`@tf.function` (or contains a nested `@tf.function`), and `fn` " - "contains a synchronization point, such as aggregating gradients. " - "This behavior is not yet supported. Instead, please wrap the entire " - "call `strategy.experimental_run(fn)` in a `@tf.function`, and avoid " - "nested `tf.function`s that may potentially cross a synchronization " - "boundary.") + "`merge_call` called while defining a new graph or a tf.function." + " This can often happen if the function `fn` passed to" + " `strategy.run()` contains a nested `@tf.function`, and the nested " + "`@tf.function` contains a synchronization point, such as aggregating" + " gradients (e.g, optimizer.apply_gradients), or if the function `fn`" + " uses a control flow statement which contains a synchronization" + " point in the body. Such behaviors are not yet supported. Instead," + " please avoid nested `tf.function`s or control flow statements that" + " may potentially cross a synchronization boundary, for example," + " wrap the `fn` passed to `strategy.run` or the entire `strategy.run`" + " inside a `tf.function` or move the control flow out of `fn`") t.has_paused.set() t.should_run.wait() diff --git a/tensorflow/python/distribute/mirrored_strategy_test.py b/tensorflow/python/distribute/mirrored_strategy_test.py index f06360ca021..2eb2191ad48 100644 --- a/tensorflow/python/distribute/mirrored_strategy_test.py +++ b/tensorflow/python/distribute/mirrored_strategy_test.py @@ -379,15 +379,18 @@ class MirroredStrategyCallForEachReplicaTest(test.TestCase): self.evaluate(distribution.experimental_local_results(result))) self.assertLen(traces, distribution.num_replicas_in_sync) - def testNestedFunctionInCallForEachReplicaWithMergeCall(self, distribution): - def merge_fn(_): - pass + def testControlFlowFunctionInCallForEachReplicaWithMergeCall( + self, distribution): + + def merge_fn(strategy, value): + return strategy.reduce(reduce_util.ReduceOp.SUM, value, axis=None) @def_function.function def model_fn(): + def body_fn(i): - ds_context.get_replica_context().merge_call(merge_fn) - return i + 1 + return ds_context.get_replica_context().merge_call(merge_fn, args=(i,)) + return control_flow_ops.while_loop_v2(lambda i: i < 2, body_fn, [0]) with distribution.scope(): @@ -395,6 +398,25 @@ class MirroredStrategyCallForEachReplicaTest(test.TestCase): RuntimeError, "`merge_call` called while defining a new graph."): distribution.extended.call_for_each_replica(model_fn) + def testNestedFunctionInCallForEachReplicaWithMergeCall(self, distribution): + + def merge_fn(strategy, value): + return strategy.reduce(reduce_util.ReduceOp.SUM, value, axis=None) + + def model_fn(): + + @def_function.function + def model_fn_nested(): + t = constant_op.constant(1) + return ds_context.get_replica_context().merge_call(merge_fn, args=(t,)) + + return model_fn_nested() + + with distribution.scope(): + with self.assertRaisesRegexp( + RuntimeError, "`merge_call` called while defining a new graph."): + distribution.extended.call_for_each_replica(model_fn) + def testFunctionInCallForEachReplicaWithMergeCall(self, distribution): def merge_fn(_): pass diff --git a/tensorflow/python/distribute/tpu_strategy.py b/tensorflow/python/distribute/tpu_strategy.py index 05c9f75f09e..3c095469927 100644 --- a/tensorflow/python/distribute/tpu_strategy.py +++ b/tensorflow/python/distribute/tpu_strategy.py @@ -672,11 +672,9 @@ class TPUExtended(distribute_lib.StrategyExtendedV1): reduce_op, value, destinations, self._num_replicas_in_sync) # TODO(cjfj): Detect when it is possible to use `cross_replica_sum`. - # Always performs the reduction on the TPU host. - with ops.device(self._host_device): - output = math_ops.add_n(value.values) - if reduce_op == reduce_util.ReduceOp.MEAN: - output *= (1. / len(value.values)) + output = math_ops.add_n(value.values) + if reduce_op == reduce_util.ReduceOp.MEAN: + output *= (1. / len(value.values)) devices = cross_device_ops_lib.get_devices_from(destinations) diff --git a/tensorflow/python/distribute/tpu_strategy_test.py b/tensorflow/python/distribute/tpu_strategy_test.py index 5158d1fdb35..de4c975d5ef 100644 --- a/tensorflow/python/distribute/tpu_strategy_test.py +++ b/tensorflow/python/distribute/tpu_strategy_test.py @@ -20,6 +20,7 @@ from __future__ import print_function from tensorflow.python import keras from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.distribute import distribution_strategy_context from tensorflow.python.distribute import reduce_util from tensorflow.python.distribute import tpu_strategy as tpu_lib from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver @@ -30,15 +31,18 @@ from tensorflow.python.eager import test from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_spec from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import embedding_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import variables from tensorflow.python.platform import flags from tensorflow.python.platform import tf_logging as logging from tensorflow.python.tpu import device_assignment as device_assignment_lib +from tensorflow.python.tpu import tpu from tensorflow.python.tpu import tpu_strategy_util @@ -349,6 +353,149 @@ class TPUStrategyTest(test.TestCase): for i in dataset: strategy.run(step_fn, args=(i,)) + # TODO(b/145574622): Remove this test once it is re-enabled in values_test.py. + def test_all_reduce_on_sync_on_read_variable(self): + strategy = get_tpu_strategy() + dataset = dataset_ops.Dataset.range( + strategy.num_replicas_in_sync, output_type=dtypes.float32).batch( + strategy.num_replicas_in_sync, drop_remainder=True) + input_iterator = iter(strategy.experimental_distribute_dataset(dataset)) + + with strategy.scope(): + w = variables.Variable( + (0.,), + shape=(1,), + trainable=False, + synchronization=variables.VariableSynchronization.ON_READ, + aggregation=variables.VariableAggregation.ONLY_FIRST_REPLICA) + + @def_function.function + def run(iterator): + + def computation(x): + w.assign(x + w) + return w + + def all_reduce(x): + ctx = distribution_strategy_context.get_replica_context() + return ctx.all_reduce("SUM", w) + x + + outputs = strategy.run(computation, args=(next(iterator),)) + outputs2 = strategy.experimental_local_results( + strategy.run(all_reduce, args=(outputs,))) + return outputs2 + + data = range(0, strategy.num_replicas_in_sync) + data_sum = sum(data) + expected_result = [ + [x + data_sum] for x in range(0, strategy.num_replicas_in_sync) + ] + self.assertAllEqual(expected_result, run(input_iterator)) + self.assertAllEqual((0.,), w.read_value()) + + # TODO(b/140633529): Re-enable the test. + def disable_test_experimental_run_output_on_device(self): + strategy = get_tpu_strategy() + + def computation(x): + return math_ops.square(x) + + @def_function.function + def train_step(): + outputs = strategy.experimental_local_results( + strategy.run(computation, args=(2,))) + return outputs + + results = train_step() + self.assertAllEqual([4., 4.], results) + self.assertAllEqual("/job:localhost/replica:0/task:0/device:TPU:0", + results[0].backing_device) + self.assertAllEqual("/job:localhost/replica:0/task:0/device:TPU:1", + results[1].backing_device) + + def test_composite_input(self): + strategy = get_tpu_strategy() + if strategy.num_replicas_in_sync != 2: + self.skipTest("Test assumes two replicas.") + + with strategy.scope(): + table = variables.Variable( + initial_value=[[0.0, 1.0], [3.0, 7.0]], dtype=dtypes.float32) + + @def_function.function + def sparse_lookup(iterator): + + def tpu_function(sparse): + # Assumes dense_shape is (2, *) + looked_up = array_ops.gather(table, sparse.values) + return math_ops.unsorted_segment_sum(looked_up, sparse.indices[:, 0], 2) + + return strategy.experimental_local_results( + strategy.run(tpu_function, args=(next(iterator),))) + + def dataset_fn(_): + dataset = dataset_ops.Dataset.range(2) + + def make_sparse(_): + return sparse_tensor.SparseTensor( + indices=array_ops.constant([[0, 0], [1, 0], [1, 1]], + dtype=dtypes.int64), + values=array_ops.constant([0, 0, 1], dtype=dtypes.int32), + dense_shape=array_ops.constant([2, 2], dtype=dtypes.int64)) + + return dataset.map(make_sparse) + + strategy.extended._set_prefetch_on_host(True) # pylint: disable=protected-access + dataset = iter( + strategy.experimental_distribute_datasets_from_function(dataset_fn)) + + result = sparse_lookup(dataset) + self.assertAllEqual(result, + [[[0.0, 1.0], [3.0, 8.0]], [[0.0, 1.0], [3.0, 8.0]]]) + + def test_composite_input_dynamic_shapes_outside_compilation(self): + strategy = get_tpu_strategy() + if strategy.num_replicas_in_sync != 2: + self.skipTest("Test assumes two replicas.") + + table = variables.Variable( + initial_value=[[0.0, 1.0], [3.0, 7.0]], dtype=dtypes.float32) + + @def_function.function + def sparse_lookup(iterator): + + def tpu_function(sparse): + lookup = tpu.outside_compilation( + embedding_ops.safe_embedding_lookup_sparse, table, sparse) + return math_ops.reduce_sum(lookup, axis=0) + + return strategy.experimental_local_results( + strategy.run(tpu_function, args=(next(iterator),))) + + def dataset_fn(_): + dataset = dataset_ops.Dataset.range(2) + + def make_sparse(i): + indices = array_ops.constant([[0, 0], [1, 0], [1, 1]], + dtype=dtypes.int64)[0:2 + i] + values = array_ops.constant([0, 0, 1], dtype=dtypes.int32)[0:2 + i] + shape = [ + array_ops.constant([2], dtype=dtypes.int64), + array_ops.expand_dims(1 + i, axis=0) + ] + dense_shape = array_ops.concat(shape, axis=0) + return sparse_tensor.SparseTensor( + indices=indices, values=values, dense_shape=dense_shape) + + return dataset.map(make_sparse) + + strategy.extended._set_prefetch_on_host(True) # pylint: disable=protected-access + dataset = iter( + strategy.experimental_distribute_datasets_from_function(dataset_fn)) + + result = sparse_lookup(dataset) + self.assertAllEqual(result, [[0.0, 2.0], [1.5, 5.0]]) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/distribute/values.py b/tensorflow/python/distribute/values.py index 2bc5f61d076..2f8f9b8afe4 100644 --- a/tensorflow/python/distribute/values.py +++ b/tensorflow/python/distribute/values.py @@ -18,7 +18,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import collections import weakref from tensorflow.python.distribute import device_util @@ -404,8 +403,23 @@ def _assign_sub_on_device(device, variable, tensor): return variable.assign_sub(tensor) -DistributedVarOp = collections.namedtuple( - "DistributedVarOp", ["name", "graph", "traceback", "type"]) +class DistributedVarOp(object): + """A class that looks like `tf.Operation`.""" + + def __init__(self, name, graph, traceback, typ): + self.name = name + self.graph = graph + self.traceback = traceback + self.type = typ + + def __eq__(self, o): + if not isinstance(o, self.__class__): + raise NotImplementedError + return (self.name == o.name and self.graph == o.graph and + self.traceback == o.traceback and self.type == o.type) + + def __hash__(self): + return hash((self.name, self.graph, self.traceback, self.type)) class DistributedVariable(DistributedDelegate, variables_lib.Variable): @@ -879,7 +893,13 @@ class MirroredVariable(DistributedVariable, Mirrored): """Converts a variable to a tensor.""" # Try to avoid assignments to and other mutations of MirroredVariable # state except through a DistributionStrategy.extended.update() call. - assert not as_ref + if as_ref: + # A TF 1.x case where the variable is a boolean variable and used like: + # tf.cond(v, true_fn, false_fn). + raise ValueError( + "You may be using variable created under distribute strategy in TF " + "1.x control flows. Try explicitly converting the variable to Tensor " + "using variable.read_value(), or switch to TF 2.x.") return ops.convert_to_tensor( self._get(), dtype=dtype, name=name, as_ref=as_ref) diff --git a/tensorflow/python/distribute/values_test.py b/tensorflow/python/distribute/values_test.py index 685dbaf4d40..c3c3e0d5286 100644 --- a/tensorflow/python/distribute/values_test.py +++ b/tensorflow/python/distribute/values_test.py @@ -995,7 +995,6 @@ class MirroredVariableTest(test.TestCase, parameterized.TestCase): distribution=[ strategy_combinations.mirrored_strategy_with_gpu_and_cpu, strategy_combinations.tpu_strategy, - strategy_combinations.central_storage_strategy_with_two_gpus, ], mode=["eager"])) def testInitScope(self, distribution): diff --git a/tensorflow/python/dlpack/dlpack.py b/tensorflow/python/dlpack/dlpack.py index 47bf5b35a8b..8bad390ef21 100644 --- a/tensorflow/python/dlpack/dlpack.py +++ b/tensorflow/python/dlpack/dlpack.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function from tensorflow.python import pywrap_tfe +from tensorflow.python.eager import context from tensorflow.python.util.tf_export import tf_export @@ -62,4 +63,4 @@ def from_dlpack(dlcapsule): Returns: A Tensorflow eager tensor """ - return pywrap_tfe.TFE_FromDlpackCapsule(dlcapsule) + return pywrap_tfe.TFE_FromDlpackCapsule(dlcapsule, context.context()._handle) diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD index 9df6113b95f..9c229974b05 100644 --- a/tensorflow/python/eager/BUILD +++ b/tensorflow/python/eager/BUILD @@ -216,6 +216,7 @@ cuda_py_test( deps = [ ":profiler", ":test", + "//tensorflow/core/profiler/protobuf:trace_events_proto_py", "//tensorflow/python:constant_op", "//tensorflow/python/profiler:traceme", ], @@ -230,6 +231,7 @@ py_library( "//tensorflow/core/profiler:internal", ], deps = [ + "//tensorflow/python:util", "//tensorflow/python/profiler/internal:_pywrap_profiler", ], ) @@ -860,7 +862,10 @@ cuda_py_test( ":def_function", ":remote", ":test", + "//tensorflow/python:dtypes", + "//tensorflow/python:functional_ops", "//tensorflow/python:resource_variable_ops", + "//tensorflow/python:tensor_spec", "@absl_py//absl/testing:parameterized", "@six_archive//:six", ], @@ -891,7 +896,7 @@ cuda_py_test( srcs = ["remote_cluster_test.py"], grpc_enabled = True, python_version = "PY3", - shard_count = 16, + shard_count = 8, tags = [ "no_oss", # This test launches local server "notsan", # TODO(b/152075365) diff --git a/tensorflow/python/eager/backprop_util.py b/tensorflow/python/eager/backprop_util.py index ae026c0fbbb..117b05e0956 100644 --- a/tensorflow/python/eager/backprop_util.py +++ b/tensorflow/python/eager/backprop_util.py @@ -30,4 +30,4 @@ def IsTrainable(tensor_or_dtype): dtype = dtypes.as_dtype(dtype) return dtype.base_dtype in (dtypes.float16, dtypes.float32, dtypes.float64, dtypes.complex64, dtypes.complex128, - dtypes.resource, dtypes.variant) + dtypes.resource, dtypes.variant, dtypes.bfloat16) diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py index c0331d760b9..6dfe419e5c6 100644 --- a/tensorflow/python/eager/context.py +++ b/tensorflow/python/eager/context.py @@ -411,6 +411,7 @@ class Context(object): execution_mode = SYNC self._default_is_async = execution_mode == ASYNC self._lazy_remote_inputs_copy = None + self._use_tfrt = None self._server_def = server_def self._collective_ops_server_def = None self._collective_leader = None @@ -514,6 +515,8 @@ class Context(object): if self._lazy_remote_inputs_copy is not None: pywrap_tfe.TFE_ContextOptionsSetLazyRemoteInputsCopy( opts, self._lazy_remote_inputs_copy) + if self._use_tfrt is not None: + pywrap_tfe.TFE_ContextOptionsSetTfrt(opts, self._use_tfrt) context_handle = pywrap_tfe.TFE_NewContext(opts) finally: pywrap_tfe.TFE_DeleteContextOptions(opts) @@ -1565,6 +1568,21 @@ class Context(object): "lazy_remote_inputs_copy should be set before being initialized.") self._lazy_remote_inputs_copy = lazy_copy + @property + def use_tfrt(self): + return self._use_tfrt + + @use_tfrt.setter + def use_tfrt(self, tfrt): + """Sets whether to use TFRT.""" + if not isinstance(tfrt, bool): + raise ValueError("Expecting a boolean but got %s" % type(tfrt)) + + if self._use_tfrt != tfrt: + if self._initialized: + raise ValueError("use_tfrt should be set before being initialized.") + self._use_tfrt = tfrt + def enable_run_metadata(self): """Enables tracing of op execution via RunMetadata. diff --git a/tensorflow/python/eager/def_function.py b/tensorflow/python/eager/def_function.py index e49a19807a8..6f50325015f 100644 --- a/tensorflow/python/eager/def_function.py +++ b/tensorflow/python/eager/def_function.py @@ -22,7 +22,11 @@ from __future__ import print_function import functools import threading import weakref +import six +from google.protobuf import text_format as _text_format +from google.protobuf.message import DecodeError +from tensorflow.core.framework import attr_value_pb2 from tensorflow.python import pywrap_tfe from tensorflow.python.eager import context from tensorflow.python.eager import function as function_lib @@ -36,8 +40,6 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training.tracking import base as trackable - - from tensorflow.python.util import nest from tensorflow.python.util import object_identity from tensorflow.python.util import tf_decorator @@ -376,7 +378,10 @@ class Function(object): def embedding_matmul(a, b): # custom implementation here ``` - + This can either be specified as just the string name of the function or + a NameAttrList corresponding to a list of key-value attributes + with the function name. The name of the function will be in the 'name' + field of the NameAttrList. experimental_autograph_options: optional tuple of tensorflow.autograph.Feature values. Allows enabling additional conversion options when autograph is set to True. @@ -445,11 +450,35 @@ class Function(object): self._python_function, wrapped_fn)) + def _create_implements_attribute(self): + """Creates the attribute value corresponding to IMPLEMENTS_ATTRIBUTE_NAME.""" + attributes = {} + if isinstance(self._implements, str): + # First check if the IMPLEMENTS_ATTRIBUTE_NAME is specified as a + # NameAttrList. This is used when apart from the function name being + # implemented, a list of attributes is also being specified. + # The attributes are specified as key-value pairs in the NameAttrList + # of the corresponding AttrValue. The function name will be in the + # 'name' field of the NameAttrList. Else, it is just a string + # corresponding to the function name. + try: + implements_attr = six.ensure_text(self._implements, "utf-8") + attr_value = attr_value_pb2.AttrValue() + nameattrlist = attr_value_pb2.NameAttrList() + _text_format.Merge(implements_attr, nameattrlist) + attr_value.func.CopyFrom(nameattrlist) + attributes[function_lib.IMPLEMENTS_ATTRIBUTE_NAME] = attr_value + except (_text_format.ParseError, DecodeError): + attributes[function_lib.IMPLEMENTS_ATTRIBUTE_NAME] = self._implements + return attributes + def _defun(self, fn): """Returns a defun generated from the input function.""" attributes = {} + if self._implements is not None: - attributes[function_lib.IMPLEMENTS_ATTRIBUTE_NAME] = self._implements + attributes = self._create_implements_attribute() + if self._experimental_compile is not None: attributes.update(_XlaMustCompile=bool(self._experimental_compile)) if self._experimental_compile: @@ -1186,6 +1215,10 @@ def function(func=None, `embedded_matmul` (perhaps more efficiently!) by specifying it using this parameter: `@tf.function(experimental_implements="embedded_matmul")` + This can either be specified as just the string name of the function or + a NameAttrList corresponding to a list of key-value attributes associated + with the function name. The name of the function will be in the 'name' + field of the NameAttrList. experimental_autograph_options: Optional tuple of `tf.autograph.experimental.Feature` values. experimental_relax_shapes: When True, `tf.function` may generate fewer, diff --git a/tensorflow/python/eager/def_function_xla_jit_test.py b/tensorflow/python/eager/def_function_xla_jit_test.py index 16d57ef36da..adff0858488 100644 --- a/tensorflow/python/eager/def_function_xla_jit_test.py +++ b/tensorflow/python/eager/def_function_xla_jit_test.py @@ -232,6 +232,24 @@ class DefFunctionTest(test.TestCase): with self.assertRaisesRegexp(errors.InvalidArgumentError, 'not compilable'): c.f1(inputs) + def testMustBeConstantPropagation(self): + if test.is_built_with_rocm(): + return + + @def_function.function(experimental_compile=True) + def f(): + return constant_op.constant([0, 2, 1], dtype=dtypes.int32) + + @def_function.function(experimental_compile=True) + def g(a, b): + return array_ops.transpose(a, b) + + @def_function.function + def z(): + return g(array_ops.ones([3, 4, 3], dtype=dtypes.float32), f()) + + z() + if __name__ == '__main__': ops.enable_eager_execution() diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py index 7b599a995e2..d98c83665b9 100644 --- a/tensorflow/python/eager/function_test.py +++ b/tensorflow/python/eager/function_test.py @@ -287,6 +287,36 @@ class FunctionTest(test.TestCase, parameterized.TestCase): numpy.testing.assert_equal(r1.eval(), [3.]) numpy.testing.assert_equal(r2.eval(), [3., 3.]) + def testImplementsAttributeAsNameAttrList(self): + implements_attr = ( + 'name: "embedding_matmul" attr { key: "key1" value { i: 2 } ' + '} attr { key: "key2" value { b: false } }') + v = def_function.function( + experimental_implements=implements_attr)(lambda x, y: x + y) + with context.graph_mode(), self.cached_session(): + a = array_ops.placeholder(dtypes.float32, ()) + b = array_ops.placeholder(dtypes.float32, ()) + v(a, b) + gradients_impl.gradients(v(a, b), [a, b]) + fdefs = ops.get_default_graph().as_graph_def().library.function + self.assertLen(fdefs, 3) + not_present = 0 + present = 0 + for f in fdefs: + name = f.signature.name + if 'forward' in name or 'backward' in name: + not_present += 1 + self.assertNotIn(function.IMPLEMENTS_ATTRIBUTE_NAME, f.attr, f) + else: + present += 1 + attr_value = f.attr[function.IMPLEMENTS_ATTRIBUTE_NAME] + self.assertIsNotNone(attr_value.func, f) + self.assertEqual(attr_value.func.name, 'embedding_matmul') + name_attrs = attr_value.func.attr + self.assertLen(name_attrs, 2) + self.assertEqual(not_present, 2, fdefs) + self.assertEqual(present, 1, fdefs) + def testExternalControlDependency(self): with ops.Graph().as_default(), self.test_session(): v = variables.Variable(1.0) diff --git a/tensorflow/python/eager/profiler.py b/tensorflow/python/eager/profiler.py index 835a0d72bbf..91761986107 100644 --- a/tensorflow/python/eager/profiler.py +++ b/tensorflow/python/eager/profiler.py @@ -45,6 +45,7 @@ from tensorflow.python.platform import gfile from tensorflow.python.platform import tf_logging as logging from tensorflow.python.profiler.internal import _pywrap_profiler from tensorflow.python.util import compat +from tensorflow.python.util.deprecation import deprecated _profiler = None _profiler_lock = threading.Lock() @@ -62,6 +63,7 @@ class ProfilerNotRunningError(Exception): pass +@deprecated('2020-07-01', 'use `tf.profiler.experimental.start` instead.') def start(): """Start profiling. @@ -84,6 +86,7 @@ def start(): raise ProfilerAlreadyRunningError('Another profiler is running.') +@deprecated('2020-07-01', 'use `tf.profiler.experimental.stop` instead.') def stop(): """Stop current profiling session and return its result. @@ -108,6 +111,10 @@ def stop(): return result +@deprecated( + '2020-07-01', + '`tf.python.eager.profiler` has deprecated, use `tf.profiler` instead.' +) def maybe_create_event_file(logdir): """Create an empty event file if not already exists. @@ -126,6 +133,10 @@ def maybe_create_event_file(logdir): event_writer.InitWithSuffix(compat.as_bytes(_EVENT_FILE_SUFFIX)) +@deprecated( + '2020-07-01', + '`tf.python.eager.profiler` has deprecated, use `tf.profiler` instead.' +) def save(logdir, result): """Save profile result to TensorBoard logdir. @@ -142,6 +153,7 @@ def save(logdir, result): f.write(result) +@deprecated('2020-07-01', 'use `tf.profiler.experimental.server.start`.') def start_profiler_server(port): """Start a profiler grpc server that listens to given port. @@ -160,6 +172,7 @@ def start_profiler_server(port): _pywrap_profiler.start_server(port) +@deprecated('2020-07-01', 'use `tf.profiler.experimental.Profile` instead.') class Profiler(object): """Context-manager eager profiler api. diff --git a/tensorflow/python/eager/profiler_client.py b/tensorflow/python/eager/profiler_client.py index 81fcd9fe498..cf7aee30708 100644 --- a/tensorflow/python/eager/profiler_client.py +++ b/tensorflow/python/eager/profiler_client.py @@ -19,8 +19,10 @@ from __future__ import division from __future__ import print_function from tensorflow.python.profiler.internal import _pywrap_profiler +from tensorflow.python.util.deprecation import deprecated +@deprecated('2020-07-01', 'use `tf.profiler.experimental.client.trace`.') def start_tracing(service_addr, logdir, duration_ms, @@ -48,6 +50,7 @@ def start_tracing(service_addr, duration_ms, num_tracing_attempts) +@deprecated('2020-07-01', 'use `tf.profiler.experimental.client.monitor`.') def monitor(service_addr, duration_ms, monitoring_level=1, diff --git a/tensorflow/python/eager/profiler_test.py b/tensorflow/python/eager/profiler_test.py index 8d9c27c83e6..428cfab1c96 100644 --- a/tensorflow/python/eager/profiler_test.py +++ b/tensorflow/python/eager/profiler_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import os -from tensorflow.core.protobuf import trace_events_pb2 +from tensorflow.core.profiler.protobuf import trace_events_pb2 from tensorflow.python.eager import profiler from tensorflow.python.eager import test from tensorflow.python.framework import config diff --git a/tensorflow/python/eager/pywrap_gradient_exclusions.cc b/tensorflow/python/eager/pywrap_gradient_exclusions.cc index afae0b57ee7..882c8097a0f 100644 --- a/tensorflow/python/eager/pywrap_gradient_exclusions.cc +++ b/tensorflow/python/eager/pywrap_gradient_exclusions.cc @@ -410,7 +410,7 @@ absl::optional> OpGradientUnusedInputIndices( absl::optional> OpGradientUnusedOutputIndices( const tensorflow::string &op_name) { - static std::array a = {{ + static std::array a = {{ {"Abs"}, {"AccumulateNV2"}, {"Acos"}, @@ -821,20 +821,11 @@ absl::optional> OpGradientUnusedOutputIndices( {"TensorArrayRead"}, {"TensorArrayReadV2"}, {"TensorArrayReadV3"}, - {"TensorArrayScatter"}, - {"TensorArrayScatterV2"}, - {"TensorArrayScatterV3"}, {"TensorArraySize"}, {"TensorArraySizeV2"}, {"TensorArraySizeV3"}, - {"TensorArraySplit"}, - {"TensorArraySplitV2"}, - {"TensorArraySplitV3"}, {"TensorArrayV2"}, {"TensorArrayV3"}, - {"TensorArrayWrite"}, - {"TensorArrayWriteV2"}, - {"TensorArrayWriteV3"}, {"TensorListConcat", 1, {0}}, {"TensorListConcatLists"}, {"TensorListConcatV2", 1, {0}}, @@ -842,7 +833,6 @@ absl::optional> OpGradientUnusedOutputIndices( {"TensorListGather"}, {"TensorListGetItem"}, {"TensorListLength"}, - {"TensorListPopBack", 1, {1}}, {"TensorListPushBack"}, {"TensorListPushBackBatch"}, {"TensorListResize"}, diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc index a3f9b0bed5c..a120c1ccdd9 100644 --- a/tensorflow/python/eager/pywrap_tfe_src.cc +++ b/tensorflow/python/eager/pywrap_tfe_src.cc @@ -75,7 +75,7 @@ TFE_Op* GetOp(TFE_Context* ctx, const char* op_or_function_name, const char* raw_device_name, TF_Status* status) { std::unique_ptr op = ReleaseThreadLocalOp(ctx); if (!op) { - op.reset(new TFE_Op{std::make_unique(ctx)}); + op.reset(new TFE_Op{ctx->context->CreateOperation()}); } status->status = op->operation->Reset(op_or_function_name, raw_device_name); if (!status->status.ok()) { diff --git a/tensorflow/python/eager/remote_cluster_test.py b/tensorflow/python/eager/remote_cluster_test.py index 78e7098d081..11310c0b5c4 100644 --- a/tensorflow/python/eager/remote_cluster_test.py +++ b/tensorflow/python/eager/remote_cluster_test.py @@ -495,28 +495,6 @@ class DynamicClusterTest(test.TestCase, parameterized.TestCase): context.check_alive("/job:remote_device/replica:0/task:10") -class DynamicClusterWithoutLazyRemoteInputsCopyTest(DynamicClusterTest): - - @classmethod - def setUpClass(cls): - super(DynamicClusterWithoutLazyRemoteInputsCopyTest, cls).setUpClass() - context._reset_context() - context.context().lazy_remote_inputs_copy = False - - @classmethod - def tearDownClass(cls): - super(DynamicClusterWithoutLazyRemoteInputsCopyTest, cls).tearDownClass() - context._reset_context() - context.context().lazy_remote_inputs_copy = True - - # TODO(haoyuzhang): When lazyh remote inputs copy is disabled, we use the - # WorkerService RunGraph request to execute component functions in distributed - # function execution. We currently do not have access control in WorkerService - # to allow concurrent cluster update and function execution. - def testMultiThreadPendingNodesLockFree(self): - self.skipTest("Unsupported case") - - if __name__ == "__main__": ops.enable_eager_execution() test.main() diff --git a/tensorflow/python/eager/remote_test.py b/tensorflow/python/eager/remote_test.py index 2e64b8f73da..b32a773e894 100644 --- a/tensorflow/python/eager/remote_test.py +++ b/tensorflow/python/eager/remote_test.py @@ -31,11 +31,14 @@ from tensorflow.python.eager import def_function from tensorflow.python.eager import remote from tensorflow.python.eager import test from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import functional_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import variables from tensorflow.python.training import server_lib @@ -98,7 +101,6 @@ class SingleWorkerTest(test.TestCase, parameterized.TestCase): self.assertAllEqual( remote_output(constant_op.constant([1]))[0].numpy(), 2) - # TODO(b/148235520): Re-enable this test. def testMultiDeviceFunctionAmbiguousDevice(self): @def_function.function @@ -168,6 +170,26 @@ class SingleWorkerTest(test.TestCase, parameterized.TestCase): with ops.device('/job:worker/task:0'): self.assertAllEqual(func(), 1) + @test_util.eager_lazy_remote_copy_on_and_off + def testRemoteCall(self): + + @def_function.function( + input_signature=[tensor_spec.TensorSpec([], dtypes.int32)]) + def _remote_fn(x): + return constant_op.constant(1) + x + + remote_fn = _remote_fn.get_concrete_function() + + @def_function.function + def func(x): + return functional_ops.remote_call( + args=[x], + Tout=[dtypes.int32], + f=remote_fn, + target='/job:worker/task:0') + + self.assertAllEqual(func(constant_op.constant(1)), [2]) + class RemoteAsyncTest(test.TestCase): diff --git a/tensorflow/python/feature_column/BUILD b/tensorflow/python/feature_column/BUILD index b29214f3f07..2e2b831750a 100644 --- a/tensorflow/python/feature_column/BUILD +++ b/tensorflow/python/feature_column/BUILD @@ -113,6 +113,7 @@ tf_py_test( tags = [ "no_cuda_on_cpu_tap", "no_pip", + "no_rocm", "no_windows", ], deps = [ @@ -165,6 +166,7 @@ tf_py_test( tags = [ "no_cuda_on_cpu_tap", "no_pip", + "no_rocm", "no_windows", ], deps = [":feature_column_v2_test_main_lib"], diff --git a/tensorflow/python/framework/importer_test.py b/tensorflow/python/framework/importer_test.py index 66e80b55852..ae30c15e844 100644 --- a/tensorflow/python/framework/importer_test.py +++ b/tensorflow/python/framework/importer_test.py @@ -472,21 +472,6 @@ class ImportGraphDefTest(test.TestCase): node { name: 'B' op: 'FloatInput' input: 'A:0' } """)) - def testShapeWhitelist(self): - # Barrier's shape is an output vector of 2, but the - # graph says it's a scalar. This is currently whitelisted. - with ops.Graph().as_default(): - _ = importer.import_graph_def( - self._MakeGraphDef(""" - node { name: 'A' op: 'Barrier' - attr { key: '_output_shapes' - value { list { shape { } } } } - attr { key: 'component_types' - value { list { type: DT_FLOAT } } } } - """), - return_elements=["A"], - name="import") - def testShapeWhitelistViolation(self): # L2 loss produces a scalar shape, but the graph # has the wrong shape, so raise an error. diff --git a/tensorflow/python/framework/indexed_slices.py b/tensorflow/python/framework/indexed_slices.py index abf90547e50..c1b3a1775ec 100644 --- a/tensorflow/python/framework/indexed_slices.py +++ b/tensorflow/python/framework/indexed_slices.py @@ -54,13 +54,9 @@ tensor_util = LazyLoader( "tensor_util", globals(), "tensorflow.python.framework.tensor_util") -# pylint: disable=protected-access -_TensorLike = tensor_like._TensorLike -# pylint: enable=protected-access - @tf_export("IndexedSlices") -class IndexedSlices(_TensorLike, composite_tensor.CompositeTensor): +class IndexedSlices(tensor_like.TensorLike, composite_tensor.CompositeTensor): """A sparse representation of a set of tensor slices at given indices. This class is a simple wrapper for a pair of `Tensor` objects: @@ -309,7 +305,7 @@ def internal_convert_to_tensor_or_indexed_slices(value, """ if isinstance(value, ops.EagerTensor) and not context.executing_eagerly(): return ops.convert_to_tensor(value, dtype=dtype, name=name, as_ref=as_ref) - elif isinstance(value, _TensorLike): + elif isinstance(value, tensor_like.TensorLike): if dtype and not dtypes.as_dtype(dtype).is_compatible_with(value.dtype): raise ValueError( "Tensor conversion requested dtype %s for Tensor with dtype %s: %r" % diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index 0df275f9093..e55a25787fe 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -94,7 +94,6 @@ _api_usage_gauge = monitoring.BoolGauge( # pylint: disable=protected-access -_TensorLike = tensor_like._TensorLike _DTYPES_INTERN_TABLE = dtypes._INTERN_TABLE # pylint: enable=protected-access @@ -290,7 +289,7 @@ def disable_tensor_equality(): @tf_export("Tensor") -class Tensor(_TensorLike): +class Tensor(tensor_like.TensorLike): """A tensor is a multidimensional array of elements represented by a `tf.Tensor` object. All elements are of a single known data type. @@ -5983,9 +5982,12 @@ def _assert_same_graph(original_item, item): Raises: ValueError: if graphs do not match. """ - if original_item.graph is not item.graph: - raise ValueError("%s must be from the same graph as %s." % - (item, original_item)) + original_graph = getattr(original_item, "graph", None) + graph = getattr(item, "graph", None) + if original_graph and graph and original_graph is not graph: + raise ValueError( + "%s must be from the same graph as %s (graphs are %s and %s)." % + (item, original_item, graph, original_graph)) def _get_graph_from_inputs(op_input_list, graph=None): @@ -6039,7 +6041,7 @@ def _get_graph_from_inputs(op_input_list, graph=None): # TODO(josh11b): Note that we exclude subclasses of Tensor. Need to clean this # up. graph_element = None - if (isinstance(op_input, (Operation, _TensorLike)) and + if (isinstance(op_input, (Operation, tensor_like.TensorLike)) and ((not isinstance(op_input, Tensor)) or type(op_input) == Tensor)): # pylint: disable=unidiomatic-typecheck graph_element = op_input else: @@ -6048,7 +6050,7 @@ def _get_graph_from_inputs(op_input_list, graph=None): if graph_element is not None: if not graph: original_graph_element = graph_element - graph = graph_element.graph + graph = getattr(graph_element, "graph", None) elif original_graph_element is not None: _assert_same_graph(original_graph_element, graph_element) elif graph_element.graph is not graph: diff --git a/tensorflow/python/framework/sparse_tensor.py b/tensorflow/python/framework/sparse_tensor.py index 9af74ed569d..a175f80a6c3 100644 --- a/tensorflow/python/framework/sparse_tensor.py +++ b/tensorflow/python/framework/sparse_tensor.py @@ -38,14 +38,13 @@ from tensorflow.python.ops import gen_sparse_ops from tensorflow.python.util.tf_export import tf_export # pylint: disable=protected-access -_TensorLike = tensor_like._TensorLike _eval_using_default_session = ops._eval_using_default_session _override_helper = ops._override_helper # pylint: enable=protected-access @tf_export("sparse.SparseTensor", "SparseTensor") -class SparseTensor(_TensorLike, composite_tensor.CompositeTensor): +class SparseTensor(tensor_like.TensorLike, composite_tensor.CompositeTensor): """Represents a sparse tensor. TensorFlow represents a sparse tensor as three separate dense tensors: diff --git a/tensorflow/python/framework/tensor_like.py b/tensorflow/python/framework/tensor_like.py index 5fc2d3dd01f..e8fe2f2fc05 100644 --- a/tensorflow/python/framework/tensor_like.py +++ b/tensorflow/python/framework/tensor_like.py @@ -19,7 +19,10 @@ from __future__ import division from __future__ import print_function -# NOTE(ebrevdo): Do not subclass this. If you do, I will break you on purpose. -class _TensorLike(object): - """Internal cls for grouping Tensor, SparseTensor, ..., for is_instance.""" +class TensorLike(object): + """TF-specific types TF operations are expected to natively support. + + Do not check this with isinstance directly; prefer instead using + `tf.is_tensor` to check whether converting to a tensor is necessary. + """ pass diff --git a/tensorflow/python/framework/tensor_util.py b/tensorflow/python/framework/tensor_util.py index 7504f7e27cd..30cac68bfc5 100644 --- a/tensorflow/python/framework/tensor_util.py +++ b/tensorflow/python/framework/tensor_util.py @@ -27,6 +27,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_like from tensorflow.python.framework import tensor_shape +from tensorflow.python.types import core from tensorflow.python.util import compat from tensorflow.python.util import nest from tensorflow.python.util.tf_export import tf_export @@ -977,25 +978,31 @@ def constant_value_as_shape(tensor): # pylint: disable=invalid-name @tf_export("is_tensor") def is_tensor(x): # pylint: disable=invalid-name - """Checks whether `x` is a tensor or "tensor-like". + """Checks whether `x` is a TF-native type that can be passed to many TF ops. + + Use is_tensor to differentiate types that can ingested by TensorFlow ops + without any conversion (e.g., `tf.Tensor`, `tf.SparseTensor`, and + `tf.RaggedTensor`) from types that need to be converted into tensors before + they are ingested (e.g., numpy `ndarray` and Python scalars). + + For example, in the following code block: + + ```python + if not tf.is_tensor(t): + t = tf.convert_to_tensor(t) + return t.dtype + ``` + + we check to make sure that `t` is a tensor (and convert it if not) before + accessing its `shape` and `dtype`. - If `is_tensor(x)` returns `True`, it is safe to assume that `x` is a tensor or - can be converted to a tensor using `ops.convert_to_tensor(x)`. - - Usage example: - - >>> tf.is_tensor(tf.constant([[1,2,3],[4,5,6],[7,8,9]])) - True - >>> tf.is_tensor("Hello World") - False - Args: x: A python object to check. Returns: `True` if `x` is a tensor or "tensor-like", `False` if not. """ - return (isinstance(x, tensor_like._TensorLike) or # pylint: disable=protected-access + return (isinstance(x, (tensor_like.TensorLike, core.Tensor)) or ops.is_dense_tensor_like(x) or getattr(x, "is_tensor_like", False)) diff --git a/tensorflow/python/grappler/auto_mixed_precision_test.py b/tensorflow/python/grappler/auto_mixed_precision_test.py index 8bb3a259228..494f6fc78fc 100644 --- a/tensorflow/python/grappler/auto_mixed_precision_test.py +++ b/tensorflow/python/grappler/auto_mixed_precision_test.py @@ -410,7 +410,11 @@ class AutoMixedPrecisionTest(test.TestCase): self.assertEqual(num_to_fp16, 3) # Before Conv2D:0, Conv2D:1, Conv2D_1:1 self.assertEqual(num_to_fp32, 1) # After FusedBatchNormV3:0 - self.assertAllClose(output_val_ref, output_val, atol=1e-3, rtol=1e-3) + # Bump up the tolerance for the ROCm platform + # The default tolerance (1e-3) results in a tiny fraction (<1%) of + # miscompares on ROCm platform, and hence the tolerance bump + tol = 2e-3 if test.is_built_with_rocm else 1e-3 + self.assertAllClose(output_val_ref, output_val, atol=tol, rtol=tol) # TODO: enable these tests when cuDNN is upgraded to >= 7.6.2. Same with the # test_conv3d() below. @@ -489,7 +493,11 @@ class AutoMixedPrecisionTest(test.TestCase): self._assert_output_fp16(node_map, 'Conv2D_1') output_val_ref, output_val, cost_graph = self._run(output) - self.assertAllClose(output_val_ref, output_val, atol=2e-3, rtol=2e-3) + # Bump up the tolerance for the ROCm platform + # The default tolerance (1e-3) results in a tiny fraction (<1%) of + # miscompares on ROCm platform, and hence the tolerance bump + tol = 2e-3 if test.is_built_with_rocm else 1e-3 + self.assertAllClose(output_val_ref, output_val, atol=tol, rtol=tol) @test_util.run_deprecated_v1 @test_util.disable_xla('This test does not pass with XLA') @@ -602,7 +610,11 @@ class AutoMixedPrecisionTest(test.TestCase): self._assert_output_fp16(node_map, 'Relu') self._assert_output_fp16(node_map, 'MatMul_1') self._assert_output_fp16(node_map, 'Relu_1') - self.assertAllClose(output_val_ref, output_val, atol=1e-3, rtol=1e-3) + # Bump up the tolerance for the ROCm platform + # The default tolerance (1e-3) results in a tiny fraction (<1%) of + # miscompares on ROCm platform, and hence the tolerance bump + tol = 2e-3 if test.is_built_with_rocm else 1e-3 + self.assertAllClose(output_val_ref, output_val, atol=tol, rtol=tol) @test_util.run_v1_only('b/138749235') @test_util.disable_xla('This test does not pass with XLA') diff --git a/tensorflow/python/grappler/controller.py b/tensorflow/python/grappler/controller.py deleted file mode 100644 index 9f920026714..00000000000 --- a/tensorflow/python/grappler/controller.py +++ /dev/null @@ -1,143 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Controller Class.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from collections import defaultdict - - -class Controller(object): - """Controller class.""" - - def __init__(self, item, cluster): - """Controller class initializer. - - Args: - item: The metagraph to place wrapped in a cluster. - cluster: A cluster of devices on which to place the item. - """ - self.item = item - - self._node = {} - for node in item.metagraph.graph_def.node: - self._node[node.name] = node - - self._fanout = defaultdict(lambda: []) - for node in item.metagraph.graph_def.node: - for fanin in self._get_node_fanin(node): - self._fanout[fanin.name].append(node) - - important_op_names = item.IdentifyImportantOps(sort_topologically=True) - - # List of important ops (these are the ops to place) sorted in topological - # order. The order of this collection is deterministic. - self.important_ops = [] - for name in important_op_names: - self.important_ops.append(self._node[name]) - - self.node_properties = item.GetOpProperties() - - self.cluster = cluster - self.devices = cluster.ListDevices() - - self.colocation_constraints = item.GetColocationGroups() - - self.placement_constraints = cluster.GetSupportedDevices(item) - for node_name, dev in self.placement_constraints.items(): - if len(dev) == 1: - # Place the node on the supported device - node = self._node[node_name] - node.device = dev[0] - fanout = self.get_node_fanout(node) - # Update the fanout of the fanin to bypass the node - for fanin in self._get_node_fanin(node): - fanout_of_fanin = self.get_node_fanout(fanin) - fanout_of_fanin += fanout - fanout_of_fanin.remove(node) - # Remove node from the list of important ops since we don't need to - # place the node. - if node in self.important_ops: - self.important_ops.remove(node) - important_op_names.remove(node.name) - - # List of important op names, in non deterministic order. - self.important_op_names = frozenset(important_op_names) - - @property - def input_graph_def(self): - return self.item.metagraph.graph_def - - @property - def num_devices(self): - return len(self.devices) - - def get_node_by_name(self, node_name): - return self._node[node_name] - - def get_node_fanout(self, node): - return self._fanout[node.name] - - def get_placements(self, *args, **kwargs): - """Returns: Two TF ops. - - Args: - *args: "". - **kwargs: "". - - Returns: - y_preds: tensor of size [batch_size, num_ops] - log_probs: python dict of at least two fields: "sample", "target" each - containing a tensor of size [batch_size], corresponding to the log_probs. - """ - raise NotImplementedError - - def eval_placement(self, sess, *args, **kwargs): - """At this time, this method evaluates ONLY ONE placement. - - Args: - sess: a tf.compat.v1.Session() object used to retrieve cached assignment - info. - *args: "". - **kwargs: "". - - Returns: - run_time: scalar - """ - raise NotImplementedError - - def export_placement(self, metagraph): - """Annotate the placement onto the specified metagraph. - - Args: - metagraph: the metagraph to annotate with the placement. - """ - for node in metagraph.graph_def.node: - if node.name in self.important_op_names: - node.device = self.get_node_by_name(node.name).device - - # Get the nodes in the immediate fanin of node. - # Beware: this doesn't take into account the nodes that may be skipped - # since placement constraints force their placement. - def _get_node_fanin(self, node): - input_ops = [] - for fanin_name in node.input: - if fanin_name[0] == "^": - fanin_name = fanin_name[1:] - fanin_name = fanin_name.split(":")[0] - input_ops.append(self.get_node_by_name(fanin_name)) - return input_ops diff --git a/tensorflow/python/grappler/graph_placer.py b/tensorflow/python/grappler/graph_placer.py deleted file mode 100644 index 9c05ad81790..00000000000 --- a/tensorflow/python/grappler/graph_placer.py +++ /dev/null @@ -1,115 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Graph Placer.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import time -from tensorflow.core.protobuf import config_pb2 -from tensorflow.core.protobuf import meta_graph_pb2 -from tensorflow.python.framework import errors -from tensorflow.python.framework import ops as tf_ops -from tensorflow.python.grappler import cluster as gcluster -from tensorflow.python.grappler import hierarchical_controller -from tensorflow.python.grappler import item as gitem -from tensorflow.python.grappler import tf_optimizer -from tensorflow.python.training import training - - -def PlaceGraph(metagraph, - cluster=None, - allotted_time=3600, - hparams=None, - verbose=False): - """Place the provided metagraph. - - Args: - metagraph: the metagraph to place. - cluster: an optional set of hardware resource to optimize the placement for. - If none is specified, we'll optimize the placement for the hardware - available on the local machine. - allotted_time: the maximum amount to time in seconds to spend optimizing - the placement. - hparams: hyperparameters used to fine tune the placer. - verbose: prints debug information if True. - - Returns: - The placed metagraph. - """ - if cluster is None: - cluster = gcluster.Cluster() - - # Optimize the metagraph to speedup the placement - config = config_pb2.ConfigProto() - optimized_graph = tf_optimizer.OptimizeGraph( - config, metagraph, verbose=verbose, cluster=cluster) - optimized_metagraph = meta_graph_pb2.MetaGraphDef() - optimized_metagraph.CopyFrom(metagraph) - optimized_metagraph.graph_def.CopyFrom(optimized_graph) - - item = gitem.Item(optimized_metagraph) - - # Measure the runtime achievable with the original placement. - try: - _, original_run_time, _ = cluster.MeasureCosts(item) - if verbose: - print("Runtime for original placement: " + str(original_run_time)) - except errors.OpError as e: - if verbose: - print("Original placement isn't feasible: " + str(e)) - original_run_time = hparams.failing_signal - - if hparams is None: - hparams = hierarchical_controller.hierarchical_controller_hparams() - # We run with a single child - hparams.num_children = 1 - - with tf_ops.Graph().as_default(): - # Place all the nodes of the controller on the CPU. We don't want them to - # fight for accelerator memory with the model to optimize. - with tf_ops.device("/device:CPU:0"): - model = hierarchical_controller.HierarchicalController( - hparams, item, cluster) - ops = model.build_controller() - session_creator = training.ChiefSessionCreator() - with training.MonitoredSession(session_creator=session_creator) as sess: - start_time = time.time() - current_time = start_time - while current_time - start_time < allotted_time: - grouping_actions = model.generate_grouping(sess) - input_to_seq2seq = model.create_group_embeddings( - grouping_actions, verbose=verbose) - model.generate_placement(input_to_seq2seq, sess) - try: - run_time = model.eval_placement( - sess, - verbose=verbose) - except errors.OpError as e: - if verbose: - print("Failed to run graph:" + str(e)) - run_time = hparams.failing_signal - updated = model.update_reward(sess, run_time, verbose=verbose) - if updated and run_time < original_run_time: - if verbose: - print("Found better placement, with runtime " + str(run_time)) - model.export_placement(metagraph) - - model.process_reward(sess) - - current_time = time.time() - - return metagraph diff --git a/tensorflow/python/grappler/graph_placer_test.py b/tensorflow/python/grappler/graph_placer_test.py deleted file mode 100644 index 9eabe3cd543..00000000000 --- a/tensorflow/python/grappler/graph_placer_test.py +++ /dev/null @@ -1,140 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests the graph placer.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function -from tensorflow.core.protobuf import device_properties_pb2 -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import meta_graph -from tensorflow.python.framework import ops as tf_ops -from tensorflow.python.grappler import cluster -from tensorflow.python.grappler import graph_placer -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import nn_ops -from tensorflow.python.ops import random_ops -from tensorflow.python.ops import variable_scope -from tensorflow.python.platform import test - - -class GraphPlacerTest(test.TestCase): - - @staticmethod - def _buildMnist(batch_size=128, - input_size=256, - num_classes=1024, - num_layers=10, - hidden_size=256, - name='mnist'): - g = tf_ops.get_default_graph() - with g.as_default(): - ops = {} - x = random_ops.random_uniform( - [batch_size, input_size], -0.1, 0.1, dtype=dtypes.float32) - for layer_id in range(num_layers): - with variable_scope.variable_scope('layer_{}'.format(layer_id)): - a = input_size if layer_id == 0 else hidden_size - b = hidden_size if layer_id < num_layers - 1 else num_classes - w = variable_scope.get_variable('w', [a, b]) - x = math_ops.matmul(x, w) - x = nn_ops.relu(x) - ops['y_preds'] = math_ops.argmax(x, axis=1) - - train_op = g.get_collection_ref(tf_ops.GraphKeys.TRAIN_OP) - train_op.append(ops['y_preds']) - return g - - @staticmethod - def _buildCluster(num_cpus=1, num_gpus=1): - devices = [] - if num_gpus > 0: - device_properties = device_properties_pb2.DeviceProperties( - type='GPU', - vendor='NVidia', - model='GeForce GTX TITAN X', - frequency=1076, - num_cores=24, - environment={'architecture': '5.2', - 'cuda': '8000', - 'cudnn': '6021'}, - num_registers=65536, - l1_cache_size=24576, - l2_cache_size=3145728, - shared_memory_size_per_multiprocessor=98304, - memory_size=12783648768, - bandwidth=336480000) - for i in range(num_gpus): - devices.append( - device_properties_pb2.NamedDevice( - properties=device_properties, name='/GPU:' + str(i))) - - assert num_cpus > 0 - device_properties = device_properties_pb2.DeviceProperties( - type='CPU', - frequency=2000, - num_cores=4, - l1_cache_size=32768, - l2_cache_size=262144, - l3_cache_size=12582912) - for i in range(num_cpus): - devices.append( - device_properties_pb2.NamedDevice( - properties=device_properties, name='/CPU:' + str(i))) - - return cluster.Cluster(devices=devices) - - def testBasic(self): - """Place a trivial graph.""" - a = constant_op.constant(10, name='a') - b = constant_op.constant(20, name='b') - c = math_ops.add_n([a, b], name='c') - d = math_ops.add_n([b, c], name='d') - train_op = tf_ops.get_collection_ref(tf_ops.GraphKeys.TRAIN_OP) - train_op.append(d) - mg = meta_graph.create_meta_graph_def(graph=tf_ops.get_default_graph()) - - gcluster = cluster.Cluster() - placed_mg = graph_placer.PlaceGraph(mg, allotted_time=15, cluster=gcluster) - - self.assertEqual(4, len(placed_mg.graph_def.node)) - self.assertItemsEqual([node.name for node in placed_mg.graph_def.node], - [node.name for node in mg.graph_def.node]) - - available_devices = [device.name for device in gcluster.ListDevices()] - for node in placed_mg.graph_def.node: - # The constant nodes are optimized away before the placer is run, and - # therefore won't be placed. - self.assertTrue(not node.device or node.device in available_devices) - - def testMNIST(self): - graph = GraphPlacerTest._buildMnist() - mg = meta_graph.create_meta_graph_def(graph=graph) - gcluster = GraphPlacerTest._buildCluster(num_gpus=1) - # Spend 15 seconds trying to optimize the placement of the model. This - # should give us enough time to exercise the code, but not enough to find - # a good placement, so we'll just check for legality. - placed_mg = graph_placer.PlaceGraph(mg, allotted_time=15, cluster=gcluster) - self.assertEqual(len(placed_mg.graph_def.node), len(mg.graph_def.node)) - self.assertItemsEqual([node.name for node in placed_mg.graph_def.node], - [node.name for node in mg.graph_def.node]) - available_devices = [device.name for device in gcluster.ListDevices()] - for node in placed_mg.graph_def.node: - self.assertTrue(not node.device or node.device in available_devices) - - -if __name__ == '__main__': - test.main() diff --git a/tensorflow/python/grappler/hierarchical_controller.py b/tensorflow/python/grappler/hierarchical_controller.py deleted file mode 100644 index a6fc5051a6c..00000000000 --- a/tensorflow/python/grappler/hierarchical_controller.py +++ /dev/null @@ -1,1118 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""HierarchicalController Class. - -The HierarchicalController encompasses the entire lifecycle of training the -device placement policy, including generating op embeddings, getting groups for -each op, placing those groups and running the predicted placements. - -Different assignment models can inherit from this class. -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import math - -import numpy as np -import six -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import errors -from tensorflow.python.framework import ops as tf_ops -from tensorflow.python.grappler.controller import Controller -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import clip_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import embedding_ops -from tensorflow.python.ops import init_ops -from tensorflow.python.ops import linalg_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import nn_ops -from tensorflow.python.ops import random_ops -from tensorflow.python.ops import state_ops -from tensorflow.python.ops import tensor_array_ops -from tensorflow.python.ops import variable_scope -from tensorflow.python.summary import summary -from tensorflow.python.training import adam -from tensorflow.python.training import gradient_descent -from tensorflow.python.training import learning_rate_decay -from tensorflow.python.training import training_util - - -class PlacerParams(object): - """Class to hold a set of placement parameters as name-value pairs. - - A typical usage is as follows: - - ```python - # Create a PlacerParams object specifying names and values of the model - # parameters: - params = PlacerParams(hidden_size=128, decay_steps=50) - - # The parameters are available as attributes of the PlacerParams object: - hparams.hidden_size ==> 128 - hparams.decay_steps ==> 50 - ``` - - """ - - def __init__(self, **kwargs): - """Create an instance of `PlacerParams` from keyword arguments. - - The keyword arguments specify name-values pairs for the parameters. - The parameter types are inferred from the type of the values passed. - - The parameter names are added as attributes of `PlacerParams` object, - and they can be accessed directly with the dot notation `params._name_`. - - Example: - - ```python - # Define 1 parameter: 'hidden_size' - params = PlacerParams(hidden_size=128) - params.hidden_size ==> 128 - ``` - - Args: - **kwargs: Key-value pairs where the key is the parameter name and - the value is the value for the parameter. - """ - for name, value in six.iteritems(kwargs): - self.add_param(name, value) - - def add_param(self, name, value): - """Adds {name, value} pair to hyperparameters. - - Args: - name: Name of the hyperparameter. - value: Value of the hyperparameter. Can be one of the following types: - int, float, string, int list, float list, or string list. - - Raises: - ValueError: if one of the arguments is invalid. - """ - # Keys in kwargs are unique, but 'name' could be the name of a pre-existing - # attribute of this object. In that case we refuse to use it as a - # parameter name. - if getattr(self, name, None) is not None: - raise ValueError("Parameter name is reserved: %s" % name) - setattr(self, name, value) - - -def hierarchical_controller_hparams(): - """Hyperparameters for hierarchical planner.""" - return PlacerParams( - hidden_size=512, - forget_bias_init=1.0, - temperature=1.0, - logits_std_noise=0.5, - stop_noise_step=750, - decay_steps=50, - max_num_outputs=5, - max_output_size=5, - tanh_constant=1.0, - adj_embed_dim=20, - grouping_hidden_size=64, - num_groups=None, - bi_lstm=True, - failing_signal=100, - stop_sampling=500, - start_with_failing_signal=True, - always_update_baseline=False, - bl_dec=0.9, - grad_bound=1.0, - lr=0.1, - lr_dec=0.95, - start_decay_step=400, - optimizer_type="adam", - stop_updating_after_steps=1000, - name="hierarchical_controller", - keep_prob=1.0, - reward_function="sqrt", - seed=1234, - # distributed training params - num_children=1) - - -class HierarchicalController(Controller): - """HierarchicalController class.""" - - def __init__(self, hparams, item, cluster, controller_id=0): - """HierarchicalController class initializer. - - Args: - hparams: All hyper-parameters. - item: The metagraph to place. - cluster: The cluster of hardware devices to optimize for. - controller_id: the id of the controller in a multi-controller setup. - """ - super(HierarchicalController, self).__init__(item, cluster) - self.ctrl_id = controller_id - self.hparams = hparams - - if self.hparams.num_groups is None: - self.num_groups = min(256, 20 * self.num_devices) - else: - self.num_groups = self.hparams.num_groups - - # creates self.op_embeddings and self.type_dict - self.create_op_embeddings(verbose=False) - # TODO(azalia) clean up embedding/group_embedding_size names - self.group_emb_size = ( - 2 * self.num_groups + len(self.type_dict) + - self.hparams.max_num_outputs * self.hparams.max_output_size) - self.embedding_size = self.group_emb_size - self.initializer = init_ops.glorot_uniform_initializer( - seed=self.hparams.seed) - - with variable_scope.variable_scope( - self.hparams.name, - initializer=self.initializer, - reuse=variable_scope.AUTO_REUSE): - # define parameters of feedforward - variable_scope.get_variable("w_grouping_ff", [ - 1 + self.hparams.max_num_outputs * self.hparams.max_output_size + - self.hparams.adj_embed_dim, self.hparams.grouping_hidden_size - ]) - variable_scope.get_variable( - "w_grouping_softmax", - [self.hparams.grouping_hidden_size, self.num_groups]) - if self.hparams.bi_lstm: - variable_scope.get_variable("encoder_lstm_forward", [ - self.embedding_size + self.hparams.hidden_size // 2, - 2 * self.hparams.hidden_size - ]) - variable_scope.get_variable("encoder_lstm_backward", [ - self.embedding_size + self.hparams.hidden_size // 2, - 2 * self.hparams.hidden_size - ]) - variable_scope.get_variable( - "device_embeddings", [self.num_devices, self.hparams.hidden_size]) - variable_scope.get_variable( - "decoder_lstm", - [2 * self.hparams.hidden_size, 4 * self.hparams.hidden_size]) - variable_scope.get_variable( - "device_softmax", [2 * self.hparams.hidden_size, self.num_devices]) - variable_scope.get_variable("device_go_embedding", - [1, self.hparams.hidden_size]) - variable_scope.get_variable( - "encoder_forget_bias", - shape=1, - dtype=dtypes.float32, - initializer=init_ops.constant_initializer( - self.hparams.forget_bias_init)) - variable_scope.get_variable( - "decoder_forget_bias", - shape=1, - dtype=dtypes.float32, - initializer=init_ops.constant_initializer( - self.hparams.forget_bias_init)) - variable_scope.get_variable( - "attn_w_1", [self.hparams.hidden_size, self.hparams.hidden_size]) - variable_scope.get_variable( - "attn_w_2", [self.hparams.hidden_size, self.hparams.hidden_size]) - variable_scope.get_variable("attn_v", [self.hparams.hidden_size, 1]) - - else: - variable_scope.get_variable("encoder_lstm", [ - self.embedding_size + self.hparams.hidden_size, - 4 * self.hparams.hidden_size - ]) - variable_scope.get_variable( - "device_embeddings", [self.num_devices, self.hparams.hidden_size]) - variable_scope.get_variable( - "decoder_lstm", - [2 * self.hparams.hidden_size, 4 * self.hparams.hidden_size]) - variable_scope.get_variable( - "device_softmax", [2 * self.hparams.hidden_size, self.num_devices]) - variable_scope.get_variable("device_go_embedding", - [1, self.hparams.hidden_size]) - variable_scope.get_variable( - "encoder_forget_bias", - shape=1, - dtype=dtypes.float32, - initializer=init_ops.constant_initializer( - self.hparams.forget_bias_init)) - variable_scope.get_variable( - "decoder_forget_bias", - shape=1, - dtype=dtypes.float32, - initializer=init_ops.constant_initializer( - self.hparams.forget_bias_init)) - variable_scope.get_variable( - "attn_w_1", [self.hparams.hidden_size, self.hparams.hidden_size]) - variable_scope.get_variable( - "attn_w_2", [self.hparams.hidden_size, self.hparams.hidden_size]) - variable_scope.get_variable("attn_v", [self.hparams.hidden_size, 1]) - seq2seq_input_layer = array_ops.placeholder_with_default( - array_ops.zeros([self.hparams.num_children, - self.num_groups, - self.group_emb_size], - dtypes.float32), - shape=(self.hparams.num_children, self.num_groups, self.group_emb_size)) - self.seq2seq_input_layer = seq2seq_input_layer - - def compute_reward(self, run_time): - if self.hparams.reward_function == "id": - reward = run_time - elif self.hparams.reward_function == "sqrt": - reward = math.sqrt(run_time) - elif self.hparams.reward_function == "log": - reward = math.log1p(run_time) - else: - raise NotImplementedError( - "Unrecognized reward function '%s', consider your " - "--reward_function flag value." % self.hparams.reward_function) - return reward - - def build_controller(self): - """RL optimization interface. - - Returns: - ops: A dictionary holding handles of the model used for training. - """ - - self._global_step = training_util.get_or_create_global_step() - ops = {} - ops["loss"] = 0 - - failing_signal = self.compute_reward(self.hparams.failing_signal) - - ctr = {} - - with tf_ops.name_scope("controller_{}".format(self.ctrl_id)): - with variable_scope.variable_scope("controller_{}".format(self.ctrl_id)): - ctr["reward"] = {"value": [], "ph": [], "update": []} - ctr["ready"] = {"value": [], "ph": [], "update": []} - ctr["best_reward"] = {"value": [], "update": []} - for i in range(self.hparams.num_children): - reward_value = variable_scope.get_local_variable( - "reward_{}".format(i), - initializer=0.0, - dtype=dtypes.float32, - trainable=False) - reward_ph = array_ops.placeholder( - dtypes.float32, shape=(), name="reward_ph_{}".format(i)) - reward_update = state_ops.assign( - reward_value, reward_ph, use_locking=True) - ctr["reward"]["value"].append(reward_value) - ctr["reward"]["ph"].append(reward_ph) - ctr["reward"]["update"].append(reward_update) - best_reward = variable_scope.get_local_variable( - "best_reward_{}".format(i), - initializer=failing_signal, - dtype=dtypes.float32, - trainable=False) - ctr["best_reward"]["value"].append(best_reward) - ctr["best_reward"]["update"].append( - state_ops.assign(best_reward, - math_ops.minimum(best_reward, reward_update))) - - ready_value = variable_scope.get_local_variable( - "ready_{}".format(i), - initializer=True, - dtype=dtypes.bool, - trainable=False) - ready_ph = array_ops.placeholder( - dtypes.bool, shape=(), name="ready_ph_{}".format(i)) - ready_update = state_ops.assign( - ready_value, ready_ph, use_locking=True) - ctr["ready"]["value"].append(ready_value) - ctr["ready"]["ph"].append(ready_ph) - ctr["ready"]["update"].append(ready_update) - - ctr["grouping_y_preds"], ctr["grouping_log_probs"] = self.get_groupings() - summary.histogram( - "grouping_actions", - array_ops.slice(ctr["grouping_y_preds"]["sample"], [0, 0], - [1, array_ops.shape(self.op_embeddings)[0]])) - - with variable_scope.variable_scope("controller_{}".format(self.ctrl_id)): - ctr["baseline"] = variable_scope.get_local_variable( - "baseline", - initializer=failing_signal - if self.hparams.start_with_failing_signal else 0.0, - dtype=dtypes.float32, - trainable=False) - - new_baseline = self.hparams.bl_dec * ctr["baseline"] + ( - 1 - self.hparams.bl_dec) * math_ops.reduce_mean( - ctr["reward"]["value"]) - if not self.hparams.always_update_baseline: - baseline_mask = math_ops.less(ctr["reward"]["value"], failing_signal) - selected_reward = array_ops.boolean_mask(ctr["reward"]["value"], - baseline_mask) - selected_baseline = control_flow_ops.cond( - math_ops.reduce_any(baseline_mask), - lambda: math_ops.reduce_mean(selected_reward), - lambda: constant_op.constant(0, dtype=dtypes.float32)) - ctr["pos_reward"] = selected_baseline - pos_ = math_ops.less( - constant_op.constant(0, dtype=dtypes.float32), selected_baseline) - selected_baseline = self.hparams.bl_dec * ctr["baseline"] + ( - 1 - self.hparams.bl_dec) * selected_baseline - selected_baseline = control_flow_ops.cond( - pos_, lambda: selected_baseline, lambda: ctr["baseline"]) - new_baseline = control_flow_ops.cond( - math_ops.less(self.global_step, - self.hparams.stop_updating_after_steps), - lambda: new_baseline, lambda: selected_baseline) - ctr["baseline_update"] = state_ops.assign( - ctr["baseline"], new_baseline, use_locking=True) - - ctr["y_preds"], ctr["log_probs"] = self.get_placements() - summary.histogram("actions", ctr["y_preds"]["sample"]) - mask = math_ops.less(ctr["reward"]["value"], failing_signal) - ctr["loss"] = ctr["reward"]["value"] - ctr["baseline"] - ctr["loss"] *= ( - ctr["log_probs"]["sample"] + ctr["grouping_log_probs"]["sample"]) - - selected_loss = array_ops.boolean_mask(ctr["loss"], mask) - selected_loss = control_flow_ops.cond( - math_ops.reduce_any(mask), - lambda: math_ops.reduce_mean(-selected_loss), - lambda: constant_op.constant(0, dtype=dtypes.float32)) - - ctr["loss"] = control_flow_ops.cond( - math_ops.less(self.global_step, - self.hparams.stop_updating_after_steps), - lambda: math_ops.reduce_mean(-ctr["loss"]), lambda: selected_loss) - - ctr["reward_s"] = math_ops.reduce_mean(ctr["reward"]["value"]) - summary.scalar("loss", ctr["loss"]) - summary.scalar("avg_reward", ctr["reward_s"]) - summary.scalar("best_reward_so_far", best_reward) - summary.scalar( - "advantage", - math_ops.reduce_mean(ctr["reward"]["value"] - ctr["baseline"])) - - with variable_scope.variable_scope( - "optimizer", reuse=variable_scope.AUTO_REUSE): - (ctr["train_op"], ctr["lr"], ctr["grad_norm"], - ctr["grad_norms"]) = self._get_train_ops( - ctr["loss"], - tf_ops.get_collection(tf_ops.GraphKeys.TRAINABLE_VARIABLES), - self.global_step, - grad_bound=self.hparams.grad_bound, - lr_init=self.hparams.lr, - lr_dec=self.hparams.lr_dec, - start_decay_step=self.hparams.start_decay_step, - decay_steps=self.hparams.decay_steps, - optimizer_type=self.hparams.optimizer_type) - - summary.scalar("gradnorm", ctr["grad_norm"]) - summary.scalar("lr", ctr["lr"]) - ctr["summary"] = summary.merge_all() - ops["controller"] = ctr - - self.ops = ops - return ops - - @property - def global_step(self): - return self._global_step - - def create_op_embeddings(self, verbose=False): - if verbose: - print("process input graph for op embeddings") - self.num_ops = len(self.important_ops) - # topological sort of important nodes - topo_order = [op.name for op in self.important_ops] - - # create index to name for topologicaly sorted important nodes - name_to_topo_order_index = {} - for idx, x in enumerate(topo_order): - name_to_topo_order_index[x] = idx - self.name_to_topo_order_index = name_to_topo_order_index - - # create adj matrix - adj_dict = {} - for idx, op in enumerate(self.important_ops): - for output_op in self.get_node_fanout(op): - output_op_name = output_op.name - if output_op_name in self.important_op_names: - if name_to_topo_order_index[op.name] not in adj_dict: - adj_dict[name_to_topo_order_index[op.name]] = [] - adj_dict[name_to_topo_order_index[op.name]].extend( - [name_to_topo_order_index[output_op_name], 1]) - if output_op_name not in adj_dict: - adj_dict[name_to_topo_order_index[output_op_name]] = [] - adj_dict[name_to_topo_order_index[output_op_name]].extend( - [name_to_topo_order_index[op.name], -1]) - - # get op_type op_output_shape, and adj info - output_embed_dim = (self.hparams.max_num_outputs * - self.hparams.max_output_size) - - # TODO(bsteiner): don't filter based on used ops so that we can generalize - # to models that use other types of ops. - used_ops = set() - for node in self.important_ops: - op_type = str(node.op) - used_ops.add(op_type) - - self.type_dict = {} - for op_type in self.cluster.ListAvailableOps(): - if op_type in used_ops: - self.type_dict[op_type] = len(self.type_dict) - - op_types = np.zeros([self.num_ops], dtype=np.int32) - op_output_shapes = np.full( - [self.num_ops, output_embed_dim], -1.0, dtype=np.float32) - for idx, node in enumerate(self.important_ops): - op_types[idx] = self.type_dict[node.op] - # output shape - op_name = node.name - for i, output_prop in enumerate(self.node_properties[op_name]): - if output_prop.shape.__str__() == "": - continue - shape = output_prop.shape - for j, dim in enumerate(shape.dim): - if dim.size >= 0: - if i * self.hparams.max_output_size + j >= output_embed_dim: - break - op_output_shapes[idx, - i * self.hparams.max_output_size + j] = dim.size - # adj for padding - op_adj = np.full( - [self.num_ops, self.hparams.adj_embed_dim], 0, dtype=np.float32) - for idx in adj_dict: - neighbors = adj_dict[int(idx)] - min_dim = min(self.hparams.adj_embed_dim, len(neighbors)) - padding_size = self.hparams.adj_embed_dim - min_dim - neighbors = neighbors[:min_dim] + [0] * padding_size - op_adj[int(idx)] = neighbors - - # op_embedding starts here - op_embeddings = np.zeros( - [ - self.num_ops, - 1 + self.hparams.max_num_outputs * self.hparams.max_output_size + - self.hparams.adj_embed_dim - ], - dtype=np.float32) - for idx, op_name in enumerate(topo_order): - op_embeddings[idx] = np.concatenate( - (np.array([op_types[idx]]), op_output_shapes[idx], op_adj[int(idx)])) - self.op_embeddings = constant_op.constant( - op_embeddings, dtype=dtypes.float32) - if verbose: - print("num_ops = {}".format(self.num_ops)) - print("num_types = {}".format(len(self.type_dict))) - - def get_groupings(self, *args, **kwargs): - num_children = self.hparams.num_children - with variable_scope.variable_scope("controller_{}".format(self.ctrl_id)): - grouping_actions_cache = variable_scope.get_local_variable( - "grouping_actions_cache", - initializer=init_ops.zeros_initializer, - dtype=dtypes.int32, - shape=[num_children, self.num_ops], - trainable=False) - input_layer = self.op_embeddings - input_layer = array_ops.expand_dims(input_layer, 0) - feed_ff_input_layer = array_ops.tile(input_layer, [num_children, 1, 1]) - grouping_actions, grouping_log_probs = {}, {} - grouping_actions["sample"], grouping_log_probs[ - "sample"] = self.make_grouping_predictions(feed_ff_input_layer) - - grouping_actions["sample"] = state_ops.assign(grouping_actions_cache, - grouping_actions["sample"]) - self.grouping_actions_cache = grouping_actions_cache - - return grouping_actions, grouping_log_probs - - def make_grouping_predictions(self, input_layer, reuse=None): - """model that predicts grouping (grouping_actions). - - Args: - input_layer: group_input_layer - reuse: reuse - - Returns: - grouping_actions: actions - grouping_log_probs: log probabilities corresponding to actions - """ - with variable_scope.variable_scope(self.hparams.name, reuse=True): - # input_layer: tensor of size [1, num_ops, hidden_size] - w_grouping_ff = variable_scope.get_variable("w_grouping_ff") - w_grouping_softmax = variable_scope.get_variable("w_grouping_softmax") - - batch_size = array_ops.shape(input_layer)[0] - embedding_dim = array_ops.shape(input_layer)[2] - - reshaped = array_ops.reshape(input_layer, - [batch_size * self.num_ops, embedding_dim]) - ff_output = math_ops.matmul(reshaped, w_grouping_ff) - logits = math_ops.matmul(ff_output, w_grouping_softmax) - if self.hparams.logits_std_noise > 0: - num_in_logits = math_ops.cast( - array_ops.size(logits), dtype=dtypes.float32) - avg_norm = math_ops.divide( - linalg_ops.norm(logits), math_ops.sqrt(num_in_logits)) - logits_noise = random_ops.random_normal( - array_ops.shape(logits), - stddev=self.hparams.logits_std_noise * avg_norm) - logits = control_flow_ops.cond( - self.global_step > self.hparams.stop_noise_step, lambda: logits, - lambda: logits + logits_noise) - logits = array_ops.reshape(logits, - [batch_size * self.num_ops, self.num_groups]) - actions = random_ops.multinomial(logits, 1, seed=self.hparams.seed) - actions = math_ops.cast(actions, dtypes.int32) - actions = array_ops.reshape(actions, [batch_size, self.num_ops]) - action_label = array_ops.reshape(actions, [-1]) - log_probs = nn_ops.sparse_softmax_cross_entropy_with_logits( - logits=logits, labels=action_label) - log_probs = array_ops.reshape(log_probs, [batch_size, -1]) - log_probs = math_ops.reduce_sum(log_probs, 1) - grouping_actions = actions - grouping_log_probs = log_probs - return grouping_actions, grouping_log_probs - - def create_group_embeddings(self, grouping_actions, verbose=False): - """Approximating the blocks of a TF graph from a graph_def. - - Args: - grouping_actions: grouping predictions. - verbose: print stuffs. - - Returns: - groups: list of groups. - """ - groups = [ - self._create_group_embeddings(grouping_actions, i, verbose) for - i in range(self.hparams.num_children) - ] - return np.stack(groups, axis=0) - - def _create_group_embeddings(self, grouping_actions, child_id, verbose=False): - """Approximating the blocks of a TF graph from a graph_def for each child. - - Args: - grouping_actions: grouping predictions. - child_id: child_id for the group. - verbose: print stuffs. - - Returns: - groups: group embedding for the child_id. - """ - if verbose: - print("Processing input_graph") - - # TODO(azalia): Build inter-adjacencies dag matrix. - # record dag_matrix - dag_matrix = np.zeros([self.num_groups, self.num_groups], dtype=np.float32) - for op in self.important_ops: - topo_op_index = self.name_to_topo_order_index[op.name] - group_index = grouping_actions[child_id][topo_op_index] - for output_op in self.get_node_fanout(op): - if output_op.name not in self.important_op_names: - continue - output_group_index = ( - grouping_actions[child_id][self.name_to_topo_order_index[ - output_op.name]]) - dag_matrix[group_index, output_group_index] += 1.0 - num_connections = np.sum(dag_matrix) - num_intra_group_connections = dag_matrix.trace() - num_inter_group_connections = num_connections - num_intra_group_connections - if verbose: - print("grouping evaluation metric") - print(("num_connections={} num_intra_group_connections={} " - "num_inter_group_connections={}").format( - num_connections, num_intra_group_connections, - num_inter_group_connections)) - self.dag_matrix = dag_matrix - - # output_shape - op_output_shapes = np.zeros( - [ - len(self.important_ops), - self.hparams.max_num_outputs * self.hparams.max_output_size - ], - dtype=np.float32) - - for idx, op in enumerate(self.important_ops): - for i, output_properties in enumerate(self.node_properties[op.name]): - if output_properties.shape.__str__() == "": - continue - if i > self.hparams.max_num_outputs: - break - shape = output_properties.shape - for j, dim in enumerate(shape.dim): - if dim.size > 0: - k = i * self.hparams.max_output_size + j - if k >= self.hparams.max_num_outputs * self.hparams.max_output_size: - break - op_output_shapes[idx, k] = dim.size - - # group_embedding - group_embedding = np.zeros( - [ - self.num_groups, len(self.type_dict) + - self.hparams.max_num_outputs * self.hparams.max_output_size - ], - dtype=np.float32) - for op_index, op in enumerate(self.important_ops): - group_index = grouping_actions[child_id][ - self.name_to_topo_order_index[op.name]] - type_name = str(op.op) - type_index = self.type_dict[type_name] - group_embedding[group_index, type_index] += 1 - group_embedding[group_index, :self.hparams.max_num_outputs * self.hparams. - max_output_size] += ( - op_output_shapes[op_index]) - grouping_adjacencies = np.concatenate( - [dag_matrix, np.transpose(dag_matrix)], axis=1) - group_embedding = np.concatenate( - [grouping_adjacencies, group_embedding], axis=1) - group_normalizer = np.amax(group_embedding, axis=1, keepdims=True) - group_embedding /= (group_normalizer + 1.0) - if verbose: - print("Finished Processing Input Graph") - return group_embedding - - def get_placements(self, *args, **kwargs): - num_children = self.hparams.num_children - with variable_scope.variable_scope("controller_{}".format(self.ctrl_id)): - actions_cache = variable_scope.get_local_variable( - "actions_cache", - initializer=init_ops.zeros_initializer, - dtype=dtypes.int32, - shape=[num_children, self.num_groups], - trainable=False) - - x = self.seq2seq_input_layer - last_c, last_h, attn_mem = self.encode(x) - actions, log_probs = {}, {} - actions["sample"], log_probs["sample"] = ( - self.decode( - x, last_c, last_h, attn_mem, mode="sample")) - actions["target"], log_probs["target"] = ( - self.decode( - x, - last_c, - last_h, - attn_mem, - mode="target", - y=actions_cache)) - actions["greedy"], log_probs["greedy"] = ( - self.decode( - x, last_c, last_h, attn_mem, mode="greedy")) - actions["sample"] = control_flow_ops.cond( - self.global_step < self.hparams.stop_sampling, - lambda: state_ops.assign(actions_cache, actions["sample"]), - lambda: state_ops.assign(actions_cache, actions["target"])) - self.actions_cache = actions_cache - - return actions, log_probs - - def encode(self, x): - """Encoder using LSTM. - - Args: - x: tensor of size [num_children, num_groups, embedding_size] - - Returns: - last_c, last_h: tensors of size [num_children, hidden_size], the final - LSTM states - attn_mem: tensor of size [num_children, num_groups, hidden_size], the - attention - memory, i.e. concatenation of all hidden states, linearly transformed by - an attention matrix attn_w_1 - """ - if self.hparams.bi_lstm: - with variable_scope.variable_scope(self.hparams.name, reuse=True): - w_lstm_forward = variable_scope.get_variable("encoder_lstm_forward") - w_lstm_backward = variable_scope.get_variable("encoder_lstm_backward") - forget_bias = variable_scope.get_variable("encoder_forget_bias") - attn_w_1 = variable_scope.get_variable("attn_w_1") - else: - with variable_scope.variable_scope(self.hparams.name, reuse=True): - w_lstm = variable_scope.get_variable("encoder_lstm") - forget_bias = variable_scope.get_variable("encoder_forget_bias") - attn_w_1 = variable_scope.get_variable("attn_w_1") - - embedding_size = array_ops.shape(x)[2] - - signals = array_ops.split(x, self.num_groups, axis=1) - for i in range(len(signals)): - signals[i] = array_ops.reshape( - signals[i], [self.hparams.num_children, embedding_size]) - - if self.hparams.bi_lstm: - - def body(i, prev_c_forward, prev_h_forward, prev_c_backward, - prev_h_backward): - """while loop for LSTM.""" - signal_forward = signals[i] - next_c_forward, next_h_forward = lstm(signal_forward, prev_c_forward, - prev_h_forward, w_lstm_forward, - forget_bias) - - signal_backward = signals[self.num_groups - 1 - i] - next_c_backward, next_h_backward = lstm( - signal_backward, prev_c_backward, prev_h_backward, w_lstm_backward, - forget_bias) - - next_h = array_ops.concat([next_h_forward, next_h_backward], axis=1) - all_h.append(next_h) - - return (next_c_forward, next_h_forward, next_c_backward, - next_h_backward) - - c_forward = array_ops.zeros( - [self.hparams.num_children, self.hparams.hidden_size / 2], - dtype=dtypes.float32) - h_forward = array_ops.zeros( - [self.hparams.num_children, self.hparams.hidden_size / 2], - dtype=dtypes.float32) - - c_backward = array_ops.zeros( - [self.hparams.num_children, self.hparams.hidden_size / 2], - dtype=dtypes.float32) - h_backward = array_ops.zeros( - [self.hparams.num_children, self.hparams.hidden_size / 2], - dtype=dtypes.float32) - all_h = [] - - for i in range(0, self.num_groups): - c_forward, h_forward, c_backward, h_backward = body( - i, c_forward, h_forward, c_backward, h_backward) - - last_c = array_ops.concat([c_forward, c_backward], axis=1) - last_h = array_ops.concat([h_forward, h_backward], axis=1) - attn_mem = array_ops.stack(all_h) - - else: - - def body(i, prev_c, prev_h): - signal = signals[i] - next_c, next_h = lstm(signal, prev_c, prev_h, w_lstm, forget_bias) - all_h.append(next_h) - return next_c, next_h - - c = array_ops.zeros( - [self.hparams.num_children, self.hparams.hidden_size], - dtype=dtypes.float32) - h = array_ops.zeros( - [self.hparams.num_children, self.hparams.hidden_size], - dtype=dtypes.float32) - all_h = [] - - for i in range(0, self.num_groups): - c, h = body(i, c, h) - - last_c = c - last_h = h - attn_mem = array_ops.stack(all_h) - - attn_mem = array_ops.transpose(attn_mem, [1, 0, 2]) - attn_mem = array_ops.reshape( - attn_mem, - [self.hparams.num_children * self.num_groups, self.hparams.hidden_size]) - attn_mem = math_ops.matmul(attn_mem, attn_w_1) - attn_mem = array_ops.reshape( - attn_mem, - [self.hparams.num_children, self.num_groups, self.hparams.hidden_size]) - - return last_c, last_h, attn_mem - - def decode(self, - x, - last_c, - last_h, - attn_mem, - mode="target", - y=None): - """Decoder using LSTM. - - Args: - x: tensor of size [num_children, num_groups, embedding_size]. - last_c: tensor of size [num_children, hidden_size], the final LSTM states - computed by self.encoder. - last_h: same as last_c. - attn_mem: tensor of size [num_children, num_groups, hidden_size]. - mode: "target" or "sample". - y: tensor of size [num_children, num_groups], the device placements. - - Returns: - actions: tensor of size [num_children, num_groups], the placements of - devices - """ - with variable_scope.variable_scope(self.hparams.name, reuse=True): - w_lstm = variable_scope.get_variable("decoder_lstm") - forget_bias = variable_scope.get_variable("decoder_forget_bias") - device_embeddings = variable_scope.get_variable("device_embeddings") - device_softmax = variable_scope.get_variable("device_softmax") - device_go_embedding = variable_scope.get_variable("device_go_embedding") - attn_w_2 = variable_scope.get_variable("attn_w_2") - attn_v = variable_scope.get_variable("attn_v") - - actions = tensor_array_ops.TensorArray( - dtypes.int32, - size=self.num_groups, - infer_shape=False, - clear_after_read=False) - - # pylint: disable=unused-argument - def condition(i, *args): - return math_ops.less(i, self.num_groups) - - # pylint: disable=missing-docstring - def body(i, prev_c, prev_h, actions, log_probs): - # pylint: disable=g-long-lambda - signal = control_flow_ops.cond( - math_ops.equal(i, 0), - lambda: array_ops.tile(device_go_embedding, - [self.hparams.num_children, 1]), - lambda: embedding_ops.embedding_lookup(device_embeddings, - actions.read(i - 1)) - ) - if self.hparams.keep_prob is not None: - signal = nn_ops.dropout(signal, rate=(1 - self.hparams.keep_prob)) - next_c, next_h = lstm(signal, prev_c, prev_h, w_lstm, forget_bias) - query = math_ops.matmul(next_h, attn_w_2) - query = array_ops.reshape( - query, [self.hparams.num_children, 1, self.hparams.hidden_size]) - query = math_ops.tanh(query + attn_mem) - query = array_ops.reshape(query, [ - self.hparams.num_children * self.num_groups, self.hparams.hidden_size - ]) - query = math_ops.matmul(query, attn_v) - query = array_ops.reshape(query, - [self.hparams.num_children, self.num_groups]) - query = nn_ops.softmax(query) - query = array_ops.reshape(query, - [self.hparams.num_children, self.num_groups, 1]) - query = math_ops.reduce_sum(attn_mem * query, axis=1) - query = array_ops.concat([next_h, query], axis=1) - logits = math_ops.matmul(query, device_softmax) - logits /= self.hparams.temperature - if self.hparams.tanh_constant > 0: - logits = math_ops.tanh(logits) * self.hparams.tanh_constant - if self.hparams.logits_std_noise > 0: - num_in_logits = math_ops.cast( - array_ops.size(logits), dtype=dtypes.float32) - avg_norm = math_ops.divide( - linalg_ops.norm(logits), math_ops.sqrt(num_in_logits)) - logits_noise = random_ops.random_normal( - array_ops.shape(logits), - stddev=self.hparams.logits_std_noise * avg_norm) - logits = control_flow_ops.cond( - self.global_step > self.hparams.stop_noise_step, lambda: logits, - lambda: logits + logits_noise) - - if mode == "sample": - next_y = random_ops.multinomial(logits, 1, seed=self.hparams.seed) - elif mode == "greedy": - next_y = math_ops.argmax(logits, 1) - elif mode == "target": - next_y = array_ops.slice(y, [0, i], [-1, 1]) - else: - raise NotImplementedError - next_y = math_ops.cast(next_y, dtypes.int32) - next_y = array_ops.reshape(next_y, [self.hparams.num_children]) - actions = actions.write(i, next_y) - log_probs += nn_ops.sparse_softmax_cross_entropy_with_logits( - logits=logits, labels=next_y) - return i + 1, next_c, next_h, actions, log_probs - - loop_vars = [ - constant_op.constant(0, dtype=dtypes.int32), last_c, last_h, actions, - array_ops.zeros([self.hparams.num_children], dtype=dtypes.float32) - ] - loop_outputs = control_flow_ops.while_loop(condition, body, loop_vars) - - last_c = loop_outputs[-4] - last_h = loop_outputs[-3] - actions = loop_outputs[-2].stack() - actions = array_ops.transpose(actions, [1, 0]) - log_probs = loop_outputs[-1] - return actions, log_probs - - def eval_placement(self, - sess, - child_id=0, - verbose=False): - grouping_actions, actions = sess.run([ - self.grouping_actions_cache, - self.actions_cache - ]) - grouping_actions = grouping_actions[child_id] - actions = actions[child_id] - if verbose: - global_step = sess.run(self.global_step) - if global_step % 100 == 0: - log_string = "op group assignments: " - for a in grouping_actions: - log_string += "{} ".format(a) - print(log_string[:-1]) - log_string = "group device assignments: " - for a in actions: - log_string += "{} ".format(a) - print(log_string[:-1]) - - for op in self.important_ops: - topo_order_index = self.name_to_topo_order_index[op.name] - group_index = grouping_actions[topo_order_index] - op.device = self.devices[actions[group_index]].name - try: - _, run_time, _ = self.cluster.MeasureCosts(self.item) - except errors.ResourceExhaustedError: - run_time = self.hparams.failing_signal - return run_time - - def update_reward(self, - sess, - run_time, - child_id=0, - verbose=False): - reward = self.compute_reward(run_time) - controller_ops = self.ops["controller"] - _, best_reward = sess.run( - [ - controller_ops["reward"]["update"][child_id], - controller_ops["best_reward"]["update"][child_id] - ], - feed_dict={ - controller_ops["reward"]["ph"][child_id]: reward, - }) - if verbose: - print(("run_time={:<.5f} reward={:<.5f} " - "best_reward={:<.5f}").format(run_time, reward, best_reward)) - - # Reward is a double, best_reward a float: allow for some slack in the - # comparison. - updated = abs(best_reward - reward) < 1e-6 - return updated - - def generate_grouping(self, sess): - controller_ops = self.ops["controller"] - grouping_actions = sess.run(controller_ops["grouping_y_preds"]["sample"]) - return grouping_actions - - def generate_placement(self, grouping, sess): - controller_ops = self.ops["controller"] - feed_seq2seq_input_dict = {} - feed_seq2seq_input_dict[self.seq2seq_input_layer] = grouping - sess.run( - controller_ops["y_preds"]["sample"], feed_dict=feed_seq2seq_input_dict) - - def process_reward(self, sess): - controller_ops = self.ops["controller"] - run_ops = [ - controller_ops["loss"], controller_ops["lr"], - controller_ops["grad_norm"], controller_ops["grad_norms"], - controller_ops["train_op"] - ] - sess.run(run_ops) - sess.run(controller_ops["baseline_update"]) - - def _get_train_ops(self, - loss, - tf_variables, - global_step, - grad_bound=1.25, - lr_init=1e-3, - lr_dec=0.9, - start_decay_step=10000, - decay_steps=100, - optimizer_type="adam"): - """Loss optimizer. - - Args: - loss: scalar tf tensor - tf_variables: list of training variables, typically - tf.compat.v1.trainable_variables() - global_step: global_step - grad_bound: max gradient norm - lr_init: initial learning rate - lr_dec: leaning rate decay coefficient - start_decay_step: start decaying learning rate after this many steps - decay_steps: apply decay rate factor at this step intervals - optimizer_type: optimizer type should be either adam or sgd - - Returns: - train_op: training op - learning_rate: scalar learning rate tensor - grad_norm: l2 norm of the gradient vector - all_grad_norms: l2 norm of each component - """ - lr_gstep = global_step - start_decay_step - - def f1(): - return constant_op.constant(lr_init) - - def f2(): - return learning_rate_decay.exponential_decay(lr_init, lr_gstep, - decay_steps, lr_dec, True) - - learning_rate = control_flow_ops.cond( - math_ops.less(global_step, start_decay_step), - f1, - f2, - name="learning_rate") - - if optimizer_type == "adam": - opt = adam.AdamOptimizer(learning_rate) - elif optimizer_type == "sgd": - opt = gradient_descent.GradientDescentOptimizer(learning_rate) - grads_and_vars = opt.compute_gradients(loss, tf_variables) - grad_norm = clip_ops.global_norm([g for g, v in grads_and_vars]) - all_grad_norms = {} - clipped_grads = [] - clipped_rate = math_ops.maximum(grad_norm / grad_bound, 1.0) - for g, v in grads_and_vars: - if g is not None: - if isinstance(g, tf_ops.IndexedSlices): - clipped = g.values / clipped_rate - norm_square = math_ops.reduce_sum(clipped * clipped) - clipped = tf_ops.IndexedSlices(clipped, g.indices) - else: - clipped = g / clipped_rate - norm_square = math_ops.reduce_sum(clipped * clipped) - all_grad_norms[v.name] = math_ops.sqrt(norm_square) - clipped_grads.append((clipped, v)) - - train_op = opt.apply_gradients(clipped_grads, global_step) - return train_op, learning_rate, grad_norm, all_grad_norms - - -def lstm(x, prev_c, prev_h, w_lstm, forget_bias): - """LSTM cell. - - Args: - x: tensors of size [num_children, hidden_size]. - prev_c: tensors of size [num_children, hidden_size]. - prev_h: same as prev_c. - w_lstm: . - forget_bias: . - - Returns: - next_c: - next_h: - """ - ifog = math_ops.matmul(array_ops.concat([x, prev_h], axis=1), w_lstm) - i, f, o, g = array_ops.split(ifog, 4, axis=1) - i = math_ops.sigmoid(i) - f = math_ops.sigmoid(f + forget_bias) - o = math_ops.sigmoid(o) - g = math_ops.tanh(g) - next_c = i * g + f * prev_c - next_h = o * math_ops.tanh(next_c) - return next_c, next_h diff --git a/tensorflow/python/integration_testing/BUILD b/tensorflow/python/integration_testing/BUILD new file mode 100644 index 00000000000..30cff1016aa --- /dev/null +++ b/tensorflow/python/integration_testing/BUILD @@ -0,0 +1,9 @@ +# Description: +# This directory is only for tests that should be run on a pip whl. + +package( + default_visibility = ["//tensorflow:internal"], + licenses = ["notice"], # Apache 2.0 +) + +exports_files(["LICENSE"]) diff --git a/tensorflow/python/keras/BUILD b/tensorflow/python/keras/BUILD index 6af56e7ab77..b98154b9095 100755 --- a/tensorflow/python/keras/BUILD +++ b/tensorflow/python/keras/BUILD @@ -16,7 +16,6 @@ py_library( "__init__.py", "estimator/__init__.py", "keras_parameterized.py", - "ops.py", ], srcs_version = "PY2AND3", deps = [ @@ -189,7 +188,9 @@ py_library( py_library( name = "initializers", srcs = [ - "initializers.py", + "initializers/__init__.py", + "initializers/initializers_v1.py", + "initializers/initializers_v2.py", ], srcs_version = "PY2AND3", deps = [ diff --git a/tensorflow/python/keras/api/BUILD b/tensorflow/python/keras/api/BUILD index 32c5e87a8f9..41a3f13e3eb 100644 --- a/tensorflow/python/keras/api/BUILD +++ b/tensorflow/python/keras/api/BUILD @@ -47,6 +47,8 @@ keras_packages = [ "tensorflow.python.keras.engine.training", "tensorflow.python.keras.estimator", "tensorflow.python.keras.initializers", + "tensorflow.python.keras.initializers.initializers_v1", + "tensorflow.python.keras.initializers.initializers_v2", "tensorflow.python.keras.layers.advanced_activations", "tensorflow.python.keras.layers.convolutional", "tensorflow.python.keras.layers.convolutional_recurrent", @@ -71,7 +73,6 @@ keras_packages = [ "tensorflow.python.keras.mixed_precision.experimental.loss_scale_optimizer", "tensorflow.python.keras.mixed_precision.experimental.policy", "tensorflow.python.keras.models", - "tensorflow.python.keras.ops", "tensorflow.python.keras.optimizer_v2.adadelta", "tensorflow.python.keras.optimizer_v2.adagrad", "tensorflow.python.keras.optimizer_v2.adam", diff --git a/tensorflow/python/keras/applications/BUILD b/tensorflow/python/keras/applications/BUILD index 992010a809e..1eaed45c714 100644 --- a/tensorflow/python/keras/applications/BUILD +++ b/tensorflow/python/keras/applications/BUILD @@ -47,6 +47,7 @@ tf_py_test( size = "medium", srcs = ["applications_test.py"], shard_count = 36, + tags = ["no_rocm"], deps = [ ":applications", "//tensorflow/python:client_testlib", diff --git a/tensorflow/python/keras/applications/densenet.py b/tensorflow/python/keras/applications/densenet.py index 9b11c342536..fe353bcef15 100644 --- a/tensorflow/python/keras/applications/densenet.py +++ b/tensorflow/python/keras/applications/densenet.py @@ -26,9 +26,9 @@ from __future__ import print_function import os from tensorflow.python.keras import backend -from tensorflow.python.keras import layers from tensorflow.python.keras.applications import imagenet_utils from tensorflow.python.keras.engine import training +from tensorflow.python.keras.layers import VersionAwareLayers from tensorflow.python.keras.utils import data_utils from tensorflow.python.keras.utils import layer_utils from tensorflow.python.util.tf_export import keras_export @@ -52,6 +52,8 @@ DENSENET201_WEIGHT_PATH_NO_TOP = ( BASE_WEIGTHS_PATH + 'densenet201_weights_tf_dim_ordering_tf_kernels_notop.h5') +layers = VersionAwareLayers() + def dense_block(x, blocks, name): """A dense block. @@ -133,8 +135,7 @@ def DenseNet( input_shape=None, pooling=None, classes=1000, - classifier_activation='softmax', -): + classifier_activation='softmax'): """Instantiates the DenseNet architecture. Reference paper: @@ -358,37 +359,12 @@ def DenseNet201(include_top=True, @keras_export('keras.applications.densenet.preprocess_input') def preprocess_input(x, data_format=None): - """Preprocesses a numpy array encoding a batch of images. - - Arguments - x: A 4D numpy array consists of RGB values within [0, 255]. - - Returns - Preprocessed array. - - Raises - ValueError: In case of unknown `data_format` argument. - """ return imagenet_utils.preprocess_input( x, data_format=data_format, mode='torch') @keras_export('keras.applications.densenet.decode_predictions') def decode_predictions(preds, top=5): - """Decodes the prediction result from the model. - - Arguments - preds: Numpy tensor encoding a batch of predictions. - top: Integer, how many top-guesses to return. - - Returns - A list of lists of top class prediction tuples - `(class_name, class_description, score)`. - One list of tuples per sample in batch input. - - Raises - ValueError: In case of invalid shape of the `preds` array (must be 2D). - """ return imagenet_utils.decode_predictions(preds, top=top) @@ -405,7 +381,7 @@ DOC = """ Optionally loads weights pre-trained on ImageNet. Note that the data format convention used by the model is the one specified in your Keras config at `~/.keras/keras.json`. - + Arguments: include_top: whether to include the fully-connected layer at the top of the network. diff --git a/tensorflow/python/keras/applications/efficientnet.py b/tensorflow/python/keras/applications/efficientnet.py index 4b9487dcdd6..0487450f880 100644 --- a/tensorflow/python/keras/applications/efficientnet.py +++ b/tensorflow/python/keras/applications/efficientnet.py @@ -28,9 +28,9 @@ import math import os from tensorflow.python.keras import backend -from tensorflow.python.keras import layers from tensorflow.python.keras.applications import imagenet_utils from tensorflow.python.keras.engine import training +from tensorflow.python.keras.layers import VersionAwareLayers from tensorflow.python.keras.utils import data_utils from tensorflow.python.keras.utils import layer_utils from tensorflow.python.util.tf_export import keras_export @@ -140,6 +140,8 @@ DENSE_KERNEL_INITIALIZER = { } } +layers = VersionAwareLayers() + def EfficientNet( width_coefficient, @@ -157,8 +159,7 @@ def EfficientNet( input_shape=None, pooling=None, classes=1000, - classifier_activation='softmax', -): + classifier_activation='softmax'): """Instantiates the EfficientNet architecture using given scaling coefficients. Reference paper: @@ -664,18 +665,7 @@ def preprocess_input(x, data_format=None): # pylint: disable=unused-argument @keras_export('keras.applications.efficientnet.decode_predictions') def decode_predictions(preds, top=5): - """Decodes the prediction result from the model. - - Arguments - preds: Numpy tensor encoding a batch of predictions. - top: Integer, how many top-guesses to return. - - Returns - A list of lists of top class prediction tuples - `(class_name, class_description, score)`. - One list of tuples per sample in batch input. - - Raises - ValueError: In case of invalid shape of the `preds` array (must be 2D). - """ return imagenet_utils.decode_predictions(preds, top=top) + + +decode_predictions.__doc__ = imagenet_utils.decode_predictions.__doc__ diff --git a/tensorflow/python/keras/applications/inception_resnet_v2.py b/tensorflow/python/keras/applications/inception_resnet_v2.py index 7f338f82597..d4ffd372a10 100644 --- a/tensorflow/python/keras/applications/inception_resnet_v2.py +++ b/tensorflow/python/keras/applications/inception_resnet_v2.py @@ -28,9 +28,9 @@ from __future__ import print_function import os from tensorflow.python.keras import backend -from tensorflow.python.keras import layers from tensorflow.python.keras.applications import imagenet_utils from tensorflow.python.keras.engine import training +from tensorflow.python.keras.layers import VersionAwareLayers from tensorflow.python.keras.utils import data_utils from tensorflow.python.keras.utils import layer_utils from tensorflow.python.util.tf_export import keras_export @@ -38,6 +38,7 @@ from tensorflow.python.util.tf_export import keras_export BASE_WEIGHT_URL = ('https://storage.googleapis.com/tensorflow/' 'keras-applications/inception_resnet_v2/') +layers = None @keras_export('keras.applications.inception_resnet_v2.InceptionResNetV2', @@ -105,9 +106,11 @@ def InceptionResNetV2(include_top=True, ValueError: if `classifier_activation` is not `softmax` or `None` when using a pretrained top layer. """ + global layers if 'layers' in kwargs: - global layers layers = kwargs.pop('layers') + else: + layers = VersionAwareLayers() if kwargs: raise ValueError('Unknown argument(s): %s' % (kwargs,)) if not (weights in {'imagenet', None} or os.path.exists(weights)): @@ -378,36 +381,11 @@ def inception_resnet_block(x, scale, block_type, block_idx, activation='relu'): @keras_export('keras.applications.inception_resnet_v2.preprocess_input') def preprocess_input(x, data_format=None): - """Preprocesses a numpy array encoding a batch of images. - - Arguments - x: A 4D numpy array consists of RGB values within [0, 255]. - - Returns - Preprocessed array. - - Raises - ValueError: In case of unknown `data_format` argument. - """ return imagenet_utils.preprocess_input(x, data_format=data_format, mode='tf') @keras_export('keras.applications.inception_resnet_v2.decode_predictions') def decode_predictions(preds, top=5): - """Decodes the prediction result from the model. - - Arguments - preds: Numpy tensor encoding a batch of predictions. - top: Integer, how many top-guesses to return. - - Returns - A list of lists of top class prediction tuples - `(class_name, class_description, score)`. - One list of tuples per sample in batch input. - - Raises - ValueError: In case of invalid shape of the `preds` array (must be 2D). - """ return imagenet_utils.decode_predictions(preds, top=top) diff --git a/tensorflow/python/keras/applications/inception_v3.py b/tensorflow/python/keras/applications/inception_v3.py index fa44becfe48..21f65b1fbc7 100644 --- a/tensorflow/python/keras/applications/inception_v3.py +++ b/tensorflow/python/keras/applications/inception_v3.py @@ -26,9 +26,9 @@ from __future__ import print_function import os from tensorflow.python.keras import backend -from tensorflow.python.keras import layers from tensorflow.python.keras.applications import imagenet_utils from tensorflow.python.keras.engine import training +from tensorflow.python.keras.layers import VersionAwareLayers from tensorflow.python.keras.utils import data_utils from tensorflow.python.keras.utils import layer_utils from tensorflow.python.util.tf_export import keras_export @@ -41,6 +41,8 @@ WEIGHTS_PATH_NO_TOP = ( 'https://storage.googleapis.com/tensorflow/keras-applications/' 'inception_v3/inception_v3_weights_tf_dim_ordering_tf_kernels_notop.h5') +layers = VersionAwareLayers() + @keras_export('keras.applications.inception_v3.InceptionV3', 'keras.applications.InceptionV3') @@ -51,8 +53,7 @@ def InceptionV3( input_shape=None, pooling=None, classes=1000, - classifier_activation='softmax', -): + classifier_activation='softmax'): """Instantiates the Inception v3 architecture. Reference paper: @@ -406,36 +407,11 @@ def conv2d_bn(x, @keras_export('keras.applications.inception_v3.preprocess_input') def preprocess_input(x, data_format=None): - """Preprocesses a numpy array encoding a batch of images. - - Arguments - x: A 4D numpy array consists of RGB values within [0, 255]. - - Returns - Preprocessed array. - - Raises - ValueError: In case of unknown `data_format` argument. - """ return imagenet_utils.preprocess_input(x, data_format=data_format, mode='tf') @keras_export('keras.applications.inception_v3.decode_predictions') def decode_predictions(preds, top=5): - """Decodes the prediction result from the model. - - Arguments - preds: Numpy tensor encoding a batch of predictions. - top: Integer, how many top-guesses to return. - - Returns - A list of lists of top class prediction tuples - `(class_name, class_description, score)`. - One list of tuples per sample in batch input. - - Raises - ValueError: In case of invalid shape of the `preds` array (must be 2D). - """ return imagenet_utils.decode_predictions(preds, top=top) diff --git a/tensorflow/python/keras/applications/mobilenet.py b/tensorflow/python/keras/applications/mobilenet.py index d935282f98a..c79627c6aa7 100644 --- a/tensorflow/python/keras/applications/mobilenet.py +++ b/tensorflow/python/keras/applications/mobilenet.py @@ -67,9 +67,9 @@ from __future__ import print_function import os from tensorflow.python.keras import backend -from tensorflow.python.keras import layers from tensorflow.python.keras.applications import imagenet_utils from tensorflow.python.keras.engine import training +from tensorflow.python.keras.layers import VersionAwareLayers from tensorflow.python.keras.utils import data_utils from tensorflow.python.keras.utils import layer_utils from tensorflow.python.platform import tf_logging as logging @@ -77,6 +77,7 @@ from tensorflow.python.util.tf_export import keras_export BASE_WEIGHT_PATH = ('https://storage.googleapis.com/tensorflow/' 'keras-applications/mobilenet/') +layers = None @keras_export('keras.applications.mobilenet.MobileNet', @@ -155,9 +156,11 @@ def MobileNet(input_shape=None, ValueError: if `classifier_activation` is not `softmax` or `None` when using a pretrained top layer. """ + global layers if 'layers' in kwargs: - global layers layers = kwargs.pop('layers') + else: + layers = VersionAwareLayers() if kwargs: raise ValueError('Unknown argument(s): %s' % (kwargs,)) if not (weights in {'imagenet', None} or os.path.exists(weights)): @@ -439,36 +442,11 @@ def _depthwise_conv_block(inputs, @keras_export('keras.applications.mobilenet.preprocess_input') def preprocess_input(x, data_format=None): - """Preprocesses a numpy array encoding a batch of images. - - Arguments - x: A 4D numpy array consists of RGB values within [0, 255]. - - Returns - Preprocessed array. - - Raises - ValueError: In case of unknown `data_format` argument. - """ return imagenet_utils.preprocess_input(x, data_format=data_format, mode='tf') @keras_export('keras.applications.mobilenet.decode_predictions') def decode_predictions(preds, top=5): - """Decodes the prediction result from the model. - - Arguments - preds: Numpy tensor encoding a batch of predictions. - top: Integer, how many top-guesses to return. - - Returns - A list of lists of top class prediction tuples - `(class_name, class_description, score)`. - One list of tuples per sample in batch input. - - Raises - ValueError: In case of invalid shape of the `preds` array (must be 2D). - """ return imagenet_utils.decode_predictions(preds, top=top) diff --git a/tensorflow/python/keras/applications/mobilenet_v2.py b/tensorflow/python/keras/applications/mobilenet_v2.py index bdd21c3da62..59aeba572e3 100644 --- a/tensorflow/python/keras/applications/mobilenet_v2.py +++ b/tensorflow/python/keras/applications/mobilenet_v2.py @@ -80,9 +80,9 @@ from __future__ import print_function import os from tensorflow.python.keras import backend -from tensorflow.python.keras import layers from tensorflow.python.keras.applications import imagenet_utils from tensorflow.python.keras.engine import training +from tensorflow.python.keras.layers import VersionAwareLayers from tensorflow.python.keras.utils import data_utils from tensorflow.python.keras.utils import layer_utils from tensorflow.python.platform import tf_logging as logging @@ -90,6 +90,7 @@ from tensorflow.python.util.tf_export import keras_export BASE_WEIGHT_PATH = ('https://storage.googleapis.com/tensorflow/' 'keras-applications/mobilenet_v2/') +layers = None @keras_export('keras.applications.mobilenet_v2.MobileNetV2', @@ -173,9 +174,11 @@ def MobileNetV2(input_shape=None, ValueError: if `classifier_activation` is not `softmax` or `None` when using a pretrained top layer. """ + global layers if 'layers' in kwargs: - global layers layers = kwargs.pop('layers') + else: + layers = VersionAwareLayers() if kwargs: raise ValueError('Unknown argument(s): %s' % (kwargs,)) if not (weights in {'imagenet', None} or os.path.exists(weights)): diff --git a/tensorflow/python/keras/applications/nasnet.py b/tensorflow/python/keras/applications/nasnet.py index 3da415dbb12..5c3117d8a47 100644 --- a/tensorflow/python/keras/applications/nasnet.py +++ b/tensorflow/python/keras/applications/nasnet.py @@ -44,9 +44,9 @@ from __future__ import print_function import os from tensorflow.python.keras import backend -from tensorflow.python.keras import layers from tensorflow.python.keras.applications import imagenet_utils from tensorflow.python.keras.engine import training +from tensorflow.python.keras.layers import VersionAwareLayers from tensorflow.python.keras.utils import data_utils from tensorflow.python.keras.utils import layer_utils from tensorflow.python.platform import tf_logging as logging @@ -60,6 +60,8 @@ NASNET_MOBILE_WEIGHT_PATH_NO_TOP = BASE_WEIGHTS_PATH + 'NASNet-mobile-no-top.h5' NASNET_LARGE_WEIGHT_PATH = BASE_WEIGHTS_PATH + 'NASNet-large.h5' NASNET_LARGE_WEIGHT_PATH_NO_TOP = BASE_WEIGHTS_PATH + 'NASNet-large-no-top.h5' +layers = VersionAwareLayers() + def NASNet( input_shape=None, @@ -74,8 +76,7 @@ def NASNet( pooling=None, classes=1000, default_size=None, - classifier_activation='softmax', -): + classifier_activation='softmax'): """Instantiates a NASNet model. Reference paper: @@ -785,36 +786,11 @@ def _reduction_a_cell(ip, p, filters, block_id=None): @keras_export('keras.applications.nasnet.preprocess_input') def preprocess_input(x, data_format=None): - """Preprocesses a numpy array encoding a batch of images. - - Arguments - x: A 4D numpy array consists of RGB values within [0, 255]. - - Returns - Preprocessed array. - - Raises - ValueError: In case of unknown `data_format` argument. - """ return imagenet_utils.preprocess_input(x, data_format=data_format, mode='tf') @keras_export('keras.applications.nasnet.decode_predictions') def decode_predictions(preds, top=5): - """Decodes the prediction result from the model. - - Arguments - preds: Numpy tensor encoding a batch of predictions. - top: Integer, how many top-guesses to return. - - Returns - A list of lists of top class prediction tuples - `(class_name, class_description, score)`. - One list of tuples per sample in batch input. - - Raises - ValueError: In case of invalid shape of the `preds` array (must be 2D). - """ return imagenet_utils.decode_predictions(preds, top=top) diff --git a/tensorflow/python/keras/applications/resnet.py b/tensorflow/python/keras/applications/resnet.py index 3e33bb04bdd..ecb3f31e0c9 100644 --- a/tensorflow/python/keras/applications/resnet.py +++ b/tensorflow/python/keras/applications/resnet.py @@ -26,9 +26,9 @@ from __future__ import print_function import os from tensorflow.python.keras import backend -from tensorflow.python.keras import layers from tensorflow.python.keras.applications import imagenet_utils from tensorflow.python.keras.engine import training +from tensorflow.python.keras.layers import VersionAwareLayers from tensorflow.python.keras.utils import data_utils from tensorflow.python.keras.utils import layer_utils from tensorflow.python.util.tf_export import keras_export @@ -55,6 +55,8 @@ WEIGHTS_HASHES = { ('34fb605428fcc7aa4d62f44404c11509', '0f678c91647380debd923963594981b3') } +layers = None + def ResNet(stack_fn, preact, @@ -129,9 +131,11 @@ def ResNet(stack_fn, ValueError: if `classifier_activation` is not `softmax` or `None` when using a pretrained top layer. """ + global layers if 'layers' in kwargs: - global layers layers = kwargs.pop('layers') + else: + layers = VersionAwareLayers() if kwargs: raise ValueError('Unknown argument(s): %s' % (kwargs,)) if not (weights in {'imagenet', None} or os.path.exists(weights)): @@ -517,17 +521,6 @@ def ResNet152(include_top=True, @keras_export('keras.applications.resnet50.preprocess_input', 'keras.applications.resnet.preprocess_input') def preprocess_input(x, data_format=None): - """Preprocesses a numpy array encoding a batch of images. - - Arguments - x: A 4D numpy array consists of RGB values within [0, 255]. - - Returns - Preprocessed array. - - Raises - ValueError: In case of unknown `data_format` argument. - """ return imagenet_utils.preprocess_input( x, data_format=data_format, mode='caffe') @@ -535,20 +528,6 @@ def preprocess_input(x, data_format=None): @keras_export('keras.applications.resnet50.decode_predictions', 'keras.applications.resnet.decode_predictions') def decode_predictions(preds, top=5): - """Decodes the prediction result from the model. - - Arguments - preds: Numpy tensor encoding a batch of predictions. - top: Integer, how many top-guesses to return. - - Returns - A list of lists of top class prediction tuples - `(class_name, class_description, score)`. - One list of tuples per sample in batch input. - - Raises - ValueError: In case of invalid shape of the `preds` array (must be 2D). - """ return imagenet_utils.decode_predictions(preds, top=top) @@ -565,7 +544,7 @@ DOC = """ Optionally loads weights pre-trained on ImageNet. Note that the data format convention used by the model is the one specified in your Keras config at `~/.keras/keras.json`. - + Arguments: include_top: whether to include the fully-connected layer at the top of the network. diff --git a/tensorflow/python/keras/applications/resnet_v2.py b/tensorflow/python/keras/applications/resnet_v2.py index 2e1ee272c4b..a8f6e526ad5 100644 --- a/tensorflow/python/keras/applications/resnet_v2.py +++ b/tensorflow/python/keras/applications/resnet_v2.py @@ -37,8 +37,7 @@ def ResNet50V2( input_shape=None, pooling=None, classes=1000, - classifier_activation='softmax', -): + classifier_activation='softmax'): """Instantiates the ResNet50V2 architecture.""" def stack_fn(x): x = resnet.stack2(x, 64, 3, name='conv2') @@ -57,8 +56,7 @@ def ResNet50V2( input_shape, pooling, classes, - classifier_activation=classifier_activation, - ) + classifier_activation=classifier_activation) @keras_export('keras.applications.resnet_v2.ResNet101V2', @@ -70,8 +68,7 @@ def ResNet101V2( input_shape=None, pooling=None, classes=1000, - classifier_activation='softmax', -): + classifier_activation='softmax'): """Instantiates the ResNet101V2 architecture.""" def stack_fn(x): x = resnet.stack2(x, 64, 3, name='conv2') @@ -90,8 +87,7 @@ def ResNet101V2( input_shape, pooling, classes, - classifier_activation=classifier_activation, - ) + classifier_activation=classifier_activation) @keras_export('keras.applications.resnet_v2.ResNet152V2', @@ -103,8 +99,7 @@ def ResNet152V2( input_shape=None, pooling=None, classes=1000, - classifier_activation='softmax', -): + classifier_activation='softmax'): """Instantiates the ResNet152V2 architecture.""" def stack_fn(x): x = resnet.stack2(x, 64, 3, name='conv2') @@ -123,43 +118,17 @@ def ResNet152V2( input_shape, pooling, classes, - classifier_activation=classifier_activation, - ) + classifier_activation=classifier_activation) @keras_export('keras.applications.resnet_v2.preprocess_input') def preprocess_input(x, data_format=None): - """Preprocesses a numpy array encoding a batch of images. - - Arguments - x: A 4D numpy array consists of RGB values within [0, 255]. - - Returns - Preprocessed array. - - Raises - ValueError: In case of unknown `data_format` argument. - """ return imagenet_utils.preprocess_input( x, data_format=data_format, mode='tf') @keras_export('keras.applications.resnet_v2.decode_predictions') def decode_predictions(preds, top=5): - """Decodes the prediction result from the model. - - Arguments - preds: Numpy tensor encoding a batch of predictions. - top: Integer, how many top-guesses to return. - - Returns - A list of lists of top class prediction tuples - `(class_name, class_description, score)`. - One list of tuples per sample in batch input. - - Raises - ValueError: In case of invalid shape of the `preds` array (must be 2D). - """ return imagenet_utils.decode_predictions(preds, top=top) diff --git a/tensorflow/python/keras/applications/vgg16.py b/tensorflow/python/keras/applications/vgg16.py index 534d2cff6be..3a523dc5dc3 100644 --- a/tensorflow/python/keras/applications/vgg16.py +++ b/tensorflow/python/keras/applications/vgg16.py @@ -26,9 +26,9 @@ from __future__ import print_function import os from tensorflow.python.keras import backend -from tensorflow.python.keras import layers from tensorflow.python.keras.applications import imagenet_utils from tensorflow.python.keras.engine import training +from tensorflow.python.keras.layers import VersionAwareLayers from tensorflow.python.keras.utils import data_utils from tensorflow.python.keras.utils import layer_utils from tensorflow.python.util.tf_export import keras_export @@ -40,6 +40,8 @@ WEIGHTS_PATH_NO_TOP = ('https://storage.googleapis.com/tensorflow/' 'keras-applications/vgg16/' 'vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5') +layers = VersionAwareLayers() + @keras_export('keras.applications.vgg16.VGG16', 'keras.applications.VGG16') def VGG16( @@ -49,8 +51,7 @@ def VGG16( input_shape=None, pooling=None, classes=1000, - classifier_activation='softmax', -): + classifier_activation='softmax'): """Instantiates the VGG16 model. Reference paper: @@ -227,37 +228,12 @@ def VGG16( @keras_export('keras.applications.vgg16.preprocess_input') def preprocess_input(x, data_format=None): - """Preprocesses a numpy array encoding a batch of images. - - Arguments - x: A 4D numpy array consists of RGB values within [0, 255]. - - Returns - Preprocessed array. - - Raises - ValueError: In case of unknown `data_format` argument. - """ return imagenet_utils.preprocess_input( x, data_format=data_format, mode='caffe') @keras_export('keras.applications.vgg16.decode_predictions') def decode_predictions(preds, top=5): - """Decodes the prediction result from the model. - - Arguments - preds: Numpy tensor encoding a batch of predictions. - top: Integer, how many top-guesses to return. - - Returns - A list of lists of top class prediction tuples - `(class_name, class_description, score)`. - One list of tuples per sample in batch input. - - Raises - ValueError: In case of invalid shape of the `preds` array (must be 2D). - """ return imagenet_utils.decode_predictions(preds, top=top) diff --git a/tensorflow/python/keras/applications/vgg19.py b/tensorflow/python/keras/applications/vgg19.py index 81c90e1ebb4..e4385cc8f6a 100644 --- a/tensorflow/python/keras/applications/vgg19.py +++ b/tensorflow/python/keras/applications/vgg19.py @@ -26,9 +26,9 @@ from __future__ import print_function import os from tensorflow.python.keras import backend -from tensorflow.python.keras import layers from tensorflow.python.keras.applications import imagenet_utils from tensorflow.python.keras.engine import training +from tensorflow.python.keras.layers import VersionAwareLayers from tensorflow.python.keras.utils import data_utils from tensorflow.python.keras.utils import layer_utils from tensorflow.python.util.tf_export import keras_export @@ -40,6 +40,8 @@ WEIGHTS_PATH_NO_TOP = ('https://storage.googleapis.com/tensorflow/' 'keras-applications/vgg19/' 'vgg19_weights_tf_dim_ordering_tf_kernels_notop.h5') +layers = VersionAwareLayers() + @keras_export('keras.applications.vgg19.VGG19', 'keras.applications.VGG19') def VGG19( @@ -49,8 +51,7 @@ def VGG19( input_shape=None, pooling=None, classes=1000, - classifier_activation='softmax', -): + classifier_activation='softmax'): """Instantiates the VGG19 architecture. Reference: @@ -232,37 +233,12 @@ def VGG19( @keras_export('keras.applications.vgg19.preprocess_input') def preprocess_input(x, data_format=None): - """Preprocesses a numpy array encoding a batch of images. - - Arguments - x: A 4D numpy array consists of RGB values within [0, 255]. - - Returns - Preprocessed array. - - Raises - ValueError: In case of unknown `data_format` argument. - """ return imagenet_utils.preprocess_input( x, data_format=data_format, mode='caffe') @keras_export('keras.applications.vgg19.decode_predictions') def decode_predictions(preds, top=5): - """Decodes the prediction result from the model. - - Arguments - preds: Numpy tensor encoding a batch of predictions. - top: Integer, how many top-guesses to return. - - Returns - A list of lists of top class prediction tuples - `(class_name, class_description, score)`. - One list of tuples per sample in batch input. - - Raises - ValueError: In case of invalid shape of the `preds` array (must be 2D). - """ return imagenet_utils.decode_predictions(preds, top=top) diff --git a/tensorflow/python/keras/applications/xception.py b/tensorflow/python/keras/applications/xception.py index 5ea0f14cc79..d92bfd0f4c6 100644 --- a/tensorflow/python/keras/applications/xception.py +++ b/tensorflow/python/keras/applications/xception.py @@ -30,9 +30,9 @@ from __future__ import print_function import os from tensorflow.python.keras import backend -from tensorflow.python.keras import layers from tensorflow.python.keras.applications import imagenet_utils from tensorflow.python.keras.engine import training +from tensorflow.python.keras.layers import VersionAwareLayers from tensorflow.python.keras.utils import data_utils from tensorflow.python.keras.utils import layer_utils from tensorflow.python.util.tf_export import keras_export @@ -45,6 +45,8 @@ TF_WEIGHTS_PATH_NO_TOP = ( 'https://storage.googleapis.com/tensorflow/keras-applications/' 'xception/xception_weights_tf_dim_ordering_tf_kernels_notop.h5') +layers = VersionAwareLayers() + @keras_export('keras.applications.xception.Xception', 'keras.applications.Xception') @@ -55,8 +57,7 @@ def Xception( input_shape=None, pooling=None, classes=1000, - classifier_activation='softmax', -): + classifier_activation='softmax'): """Instantiates the Xception architecture. Optionally loads weights pre-trained on ImageNet. @@ -312,36 +313,11 @@ def Xception( @keras_export('keras.applications.xception.preprocess_input') def preprocess_input(x, data_format=None): - """Preprocesses a numpy array encoding a batch of images. - - Arguments - x: A 4D numpy array consists of RGB values within [0, 255]. - - Returns - Preprocessed array. - - Raises - ValueError: In case of unknown `data_format` argument. - """ return imagenet_utils.preprocess_input(x, data_format=data_format, mode='tf') @keras_export('keras.applications.xception.decode_predictions') def decode_predictions(preds, top=5): - """Decodes the prediction result from the model. - - Arguments - preds: Numpy tensor encoding a batch of predictions. - top: Integer, how many top-guesses to return. - - Returns - A list of lists of top class prediction tuples - `(class_name, class_description, score)`. - One list of tuples per sample in batch input. - - Raises - ValueError: In case of invalid shape of the `preds` array (must be 2D). - """ return imagenet_utils.decode_predictions(preds, top=top) diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py index 8c0acc6f25f..35ef8def6d8 100644 --- a/tensorflow/python/keras/backend.py +++ b/tensorflow/python/keras/backend.py @@ -75,6 +75,7 @@ from tensorflow.python.ops.ragged import ragged_concat_ops from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import moving_averages +from tensorflow.python.training.tracking import util as tracking_util from tensorflow.python.util import nest from tensorflow.python.util import object_identity from tensorflow.python.util import tf_contextlib @@ -546,6 +547,11 @@ def get_session(op_input_list=()): return session +# Inject the get_session function to tracking_util to avoid the backward +# dependency from TF to Keras. +tracking_util.register_session_provider(get_session) + + def get_graph(): if context.executing_eagerly(): global _GRAPH @@ -818,6 +824,9 @@ def name_scope(name): """ return ops.name_scope_v2(name) +# Export V1 version. +keras_export(v1=['keras.backend.name_scope'])(ops.name_scope_v1) + @keras_export('keras.backend.variable') def variable(value, dtype=None, name=None, constraint=None): diff --git a/tensorflow/python/keras/callbacks.py b/tensorflow/python/keras/callbacks.py index 7c5124e923e..2e9338e030d 100644 --- a/tensorflow/python/keras/callbacks.py +++ b/tensorflow/python/keras/callbacks.py @@ -390,7 +390,7 @@ class CallbackList(object): Arguments: batch: integer, index of batch within the current epoch. - logs: dict. Metric results for this batch. + logs: dict. Aggregated metric results up until this batch. """ if self._should_call_train_batch_hooks: self._call_batch_hook(ModeKeys.TRAIN, 'end', batch, logs=logs) @@ -411,7 +411,7 @@ class CallbackList(object): Arguments: batch: integer, index of batch within the current epoch. - logs: dict. Metric results for this batch. + logs: dict. Aggregated metric results up until this batch. """ if self._should_call_test_batch_hooks: self._call_batch_hook(ModeKeys.TEST, 'end', batch, logs=logs) @@ -432,7 +432,7 @@ class CallbackList(object): Arguments: batch: integer, index of batch within the current epoch. - logs: dict. Metric results for this batch. + logs: dict. Aggregated metric results up until this batch. """ if self._should_call_predict_batch_hooks: self._call_batch_hook(ModeKeys.PREDICT, 'end', batch, logs=logs) @@ -648,7 +648,7 @@ class Callback(object): Arguments: batch: integer, index of batch within the current epoch. - logs: dict. Metric results for this batch. + logs: dict. Aggregated metric results up until this batch. """ # For backwards compatibility. self.on_batch_end(batch, logs=logs) @@ -681,7 +681,7 @@ class Callback(object): Arguments: batch: integer, index of batch within the current epoch. - logs: dict. Metric results for this batch. + logs: dict. Aggregated metric results up until this batch. """ @doc_controls.for_subclass_implementers @@ -706,7 +706,7 @@ class Callback(object): Arguments: batch: integer, index of batch within the current epoch. - logs: dict. Metric results for this batch. + logs: dict. Aggregated metric results up until this batch. """ @doc_controls.for_subclass_implementers @@ -2005,6 +2005,7 @@ class TensorBoard(Callback, version_utils.TensorBoardVersionSelector): self._should_trace = not (self._start_batch == 0 and self._stop_batch == 0) def on_train_begin(self, logs=None): + self._global_train_batch = 0 self._push_writer(self._train_writer, self._train_step) def on_train_end(self, logs=None): diff --git a/tensorflow/python/keras/callbacks_test.py b/tensorflow/python/keras/callbacks_test.py index 5de4cacfa8a..b3d16907ada 100644 --- a/tensorflow/python/keras/callbacks_test.py +++ b/tensorflow/python/keras/callbacks_test.py @@ -2018,14 +2018,16 @@ class TestTensorBoardV2NonParameterizedTest(keras_parameterized.TestCase): run_eagerly=testing_utils.should_run_eagerly()) return model - def _get_trace_file(self, logdir): + def _count_trace_file(self, logdir): profile_dir = os.path.join(logdir, 'plugins', 'profile') + count = 0 for (dirpath, dirnames, filenames) in os.walk(profile_dir): + del dirpath # unused del dirnames # unused for filename in filenames: if filename.endswith('.trace.json.gz'): - return os.path.join(dirpath, filename) - return None + count += 1 + return count def fitModelAndAssertKerasModelWritten(self, model): x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1)) @@ -2095,7 +2097,7 @@ class TestTensorBoardV2NonParameterizedTest(keras_parameterized.TestCase): _ObservedSummary(logdir=self.train_dir, tag=u'batch_1'), }, ) - self.assertIsNotNone(self._get_trace_file(logdir=self.train_dir)) + self.assertEqual(1, self._count_trace_file(logdir=self.train_dir)) def test_TensorBoard_autoTrace_tagNameWithBatchNum(self): model = self._get_seq_model() @@ -2118,7 +2120,7 @@ class TestTensorBoardV2NonParameterizedTest(keras_parameterized.TestCase): _ObservedSummary(logdir=self.train_dir, tag=u'batch_2'), }, ) - self.assertIsNotNone(self._get_trace_file(logdir=self.train_dir)) + self.assertEqual(1, self._count_trace_file(logdir=self.train_dir)) def test_TensorBoard_autoTrace_profileBatchRangeSingle(self): model = self._get_seq_model() @@ -2142,7 +2144,32 @@ class TestTensorBoardV2NonParameterizedTest(keras_parameterized.TestCase): _ObservedSummary(logdir=self.train_dir, tag=u'batch_2'), }, ) - self.assertIsNotNone(self._get_trace_file(logdir=self.train_dir)) + self.assertEqual(1, self._count_trace_file(logdir=self.train_dir)) + + def test_TensorBoard_autoTrace_profileBatchRangeTwice(self): + model = self._get_seq_model() + x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1)) + tb_cbk = keras.callbacks.TensorBoard( + self.logdir, histogram_freq=1, profile_batch='10,10', write_graph=False) + + model.fit( + x, + y, + batch_size=3, + epochs=10, + validation_data=(x, y), + callbacks=[tb_cbk]) + + time.sleep(1) # Avoids the second profile over-writing the first. + + model.fit( + x, + y, + batch_size=3, + epochs=10, + validation_data=(x, y), + callbacks=[tb_cbk]) + self.assertEqual(2, self._count_trace_file(logdir=self.train_dir)) # Test case that replicates a Github issue. # https://github.com/tensorflow/tensorflow/issues/37543 @@ -2162,7 +2189,7 @@ class TestTensorBoardV2NonParameterizedTest(keras_parameterized.TestCase): callbacks=[keras.callbacks.TensorBoard(logdir, profile_batch=1)], ) # Verifies trace exists in the first logdir. - self.assertIsNotNone(self._get_trace_file(logdir=logdir)) + self.assertEqual(1, self._count_trace_file(logdir=logdir)) logdir = os.path.join(self.get_temp_dir(), 'tb2') model.fit( np.zeros((64, 1)), @@ -2171,7 +2198,7 @@ class TestTensorBoardV2NonParameterizedTest(keras_parameterized.TestCase): callbacks=[keras.callbacks.TensorBoard(logdir, profile_batch=2)], ) # Verifies trace exists in the second logdir. - self.assertIsNotNone(self._get_trace_file(logdir=logdir)) + self.assertEqual(1, self._count_trace_file(logdir=logdir)) def test_TensorBoard_autoTrace_profileBatchRange(self): model = self._get_seq_model() @@ -2195,7 +2222,7 @@ class TestTensorBoardV2NonParameterizedTest(keras_parameterized.TestCase): _ObservedSummary(logdir=self.train_dir, tag=u'batch_3'), }, ) - self.assertIsNotNone(self._get_trace_file(logdir=self.train_dir)) + self.assertEqual(1, self._count_trace_file(logdir=self.train_dir)) def test_TensorBoard_autoTrace_profileInvalidBatchRange(self): with self.assertRaises(ValueError): @@ -2237,7 +2264,7 @@ class TestTensorBoardV2NonParameterizedTest(keras_parameterized.TestCase): # Enabled trace only on the 10000th batch, thus it should be empty. self.assertEmpty(summary_file.tensors) - self.assertIsNone(self._get_trace_file(logdir=self.train_dir)) + self.assertEqual(0, self._count_trace_file(logdir=self.train_dir)) class MostRecentlyModifiedFileMatchingPatternTest(test.TestCase): diff --git a/tensorflow/python/keras/distribute/distribute_strategy_test.py b/tensorflow/python/keras/distribute/distribute_strategy_test.py index a7e25b77627..874ca84cab9 100644 --- a/tensorflow/python/keras/distribute/distribute_strategy_test.py +++ b/tensorflow/python/keras/distribute/distribute_strategy_test.py @@ -1811,6 +1811,27 @@ class TestDistributionStrategyWithKerasModels(test.TestCase, self.assertEqual(bc.test_begin_batches, [0, 20, 40]) self.assertEqual(bc.test_end_batches, [19, 39, 49]) + @combinations.generate( + combinations.combine(distribution=all_strategies, mode=['eager'])) + def test_host_training_loop_truncate_to_epoch(self, distribution): + with distribution.scope(): + inputs = keras.Input(10) + outputs = keras.layers.Dense(1)(inputs) + model = keras.Model(inputs, outputs) + + model.compile('sgd', 'mse', experimental_steps_per_execution=500) + + x, y = np.ones((100, 10)), np.ones((100, 1)) + bc = BatchCountingCB() + model.fit(x, y, batch_size=2, epochs=2, callbacks=[bc]) + self.assertEqual(bc.train_begin_batches, [0, 0]) + self.assertEqual(bc.train_end_batches, [49, 49]) + + x, y = np.ones((50, 10)), np.ones((50, 1)) + model.evaluate(x, y, batch_size=2, callbacks=[bc]) + self.assertEqual(bc.test_begin_batches, [0]) + self.assertEqual(bc.test_end_batches, [24]) + @combinations.generate( combinations.times( all_strategy_combinations_minus_default())) diff --git a/tensorflow/python/keras/engine/base_layer.py b/tensorflow/python/keras/engine/base_layer.py index 62a94afe51b..40fec808816 100644 --- a/tensorflow/python/keras/engine/base_layer.py +++ b/tensorflow/python/keras/engine/base_layer.py @@ -104,6 +104,38 @@ class Layer(module.Module, version_utils.LayerVersionSelector): Users will just instantiate a layer and then treat it as a callable. + Arguments: + trainable: Boolean, whether the layer's variables should be trainable. + name: String name of the layer. + dtype: The dtype of the layer's computations and weights (default of + `None` means use `tf.keras.backend.floatx` in TensorFlow 2, or the type + of the first input in TensorFlow 1). + dynamic: Set this to `True` if your layer should only be run eagerly, and + should not be used to generate a static computation graph. + This would be the case for a Tree-RNN or a recursive network, + for example, or generally for any layer that manipulates tensors + using Python control flow. If `False`, we assume that the layer can + safely be used to generate a static computation graph. + + Attributes: + name: The name of the layer (string). + dtype: The dtype of the layer's computations and weights. If mixed + precision is used with a `tf.keras.mixed_precision.experimental.Policy`, + this is instead just the dtype of the layer's weights, as the computations + are done in a different dtype. + losses: List of losses added to this layer (via `self.add_loss()`). + metrics: List of metrics added to this layer (via `self.add_metric()`).. + trainable_weights: List of variables to be included in backprop. + non_trainable_weights: List of variables that should not be + included in backprop. + weights: The concatenation of the lists trainable_weights and + non_trainable_weights (in this order). + trainable: Whether the layer should be trained (boolean), i.e. whether + its potentially-trainable weights should be returned as part of + `layer.trainable_weights`. + input_spec: Optional (list of) `InputSpec` object(s) specifying the + constraints on inputs that can be accepted by the layer. + We recommend that descendants of `Layer` implement the following methods: * `__init__()`: Defines custom layer attributes, and creates layer state @@ -223,35 +255,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector): [Writing custom layers and models with Keras]( https://www.tensorflow.org/guide/keras/custom_layers_and_models) - Arguments: - trainable: Boolean, whether the layer's variables should be trainable. - name: String name of the layer. - dtype: The dtype of the layer's computations and weights (default of - `None` means use `tf.keras.backend.floatx` in TensorFlow 2, or the type - of the first input in TensorFlow 1). - dynamic: Set this to `True` if your layer should only be run eagerly, and - should not be used to generate a static computation graph. - This would be the case for a Tree-RNN or a recursive network, - for example, or generally for any layer that manipulates tensors - using Python control flow. If `False`, we assume that the layer can - safely be used to generate a static computation graph. - - Attributes: - name: The name of the layer (string). - dtype: The dtype of the layer's computations and weights. If mixed - precision is used with a `tf.keras.mixed_precision.experimental.Policy`, - this is instead just the dtype of the layer's weights, as the computations - are done in a different dtype. - updates: List of update ops of this layer. - losses: List of losses added by this layer. - trainable_weights: List of variables to be included in backprop. - non_trainable_weights: List of variables that should not be - included in backprop. - weights: The concatenation of the lists trainable_weights and - non_trainable_weights (in this order). - trainable: Whether the layer should be trained (boolean). - input_spec: Optional (list of) `InputSpec` object(s) specifying the - constraints on inputs that can be accepted by the layer. + About the layer's `dtype` attribute: Each layer has a dtype, which is typically the dtype of the layer's computations and variables. A layer's dtype can be queried via the @@ -400,7 +404,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector): `TensorShape` if the layer expects a list of inputs (one instance per input). """ - # Only record the build input shapes of overridden the build methods. + # Only record the build input shapes of overridden build methods. if not hasattr(self.build, '_is_default'): self._build_input_shape = input_shape self.built = True @@ -538,11 +542,11 @@ class Layer(module.Module, version_utils.LayerVersionSelector): if initializer is None: # If dtype is DT_FLOAT, provide a uniform unit scaling initializer if dtype.is_floating: - initializer = initializers.glorot_uniform() + initializer = initializers.get('glorot_uniform') # If dtype is DT_INT/DT_UINT, provide a default value `zero` # If dtype is DT_BOOL, provide a default value `FALSE` elif dtype.is_integer or dtype.is_unsigned or dtype.is_bool: - initializer = initializers.zeros() + initializer = initializers.get('zeros') # NOTES:Do we need to support for handling DT_STRING and DT_COMPLEX here? else: raise ValueError('An initializer for variable %s of type %s is required' diff --git a/tensorflow/python/keras/engine/base_layer_utils.py b/tensorflow/python/keras/engine/base_layer_utils.py index 72024a0f658..c0c0d9d04be 100644 --- a/tensorflow/python/keras/engine/base_layer_utils.py +++ b/tensorflow/python/keras/engine/base_layer_utils.py @@ -29,6 +29,7 @@ from tensorflow.python.framework import tensor_util from tensorflow.python.keras import backend from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_util +from tensorflow.python.ops import control_flow_util_v2 from tensorflow.python.ops import control_flow_v2_func_graphs from tensorflow.python.ops import init_ops from tensorflow.python.ops import init_ops_v2 @@ -395,6 +396,9 @@ def call_context(): return _call_context.call_context +control_flow_util_v2._register_keras_layer_context_function(call_context) # pylint: disable=protected-access + + class CallContext(object): """Keeps track of properties currently inside a Layer/Model's `call`. diff --git a/tensorflow/python/keras/engine/base_layer_v1.py b/tensorflow/python/keras/engine/base_layer_v1.py index 1cf2450edf3..7b4ce8ad54c 100644 --- a/tensorflow/python/keras/engine/base_layer_v1.py +++ b/tensorflow/python/keras/engine/base_layer_v1.py @@ -402,7 +402,7 @@ class Layer(base_layer.Layer): if initializer is None: # If dtype is DT_FLOAT, provide a uniform unit scaling initializer if dtype.is_floating: - initializer = initializers.glorot_uniform() + initializer = initializers.get('glorot_uniform') # If dtype is DT_INT/DT_UINT, provide a default value `zero` # If dtype is DT_BOOL, provide a default value `FALSE` elif dtype.is_integer or dtype.is_unsigned or dtype.is_bool: diff --git a/tensorflow/python/keras/engine/data_adapter.py b/tensorflow/python/keras/engine/data_adapter.py index 1be2f8449ce..3994db4a541 100644 --- a/tensorflow/python/keras/engine/data_adapter.py +++ b/tensorflow/python/keras/engine/data_adapter.py @@ -1155,19 +1155,37 @@ class DataHandler(object): def enumerate_epochs(self): """Yields `(epoch, tf.data.Iterator)`.""" - data_iterator = iter(self._dataset) - for epoch in range(self._initial_epoch, self._epochs): - if self._insufficient_data: # Set by `catch_stop_iteration`. - break - if self._adapter.should_recreate_iterator(): - if ds_context.has_strategy(): - # TODO(b/138326910): remove this when MultiDeviceIterator is a - # CompositeTensor (unless this is more efficient) - data_iterator._initializer # pylint: disable=pointless-statement, protected-access - else: - data_iterator = iter(self._dataset) - yield epoch, data_iterator - self._adapter.on_epoch_end() + with self._truncate_execution_to_epoch(): + data_iterator = iter(self._dataset) + for epoch in range(self._initial_epoch, self._epochs): + if self._insufficient_data: # Set by `catch_stop_iteration`. + break + if self._adapter.should_recreate_iterator(): + if ds_context.has_strategy(): + # TODO(b/138326910): remove this when MultiDeviceIterator is a + # CompositeTensor (unless this is more efficient) + data_iterator._initializer # pylint: disable=pointless-statement, protected-access + else: + data_iterator = iter(self._dataset) + yield epoch, data_iterator + self._adapter.on_epoch_end() + + @contextlib.contextmanager + def _truncate_execution_to_epoch(self): + """Truncates steps per execution to at most one epoch.""" + should_truncate = ( + self._inferred_steps is not None and + self._steps_per_execution_value > self._inferred_steps) + original_value = self._steps_per_execution_value + try: + if should_truncate: + self._steps_per_execution.assign(self._inferred_steps) + self._steps_per_execution_value = self._inferred_steps + yield + finally: + if should_truncate: + self._steps_per_execution.assign(original_value) + self._steps_per_execution_value = original_value @contextlib.contextmanager def catch_stop_iteration(self): @@ -1309,7 +1327,7 @@ def _make_class_weight_map_fn(class_weight): raise ValueError(error_msg) class_weight_tensor = ops.convert_to_tensor_v2( - [int(class_weight[c]) for c in class_ids], dtype="int64") + [class_weight[int(c)] for c in class_ids]) def _class_weights_map_fn(*data): """Convert `class_weight` to `sample_weight`.""" @@ -1379,10 +1397,11 @@ def train_validation_split(arrays, validation_split, shuffle=True): return isinstance(t, tensor_types) or t is None flat_arrays = nest.flatten(arrays) - if not all(_can_split(t) for t in flat_arrays): + unsplitable = [type(t) for t in flat_arrays if not _can_split(t)] + if unsplitable: raise ValueError( "`validation_split` is only supported for Tensors or NumPy " - "arrays, found: {}".format(arrays)) + "arrays, found following types in the input: {}".format(unsplitable)) if all(t is None for t in flat_arrays): return arrays, arrays @@ -1402,6 +1421,14 @@ def train_validation_split(arrays, validation_split, shuffle=True): train_indices = indices[:split_at] val_indices = indices[split_at:] + if split_at == 0 or split_at == batch_dim: + raise ValueError( + "Training data contains {batch_dim} samples, which is not sufficient " + "to split it into a validation and training set as specified by " + "`validation_split={validation_split}`. Either provide more data, or a " + "different value for the `validation_split` argument." .format( + batch_dim=batch_dim, validation_split=validation_split)) + def _split(t, indices): if t is None: return t diff --git a/tensorflow/python/keras/engine/data_adapter_test.py b/tensorflow/python/keras/engine/data_adapter_test.py index 24cfc1f2bf0..37b204c87ff 100644 --- a/tensorflow/python/keras/engine/data_adapter_test.py +++ b/tensorflow/python/keras/engine/data_adapter_test.py @@ -1054,6 +1054,12 @@ class TestValidationSplit(keras_parameterized.TestCase): data_adapter.train_validation_split( lambda: np.ones((10, 1)), validation_split=0.2) + def test_validation_split_examples_too_few(self): + with self.assertRaisesRegexp( + ValueError, 'not sufficient to split it'): + data_adapter.train_validation_split( + np.ones((1, 10)), validation_split=0.2) + def test_validation_split_none(self): train_sw, val_sw = data_adapter.train_validation_split( None, validation_split=0.2) diff --git a/tensorflow/python/keras/engine/sequential.py b/tensorflow/python/keras/engine/sequential.py index 667899660c1..a41aa2e891b 100644 --- a/tensorflow/python/keras/engine/sequential.py +++ b/tensorflow/python/keras/engine/sequential.py @@ -183,8 +183,7 @@ class Sequential(training.Model): set_inputs = False if not self._layers: if isinstance(layer, input_layer.InputLayer): - # Corner case where the user passes an InputLayer layer via `add`. - assert len(nest.flatten(layer._inbound_nodes[-1].output_tensors)) == 1 + # Case where the user passes an Input or InputLayer layer via `add`. set_inputs = True else: batch_shape, dtype = training_utils.get_input_shape_and_dtype(layer) @@ -199,13 +198,12 @@ class Sequential(training.Model): set_inputs = True if set_inputs: - # If an input layer (placeholder) is available. - if len(nest.flatten(layer._inbound_nodes[-1].output_tensors)) != 1: + outputs = nest.flatten(layer._inbound_nodes[-1].output_tensors) + if len(outputs) != 1: raise ValueError(SINGLE_LAYER_OUTPUT_ERROR_MSG) - self.outputs = [ - nest.flatten(layer._inbound_nodes[-1].output_tensors)[0] - ] + self.outputs = outputs self.inputs = layer_utils.get_source_inputs(self.outputs[0]) + self.built = True elif self.outputs: # If the model is being built continuously on top of an input layer: @@ -214,10 +212,6 @@ class Sequential(training.Model): if len(nest.flatten(output_tensor)) != 1: raise ValueError(SINGLE_LAYER_OUTPUT_ERROR_MSG) self.outputs = [output_tensor] - - if self.outputs: - # True if set_inputs or self._is_graph_network or if adding a layer - # to an already built deferred seq model. self.built = True if set_inputs or self._is_graph_network: @@ -267,15 +261,14 @@ class Sequential(training.Model): self.built = True def call(self, inputs, training=None, mask=None): # pylint: disable=redefined-outer-name - if self._build_input_shape is None: - input_shapes = nest.map_structure(_get_shape_tuple, inputs) - self._build_input_shape = input_shapes - if self._is_graph_network: if not self.built: self._init_graph_network(self.inputs, self.outputs, name=self.name) return super(Sequential, self).call(inputs, training=training, mask=mask) + if self._build_input_shape is None: + self._build_input_shape = nest.map_structure(_get_shape_tuple, inputs) + outputs = inputs # handle the corner case where self.layers is empty for layer in self.layers: # During each iteration, `inputs` are the inputs to `layer`, and `outputs` @@ -364,21 +357,16 @@ class Sequential(training.Model): def get_config(self): layer_configs = [] - for layer in self.layers: + for layer in super(Sequential, self).layers: + # `super().layers` include the InputLayer if available (it is filtered out + # of `self.layers`). Note that `self._layers` is managed by the + # tracking infrastructure and should not be used. layer_configs.append(generic_utils.serialize_keras_object(layer)) - # When constructed using an `InputLayer` the first non-input layer may not - # have the shape information to reconstruct `Sequential` as a graph network. - if (self._is_graph_network and layer_configs and - 'batch_input_shape' not in layer_configs[0]['config'] and - isinstance(self._layers[0], input_layer.InputLayer)): - batch_input_shape = self._layers[0]._batch_input_shape - layer_configs[0]['config']['batch_input_shape'] = batch_input_shape - config = { 'name': self.name, 'layers': copy.deepcopy(layer_configs) } - if self._build_input_shape is not None: + if not self._is_graph_network and self._build_input_shape is not None: config['build_input_shape'] = self._build_input_shape return config diff --git a/tensorflow/python/keras/engine/sequential_test.py b/tensorflow/python/keras/engine/sequential_test.py index a9694cb69be..682967b7f02 100644 --- a/tensorflow/python/keras/engine/sequential_test.py +++ b/tensorflow/python/keras/engine/sequential_test.py @@ -429,6 +429,18 @@ class TestSequential(keras_parameterized.TestCase): model.pop() self.assertEqual(model._layers[-1], layer) + def test_config_preserves_input_layer(self): + model = keras.Sequential([ + keras.Input((None,), name='my_embedding_input', dtype='int32'), + keras.layers.Embedding(32, 32), + keras.layers.Dense(3), + ]) + config = model.get_config() + new_model = keras.Sequential.from_config(config) + self.assertTrue(new_model.built) + self.assertEqual(new_model._layers[0].dtype, 'int32') + self.assertEqual(new_model._layers[0].name, 'my_embedding_input') + class TestSequentialEagerIntegration(keras_parameterized.TestCase): diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py index 94570d96208..64b5ff16f21 100644 --- a/tensorflow/python/keras/engine/training.py +++ b/tensorflow/python/keras/engine/training.py @@ -341,7 +341,9 @@ class Model(network.Network, version_utils.ModelVersionSelector): on TPUs or small models with a large Python overhead. Note that if this value is set to `N`, `Callback.on_batch` methods will only be called every `N` batches. This currently defaults to `1`. At most, - one full epoch can be run each execution. + one full epoch will be run each execution. If a number larger than + the size of the epoch is passed, the execution will be truncated + to the size of the epoch. Raises: ValueError: In case of invalid arguments for @@ -398,7 +400,41 @@ class Model(network.Network, version_utils.ModelVersionSelector): @property def metrics(self): - """Returns the model's metrics added using `compile`, `add_metric` APIs.""" + """Returns the model's metrics added using `compile`, `add_metric` APIs. + + Note: `metrics` are available only after a `keras.Model` has been + trained/evaluated on actual data. + + Examples: + + >>> inputs = tf.keras.layers.Input(shape=(3,)) + >>> outputs = tf.keras.layers.Dense(2)(inputs) + >>> model = tf.keras.models.Model(inputs=inputs, outputs=outputs) + >>> model.compile(optimizer="Adam", loss="mse", metrics=["mae"]) + >>> [m.name for m in model.metrics] + [] + + >>> x = np.random.random((2, 3)) + >>> y = np.random.randint(0, 2, (2, 2)) + >>> _ = model.fit(x, y, verbose=0) + >>> [m.name for m in model.metrics] + ['loss', 'mae'] + + >>> inputs = tf.keras.layers.Input(shape=(3,)) + >>> d = tf.keras.layers.Dense(2, name='out') + >>> output_1 = d(inputs) + >>> output_2 = d(inputs) + >>> model = tf.keras.models.Model( + ... inputs=inputs, outputs=[output_1, output_2]) + >>> model.add_metric( + ... tf.reduce_sum(output_2), name='mean', aggregation='mean') + >>> model.compile(optimizer="Adam", loss="mse", metrics=["mae", "acc"]) + >>> _ = model.fit(x, (y, y), verbose=0) + >>> [m.name for m in model.metrics] + ['loss', 'out_loss', 'out_1_loss', 'out_mae', 'out_acc', 'out_1_mae', + 'out_1_acc', 'mean'] + + """ metrics = [] if self._is_compiled: # TODO(omalleyt): Track `LossesContainer` and `MetricsContainer` objects @@ -415,7 +451,39 @@ class Model(network.Network, version_utils.ModelVersionSelector): @property def metrics_names(self): - """Returns the model's display labels for all outputs.""" + """Returns the model's display labels for all outputs. + + Note: `metrics_names` are available only after a `keras.Model` has been + trained/evaluated on actual data. + + Examples: + + >>> inputs = tf.keras.layers.Input(shape=(3,)) + >>> outputs = tf.keras.layers.Dense(2)(inputs) + >>> model = tf.keras.models.Model(inputs=inputs, outputs=outputs) + >>> model.compile(optimizer="Adam", loss="mse", metrics=["mae"]) + >>> model.metrics_names + [] + + >>> x = np.random.random((2, 3)) + >>> y = np.random.randint(0, 2, (2, 2)) + >>> _ = model.fit(x, y, verbose=0) + >>> model.metrics_names + ['loss', 'mae'] + + >>> inputs = tf.keras.layers.Input(shape=(3,)) + >>> d = tf.keras.layers.Dense(2, name='out') + >>> output_1 = d(inputs) + >>> output_2 = d(inputs) + >>> model = tf.keras.models.Model( + ... inputs=inputs, outputs=[output_1, output_2]) + >>> model.compile(optimizer="Adam", loss="mse", metrics=["mae", "acc"]) + >>> _ = model.fit(x, (y, y), verbose=0) + >>> model.metrics_names + ['loss', 'out_loss', 'out_1_loss', 'out_mae', 'out_acc', 'out_1_mae', + 'out_1_acc'] + + """ # This property includes all output names including `loss` and per-output # losses for backward compatibility. @@ -473,7 +541,7 @@ class Model(network.Network, version_utils.ModelVersionSelector): """The logic for one training step. This method can be overridden to support custom training logic. - This method is called by `Model._make_train_function`. + This method is called by `Model.make_train_function`. This method should contain the mathemetical logic for one step of training. This typically includes the forward pass, loss calculation, backpropagation, @@ -481,7 +549,7 @@ class Model(network.Network, version_utils.ModelVersionSelector): Configuration details for *how* this logic is run (e.g. `tf.function` and `tf.distribute.Strategy` settings), should be left to - `Model._make_train_function`, which can also be overridden. + `Model.make_train_function`, which can also be overridden. Arguments: data: A nested structure of `Tensor`s. @@ -523,7 +591,7 @@ class Model(network.Network, version_utils.ModelVersionSelector): Typically, this method directly controls `tf.function` and `tf.distribute.Strategy` settings, and delegates the actual training - logic to `Model._train_step`. + logic to `Model.train_step`. This function is cached the first time `Model.fit` or `Model.train_on_batch` is called. The cache is cleared whenever diff --git a/tensorflow/python/keras/engine/training_test.py b/tensorflow/python/keras/engine/training_test.py index d8d44d18033..404175af137 100644 --- a/tensorflow/python/keras/engine/training_test.py +++ b/tensorflow/python/keras/engine/training_test.py @@ -1453,7 +1453,7 @@ class LossWeightingTest(keras_parameterized.TestCase): batch_size = 5 epochs = 10 weighted_class = 3 - weight = 10. + weight = .5 train_samples = 1000 test_samples = 1000 input_dim = 5 diff --git a/tensorflow/python/keras/engine/training_utils.py b/tensorflow/python/keras/engine/training_utils.py index d901c4986db..680f33f75a5 100644 --- a/tensorflow/python/keras/engine/training_utils.py +++ b/tensorflow/python/keras/engine/training_utils.py @@ -56,6 +56,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.losses import util as tf_losses_utils +from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import nest from tensorflow.python.util import tf_inspect @@ -1049,11 +1050,21 @@ def has_symbolic_tensors(ls): def has_tensors(ls): + """Returns true if `ls` contains tensors.""" + # Note: at some point in time ragged tensors didn't count as tensors, so this + # returned false for ragged tensors. Making this return true fails some tests + # which would then require a steps_per_epoch argument. if isinstance(ls, (list, tuple)): - return any(tensor_util.is_tensor(v) for v in ls) + return any( + tensor_util.is_tensor(v) and + not isinstance(v, ragged_tensor.RaggedTensor) for v in ls) if isinstance(ls, dict): - return any(tensor_util.is_tensor(v) for _, v in six.iteritems(ls)) - return tensor_util.is_tensor(ls) + return any( + tensor_util.is_tensor(v) and + not isinstance(v, ragged_tensor.RaggedTensor) + for _, v in six.iteritems(ls)) + return tensor_util.is_tensor(ls) and not isinstance( + ls, ragged_tensor.RaggedTensor) def get_metric_name(metric, weighted=False): diff --git a/tensorflow/python/keras/initializers.py b/tensorflow/python/keras/initializers.py deleted file mode 100644 index 58a90ccadc3..00000000000 --- a/tensorflow/python/keras/initializers.py +++ /dev/null @@ -1,234 +0,0 @@ -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Keras initializer serialization / deserialization. -""" -# pylint: disable=unused-import -# pylint: disable=line-too-long -# pylint: disable=g-import-not-at-top -# pylint: disable=g-bad-import-order -# pylint: disable=invalid-name -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import six - -from tensorflow.python import tf2 -from tensorflow.python.framework import dtypes -from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object -from tensorflow.python.keras.utils.generic_utils import serialize_keras_object -from tensorflow.python.ops import init_ops_v2 - -# These imports are brought in so that keras.initializers.deserialize -# has them available in module_objects. -from tensorflow.python.ops.init_ops_v2 import Constant as ConstantV2 -from tensorflow.python.ops.init_ops_v2 import GlorotNormal as GlorotNormalV2 -from tensorflow.python.ops.init_ops_v2 import GlorotUniform as GlorotUniformV2 -from tensorflow.python.ops.init_ops_v2 import he_normal as he_normalV2 -from tensorflow.python.ops.init_ops_v2 import he_uniform as he_uniformV2 -from tensorflow.python.ops.init_ops_v2 import Identity as IdentityV2 -from tensorflow.python.ops.init_ops_v2 import Initializer as InitializerV2 -from tensorflow.python.ops.init_ops_v2 import lecun_normal as lecun_normalV2 -from tensorflow.python.ops.init_ops_v2 import lecun_uniform as lecun_uniformV2 -from tensorflow.python.ops.init_ops_v2 import Ones as OnesV2 -from tensorflow.python.ops.init_ops_v2 import Orthogonal as OrthogonalV2 -from tensorflow.python.ops.init_ops_v2 import RandomNormal as RandomNormalV2 -from tensorflow.python.ops.init_ops_v2 import RandomUniform as RandomUniformV2 -from tensorflow.python.ops.init_ops_v2 import TruncatedNormal as TruncatedNormalV2 -from tensorflow.python.ops.init_ops_v2 import VarianceScaling as VarianceScalingV2 -from tensorflow.python.ops.init_ops_v2 import Zeros as ZerosV2 - -if tf2.enabled(): - Constant = ConstantV2 - GlorotNormal = GlorotNormalV2 - GlorotUniform = GlorotUniformV2 - he_normal = he_normalV2 - he_uniform = he_uniformV2 - Identity = IdentityV2 - Initializer = InitializerV2 - lecun_normal = lecun_normalV2 - lecun_uniform = lecun_uniformV2 - Ones = OnesV2 - Orthogonal = OrthogonalV2 - VarianceScaling = VarianceScalingV2 - Zeros = ZerosV2 -else: - from tensorflow.python.ops.init_ops import Constant - from tensorflow.python.ops.init_ops import GlorotNormal - from tensorflow.python.ops.init_ops import GlorotUniform - from tensorflow.python.ops.init_ops import he_normal - from tensorflow.python.ops.init_ops import he_uniform - from tensorflow.python.ops.init_ops import Identity - from tensorflow.python.ops.init_ops import Initializer - from tensorflow.python.ops.init_ops import lecun_normal - from tensorflow.python.ops.init_ops import lecun_uniform - from tensorflow.python.ops.init_ops import Ones - from tensorflow.python.ops.init_ops import Orthogonal - from tensorflow.python.ops.init_ops import VarianceScaling - from tensorflow.python.ops.init_ops import Zeros - -from tensorflow.python.ops.init_ops import RandomNormal as TFRandomNormalV1 -from tensorflow.python.ops.init_ops import RandomUniform as TFRandomUniformV1 -from tensorflow.python.ops.init_ops import TruncatedNormal as TFTruncatedNormalV1 - -from tensorflow.python.util.tf_export import keras_export - - -@keras_export(v1=['keras.initializers.TruncatedNormal', - 'keras.initializers.truncated_normal']) -class TruncatedNormalV1(TFTruncatedNormalV1): - """Initializer that generates a truncated normal distribution. - - These values are similar to values from a `random_normal_initializer` - except that values more than two standard deviations from the mean - are discarded and re-drawn. This is the recommended initializer for - neural network weights and filters. - - Args: - mean: a python scalar or a scalar tensor. Mean of the random values to - generate. Defaults to 0. - stddev: a python scalar or a scalar tensor. Standard deviation of the random - values to generate. Defaults to 0.05. - seed: A Python integer. Used to create random seeds. See - `tf.compat.v1.set_random_seed` for behavior. - dtype: The data type. Only floating point types are supported. - - Returns: - A TruncatedNormal instance. - """ - - def __init__(self, mean=0.0, stddev=0.05, seed=None, dtype=dtypes.float32): - super(TruncatedNormalV1, self).__init__( - mean=mean, stddev=stddev, seed=seed, dtype=dtype) - - -@keras_export(v1=['keras.initializers.RandomUniform', - 'keras.initializers.uniform', - 'keras.initializers.random_uniform']) -class RandomUniformV1(TFRandomUniformV1): - """Initializer that generates tensors with a uniform distribution. - - Args: - minval: A python scalar or a scalar tensor. Lower bound of the range of - random values to generate. Defaults to -0.05. - maxval: A python scalar or a scalar tensor. Upper bound of the range of - random values to generate. Defaults to 0.05. - seed: A Python integer. Used to create random seeds. See - `tf.compat.v1.set_random_seed` for behavior. - dtype: The data type. - - Returns: - A RandomUniform instance. - """ - - def __init__(self, minval=-0.05, maxval=0.05, seed=None, - dtype=dtypes.float32): - super(RandomUniformV1, self).__init__( - minval=minval, maxval=maxval, seed=seed, dtype=dtype) - - -@keras_export(v1=['keras.initializers.RandomNormal', - 'keras.initializers.normal', - 'keras.initializers.random_normal']) -class RandomNormalV1(TFRandomNormalV1): - """Initializer that generates tensors with a normal distribution. - - Args: - mean: a python scalar or a scalar tensor. Mean of the random values to - generate. Defaults to 0. - stddev: a python scalar or a scalar tensor. Standard deviation of the random - values to generate. Defaults to 0.05. - seed: A Python integer. Used to create random seeds. See - `tf.compat.v1.set_random_seed` for behavior. - dtype: The data type. Only floating point types are supported. - - Returns: - RandomNormal instance. - """ - - def __init__(self, mean=0.0, stddev=0.05, seed=None, dtype=dtypes.float32): - super(RandomNormalV1, self).__init__( - mean=mean, stddev=stddev, seed=seed, dtype=dtype) - - -if tf2.enabled(): - RandomNormal = RandomNormalV2 - RandomUniform = RandomUniformV2 - TruncatedNormal = TruncatedNormalV2 -else: - RandomNormal = RandomNormalV1 - RandomUniform = RandomUniformV1 - TruncatedNormal = TruncatedNormalV1 - -# Compatibility aliases -zero = zeros = Zeros -one = ones = Ones -constant = Constant -uniform = random_uniform = RandomUniform -normal = random_normal = RandomNormal -truncated_normal = TruncatedNormal -identity = Identity -orthogonal = Orthogonal -glorot_normal = GlorotNormal -glorot_uniform = GlorotUniform - -# Utility functions - - -@keras_export('keras.initializers.serialize') -def serialize(initializer): - return serialize_keras_object(initializer) - - -@keras_export('keras.initializers.deserialize') -def deserialize(config, custom_objects=None): - """Return an `Initializer` object from its config.""" - if tf2.enabled(): - # Class names are the same for V1 and V2 but the V2 classes - # are aliased in this file so we need to grab them directly - # from `init_ops_v2`. - module_objects = { - obj_name: getattr(init_ops_v2, obj_name) - for obj_name in dir(init_ops_v2) - } - else: - module_objects = globals() - return deserialize_keras_object( - config, - module_objects=module_objects, - custom_objects=custom_objects, - printable_module_name='initializer') - - -@keras_export('keras.initializers.get') -def get(identifier): - if identifier is None: - return None - if isinstance(identifier, dict): - return deserialize(identifier) - elif isinstance(identifier, six.string_types): - identifier = str(identifier) - # We have to special-case functions that return classes. - # TODO(omalleyt): Turn these into classes or class aliases. - special_cases = ['he_normal', 'he_uniform', 'lecun_normal', 'lecun_uniform'] - if identifier in special_cases: - # Treat like a class. - return deserialize({'class_name': identifier, 'config': {}}) - return deserialize(identifier) - elif callable(identifier): - return identifier - else: - raise ValueError('Could not interpret initializer identifier: ' + - str(identifier)) diff --git a/tensorflow/python/keras/initializers/__init__.py b/tensorflow/python/keras/initializers/__init__.py new file mode 100644 index 00000000000..828a5b9ca49 --- /dev/null +++ b/tensorflow/python/keras/initializers/__init__.py @@ -0,0 +1,162 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Keras initializer serialization / deserialization. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import threading +import six + +from tensorflow.python import tf2 +from tensorflow.python.keras.initializers import initializers_v1 +from tensorflow.python.keras.initializers import initializers_v2 +from tensorflow.python.keras.utils import generic_utils +from tensorflow.python.ops import init_ops +from tensorflow.python.util import tf_inspect as inspect +from tensorflow.python.util.tf_export import keras_export + + +# LOCAL.ALL_OBJECTS is meant to be a global mutable. Hence we need to make it +# thread-local to avoid concurrent mutations. +LOCAL = threading.local() + + +def populate_deserializable_objects(): + """Populates dict ALL_OBJECTS with every built-in initializer. + """ + global LOCAL + if not hasattr(LOCAL, 'ALL_OBJECTS'): + LOCAL.ALL_OBJECTS = {} + LOCAL.GENERATED_WITH_V2 = None + + if LOCAL.ALL_OBJECTS and LOCAL.GENERATED_WITH_V2 == tf2.enabled(): + # Objects dict is already generated for the proper TF version: + # do nothing. + return + + LOCAL.ALL_OBJECTS = {} + LOCAL.GENERATED_WITH_V2 = tf2.enabled() + + # Compatibility aliases (need to exist in both V1 and V2). + LOCAL.ALL_OBJECTS['ConstantV2'] = initializers_v2.Constant + LOCAL.ALL_OBJECTS['GlorotNormalV2'] = initializers_v2.GlorotNormal + LOCAL.ALL_OBJECTS['GlorotUniformV2'] = initializers_v2.GlorotUniform + LOCAL.ALL_OBJECTS['HeNormalV2'] = initializers_v2.HeNormal + LOCAL.ALL_OBJECTS['HeUniformV2'] = initializers_v2.HeUniform + LOCAL.ALL_OBJECTS['IdentityV2'] = initializers_v2.Identity + LOCAL.ALL_OBJECTS['LecunNormalV2'] = initializers_v2.LecunNormal + LOCAL.ALL_OBJECTS['LecunUniformV2'] = initializers_v2.LecunUniform + LOCAL.ALL_OBJECTS['OnesV2'] = initializers_v2.Ones + LOCAL.ALL_OBJECTS['OrthogonalV2'] = initializers_v2.Orthogonal + LOCAL.ALL_OBJECTS['RandomNormalV2'] = initializers_v2.RandomNormal + LOCAL.ALL_OBJECTS['RandomUniformV2'] = initializers_v2.RandomUniform + LOCAL.ALL_OBJECTS['TruncatedNormalV2'] = initializers_v2.TruncatedNormal + LOCAL.ALL_OBJECTS['VarianceScalingV2'] = initializers_v2.VarianceScaling + LOCAL.ALL_OBJECTS['ZerosV2'] = initializers_v2.Zeros + + # Out of an abundance of caution we also include these aliases that have + # a non-zero probability of having been included in saved configs in the past. + LOCAL.ALL_OBJECTS['glorot_normalV2'] = initializers_v2.GlorotNormal + LOCAL.ALL_OBJECTS['glorot_uniformV2'] = initializers_v2.GlorotUniform + LOCAL.ALL_OBJECTS['he_normalV2'] = initializers_v2.HeNormal + LOCAL.ALL_OBJECTS['he_uniformV2'] = initializers_v2.HeUniform + LOCAL.ALL_OBJECTS['lecun_normalV2'] = initializers_v2.LecunNormal + LOCAL.ALL_OBJECTS['lecun_uniformV2'] = initializers_v2.LecunUniform + + if tf2.enabled(): + # For V2, entries are generated automatically based on the content of + # initializers_v2.py. + v2_objs = {} + base_cls = initializers_v2.Initializer + generic_utils.populate_dict_with_module_objects( + v2_objs, + [initializers_v2], + obj_filter=lambda x: inspect.isclass(x) and issubclass(x, base_cls)) + for key, value in v2_objs.items(): + LOCAL.ALL_OBJECTS[key] = value + # Functional aliases. + LOCAL.ALL_OBJECTS[generic_utils.to_snake_case(key)] = value + else: + # V1 initializers. + v1_objs = { + 'Constant': init_ops.Constant, + 'GlorotNormal': init_ops.GlorotNormal, + 'GlorotUniform': init_ops.GlorotUniform, + 'Identity': init_ops.Identity, + 'Ones': init_ops.Ones, + 'Orthogonal': init_ops.Orthogonal, + 'VarianceScaling': init_ops.VarianceScaling, + 'Zeros': init_ops.Zeros, + 'HeNormal': initializers_v1.HeNormal, + 'HeUniform': initializers_v1.HeUniform, + 'LecunNormal': initializers_v1.LecunNormal, + 'LecunUniform': initializers_v1.LecunUniform, + 'RandomNormal': initializers_v1.RandomNormal, + 'RandomUniform': initializers_v1.RandomUniform, + 'TruncatedNormal': initializers_v1.TruncatedNormal, + } + for key, value in v1_objs.items(): + LOCAL.ALL_OBJECTS[key] = value + # Functional aliases. + LOCAL.ALL_OBJECTS[generic_utils.to_snake_case(key)] = value + + # More compatibility aliases. + LOCAL.ALL_OBJECTS['normal'] = LOCAL.ALL_OBJECTS['random_normal'] + LOCAL.ALL_OBJECTS['uniform'] = LOCAL.ALL_OBJECTS['random_uniform'] + LOCAL.ALL_OBJECTS['one'] = LOCAL.ALL_OBJECTS['ones'] + LOCAL.ALL_OBJECTS['zero'] = LOCAL.ALL_OBJECTS['zeros'] + + +# For backwards compatibility, we populate this file with the objects +# from ALL_OBJECTS. We make no guarantees as to whether these objects will +# using their correct version. +populate_deserializable_objects() +globals().update(LOCAL.ALL_OBJECTS) + +# Utility functions + + +@keras_export('keras.initializers.serialize') +def serialize(initializer): + return generic_utils.serialize_keras_object(initializer) + + +@keras_export('keras.initializers.deserialize') +def deserialize(config, custom_objects=None): + """Return an `Initializer` object from its config.""" + populate_deserializable_objects() + return generic_utils.deserialize_keras_object( + config, + module_objects=LOCAL.ALL_OBJECTS, + custom_objects=custom_objects, + printable_module_name='initializer') + + +@keras_export('keras.initializers.get') +def get(identifier): + if identifier is None: + return None + if isinstance(identifier, dict): + return deserialize(identifier) + elif isinstance(identifier, six.string_types): + identifier = str(identifier) + return deserialize(identifier) + elif callable(identifier): + return identifier + else: + raise ValueError('Could not interpret initializer identifier: ' + + str(identifier)) diff --git a/tensorflow/python/keras/initializers/initializers_v1.py b/tensorflow/python/keras/initializers/initializers_v1.py new file mode 100644 index 00000000000..63b81065e8d --- /dev/null +++ b/tensorflow/python/keras/initializers/initializers_v1.py @@ -0,0 +1,112 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Keras initializers for TF 1. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import init_ops +from tensorflow.python.util.tf_export import keras_export + +keras_export(v1=['keras.initializers.Zeros', 'keras.initializers.zeros'])( + init_ops.Zeros) +keras_export(v1=['keras.initializers.Ones', 'keras.initializers.ones'])( + init_ops.Ones) +keras_export(v1=['keras.initializers.Constant', 'keras.initializers.constant'])( + init_ops.Constant) +keras_export(v1=['keras.initializers.VarianceScaling'])( + init_ops.VarianceScaling) +keras_export(v1=['keras.initializers.Orthogonal', + 'keras.initializers.orthogonal'])(init_ops.Orthogonal) +keras_export(v1=['keras.initializers.Identity', + 'keras.initializers.identity'])(init_ops.Identity) +keras_export(v1=['keras.initializers.glorot_uniform'])(init_ops.GlorotUniform) +keras_export(v1=['keras.initializers.glorot_normal'])(init_ops.GlorotNormal) + + +@keras_export(v1=['keras.initializers.RandomNormal', + 'keras.initializers.random_normal', + 'keras.initializers.normal']) +class RandomNormal(init_ops.RandomNormal): + + def __init__(self, mean=0.0, stddev=0.05, seed=None, dtype=dtypes.float32): + super(RandomNormal, self).__init__( + mean=mean, stddev=stddev, seed=seed, dtype=dtype) + + +@keras_export(v1=['keras.initializers.RandomUniform', + 'keras.initializers.random_uniform', + 'keras.initializers.uniform']) +class RandomUniform(init_ops.RandomUniform): + + def __init__(self, minval=-0.05, maxval=0.05, seed=None, + dtype=dtypes.float32): + super(RandomUniform, self).__init__( + minval=minval, maxval=maxval, seed=seed, dtype=dtype) + + +@keras_export(v1=['keras.initializers.TruncatedNormal', + 'keras.initializers.truncated_normal']) +class TruncatedNormal(init_ops.TruncatedNormal): + + def __init__(self, mean=0.0, stddev=0.05, seed=None, dtype=dtypes.float32): + super(TruncatedNormal, self).__init__( + mean=mean, stddev=stddev, seed=seed, dtype=dtype) + + +@keras_export(v1=['keras.initializers.lecun_normal']) +class LecunNormal(init_ops.VarianceScaling): + + def __init__(self, seed=None): + super(LecunNormal, self).__init__( + scale=1., mode='fan_in', distribution='truncated_normal', seed=seed) + + def get_config(self): + return {'seed': self.seed} + + +@keras_export(v1=['keras.initializers.lecun_uniform']) +class LecunUniform(init_ops.VarianceScaling): + + def __init__(self, seed=None): + super(LecunUniform, self).__init__( + scale=1., mode='fan_in', distribution='uniform', seed=seed) + + def get_config(self): + return {'seed': self.seed} + + +@keras_export(v1=['keras.initializers.he_normal']) +class HeNormal(init_ops.VarianceScaling): + + def __init__(self, seed=None): + super(HeNormal, self).__init__( + scale=2., mode='fan_in', distribution='truncated_normal', seed=seed) + + def get_config(self): + return {'seed': self.seed} + + +@keras_export(v1=['keras.initializers.he_uniform']) +class HeUniform(init_ops.VarianceScaling): + + def __init__(self, seed=None): + super(HeUniform, self).__init__( + scale=2., mode='fan_in', distribution='uniform', seed=seed) + + def get_config(self): + return {'seed': self.seed} diff --git a/tensorflow/python/keras/initializers/initializers_v2.py b/tensorflow/python/keras/initializers/initializers_v2.py new file mode 100644 index 00000000000..69dca335857 --- /dev/null +++ b/tensorflow/python/keras/initializers/initializers_v2.py @@ -0,0 +1,751 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Keras initializers for TF 2. +""" +# pylint: disable=g-classes-have-attributes +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.keras import backend +from tensorflow.python.ops import init_ops_v2 +from tensorflow.python.util.tf_export import keras_export + + +@keras_export('keras.initializers.Initializer') +class Initializer(object): + """Initializer base class: all Keras initializers inherit from this class. + + Initializers should implement a `__call__` method with the following + signature: + + ```python + def __call__(self, shape, dtype=None)`: + # returns a tensor of shape `shape` and dtype `dtype` + # containing values drawn from a distribution of your choice. + ``` + + Optionally, you an also implement the method `get_config` and the class + method `from_config` in order to support serialization -- just like with + any Keras object. + + Here's a simple example: a random normal initializer. + + ```python + import tensorflow as tf + + class ExampleRandomNormal(tf.keras.initializers.Initializer): + + def __init__(self, mean, stddev): + self.mean = mean + self.stddev = stddev + + def __call__(self, shape, dtype=None)`: + return tf.random.normal( + shape, mean=self.mean, stddev=self.stddev, dtype=dtype) + + def get_config(self): # To support serialization + return {"mean": self.mean, "stddev": self.stddev} + ``` + + Note that we don't have to implement `from_config` in the example above since + the constructor arguments of the class the keys in the config returned by + `get_config` are the same. In this case, the default `from_config` + works fine. + """ + + def __call__(self, shape, dtype=None): + """Returns a tensor object initialized as specified by the initializer. + + Args: + shape: Shape of the tensor. + dtype: Optional dtype of the tensor. + """ + raise NotImplementedError + + def get_config(self): + """Returns the configuration of the initializer as a JSON-serializable dict. + + Returns: + A JSON-serializable Python dict. + """ + return {} + + @classmethod + def from_config(cls, config): + """Instantiates an initializer from a configuration dictionary. + + Example: + + ```python + initializer = RandomUniform(-1, 1) + config = initializer.get_config() + initializer = RandomUniform.from_config(config) + ``` + + Args: + config: A Python dictionary, the output of `get_config`. + + Returns: + A `tf.keras.initializers.Initializer` instance. + """ + config.pop('dtype', None) + return cls(**config) + + +@keras_export('keras.initializers.Zeros', 'keras.initializers.zeros', v1=[]) +class Zeros(init_ops_v2.Zeros, Initializer): + """Initializer that generates tensors initialized to 0. + + Also available via the shortcut function `tf.keras.initializers.zeros`. + + Examples: + + >>> # Standalone usage: + >>> initializer = tf.keras.initializers.Zeros() + >>> values = initializer(shape=(2, 2)) + + >>> # Usage in a Keras layer: + >>> initializer = tf.keras.initializers.Zeros() + >>> layer = tf.keras.layers.Dense(3, kernel_initializer=initializer) + """ + + def __call__(self, shape, dtype=None): + """Returns a tensor object initialized as specified by the initializer. + + Args: + shape: Shape of the tensor. + dtype: Optional dtype of the tensor. Only numeric or boolean dtypes are + supported. If not specified, `tf.keras.backend.floatx()` is used, + which default to `float32` unless you configured it otherwise + (via `tf.keras.backend.set_floatx(float_dtype)`). + """ + return super(Zeros, self).__call__(shape, dtype=_get_dtype(dtype)) + + +@keras_export('keras.initializers.Ones', 'keras.initializers.ones', v1=[]) +class Ones(init_ops_v2.Ones, Initializer): + """Initializer that generates tensors initialized to 1. + + Also available via the shortcut function `tf.keras.initializers.ones`. + + Examples: + + >>> # Standalone usage: + >>> initializer = tf.keras.initializers.Ones() + >>> values = initializer(shape=(2, 2)) + + >>> # Usage in a Keras layer: + >>> initializer = tf.keras.initializers.Ones() + >>> layer = tf.keras.layers.Dense(3, kernel_initializer=initializer) + """ + + def __call__(self, shape, dtype=None): + """Returns a tensor object initialized as specified by the initializer. + + Args: + shape: Shape of the tensor. + dtype: Optional dtype of the tensor. Only numeric or boolean dtypes are + supported. If not specified, `tf.keras.backend.floatx()` is used, + which default to `float32` unless you configured it otherwise + (via `tf.keras.backend.set_floatx(float_dtype)`). + """ + return super(Ones, self).__call__(shape, dtype=_get_dtype(dtype)) + + +@keras_export('keras.initializers.Constant', + 'keras.initializers.constant', + v1=[]) +class Constant(Initializer): + """Initializer that generates tensors with constant values. + + Also available via the shortcut function `tf.keras.initializers.constant`. + + Only scalar values are allowed. + The constant value provided must be convertible to the dtype requested + when calling the initializer. + + Examples: + + >>> # Standalone usage: + >>> initializer = tf.keras.initializers.Constant(3.) + >>> values = initializer(shape=(2, 2)) + + >>> # Usage in a Keras layer: + >>> initializer = tf.keras.initializers.Constant(3.) + >>> layer = tf.keras.layers.Dense(3, kernel_initializer=initializer) + + Args: + value: A Python scalar. + """ + + def __init__(self, value=0): + self.value = value + + def __call__(self, shape, dtype=None): + """Returns a tensor object initialized to `self.value`. + + Args: + shape: Shape of the tensor. + dtype: Optional dtype of the tensor. If not specified, + `tf.keras.backend.floatx()` is used, + which default to `float32` unless you configured it otherwise + (via `tf.keras.backend.set_floatx(float_dtype)`). + """ + return constant_op.constant( + self.value, dtype=_get_dtype(dtype), shape=shape) + + def get_config(self): + return {'value': self.value} + + +@keras_export('keras.initializers.RandomUniform', + 'keras.initializers.random_uniform', + v1=[]) +class RandomUniform(init_ops_v2.RandomUniform, Initializer): + """Initializer that generates tensors with a uniform distribution. + + Also available via the shortcut function + `tf.keras.initializers.random_uniform`. + + Examples: + + >>> # Standalone usage: + >>> initializer = tf.keras.initializers.RandomUniform(minval=0., maxval=1.) + >>> values = initializer(shape=(2, 2)) + + >>> # Usage in a Keras layer: + >>> initializer = tf.keras.initializers.RandomUniform(minval=0., maxval=1.) + >>> layer = tf.keras.layers.Dense(3, kernel_initializer=initializer) + + Args: + minval: A python scalar or a scalar tensor. Lower bound of the range of + random values to generate (inclusive). + maxval: A python scalar or a scalar tensor. Upper bound of the range of + random values to generate (exclusive). + seed: A Python integer. An initializer created with a given seed will + always produce the same random tensor for a given shape and dtype. + """ + + def __call__(self, shape, dtype=None): + """Returns a tensor object initialized as specified by the initializer. + + Args: + shape: Shape of the tensor. + dtype: Optional dtype of the tensor. Only floating point and integer + types are supported. If not specified, + `tf.keras.backend.floatx()` is used, + which default to `float32` unless you configured it otherwise + (via `tf.keras.backend.set_floatx(float_dtype)`). + """ + return super(RandomUniform, self).__call__(shape, dtype=_get_dtype(dtype)) + + +@keras_export('keras.initializers.RandomNormal', + 'keras.initializers.random_normal', + v1=[]) +class RandomNormal(init_ops_v2.RandomNormal, Initializer): + """Initializer that generates tensors with a normal distribution. + + Also available via the shortcut function + `tf.keras.initializers.random_normal`. + + Examples: + + >>> # Standalone usage: + >>> initializer = tf.keras.initializers.RandomNormal(mean=0., stddev=1.) + >>> values = initializer(shape=(2, 2)) + + >>> # Usage in a Keras layer: + >>> initializer = tf.keras.initializers.RandomNormal(mean=0., stddev=1.) + >>> layer = tf.keras.layers.Dense(3, kernel_initializer=initializer) + + Args: + mean: a python scalar or a scalar tensor. Mean of the random values to + generate. + stddev: a python scalar or a scalar tensor. Standard deviation of the random + values to generate. + seed: A Python integer. An initializer created with a given seed will + always produce the same random tensor for a given shape and dtype. + """ + + def __call__(self, shape, dtype=None): + """Returns a tensor object initialized to random normal values. + + Args: + shape: Shape of the tensor. + dtype: Optional dtype of the tensor. Only floating point types are + supported. If not specified, `tf.keras.backend.floatx()` is used, + which default to `float32` unless you configured it otherwise + (via `tf.keras.backend.set_floatx(float_dtype)`) + """ + return super(RandomNormal, self).__call__(shape, dtype=_get_dtype(dtype)) + + +@keras_export('keras.initializers.TruncatedNormal', + 'keras.initializers.truncated_normal', + v1=[]) +class TruncatedNormal(init_ops_v2.TruncatedNormal, Initializer): + """Initializer that generates a truncated normal distribution. + + Also available via the shortcut function + `tf.keras.initializers.truncated_normal`. + + The values generated are similar to values from a + `tf.keras.initializers.RandomNormal` initializer except that values more + than two standard deviations from the mean are + discarded and re-drawn. + + Examples: + + >>> # Standalone usage: + >>> initializer = tf.keras.initializers.TruncatedNormal(mean=0., stddev=1.) + >>> values = initializer(shape=(2, 2)) + + >>> # Usage in a Keras layer: + >>> initializer = tf.keras.initializers.TruncatedNormal(mean=0., stddev=1.) + >>> layer = tf.keras.layers.Dense(3, kernel_initializer=initializer) + + Args: + mean: a python scalar or a scalar tensor. Mean of the random values + to generate. + stddev: a python scalar or a scalar tensor. Standard deviation of the + random values to generate. + seed: A Python integer. An initializer created with a given seed will + always produce the same random tensor for a given shape and dtype. + """ + + def __call__(self, shape, dtype=None): + """Returns a tensor object initialized to random normal values (truncated). + + Args: + shape: Shape of the tensor. + dtype: Optional dtype of the tensor. Only floating point types are + supported. If not specified, `tf.keras.backend.floatx()` is used, + which default to `float32` unless you configured it otherwise + (via `tf.keras.backend.set_floatx(float_dtype)`) + """ + return super(TruncatedNormal, self).__call__(shape, dtype=_get_dtype(dtype)) + + +@keras_export('keras.initializers.VarianceScaling', + 'keras.initializers.variance_scaling', + v1=[]) +class VarianceScaling(init_ops_v2.VarianceScaling, Initializer): + """Initializer capable of adapting its scale to the shape of weights tensors. + + Also available via the shortcut function + `tf.keras.initializers.variance_scaling`. + + With `distribution="truncated_normal" or "untruncated_normal"`, samples are + drawn from a truncated/untruncated normal distribution with a mean of zero and + a standard deviation (after truncation, if used) `stddev = sqrt(scale / n)` + where n is: + + - number of input units in the weight tensor, if mode = "fan_in" + - number of output units, if mode = "fan_out" + - average of the numbers of input and output units, if mode = "fan_avg" + + With `distribution="uniform"`, samples are drawn from a uniform distribution + within [-limit, limit], with `limit = sqrt(3 * scale / n)`. + + Examples: + + >>> # Standalone usage: + >>> initializer = tf.keras.initializers.VarianceScaling( + ... scale=0.1, mode='fan_in', distribution='uniform') + >>> values = initializer(shape=(2, 2)) + + >>> # Usage in a Keras layer: + >>> initializer = tf.keras.initializers.VarianceScaling( + ... scale=0.1, mode='fan_in', distribution='uniform') + >>> layer = tf.keras.layers.Dense(3, kernel_initializer=initializer) + + Args: + scale: Scaling factor (positive float). + mode: One of "fan_in", "fan_out", "fan_avg". + distribution: Random distribution to use. One of "truncated_normal", + "untruncated_normal" and "uniform". + seed: A Python integer. An initializer created with a given seed will + always produce the same random tensor for a given shape and dtype. + """ + + def __call__(self, shape, dtype=None): + """Returns a tensor object initialized as specified by the initializer. + + Args: + shape: Shape of the tensor. + dtype: Optional dtype of the tensor. Only floating point types are + supported. If not specified, `tf.keras.backend.floatx()` is used, + which default to `float32` unless you configured it otherwise + (via `tf.keras.backend.set_floatx(float_dtype)`) + """ + return super(VarianceScaling, self).__call__(shape, dtype=_get_dtype(dtype)) + + +@keras_export('keras.initializers.Orthogonal', + 'keras.initializers.orthogonal', + v1=[]) +class Orthogonal(init_ops_v2.Orthogonal, Initializer): + """Initializer that generates an orthogonal matrix. + + Also available via the shortcut function `tf.keras.initializers.orthogonal`. + + If the shape of the tensor to initialize is two-dimensional, it is initialized + with an orthogonal matrix obtained from the QR decomposition of a matrix of + random numbers drawn from a normal distribution. + If the matrix has fewer rows than columns then the output will have orthogonal + rows. Otherwise, the output will have orthogonal columns. + + If the shape of the tensor to initialize is more than two-dimensional, + a matrix of shape `(shape[0] * ... * shape[n - 2], shape[n - 1])` + is initialized, where `n` is the length of the shape vector. + The matrix is subsequently reshaped to give a tensor of the desired shape. + + Examples: + + >>> # Standalone usage: + >>> initializer = tf.keras.initializers.Orthogonal() + >>> values = initializer(shape=(2, 2)) + + >>> # Usage in a Keras layer: + >>> initializer = tf.keras.initializers.Orthogonal() + >>> layer = tf.keras.layers.Dense(3, kernel_initializer=initializer) + + Args: + gain: multiplicative factor to apply to the orthogonal matrix + seed: A Python integer. An initializer created with a given seed will + always produce the same random tensor for a given shape and dtype. + + References: + [Saxe et al., 2014](https://openreview.net/forum?id=_wzZwKpTDF_9C) + ([pdf](https://arxiv.org/pdf/1312.6120.pdf)) + """ + + def __call__(self, shape, dtype=None): + """Returns a tensor object initialized to an orthogonal matrix. + + Args: + shape: Shape of the tensor. + dtype: Optional dtype of the tensor. Only floating point types are + supported. If not specified, `tf.keras.backend.floatx()` is used, + which default to `float32` unless you configured it otherwise + (via `tf.keras.backend.set_floatx(float_dtype)`) + """ + return super(Orthogonal, self).__call__(shape, dtype=_get_dtype(dtype)) + + +@keras_export('keras.initializers.Identity', + 'keras.initializers.identity', + v1=[]) +class Identity(init_ops_v2.Identity, Initializer): + """Initializer that generates the identity matrix. + + Also available via the shortcut function `tf.keras.initializers.identity`. + + Only usable for generating 2D matrices. + + Examples: + + >>> # Standalone usage: + >>> initializer = tf.keras.initializers.Identity() + >>> values = initializer(shape=(2, 2)) + + >>> # Usage in a Keras layer: + >>> initializer = tf.keras.initializers.Identity() + >>> layer = tf.keras.layers.Dense(3, kernel_initializer=initializer) + + Args: + gain: Multiplicative factor to apply to the identity matrix. + """ + + def __call__(self, shape, dtype=None): + """Returns a tensor object initialized to a 2D identity matrix. + + Args: + shape: Shape of the tensor. It should have exactly rank 2. + dtype: Optional dtype of the tensor. Only floating point types are + supported. If not specified, `tf.keras.backend.floatx()` is used, + which default to `float32` unless you configured it otherwise + (via `tf.keras.backend.set_floatx(float_dtype)`) + """ + return super(Identity, self).__call__(shape, dtype=_get_dtype(dtype)) + + +@keras_export('keras.initializers.GlorotUniform', + 'keras.initializers.glorot_uniform', + v1=[]) +class GlorotUniform(VarianceScaling): + """The Glorot uniform initializer, also called Xavier uniform initializer. + + Also available via the shortcut function + `tf.keras.initializers.glorot_uniform`. + + Draws samples from a uniform distribution within [-limit, limit] where `limit` + is `sqrt(6 / (fan_in + fan_out))` where `fan_in` is the number of input units + in the weight tensor and `fan_out` is the number of output units in the weight + tensor. + + Examples: + + >>> # Standalone usage: + >>> initializer = tf.keras.initializers.GlorotUniform() + >>> values = initializer(shape=(2, 2)) + + >>> # Usage in a Keras layer: + >>> initializer = tf.keras.initializers.GlorotUniform() + >>> layer = tf.keras.layers.Dense(3, kernel_initializer=initializer) + + Args: + seed: A Python integer. An initializer created with a given seed will + always produce the same random tensor for a given shape and dtype. + + References: + [Glorot et al., 2010](http://proceedings.mlr.press/v9/glorot10a.html) + ([pdf](http://jmlr.org/proceedings/papers/v9/glorot10a/glorot10a.pdf)) + """ + + def __init__(self, seed=None): + super(GlorotUniform, self).__init__( + scale=1.0, + mode='fan_avg', + distribution='uniform', + seed=seed) + + def get_config(self): + return {'seed': self.seed} + + +@keras_export('keras.initializers.GlorotNormal', + 'keras.initializers.glorot_normal', + v1=[]) +class GlorotNormal(VarianceScaling): + """The Glorot normal initializer, also called Xavier normal initializer. + + Also available via the shortcut function + `tf.keras.initializers.glorot_normal`. + + Draws samples from a truncated normal distribution centered on 0 with `stddev + = sqrt(2 / (fan_in + fan_out))` where `fan_in` is the number of input units in + the weight tensor and `fan_out` is the number of output units in the weight + tensor. + + Examples: + + >>> # Standalone usage: + >>> initializer = tf.keras.initializers.GlorotNormal() + >>> values = initializer(shape=(2, 2)) + + >>> # Usage in a Keras layer: + >>> initializer = tf.keras.initializers.GlorotNormal() + >>> layer = tf.keras.layers.Dense(3, kernel_initializer=initializer) + + Args: + seed: A Python integer. An initializer created with a given seed will + always produce the same random tensor for a given shape and dtype. + + References: + [Glorot et al., 2010](http://proceedings.mlr.press/v9/glorot10a.html) + ([pdf](http://jmlr.org/proceedings/papers/v9/glorot10a/glorot10a.pdf)) + """ + + def __init__(self, seed=None): + super(GlorotNormal, self).__init__( + scale=1.0, + mode='fan_avg', + distribution='truncated_normal', + seed=seed) + + def get_config(self): + return {'seed': self.seed} + + +@keras_export('keras.initializers.LecunNormal', + 'keras.initializers.lecun_normal', + v1=[]) +class LecunNormal(VarianceScaling): + """Lecun normal initializer. + + Also available via the shortcut function + `tf.keras.initializers.lecun_normal`. + + Initializers allow you to pre-specify an initialization strategy, encoded in + the Initializer object, without knowing the shape and dtype of the variable + being initialized. + + Draws samples from a truncated normal distribution centered on 0 with `stddev + = sqrt(1 / fan_in)` where `fan_in` is the number of input units in the weight + tensor. + + Examples: + + >>> # Standalone usage: + >>> initializer = tf.keras.initializers.LecunNormal() + >>> values = initializer(shape=(2, 2)) + + >>> # Usage in a Keras layer: + >>> initializer = tf.keras.initializers.LecunNormal() + >>> layer = tf.keras.layers.Dense(3, kernel_initializer=initializer) + + Arguments: + seed: A Python integer. Used to seed the random generator. + + References: + - Self-Normalizing Neural Networks, + [Klambauer et al., 2017] + (https://papers.nips.cc/paper/6698-self-normalizing-neural-networks) + ([pdf] + (https://papers.nips.cc/paper/6698-self-normalizing-neural-networks.pdf)) + - Efficient Backprop, + [Lecun et al., 1998](http://yann.lecun.com/exdb/publis/pdf/lecun-98b.pdf) + """ + + def __init__(self, seed=None): + super(LecunNormal, self).__init__( + scale=1., mode='fan_in', distribution='truncated_normal', seed=seed) + + def get_config(self): + return {'seed': self.seed} + + +@keras_export('keras.initializers.LecunUniform', + 'keras.initializers.lecun_uniform', + v1=[]) +class LecunUniform(VarianceScaling): + """Lecun uniform initializer. + + Also available via the shortcut function + `tf.keras.initializers.lecun_uniform`. + + Draws samples from a uniform distribution within [-limit, limit] where `limit` + is `sqrt(3 / fan_in)` where `fan_in` is the number of input units in the + weight tensor. + + Examples: + + >>> # Standalone usage: + >>> initializer = tf.keras.initializers.LecunUniform() + >>> values = initializer(shape=(2, 2)) + + >>> # Usage in a Keras layer: + >>> initializer = tf.keras.initializers.LecunUniform() + >>> layer = tf.keras.layers.Dense(3, kernel_initializer=initializer) + + Arguments: + seed: A Python integer. An initializer created with a given seed will + always produce the same random tensor for a given shape and dtype. + + References: + - Self-Normalizing Neural Networks, + [Klambauer et al., 2017](https://papers.nips.cc/paper/6698-self-normalizing-neural-networks) # pylint: disable=line-too-long + ([pdf](https://papers.nips.cc/paper/6698-self-normalizing-neural-networks.pdf)) + - Efficient Backprop, + [Lecun et al., 1998](http://yann.lecun.com/exdb/publis/pdf/lecun-98b.pdf) + """ + + def __init__(self, seed=None): + super(LecunUniform, self).__init__( + scale=1., mode='fan_in', distribution='uniform', seed=seed) + + def get_config(self): + return {'seed': self.seed} + + +@keras_export('keras.initializers.HeNormal', + 'keras.initializers.he_normal', + v1=[]) +class HeNormal(VarianceScaling): + """He normal initializer. + + Also available via the shortcut function + `tf.keras.initializers.he_normal`. + + It draws samples from a truncated normal distribution centered on 0 with + `stddev = sqrt(2 / fan_in)` where `fan_in` is the number of input units in the + weight tensor. + + Examples: + + >>> # Standalone usage: + >>> initializer = tf.keras.initializers.HeNormal() + >>> values = initializer(shape=(2, 2)) + + >>> # Usage in a Keras layer: + >>> initializer = tf.keras.initializers.HeNormal() + >>> layer = tf.keras.layers.Dense(3, kernel_initializer=initializer) + + Arguments: + seed: A Python integer. An initializer created with a given seed will + always produce the same random tensor for a given shape and dtype. + + References: + [He et al., 2015](https://www.cv-foundation.org/openaccess/content_iccv_2015/html/He_Delving_Deep_into_ICCV_2015_paper.html) # pylint: disable=line-too-long + ([pdf](https://www.cv-foundation.org/openaccess/content_iccv_2015/papers/He_Delving_Deep_into_ICCV_2015_paper.pdf)) + """ + + def __init__(self, seed=None): + super(HeNormal, self).__init__( + scale=2., mode='fan_in', distribution='truncated_normal', seed=seed) + + def get_config(self): + return {'seed': self.seed} + + +@keras_export('keras.initializers.HeUniform', + 'keras.initializers.he_uniform', + v1=[]) +class HeUniform(VarianceScaling): + """He uniform variance scaling initializer. + + Also available via the shortcut function + `tf.keras.initializers.he_uniform`. + + Draws samples from a uniform distribution within [-limit, limit] where `limit` + is `sqrt(6 / fan_in)` where `fan_in` is the number of input units in the + weight tensor. + + Examples: + + >>> # Standalone usage: + >>> initializer = tf.keras.initializers.HeUniform() + >>> values = initializer(shape=(2, 2)) + + >>> # Usage in a Keras layer: + >>> initializer = tf.keras.initializers.HeUniform() + >>> layer = tf.keras.layers.Dense(3, kernel_initializer=initializer) + + Arguments: + seed: A Python integer. An initializer created with a given seed will + always produce the same random tensor for a given shape and dtype. + + References: + [He et al., 2015](https://www.cv-foundation.org/openaccess/content_iccv_2015/html/He_Delving_Deep_into_ICCV_2015_paper.html) # pylint: disable=line-too-long + ([pdf](https://www.cv-foundation.org/openaccess/content_iccv_2015/papers/He_Delving_Deep_into_ICCV_2015_paper.pdf)) + """ + + def __init__(self, seed=None): + super(HeUniform, self).__init__( + scale=2., mode='fan_in', distribution='uniform', seed=seed) + + def get_config(self): + return {'seed': self.seed} + + +def _get_dtype(dtype): + if dtype is None: + dtype = backend.floatx() + return dtypes.as_dtype(dtype) diff --git a/tensorflow/python/keras/initializers_test.py b/tensorflow/python/keras/initializers_test.py index fd8f50850a9..3e4502f14fc 100644 --- a/tensorflow/python/keras/initializers_test.py +++ b/tensorflow/python/keras/initializers_test.py @@ -20,7 +20,6 @@ from __future__ import print_function import numpy as np -from tensorflow.python import tf2 from tensorflow.python.framework import test_util from tensorflow.python.keras import backend from tensorflow.python.keras import combinations @@ -92,7 +91,7 @@ class KerasInitializersTest(test.TestCase): fan_in, _ = init_ops._compute_fans(tensor_shape) std = np.sqrt(1. / fan_in) self._runner( - initializers.lecun_uniformV2(seed=123), + initializers.LecunUniformV2(seed=123), tensor_shape, target_mean=0., target_std=std) @@ -114,7 +113,7 @@ class KerasInitializersTest(test.TestCase): fan_in, _ = init_ops._compute_fans(tensor_shape) std = np.sqrt(2. / fan_in) self._runner( - initializers.he_uniformV2(seed=123), + initializers.HeUniformV2(seed=123), tensor_shape, target_mean=0., target_std=std) @@ -125,7 +124,7 @@ class KerasInitializersTest(test.TestCase): fan_in, _ = init_ops._compute_fans(tensor_shape) std = np.sqrt(1. / fan_in) self._runner( - initializers.lecun_normalV2(seed=123), + initializers.LecunNormalV2(seed=123), tensor_shape, target_mean=0., target_std=std) @@ -147,7 +146,7 @@ class KerasInitializersTest(test.TestCase): fan_in, _ = init_ops._compute_fans(tensor_shape) std = np.sqrt(2. / fan_in) self._runner( - initializers.he_normalV2(seed=123), + initializers.HeNormalV2(seed=123), tensor_shape, target_mean=0., target_std=std) @@ -202,15 +201,6 @@ class KerasInitializersTest(test.TestCase): self.assertEqual(tn.mean, 0.0) self.assertEqual(tn.stddev, 0.05) - def test_initializer_v2_get(self): - tf2_force_enabled = tf2._force_enable # pylint: disable=protected-access - try: - tf2.enable() - rn = initializers.get('random_normal') - self.assertIn('init_ops_v2', rn.__class__.__module__) - finally: - tf2._force_enable = tf2_force_enabled # pylint: disable=protected-access - def test_custom_initializer_saving(self): def my_initializer(shape, dtype=None): diff --git a/tensorflow/python/keras/layers/BUILD b/tensorflow/python/keras/layers/BUILD index ad0cdb20f44..15a6f6bd191 100644 --- a/tensorflow/python/keras/layers/BUILD +++ b/tensorflow/python/keras/layers/BUILD @@ -665,7 +665,10 @@ tf_py_test( srcs = ["gru_test.py"], python_version = "PY3", shard_count = 4, - tags = ["notsan"], # http://b/62136390 + tags = [ + "no_rocm", + "notsan", # http://b/62136390 + ], deps = [ "//tensorflow/python:client_testlib", "//tensorflow/python/keras", @@ -682,6 +685,7 @@ tf_py_test( python_version = "PY3", shard_count = 4, tags = [ + "no_rocm", "noasan", # times out b/63678675 "notsan", # http://b/62189182 ], diff --git a/tensorflow/python/keras/layers/__init__.py b/tensorflow/python/keras/layers/__init__.py index 9fd902d70e9..9b4bc46ef31 100644 --- a/tensorflow/python/keras/layers/__init__.py +++ b/tensorflow/python/keras/layers/__init__.py @@ -242,9 +242,28 @@ from tensorflow.python.keras.layers.rnn_cell_wrapper_v2 import DropoutWrapper from tensorflow.python.keras.layers.rnn_cell_wrapper_v2 import ResidualWrapper # Serialization functions +from tensorflow.python.keras.layers import serialization from tensorflow.python.keras.layers.serialization import deserialize from tensorflow.python.keras.layers.serialization import serialize + +class VersionAwareLayers(object): + """Utility to be used internally to access layers in a V1/V2-aware fashion. + + When using layers within the Keras codebase, under the constraint that + e.g. `layers.BatchNormalization` should be the `BatchNormalization` version + corresponding to the current runtime (TF1 or TF2), do not simply access + `layers.BatchNormalization` since it would ignore e.g. an early + `compat.v2.disable_v2_behavior()` call. Instead, use an instance + of `VersionAwareLayers` (which you can use just like the `layers` module). + """ + + def __getattr__(self, name): + serialization.populate_deserializable_objects() + if name in serialization.LOCAL.ALL_OBJECTS: + return serialization.LOCAL.ALL_OBJECTS[name] + return super(VersionAwareLayers, self).__getattr__(name) + del absolute_import del division del print_function diff --git a/tensorflow/python/keras/layers/convolutional.py b/tensorflow/python/keras/layers/convolutional.py index 571eaeab0b2..15e8f5e526d 100644 --- a/tensorflow/python/keras/layers/convolutional.py +++ b/tensorflow/python/keras/layers/convolutional.py @@ -1158,6 +1158,7 @@ class Conv3DTranspose(Conv3D): padding='valid', output_padding=None, data_format=None, + dilation_rate=(1, 1, 1), activation=None, use_bias=True, kernel_initializer='glorot_uniform', @@ -1174,6 +1175,7 @@ class Conv3DTranspose(Conv3D): strides=strides, padding=padding, data_format=data_format, + dilation_rate=dilation_rate, activation=activations.get(activation), use_bias=use_bias, kernel_initializer=initializers.get(kernel_initializer), diff --git a/tensorflow/python/keras/layers/convolutional_recurrent.py b/tensorflow/python/keras/layers/convolutional_recurrent.py index 48f724b55e1..1929b145561 100644 --- a/tensorflow/python/keras/layers/convolutional_recurrent.py +++ b/tensorflow/python/keras/layers/convolutional_recurrent.py @@ -569,7 +569,7 @@ class ConvLSTM2DCell(DropoutRNNCellMixin, Layer): def bias_initializer(_, *args, **kwargs): return K.concatenate([ self.bias_initializer((self.filters,), *args, **kwargs), - initializers.Ones()((self.filters,), *args, **kwargs), + initializers.get('ones')((self.filters,), *args, **kwargs), self.bias_initializer((self.filters * 2,), *args, **kwargs), ]) else: diff --git a/tensorflow/python/keras/layers/convolutional_test.py b/tensorflow/python/keras/layers/convolutional_test.py index 6dd63802c91..41a69be761b 100644 --- a/tensorflow/python/keras/layers/convolutional_test.py +++ b/tensorflow/python/keras/layers/convolutional_test.py @@ -276,6 +276,39 @@ class Conv3DTest(keras_parameterized.TestCase): input_data=input_data) +@keras_parameterized.run_all_keras_modes +class Conv3DTransposeTest(keras_parameterized.TestCase): + + def _run_test(self, kwargs, expected_output_shape): + num_samples = 2 + stack_size = 3 + num_row = 7 + num_col = 6 + depth = 5 + + with test_util.use_gpu(): + testing_utils.layer_test( + keras.layers.Conv3DTranspose, + kwargs=kwargs, + input_shape=(num_samples, depth, num_row, num_col, stack_size), + expected_output_shape=expected_output_shape) + + @parameterized.named_parameters( + ('padding_valid', {'padding': 'valid'}, (None, 7, 9, 8, 2)), + ('padding_same', {'padding': 'same'}, (None, 5, 7, 6, 2)), + ('strides', {'strides': (2, 2, 2)}, (None, 11, 15, 13, 2)), + ('dilation_rate', {'dilation_rate': (2, 2, 2)}, (None, 7, 9, 8, 2)), + # Only runs on GPU with CUDA, channels_first is not supported on CPU. + # TODO(b/62340061): Support channels_first on CPU. + ('data_format', {'data_format': 'channels_first'}), + ) + def test_conv3d_transpose(self, kwargs, expected_output_shape=None): + kwargs['filters'] = 2 + kwargs['kernel_size'] = (3, 3, 3) + if 'data_format' not in kwargs or test.is_gpu_available(cuda_only=True): + self._run_test(kwargs, expected_output_shape) + + @keras_parameterized.run_all_keras_modes class ConvSequentialTest(keras_parameterized.TestCase): diff --git a/tensorflow/python/keras/layers/embeddings.py b/tensorflow/python/keras/layers/embeddings.py index 1f7a4cba1f8..21711116757 100644 --- a/tensorflow/python/keras/layers/embeddings.py +++ b/tensorflow/python/keras/layers/embeddings.py @@ -61,7 +61,7 @@ class Embedding(Layer): Arguments: input_dim: int > 0. Size of the vocabulary, i.e. maximum integer index + 1. - output_dim: int >= 0. Dimension of the dense embedding. + output_dim: int > 0. Dimension of the dense embedding. embeddings_initializer: Initializer for the `embeddings` matrix. embeddings_regularizer: Regularizer function applied to the `embeddings` matrix. @@ -103,6 +103,10 @@ class Embedding(Layer): kwargs['input_shape'] = (input_length,) else: kwargs['input_shape'] = (None,) + if input_dim <= 0 or output_dim <= 0: + raise ValueError('Both `input_dim` and `output_dim` should be positive, ' + 'found input_dim {} and output_dim {}'.format( + input_dim, output_dim)) dtype = kwargs.pop('dtype', K.floatx()) # We set autocast to False, as we do not want to cast floating- point inputs # to self.dtype. In call(), we cast to int32, and casting to self.dtype diff --git a/tensorflow/python/keras/layers/embeddings_test.py b/tensorflow/python/keras/layers/embeddings_test.py index ee067e786b0..661b29cd7bf 100644 --- a/tensorflow/python/keras/layers/embeddings_test.py +++ b/tensorflow/python/keras/layers/embeddings_test.py @@ -86,6 +86,13 @@ class EmbeddingTest(keras_parameterized.TestCase): outputs = model.predict(np.array([[0, 1, 0]], dtype='int32')) self.assertAllClose(outputs, [[[1, 1], [2, 2], [1, 1]]]) + def test_embedding_incorrect_dimension(self): + with self.assertRaises(ValueError): + keras.layers.Embedding(input_dim=0, output_dim=1) + + with self.assertRaises(ValueError): + keras.layers.Embedding(input_dim=1, output_dim=0) + @combinations.generate(combinations.combine(mode=['graph', 'eager'])) def test_eager_gpu_cpu(self): l = keras.layers.Embedding(output_dim=2, input_dim=2) diff --git a/tensorflow/python/keras/layers/preprocessing/benchmarks/BUILD b/tensorflow/python/keras/layers/preprocessing/benchmarks/BUILD new file mode 100644 index 00000000000..f488e1da34f --- /dev/null +++ b/tensorflow/python/keras/layers/preprocessing/benchmarks/BUILD @@ -0,0 +1,28 @@ +# Benchmarks for Keras preprocessing layers. +load("//tensorflow:tensorflow.bzl", "tf_py_test") + +package( + licenses = ["notice"], # Apache 2.0 +) + +exports_files(["LICENSE"]) + +tf_py_test( + name = "index_lookup_adapt_benchmark", + srcs = ["index_lookup_adapt_benchmark.py"], + python_version = "PY3", + deps = [ + "//tensorflow:tensorflow_py", + "//tensorflow/python/keras/layers/preprocessing:index_lookup", + ], +) + +tf_py_test( + name = "normalization_adapt_benchmark", + srcs = ["normalization_adapt_benchmark.py"], + python_version = "PY3", + deps = [ + "//tensorflow:tensorflow_py", + "//tensorflow/python/keras/layers/preprocessing:normalization", + ], +) diff --git a/tensorflow/python/keras/layers/preprocessing/benchmarks/index_lookup_adapt_benchmark.py b/tensorflow/python/keras/layers/preprocessing/benchmarks/index_lookup_adapt_benchmark.py new file mode 100644 index 00000000000..619fb86103b --- /dev/null +++ b/tensorflow/python/keras/layers/preprocessing/benchmarks/index_lookup_adapt_benchmark.py @@ -0,0 +1,124 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Benchmark for Keras text vectorization preprocessing layer's adapt method.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import itertools +import random +import string +import time + +from absl import flags +import numpy as np + +from tensorflow.python import keras +from tensorflow.python.compat import v2_compat +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import tensor_shape +from tensorflow.python.keras.layers.preprocessing import index_lookup +from tensorflow.python.platform import benchmark +from tensorflow.python.platform import test + +FLAGS = flags.FLAGS + +v2_compat.enable_v2_behavior() + + +# word_gen creates random sequences of ASCII letters (both lowercase and upper). +# The number of unique strings is ~2,700. +def word_gen(): + for _ in itertools.count(1): + yield "".join(random.choice(string.ascii_letters) for i in range(2)) + + +def get_top_k(dataset, k): + """Python implementation of vocabulary building using a defaultdict.""" + counts = collections.defaultdict(int) + for tensor in dataset: + data = tensor.numpy() + for element in data: + counts[element] += 1 + sorted_vocab = [ + k for k, _ in sorted( + counts.items(), key=lambda item: item[1], reverse=True) + ] + if len(sorted_vocab) > k: + sorted_vocab = sorted_vocab[:k] + return sorted_vocab + + +class BenchmarkAdapt(benchmark.Benchmark): + """Benchmark adapt.""" + + def run_numpy_implementation(self, num_elements, batch_size, k): + """Test the python implementation.""" + ds = dataset_ops.Dataset.from_generator(word_gen, dtypes.string, + tensor_shape.TensorShape([])) + batched_ds = ds.take(num_elements).batch(batch_size) + input_t = keras.Input(shape=(), dtype=dtypes.string) + layer = index_lookup.IndexLookup( + max_tokens=k, num_oov_tokens=0, reserve_zero=False) + _ = layer(input_t) + num_repeats = 5 + starts = [] + ends = [] + for _ in range(num_repeats): + starts.append(time.time()) + vocab = get_top_k(batched_ds, k) + layer.set_vocabulary(vocab) + ends.append(time.time()) + avg_time = np.mean(np.array(ends) - np.array(starts)) + return avg_time + + def bm_adapt_implementation(self, num_elements, batch_size, k): + """Test the KPL adapt implementation.""" + ds = dataset_ops.Dataset.from_generator(word_gen, dtypes.string, + tensor_shape.TensorShape([])) + batched_ds = ds.take(num_elements).batch(batch_size) + input_t = keras.Input(shape=(), dtype=dtypes.string) + layer = index_lookup.IndexLookup( + max_tokens=k, num_oov_tokens=0, reserve_zero=False) + _ = layer(input_t) + num_repeats = 5 + starts = [] + ends = [] + for _ in range(num_repeats): + starts.append(time.time()) + layer.adapt(batched_ds) + ends.append(time.time()) + avg_time = np.mean(np.array(ends) - np.array(starts)) + name = "index_lookup_adapt|%s_elements|vocab_size_%s|batch_%s" % ( + num_elements, k, batch_size) + baseline = self.run_numpy_implementation(num_elements, batch_size, k) + extras = { + "numpy implementation baseline": baseline, + "delta seconds": (baseline - avg_time), + "delta percent": ((baseline - avg_time) / baseline) * 100 + } + self.report_benchmark( + iters=num_repeats, wall_time=avg_time, extras=extras, name=name) + + def benchmark_vocab_size_by_batch(self): + for vocab_size in [100, 1000, 10000, 100000, 1000000]: + for batch in [1, 16, 2048]: + self.bm_adapt_implementation(vocab_size, batch, int(vocab_size / 10)) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/keras/layers/preprocessing/benchmarks/normalization_adapt_benchmark.py b/tensorflow/python/keras/layers/preprocessing/benchmarks/normalization_adapt_benchmark.py new file mode 100644 index 00000000000..dfce2963f75 --- /dev/null +++ b/tensorflow/python/keras/layers/preprocessing/benchmarks/normalization_adapt_benchmark.py @@ -0,0 +1,133 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Benchmark for Keras text vectorization preprocessing layer's adapt method.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import time + +from absl import flags +import numpy as np + +from tensorflow.python import keras +from tensorflow.python.compat import v2_compat +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import dtypes +from tensorflow.python.keras.layers.preprocessing import normalization +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import benchmark +from tensorflow.python.platform import test + +FLAGS = flags.FLAGS + +v2_compat.enable_v2_behavior() + + +def reduce_fn(state, values): + """tf.data.Dataset-friendly implementation of mean and variance.""" + k, n, ex, ex2 = state + # If this is the first iteration, we pick the first value to be 'k', + # which helps with precision - we assume that k is close to an average + # value and calculate mean and variance with respect to that. + k = control_flow_ops.cond(math_ops.equal(n, 0), lambda: values[0], lambda: k) + + sum_v = math_ops.reduce_sum(values, axis=0) + sum_v2 = math_ops.reduce_sum(math_ops.square(values), axis=0) + ones = array_ops.ones_like(values, dtype=dtypes.int32) + batch_size = math_ops.reduce_sum(ones, axis=0) + batch_size_f = math_ops.cast(batch_size, dtypes.float32) + + ex = 0 + sum_v - math_ops.multiply(batch_size_f, k) + ex2 = 0 + sum_v2 + math_ops.multiply( + batch_size_f, (math_ops.square(k) - + math_ops.multiply(math_ops.multiply(2.0, k), sum_v))) + + return (k, n + batch_size, ex, ex2) + + +class BenchmarkAdapt(benchmark.Benchmark): + """Benchmark adapt.""" + + def run_dataset_implementation(self, num_elements, batch_size): + input_t = keras.Input(shape=(1,)) + layer = normalization.Normalization() + _ = layer(input_t) + + num_repeats = 5 + starts = [] + ends = [] + for _ in range(num_repeats): + ds = dataset_ops.Dataset.range(num_elements) + ds = ds.map( + lambda x: array_ops.expand_dims(math_ops.cast(x, dtypes.float32), -1)) + ds = ds.batch(batch_size) + + starts.append(time.time()) + # Benchmarked code begins here. + k, n, ex, ex2 = ds.reduce((0.0, 0, 0.0, 0.0), reduce_fn) + mean = k.numpy() + ex.numpy() / n.numpy() + var = (ex2.numpy() - (ex.numpy() * ex.numpy()) / n.numpy()) / ( + n.numpy() - 1) + layer.set_weights([mean, var]) + # Benchmarked code ends here. + ends.append(time.time()) + + avg_time = np.mean(np.array(ends) - np.array(starts)) + return avg_time + + def bm_adapt_implementation(self, num_elements, batch_size): + """Test the KPL adapt implementation.""" + input_t = keras.Input(shape=(1,), dtype=dtypes.float32) + layer = normalization.Normalization() + _ = layer(input_t) + + num_repeats = 5 + starts = [] + ends = [] + for _ in range(num_repeats): + ds = dataset_ops.Dataset.range(num_elements) + ds = ds.map( + lambda x: array_ops.expand_dims(math_ops.cast(x, dtypes.float32), -1)) + ds = ds.batch(batch_size) + + starts.append(time.time()) + # Benchmarked code begins here. + layer.adapt(ds) + # Benchmarked code ends here. + ends.append(time.time()) + + avg_time = np.mean(np.array(ends) - np.array(starts)) + name = "normalization_adapt|%s_elements|batch_%s" % (num_elements, + batch_size) + baseline = self.run_dataset_implementation(num_elements, batch_size) + extras = { + "tf.data implementation baseline": baseline, + "delta seconds": (baseline - avg_time), + "delta percent": ((baseline - avg_time) / baseline) * 100 + } + self.report_benchmark( + iters=num_repeats, wall_time=avg_time, extras=extras, name=name) + + def benchmark_vocab_size_by_batch(self): + for vocab_size in [100, 1000, 10000, 100000, 1000000]: + for batch in [1, 16, 2048]: + self.bm_adapt_implementation(vocab_size, batch) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/keras/layers/preprocessing/discretization.py b/tensorflow/python/keras/layers/preprocessing/discretization.py index 1e80c5621b6..3427a311078 100644 --- a/tensorflow/python/keras/layers/preprocessing/discretization.py +++ b/tensorflow/python/keras/layers/preprocessing/discretization.py @@ -18,6 +18,7 @@ from __future__ import division from __future__ import print_function from tensorflow.python.framework import dtypes +from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_spec from tensorflow.python.keras.engine.base_layer import Layer @@ -77,6 +78,9 @@ class Discretization(Layer): def compute_output_signature(self, input_spec): output_shape = self.compute_output_shape(input_spec.shape.as_list()) output_dtype = dtypes.int64 + if isinstance(input_spec, sparse_tensor.SparseTensorSpec): + return sparse_tensor.SparseTensorSpec( + shape=output_shape, dtype=output_dtype) return tensor_spec.TensorSpec(shape=output_shape, dtype=output_dtype) def call(self, inputs): @@ -87,12 +91,24 @@ class Discretization(Layer): # ragged composite tensor. If this op is the only op a Keras model, # this can cause errors in Graph mode, so wrap the tensor in an identity. integer_buckets = array_ops.identity(integer_buckets) + elif isinstance(inputs, sparse_tensor.SparseTensor): + integer_buckets = math_ops._bucketize( # pylint: disable=protected-access + inputs.values, + boundaries=self.bins) else: integer_buckets = math_ops._bucketize(inputs, boundaries=self.bins) # pylint: disable=protected-access if self.output_mode == INTEGER: + if isinstance(inputs, sparse_tensor.SparseTensor): + return sparse_tensor.SparseTensor( + indices=array_ops.identity(inputs.indices), + values=integer_buckets, + dense_shape=array_ops.identity(inputs.dense_shape)) return integer_buckets else: + if isinstance(inputs, sparse_tensor.SparseTensor): + raise ValueError("`output_mode=binary` is not supported for " + "sparse input") # The 'bins' array is the set of boundaries between the bins. We actually # have 'len(bins)+1' outputs. # TODO(momernick): This will change when we have the ability to adapt(). diff --git a/tensorflow/python/keras/layers/preprocessing/discretization_test.py b/tensorflow/python/keras/layers/preprocessing/discretization_test.py index c847fe73d70..110bccd55e1 100644 --- a/tensorflow/python/keras/layers/preprocessing/discretization_test.py +++ b/tensorflow/python/keras/layers/preprocessing/discretization_test.py @@ -23,6 +23,7 @@ import numpy as np from tensorflow.python import keras from tensorflow.python.framework import dtypes +from tensorflow.python.framework import sparse_tensor from tensorflow.python.keras import keras_parameterized from tensorflow.python.keras.layers.preprocessing import discretization from tensorflow.python.keras.layers.preprocessing import preprocessing_test_utils @@ -87,6 +88,21 @@ class CategoricalEncodingInputTest( output_dataset = model.predict(input_array) self.assertAllEqual(expected_output, output_dataset) + def test_bucketize_with_explicit_buckets_sparse_float_input(self): + indices = [[0, 1], [0, 2], [1, 1]] + input_array = sparse_tensor.SparseTensor( + indices=indices, values=[-1.5, 1.0, 3.4], dense_shape=[2, 3]) + expected_output = [0, 2, 3] + input_data = keras.Input(shape=(3,), dtype=dtypes.float32, sparse=True) + layer = discretization.Discretization( + bins=[-.5, 0.5, 1.5], output_mode=discretization.INTEGER) + bucket_data = layer(input_data) + + model = keras.Model(inputs=input_data, outputs=bucket_data) + output_dataset = model.predict(input_array, steps=1) + self.assertAllEqual(indices, output_dataset.indices) + self.assertAllEqual(expected_output, output_dataset.values) + def test_bucketize_with_explicit_buckets_ragged_float_input(self): input_array = ragged_factory_ops.constant([[-1.5, 1.0, 3.4, .5], [0.0, 3.0, 1.3]]) @@ -121,6 +137,21 @@ class CategoricalEncodingInputTest( output_dataset = model.predict(input_array) self.assertAllEqual(expected_output, output_dataset) + def test_bucketize_with_explicit_buckets_sparse_int_input(self): + indices = [[0, 1], [0, 2], [1, 1]] + input_array = sparse_tensor.SparseTensor( + indices=indices, values=[-1, 1, 3], dense_shape=[2, 3]) + expected_output = [0, 2, 3] + input_data = keras.Input(shape=(3,), dtype=dtypes.int32, sparse=True) + layer = discretization.Discretization( + bins=[-.5, 0.5, 1.5], output_mode=discretization.INTEGER) + bucket_data = layer(input_data) + + model = keras.Model(inputs=input_data, outputs=bucket_data) + output_dataset = model.predict(input_array, steps=1) + self.assertAllEqual(indices, output_dataset.indices) + self.assertAllEqual(expected_output, output_dataset.values) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/keras/layers/preprocessing/index_lookup.py b/tensorflow/python/keras/layers/preprocessing/index_lookup.py index 364b2e3fe25..bc5fab1604d 100644 --- a/tensorflow/python/keras/layers/preprocessing/index_lookup.py +++ b/tensorflow/python/keras/layers/preprocessing/index_lookup.py @@ -341,7 +341,6 @@ class IndexLookup(base_preprocessing_layer.CombinerPreprocessingLayer): start_index = self._reserved_values + (self.vocab_size() if append else 0) values = np.arange(start_index, len(vocab) + start_index, dtype=np.int64) - vocab = self._convert_to_ndarray(vocab, self.dtype) self._assert_same_type(self.dtype, vocab, "vocab") @@ -459,8 +458,11 @@ class _IndexLookupCombiner(base_preprocessing_layer.Combiner): # TODO(momernick): Benchmark improvements to this algorithm. for document in values: - for token in document: - accumulator.count_dict[token] += 1 + if not isinstance(document, list): + accumulator.count_dict[document] += 1 + else: + for token in document: + accumulator.count_dict[token] += 1 return accumulator diff --git a/tensorflow/python/keras/layers/preprocessing/index_lookup_test.py b/tensorflow/python/keras/layers/preprocessing/index_lookup_test.py index de8d5623f5e..d9990ddb037 100644 --- a/tensorflow/python/keras/layers/preprocessing/index_lookup_test.py +++ b/tensorflow/python/keras/layers/preprocessing/index_lookup_test.py @@ -18,7 +18,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import itertools import os +import random +import string from absl.testing import parameterized import numpy as np @@ -30,6 +33,7 @@ from tensorflow.python.data.ops import dataset_ops from tensorflow.python.eager import context from tensorflow.python.framework import dtypes from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor_shape from tensorflow.python.keras import keras_parameterized from tensorflow.python.keras import testing_utils from tensorflow.python.keras.layers.preprocessing import index_lookup @@ -383,6 +387,21 @@ class CategoricalEncodingAdaptTest( output_dataset = model.predict(input_array) self.assertAllEqual(expected_output, output_dataset) + def test_single_string_generator_dataset(self): + + def word_gen(): + for _ in itertools.count(1): + yield "".join(random.choice(string.ascii_letters) for i in range(2)) + + ds = dataset_ops.Dataset.from_generator(word_gen, dtypes.string, + tensor_shape.TensorShape([])) + batched_ds = ds.take(100).batch(1) + input_t = keras.Input(shape=(), dtype=dtypes.string) + layer = get_layer_class()( + max_tokens=10, num_oov_tokens=0, reserve_zero=False) + _ = layer(input_t) + layer.adapt(batched_ds) + @keras_parameterized.run_all_keras_modes class IndexLookupOutputTest(keras_parameterized.TestCase, diff --git a/tensorflow/python/keras/layers/preprocessing/normalization.py b/tensorflow/python/keras/layers/preprocessing/normalization.py index 00ee2adf70d..5a0b8990486 100644 --- a/tensorflow/python/keras/layers/preprocessing/normalization.py +++ b/tensorflow/python/keras/layers/preprocessing/normalization.py @@ -17,7 +17,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import collections import json import numpy as np @@ -53,9 +52,9 @@ class Normalization(CombinerPreprocessingLayer): Attributes: axis: Integer or tuple of integers, the axis or axes that should be normalized (typically the features axis). We will normalize each element - in the specified axis. If set to 'None', the layer will perform - scalar normalization (diving the input by a single scalar value). - 0 (the batch axis) is not allowed. + in the specified axis. If set to 'None', the layer will perform scalar + normalization (diving the input by a single scalar value). 0 (the batch + axis) is not allowed. """ def __init__(self, axis=-1, dtype=None, **kwargs): @@ -132,12 +131,6 @@ class Normalization(CombinerPreprocessingLayer): super(Normalization, self).set_weights(weights) -class _NormalizingAccumulator( - collections.namedtuple('_NormalizingAccumulator', - ['count', 'mean', 'variance'])): - pass - - class _NormalizingCombiner(Combiner): """Combiner for the Normalization preprocessing layer. @@ -148,6 +141,9 @@ class _NormalizingCombiner(Combiner): Attributes: axis: The axis to compute mean and var over. """ + COUNT_IDX = 0 + MEAN_IDX = 1 + VAR_IDX = 2 def __init__(self, axis): self.axis = axis @@ -171,34 +167,62 @@ class _NormalizingCombiner(Combiner): reduction_axes = None else: reduction_axes = tuple(np.delete(range(values.ndim), self.axis)) + mean = np.mean(values, axis=reduction_axes, dtype=np.float64) variance = np.var(values, axis=reduction_axes, dtype=np.float64) # Create an accumulator with our new data and either return it or combine # it with the passed accumulator. - sanitized_accumulator = self._create_accumulator(count, mean, variance) if accumulator is None: - return sanitized_accumulator + return self._create_accumulator(count, mean, variance) else: - return self.merge([accumulator, sanitized_accumulator]) + return self.add_data_to_accumulator(count, mean, variance, accumulator) + + def add_data_to_accumulator(self, count, mean, variance, accumulator): + """Add new data to the totals in an accumulator.""" + # Combine accumulators and return the result. + combined_count = count + accumulator[self.COUNT_IDX] + + # To combine accumulator means, we weight each accumulator's mean by the + # number of elements that were accumulated, and then divide by the + # total number of elements. + combined_mean = (mean * count + accumulator[self.MEAN_IDX] * + accumulator[self.COUNT_IDX]) / combined_count + + # The variance is computed using the lack-of-fit sum of squares + # formula (see https://en.wikipedia.org/wiki/Lack-of-fit_sum_of_squares). + accumulator_var_contribution = accumulator[self.COUNT_IDX] * ( + accumulator[self.VAR_IDX] + + np.square(accumulator[self.MEAN_IDX] - combined_mean)) + data_var_contribution = count * (variance + np.square(mean - combined_mean)) + combined_variance = (accumulator_var_contribution + + data_var_contribution) / combined_count + + accumulator[self.COUNT_IDX] = combined_count + accumulator[self.MEAN_IDX] = np.nan_to_num(combined_mean) + accumulator[self.VAR_IDX] = np.nan_to_num(combined_variance) + return accumulator def merge(self, accumulators): """Merge several accumulators to a single accumulator.""" # Combine accumulators and return the result. - combined_count = np.sum([accumulator.count for accumulator in accumulators]) + combined_count = np.sum( + [accumulator[self.COUNT_IDX] for accumulator in accumulators]) # To combine accumulator means, we weight each accumulator's mean by the # number of elements that were accumulated, and then divide by the # total number of elements. combined_mean = np.add.reduce([ - accumulator.mean * accumulator.count for accumulator in accumulators + accumulator[self.MEAN_IDX] * accumulator[self.COUNT_IDX] + for accumulator in accumulators ]) / combined_count # The variance is computed using the lack-of-fit sum of squares # formula (see https://en.wikipedia.org/wiki/Lack-of-fit_sum_of_squares). def variance_contribution(accumulator): - return accumulator.count * ( - accumulator.variance + np.square(accumulator.mean - combined_mean)) + return accumulator[self.COUNT_IDX] * ( + accumulator[self.VAR_IDX] + + np.square(accumulator[self.MEAN_IDX] - combined_mean)) combined_variance = np.add.reduce([ variance_contribution(accumulator) for accumulator in accumulators @@ -210,9 +234,9 @@ class _NormalizingCombiner(Combiner): def extract(self, accumulator): """Convert an accumulator into a dict of output values.""" return { - _COUNT_NAME: accumulator.count, - _MEAN_NAME: accumulator.mean, - _VARIANCE_NAME: accumulator.variance + _COUNT_NAME: accumulator[self.COUNT_IDX], + _MEAN_NAME: accumulator[1], + _VARIANCE_NAME: accumulator[2] } def restore(self, output): @@ -233,9 +257,9 @@ class _NormalizingCombiner(Combiner): def serialize(self, accumulator): """Serialize an accumulator for a remote call.""" output_dict = { - _COUNT_NAME: accumulator.count.tolist(), - _MEAN_NAME: accumulator.mean.tolist(), - _VARIANCE_NAME: accumulator.variance.tolist() + _COUNT_NAME: accumulator[self.COUNT_IDX].tolist(), + _MEAN_NAME: accumulator[1].tolist(), + _VARIANCE_NAME: accumulator[2].tolist() } return compat.as_bytes(json.dumps(output_dict)) @@ -248,5 +272,4 @@ class _NormalizingCombiner(Combiner): def _create_accumulator(self, count, mean, variance): """Convert any 'nan' values in the given accumulator to numeric values.""" - return _NormalizingAccumulator( - np.array(count), np.nan_to_num(mean), np.nan_to_num(variance)) + return [count, mean, variance] diff --git a/tensorflow/python/keras/layers/recurrent.py b/tensorflow/python/keras/layers/recurrent.py index ec635590e8b..628ecc332c5 100644 --- a/tensorflow/python/keras/layers/recurrent.py +++ b/tensorflow/python/keras/layers/recurrent.py @@ -2340,7 +2340,7 @@ class LSTMCell(DropoutRNNCellMixin, Layer): def bias_initializer(_, *args, **kwargs): return K.concatenate([ self.bias_initializer((self.units,), *args, **kwargs), - initializers.Ones()((self.units,), *args, **kwargs), + initializers.get('ones')((self.units,), *args, **kwargs), self.bias_initializer((self.units * 2,), *args, **kwargs), ]) else: diff --git a/tensorflow/python/keras/layers/serialization.py b/tensorflow/python/keras/layers/serialization.py index afefcc3f040..64bee4d6121 100644 --- a/tensorflow/python/keras/layers/serialization.py +++ b/tensorflow/python/keras/layers/serialization.py @@ -21,51 +21,136 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import threading + from tensorflow.python import tf2 -from tensorflow.python.keras.engine.base_layer import AddLoss -from tensorflow.python.keras.engine.base_layer import AddMetric -from tensorflow.python.keras.engine.base_layer import TensorFlowOpLayer -from tensorflow.python.keras.engine.input_layer import Input -from tensorflow.python.keras.engine.input_layer import InputLayer -from tensorflow.python.keras.layers.advanced_activations import * -from tensorflow.python.keras.layers.convolutional import * -from tensorflow.python.keras.layers.convolutional_recurrent import * -from tensorflow.python.keras.layers.core import * -from tensorflow.python.keras.layers.cudnn_recurrent import * -from tensorflow.python.keras.layers.dense_attention import * -from tensorflow.python.keras.layers.embeddings import * -from tensorflow.python.keras.layers.local import * -from tensorflow.python.keras.layers.merge import * -from tensorflow.python.keras.layers.noise import * -from tensorflow.python.keras.layers.normalization import * -from tensorflow.python.keras.layers.pooling import * -from tensorflow.python.keras.layers.preprocessing.image_preprocessing import * -from tensorflow.python.keras.layers.preprocessing.normalization_v1 import * -from tensorflow.python.keras.layers.recurrent import * -from tensorflow.python.keras.layers.rnn_cell_wrapper_v2 import * -from tensorflow.python.keras.layers.wrappers import * -from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object -from tensorflow.python.keras.utils.generic_utils import serialize_keras_object +from tensorflow.python.keras.engine import base_layer +from tensorflow.python.keras.engine import input_layer +from tensorflow.python.keras.engine import input_spec +from tensorflow.python.keras.layers import advanced_activations +from tensorflow.python.keras.layers import convolutional +from tensorflow.python.keras.layers import convolutional_recurrent +from tensorflow.python.keras.layers import core +from tensorflow.python.keras.layers import cudnn_recurrent +from tensorflow.python.keras.layers import dense_attention +from tensorflow.python.keras.layers import embeddings +from tensorflow.python.keras.layers import local +from tensorflow.python.keras.layers import merge +from tensorflow.python.keras.layers import noise +from tensorflow.python.keras.layers import normalization +from tensorflow.python.keras.layers import normalization_v2 +from tensorflow.python.keras.layers import pooling +from tensorflow.python.keras.layers import recurrent +from tensorflow.python.keras.layers import recurrent_v2 +from tensorflow.python.keras.layers import rnn_cell_wrapper_v2 +from tensorflow.python.keras.layers import wrappers +from tensorflow.python.keras.layers.preprocessing import image_preprocessing +from tensorflow.python.keras.layers.preprocessing import normalization as preprocessing_normalization +from tensorflow.python.keras.layers.preprocessing import normalization_v1 as preprocessing_normalization_v1 +from tensorflow.python.keras.utils import generic_utils +from tensorflow.python.util import tf_inspect as inspect from tensorflow.python.util.tf_export import keras_export -if tf2.enabled(): - from tensorflow.python.keras.layers.normalization_v2 import * # pylint: disable=g-import-not-at-top - from tensorflow.python.keras.layers.recurrent_v2 import * # pylint: disable=g-import-not-at-top - from tensorflow.python.keras.layers.preprocessing.normalization import * # pylint: disable=g-import-not-at-top -# This deserialization table is added for backward compatibility, as in TF 1.13, -# BatchNormalizationV1 and BatchNormalizationV2 are used as class name for v1 -# and v2 version of BatchNormalization, respectively. Here we explicitly convert -# them to the canonical name in the config of deserialization. -_DESERIALIZATION_TABLE = { - 'BatchNormalizationV1': 'BatchNormalization', - 'BatchNormalizationV2': 'BatchNormalization', -} +ALL_MODULES = ( + base_layer, + input_layer, + advanced_activations, + convolutional, + convolutional_recurrent, + core, + cudnn_recurrent, + dense_attention, + embeddings, + local, + merge, + noise, + normalization, + pooling, + image_preprocessing, + preprocessing_normalization_v1, + recurrent, + wrappers +) +ALL_V2_MODULES = ( + rnn_cell_wrapper_v2, + normalization_v2, + recurrent_v2, + preprocessing_normalization +) + +# ALL_OBJECTS is meant to be a global mutable. Hence we need to make it +# thread-local to avoid concurrent mutations. +LOCAL = threading.local() + + +def populate_deserializable_objects(): + """Populates dict ALL_OBJECTS with every built-in layer. + """ + global LOCAL + if not hasattr(LOCAL, 'ALL_OBJECTS'): + LOCAL.ALL_OBJECTS = {} + LOCAL.GENERATED_WITH_V2 = None + + if LOCAL.ALL_OBJECTS and LOCAL.GENERATED_WITH_V2 == tf2.enabled(): + # Objects dict is already generated for the proper TF version: + # do nothing. + return + + LOCAL.ALL_OBJECTS = {} + LOCAL.GENERATED_WITH_V2 = tf2.enabled() + + base_cls = base_layer.Layer + generic_utils.populate_dict_with_module_objects( + LOCAL.ALL_OBJECTS, + ALL_MODULES, + obj_filter=lambda x: inspect.isclass(x) and issubclass(x, base_cls)) + + # Overwrite certain V1 objects with V2 versions + if tf2.enabled(): + generic_utils.populate_dict_with_module_objects( + LOCAL.ALL_OBJECTS, + ALL_V2_MODULES, + obj_filter=lambda x: inspect.isclass(x) and issubclass(x, base_cls)) + + # These deserialization aliases are added for backward compatibility, + # as in TF 1.13, "BatchNormalizationV1" and "BatchNormalizationV2" + # were used as class name for v1 and v2 version of BatchNormalization, + # respectively. Here we explicitly convert them to their canonical names. + LOCAL.ALL_OBJECTS['BatchNormalizationV1'] = normalization.BatchNormalization + LOCAL.ALL_OBJECTS[ + 'BatchNormalizationV2'] = normalization_v2.BatchNormalization + + # Prevent circular dependencies. + from tensorflow.python.keras import models # pylint: disable=g-import-not-at-top + from tensorflow.python.keras.premade.linear import LinearModel # pylint: disable=g-import-not-at-top + from tensorflow.python.keras.premade.wide_deep import WideDeepModel # pylint: disable=g-import-not-at-top + from tensorflow.python.feature_column import dense_features # pylint: disable=g-import-not-at-top + from tensorflow.python.feature_column import sequence_feature_column as sfc # pylint: disable=g-import-not-at-top + + LOCAL.ALL_OBJECTS['Input'] = input_layer.Input + LOCAL.ALL_OBJECTS['InputSpec'] = input_spec.InputSpec + LOCAL.ALL_OBJECTS['Network'] = models.Network + LOCAL.ALL_OBJECTS['Model'] = models.Model + LOCAL.ALL_OBJECTS['Sequential'] = models.Sequential + LOCAL.ALL_OBJECTS['LinearModel'] = LinearModel + LOCAL.ALL_OBJECTS['WideDeepModel'] = WideDeepModel + LOCAL.ALL_OBJECTS['DenseFeatures'] = dense_features.DenseFeatures + LOCAL.ALL_OBJECTS['SequenceFeatures'] = sfc.SequenceFeatures + # Merge layers, function versions. + LOCAL.ALL_OBJECTS['add'] = merge.add + LOCAL.ALL_OBJECTS['subtract'] = merge.subtract + LOCAL.ALL_OBJECTS['multiply'] = merge.multiply + LOCAL.ALL_OBJECTS['average'] = merge.average + LOCAL.ALL_OBJECTS['maximum'] = merge.maximum + LOCAL.ALL_OBJECTS['minimum'] = merge.minimum + LOCAL.ALL_OBJECTS['concatenate'] = merge.concatenate + LOCAL.ALL_OBJECTS['dot'] = merge.dot @keras_export('keras.layers.serialize') def serialize(layer): - return serialize_keras_object(layer) + return generic_utils.serialize_keras_object(layer) @keras_export('keras.layers.deserialize') @@ -80,30 +165,9 @@ def deserialize(config, custom_objects=None): Returns: Layer instance (may be Model, Sequential, Network, Layer...) """ - # Prevent circular dependencies. - from tensorflow.python.keras import models # pylint: disable=g-import-not-at-top - from tensorflow.python.keras.premade.linear import LinearModel # pylint: disable=g-import-not-at-top - from tensorflow.python.keras.premade.wide_deep import WideDeepModel # pylint: disable=g-import-not-at-top - from tensorflow.python.feature_column import dense_features # pylint: disable=g-import-not-at-top - from tensorflow.python.feature_column import sequence_feature_column as sfc # pylint: disable=g-import-not-at-top - - globs = globals() # All layers. - globs['Network'] = models.Network - globs['Model'] = models.Model - globs['Sequential'] = models.Sequential - globs['LinearModel'] = LinearModel - globs['WideDeepModel'] = WideDeepModel - - # Prevent circular dependencies with FeatureColumn serialization. - globs['DenseFeatures'] = dense_features.DenseFeatures - globs['SequenceFeatures'] = sfc.SequenceFeatures - - layer_class_name = config['class_name'] - if layer_class_name in _DESERIALIZATION_TABLE: - config['class_name'] = _DESERIALIZATION_TABLE[layer_class_name] - - return deserialize_keras_object( + populate_deserializable_objects() + return generic_utils.deserialize_keras_object( config, - module_objects=globs, + module_objects=LOCAL.ALL_OBJECTS, custom_objects=custom_objects, printable_module_name='layer') diff --git a/tensorflow/python/keras/layers/serialization_test.py b/tensorflow/python/keras/layers/serialization_test.py index 5c23937ddb4..cd88b072224 100644 --- a/tensorflow/python/keras/layers/serialization_test.py +++ b/tensorflow/python/keras/layers/serialization_test.py @@ -124,12 +124,6 @@ class LayerSerializationTest(parameterized.TestCase, test.TestCase): layer = batchnorm_layer( momentum=0.9, beta_initializer='zeros', gamma_regularizer='l2') config = keras.layers.serialize(layer) - # To simulate if BatchNormalizationV1 or BatchNormalizationV2 appears in the - # saved model. - if batchnorm_layer is batchnorm_v1.BatchNormalization: - config['class_name'] = 'BatchNormalizationV1' - else: - config['class_name'] = 'BatchNormalizationV2' new_layer = keras.layers.deserialize(config) self.assertEqual(new_layer.momentum, 0.9) if tf2.enabled(): diff --git a/tensorflow/python/keras/layers/wrappers.py b/tensorflow/python/keras/layers/wrappers.py index 97b51501b18..20f87538e9c 100644 --- a/tensorflow/python/keras/layers/wrappers.py +++ b/tensorflow/python/keras/layers/wrappers.py @@ -67,12 +67,7 @@ class Wrapper(Layer): return None def get_config(self): - config = { - 'layer': { - 'class_name': self.layer.__class__.__name__, - 'config': self.layer.get_config() - } - } + config = {'layer': generic_utils.serialize_keras_object(self.layer)} base_config = super(Wrapper, self).get_config() return dict(list(base_config.items()) + list(config.items())) @@ -80,7 +75,7 @@ class Wrapper(Layer): def from_config(cls, config, custom_objects=None): from tensorflow.python.keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top # Avoid mutating the input dict - config = config.copy() + config = copy.deepcopy(config) layer = deserialize_layer( config.pop('layer'), custom_objects=custom_objects) return cls(layer, **config) @@ -426,7 +421,8 @@ class Bidirectional(Wrapper): # Keep the custom backward layer config, so that we can save it later. The # layer's name might be updated below with prefix 'backward_', and we want # to preserve the original config. - self._backward_layer_config = backward_layer.get_config() + self._backward_layer_config = generic_utils.serialize_keras_object( + backward_layer) self.forward_layer._name = 'forward_' + self.forward_layer.name self.backward_layer._name = 'backward_' + self.backward_layer.name @@ -720,26 +716,26 @@ class Bidirectional(Wrapper): config['num_constants'] = self._num_constants if hasattr(self, '_backward_layer_config'): - config['backward_layer'] = { - 'class_name': self.backward_layer.__class__.__name__, - 'config': self._backward_layer_config, - } + config['backward_layer'] = self._backward_layer_config base_config = super(Bidirectional, self).get_config() return dict(list(base_config.items()) + list(config.items())) @classmethod def from_config(cls, config, custom_objects=None): # Instead of updating the input, create a copy and use that. - config = config.copy() + config = copy.deepcopy(config) num_constants = config.pop('num_constants', 0) + # Handle forward layer instantiation (as would parent class). + from tensorflow.python.keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top + config['layer'] = deserialize_layer( + config['layer'], custom_objects=custom_objects) + # Handle (optional) backward layer instantiation. backward_layer_config = config.pop('backward_layer', None) if backward_layer_config is not None: - from tensorflow.python.keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top backward_layer = deserialize_layer( backward_layer_config, custom_objects=custom_objects) config['backward_layer'] = backward_layer - - layer = super(Bidirectional, cls).from_config(config, - custom_objects=custom_objects) + # Instantiate the wrapper, adjust it and return it. + layer = cls(**config) layer._num_constants = num_constants return layer diff --git a/tensorflow/python/keras/layers/wrappers_test.py b/tensorflow/python/keras/layers/wrappers_test.py index 8765cf2dc25..bba538f0f09 100644 --- a/tensorflow/python/keras/layers/wrappers_test.py +++ b/tensorflow/python/keras/layers/wrappers_test.py @@ -900,6 +900,12 @@ class BidirectionalTest(test.TestCase, parameterized.TestCase): [None, 2, 16]) def test_Bidirectional_last_output_with_masking(self): + if test.is_built_with_rocm(): + # testcase uses input and/or output sequences which require padding + # leading to the following error on ROCm platform + # ROCm MIOpen only supports packed input output + # Skip this subtest for now + self.skipTest('Test not supported on the ROCm platform') rnn = keras.layers.LSTM samples = 2 dim = 5 @@ -927,6 +933,12 @@ class BidirectionalTest(test.TestCase, parameterized.TestCase): @parameterized.parameters([keras.layers.LSTM, keras.layers.GRU]) def test_Bidirectional_sequence_output_with_masking(self, rnn): + if test.is_built_with_rocm(): + # testcase uses input and/or output sequences which require padding + # leading to the following error on ROCm platform + # ROCm MIOpen only supports packed input output + # Skip this subtest for now + self.skipTest('Test not supported on the ROCm platform') samples = 2 dim = 5 timesteps = 3 @@ -1128,6 +1140,9 @@ class BidirectionalTest(test.TestCase, parameterized.TestCase): @parameterized.parameters(['ave', 'concat', 'mul']) def test_Bidirectional_ragged_input(self, merge_mode): + if test.is_built_with_rocm(): + # ragged tenors are not supported in ROCM RNN implementation + self.skipTest('Test not supported on the ROCm platform') np.random.seed(100) rnn = keras.layers.LSTM units = 3 diff --git a/tensorflow/python/keras/legacy_tf_layers/BUILD b/tensorflow/python/keras/legacy_tf_layers/BUILD index 49907d4328a..3ce0d3c6bec 100644 --- a/tensorflow/python/keras/legacy_tf_layers/BUILD +++ b/tensorflow/python/keras/legacy_tf_layers/BUILD @@ -159,6 +159,7 @@ tf_py_test( srcs = ["pooling_test.py"], main = "pooling_test.py", python_version = "PY3", + tags = ["no_rocm"], deps = [ ":pooling", "//tensorflow/python:array_ops", diff --git a/tensorflow/python/keras/losses.py b/tensorflow/python/keras/losses.py index 5e9f49faf31..9db3435fa50 100644 --- a/tensorflow/python/keras/losses.py +++ b/tensorflow/python/keras/losses.py @@ -1014,7 +1014,7 @@ class LogCosh(LossFunctionWrapper): ``` """ - def __init__(self, reduction=losses_utils.ReductionV2.AUTO, name='logcosh'): + def __init__(self, reduction=losses_utils.ReductionV2.AUTO, name='log_cosh'): """Initializes `LogCosh` instance. Args: @@ -1027,9 +1027,9 @@ class LogCosh(LossFunctionWrapper): will raise an error. Please see this custom training [tutorial] (https://www.tensorflow.org/tutorials/distribute/custom_training) for more details. - name: Optional name for the op. Defaults to 'logcosh'. + name: Optional name for the op. Defaults to 'log_cosh'. """ - super(LogCosh, self).__init__(logcosh, name=name, reduction=reduction) + super(LogCosh, self).__init__(log_cosh, name=name, reduction=reduction) @keras_export('keras.losses.KLDivergence') @@ -1075,7 +1075,7 @@ class KLDivergence(LossFunctionWrapper): def __init__(self, reduction=losses_utils.ReductionV2.AUTO, - name='kullback_leibler_divergence'): + name='kl_divergence'): """Initializes `KLDivergence` instance. Args: @@ -1088,10 +1088,10 @@ class KLDivergence(LossFunctionWrapper): will raise an error. Please see this custom training [tutorial] (https://www.tensorflow.org/tutorials/distribute/custom_training) for more details. - name: Optional name for the op. Defaults to 'kullback_leibler_divergence'. + name: Optional name for the op. Defaults to 'kl_divergence'. """ super(KLDivergence, self).__init__( - kullback_leibler_divergence, name=name, reduction=reduction) + kl_divergence, name=name, reduction=reduction) @keras_export('keras.losses.Huber') @@ -1160,7 +1160,7 @@ class Huber(LossFunctionWrapper): name: Optional name for the op. Defaults to 'huber_loss'. """ super(Huber, self).__init__( - huber_loss, name=name, reduction=reduction, delta=delta) + huber, name=name, reduction=reduction, delta=delta) @keras_export('keras.metrics.mean_squared_error', @@ -1401,8 +1401,7 @@ def categorical_hinge(y_true, y_pred): >>> assert np.array_equal(loss.numpy(), np.maximum(0., neg - pos + 1.)) Args: - y_true: The ground truth values. `y_true` values are expected to be -1 or 1. - If binary (0 or 1) labels are provided they will be converted to -1 or 1. + y_true: The ground truth values. `y_true` values are expected to be 0 or 1. y_pred: The predicted values. Returns: @@ -1415,7 +1414,8 @@ def categorical_hinge(y_true, y_pred): return math_ops.maximum(0., neg - pos + 1.) -def huber_loss(y_true, y_pred, delta=1.0): +@keras_export('keras.losses.huber', v1=[]) +def huber(y_true, y_pred, delta=1.0): """Computes Huber loss value. For each value x in `error = y_true - y_pred`: @@ -1450,8 +1450,8 @@ def huber_loss(y_true, y_pred, delta=1.0): axis=-1) -@keras_export('keras.losses.logcosh') -def logcosh(y_true, y_pred): +@keras_export('keras.losses.log_cosh', 'keras.losses.logcosh') +def log_cosh(y_true, y_pred): """Logarithm of the hyperbolic cosine of the prediction error. `log(cosh(x))` is approximately equal to `(x ** 2) / 2` for small `x` and @@ -1595,13 +1595,15 @@ def binary_crossentropy(y_true, y_pred, from_logits=False, label_smoothing=0): K.binary_crossentropy(y_true, y_pred, from_logits=from_logits), axis=-1) -@keras_export('keras.metrics.kullback_leibler_divergence', +@keras_export('keras.metrics.kl_divergence', + 'keras.metrics.kullback_leibler_divergence', 'keras.metrics.kld', 'keras.metrics.KLD', + 'keras.losses.kl_divergence', 'keras.losses.kullback_leibler_divergence', 'keras.losses.kld', 'keras.losses.KLD') -def kullback_leibler_divergence(y_true, y_pred): +def kl_divergence(y_true, y_pred): """Computes Kullback-Leibler divergence loss between `y_true` and `y_pred`. `loss = y_true * log(y_true / y_pred)` @@ -1796,7 +1798,9 @@ mse = MSE = mean_squared_error mae = MAE = mean_absolute_error mape = MAPE = mean_absolute_percentage_error msle = MSLE = mean_squared_logarithmic_error -kld = KLD = kullback_leibler_divergence +kld = KLD = kullback_leibler_divergence = kl_divergence +logcosh = log_cosh +huber_loss = huber def is_categorical_crossentropy(loss): diff --git a/tensorflow/python/keras/losses_test.py b/tensorflow/python/keras/losses_test.py index 855a1ed41a3..119cc5db87d 100644 --- a/tensorflow/python/keras/losses_test.py +++ b/tensorflow/python/keras/losses_test.py @@ -36,8 +36,8 @@ ALL_LOSSES = [ losses.mean_absolute_percentage_error, losses.mean_squared_logarithmic_error, losses.squared_hinge, losses.hinge, losses.categorical_crossentropy, losses.binary_crossentropy, - losses.kullback_leibler_divergence, losses.poisson, - losses.cosine_similarity, losses.logcosh, losses.categorical_hinge + losses.kl_divergence, losses.poisson, + losses.cosine_similarity, losses.log_cosh, losses.categorical_hinge ] diff --git a/tensorflow/python/keras/metrics.py b/tensorflow/python/keras/metrics.py index 4333ff784f8..5cbd59c49cf 100644 --- a/tensorflow/python/keras/metrics.py +++ b/tensorflow/python/keras/metrics.py @@ -13,6 +13,7 @@ # limitations under the License. # ============================================================================== # pylint: disable=unused-import +# pylint: disable=g-classes-have-attributes """Built-in metrics. """ from __future__ import absolute_import @@ -77,6 +78,11 @@ from tensorflow.tools.docs import doc_controls class Metric(base_layer.Layer): """Encapsulates metric logic and state. + Args: + name: (Optional) string name of the metric instance. + dtype: (Optional) data type of the metric result. + **kwargs: Additional layer keywords arguments. + Usage: ```python @@ -291,16 +297,15 @@ class Metric(base_layer.Layer): class Reduce(Metric): - """Encapsulates metrics that perform a reduce operation on the values.""" + """Encapsulates metrics that perform a reduce operation on the values. + + Args: + reduction: a `tf.keras.metrics.Reduction` enum value. + name: string name of the metric instance. + dtype: (Optional) data type of the metric result. + """ def __init__(self, reduction, name, dtype=None): - """Creates a `Reduce` instance. - - Args: - reduction: a `tf.keras.metrics.Reduction` enum value. - name: string name of the metric instance. - dtype: (Optional) data type of the metric result. - """ super(Reduce, self).__init__(name=name, dtype=dtype) self.reduction = reduction with ops.init_scope(): @@ -312,11 +317,7 @@ class Reduce(Metric): 'count', initializer=init_ops.zeros_initializer) def update_state(self, values, sample_weight=None): - """Accumulates statistics for computing the reduction metric. - - For example, if `values` is [1, 3, 5, 7] and reduction=SUM_OVER_BATCH_SIZE, - then the value of `result()` is 4. If the `sample_weight` is specified as - [1, 1, 0, 0] then value of `result()` would be 2. + """Accumulates statistics for computing the metric. Args: values: Per-example value. @@ -399,6 +400,10 @@ class Sum(Reduce): If `sample_weight` is `None`, weights default to 1. Use `sample_weight` of 0 to mask values. + Args: + name: (Optional) string name of the metric instance. + dtype: (Optional) data type of the metric result. + Usage: >>> m = tf.keras.metrics.Sum() @@ -416,12 +421,6 @@ class Sum(Reduce): """ def __init__(self, name='sum', dtype=None): - """Creates a `Sum` instance. - - Args: - name: (Optional) string name of the metric instance. - dtype: (Optional) data type of the metric result. - """ super(Sum, self).__init__(reduction=metrics_utils.Reduction.SUM, name=name, dtype=dtype) @@ -440,6 +439,10 @@ class Mean(Reduce): If `sample_weight` is `None`, weights default to 1. Use `sample_weight` of 0 to mask values. + Args: + name: (Optional) string name of the metric instance. + dtype: (Optional) data type of the metric result. + Usage: >>> m = tf.keras.metrics.Mean() @@ -461,12 +464,6 @@ class Mean(Reduce): """ def __init__(self, name='mean', dtype=None): - """Creates a `Mean` instance. - - Args: - name: (Optional) string name of the metric instance. - dtype: (Optional) data type of the metric result. - """ super(Mean, self).__init__( reduction=metrics_utils.Reduction.WEIGHTED_MEAN, name=name, dtype=dtype) @@ -483,6 +480,11 @@ class MeanRelativeError(Mean): If `sample_weight` is `None`, weights default to 1. Use `sample_weight` of 0 to mask values. + Args: + normalizer: The normalizer values with same shape as predictions. + name: (Optional) string name of the metric instance. + dtype: (Optional) data type of the metric result. + Usage: >>> m = tf.keras.metrics.MeanRelativeError(normalizer=[1, 3, 2, 3]) @@ -506,13 +508,6 @@ class MeanRelativeError(Mean): """ def __init__(self, normalizer, name=None, dtype=None): - """Creates a `MeanRelativeError` instance. - - Args: - normalizer: The normalizer values with same shape as predictions. - name: (Optional) string name of the metric instance. - dtype: (Optional) data type of the metric result. - """ super(MeanRelativeError, self).__init__(name=name, dtype=dtype) normalizer = math_ops.cast(normalizer, self._dtype) self.normalizer = normalizer @@ -555,18 +550,17 @@ class MeanRelativeError(Mean): class MeanMetricWrapper(Mean): - """Wraps a stateless metric function with the Mean metric.""" + """Wraps a stateless metric function with the Mean metric. + + Args: + fn: The metric function to wrap, with signature `fn(y_true, y_pred, + **kwargs)`. + name: (Optional) string name of the metric instance. + dtype: (Optional) data type of the metric result. + **kwargs: The keyword arguments that are passed on to `fn`. + """ def __init__(self, fn, name=None, dtype=None, **kwargs): - """Creates a `MeanMetricWrapper` instance. - - Args: - fn: The metric function to wrap, with signature - `fn(y_true, y_pred, **kwargs)`. - name: (Optional) string name of the metric instance. - dtype: (Optional) data type of the metric result. - **kwargs: The keyword arguments that are passed on to `fn`. - """ super(MeanMetricWrapper, self).__init__(name=name, dtype=dtype) self._fn = fn self._fn_kwargs = kwargs @@ -640,6 +634,10 @@ class Accuracy(MeanMetricWrapper): If `sample_weight` is `None`, weights default to 1. Use `sample_weight` of 0 to mask values. + Args: + name: (Optional) string name of the metric instance. + dtype: (Optional) data type of the metric result. + Usage: >>> m = tf.keras.metrics.Accuracy() @@ -677,6 +675,12 @@ class BinaryAccuracy(MeanMetricWrapper): If `sample_weight` is `None`, weights default to 1. Use `sample_weight` of 0 to mask values. + Args: + name: (Optional) string name of the metric instance. + dtype: (Optional) data type of the metric result. + threshold: (Optional) Float representing the threshold for deciding + whether prediction values are 1 or 0. + Usage: >>> m = tf.keras.metrics.BinaryAccuracy() @@ -699,14 +703,6 @@ class BinaryAccuracy(MeanMetricWrapper): """ def __init__(self, name='binary_accuracy', dtype=None, threshold=0.5): - """Creates a `BinaryAccuracy` instance. - - Args: - name: (Optional) string name of the metric instance. - dtype: (Optional) data type of the metric result. - threshold: (Optional) Float representing the threshold for deciding - whether prediction values are 1 or 0. - """ super(BinaryAccuracy, self).__init__( binary_accuracy, name, dtype=dtype, threshold=threshold) @@ -729,6 +725,10 @@ class CategoricalAccuracy(MeanMetricWrapper): If `sample_weight` is `None`, weights default to 1. Use `sample_weight` of 0 to mask values. + Args: + name: (Optional) string name of the metric instance. + dtype: (Optional) data type of the metric result. + Usage: >>> m = tf.keras.metrics.CategoricalAccuracy() @@ -756,12 +756,6 @@ class CategoricalAccuracy(MeanMetricWrapper): """ def __init__(self, name='categorical_accuracy', dtype=None): - """Creates a `CategoricalAccuracy` instance. - - Args: - name: (Optional) string name of the metric instance. - dtype: (Optional) data type of the metric result. - """ super(CategoricalAccuracy, self).__init__( categorical_accuracy, name, dtype=dtype) @@ -773,7 +767,7 @@ class SparseCategoricalAccuracy(MeanMetricWrapper): ```python acc = np.dot(sample_weight, np.equal(y_true, np.argmax(y_pred, axis=1)) ``` - + You can provide logits of classes as `y_pred`, since argmax of logits and probabilities are same. @@ -785,6 +779,10 @@ class SparseCategoricalAccuracy(MeanMetricWrapper): If `sample_weight` is `None`, weights default to 1. Use `sample_weight` of 0 to mask values. + Args: + name: (Optional) string name of the metric instance. + dtype: (Optional) data type of the metric result. + Usage: >>> m = tf.keras.metrics.SparseCategoricalAccuracy() @@ -818,6 +816,12 @@ class SparseCategoricalAccuracy(MeanMetricWrapper): class TopKCategoricalAccuracy(MeanMetricWrapper): """Computes how often targets are in the top `K` predictions. + Args: + k: (Optional) Number of top elements to look at for computing accuracy. + Defaults to 5. + name: (Optional) string name of the metric instance. + dtype: (Optional) data type of the metric result. + Usage: >>> m = tf.keras.metrics.TopKCategoricalAccuracy(k=1) @@ -842,14 +846,6 @@ class TopKCategoricalAccuracy(MeanMetricWrapper): """ def __init__(self, k=5, name='top_k_categorical_accuracy', dtype=None): - """Creates a `TopKCategoricalAccuracy` instance. - - Args: - k: (Optional) Number of top elements to look at for computing accuracy. - Defaults to 5. - name: (Optional) string name of the metric instance. - dtype: (Optional) data type of the metric result. - """ super(TopKCategoricalAccuracy, self).__init__( top_k_categorical_accuracy, name, dtype=dtype, k=k) @@ -858,6 +854,12 @@ class TopKCategoricalAccuracy(MeanMetricWrapper): class SparseTopKCategoricalAccuracy(MeanMetricWrapper): """Computes how often integer targets are in the top `K` predictions. + Args: + k: (Optional) Number of top elements to look at for computing accuracy. + Defaults to 5. + name: (Optional) string name of the metric instance. + dtype: (Optional) data type of the metric result. + Usage: >>> m = tf.keras.metrics.SparseTopKCategoricalAccuracy(k=1) @@ -882,38 +884,29 @@ class SparseTopKCategoricalAccuracy(MeanMetricWrapper): """ def __init__(self, k=5, name='sparse_top_k_categorical_accuracy', dtype=None): - """Creates a `SparseTopKCategoricalAccuracy` instance. - - Args: - k: (Optional) Number of top elements to look at for computing accuracy. - Defaults to 5. - name: (Optional) string name of the metric instance. - dtype: (Optional) data type of the metric result. - """ super(SparseTopKCategoricalAccuracy, self).__init__( sparse_top_k_categorical_accuracy, name, dtype=dtype, k=k) class _ConfusionMatrixConditionCount(Metric): - """Calculates the number of the given confusion matrix condition.""" + """Calculates the number of the given confusion matrix condition. + + Args: + confusion_matrix_cond: One of `metrics_utils.ConfusionMatrix` conditions. + thresholds: (Optional) Defaults to 0.5. A float value or a python list/tuple + of float threshold values in [0, 1]. A threshold is compared with + prediction values to determine the truth value of predictions (i.e., above + the threshold is `true`, below is `false`). One metric value is generated + for each threshold value. + name: (Optional) string name of the metric instance. + dtype: (Optional) data type of the metric result. + """ def __init__(self, confusion_matrix_cond, thresholds=None, name=None, dtype=None): - """Creates a `_ConfusionMatrixConditionCount` instance. - - Args: - confusion_matrix_cond: One of `metrics_utils.ConfusionMatrix` conditions. - thresholds: (Optional) Defaults to 0.5. A float value or a python - list/tuple of float threshold values in [0, 1]. A threshold is compared - with prediction values to determine the truth value of predictions - (i.e., above the threshold is `true`, below is `false`). One metric - value is generated for each threshold value. - name: (Optional) string name of the metric instance. - dtype: (Optional) data type of the metric result. - """ super(_ConfusionMatrixConditionCount, self).__init__(name=name, dtype=dtype) self._confusion_matrix_cond = confusion_matrix_cond self.init_thresholds = thresholds @@ -925,7 +918,7 @@ class _ConfusionMatrixConditionCount(Metric): initializer=init_ops.zeros_initializer) def update_state(self, y_true, y_pred, sample_weight=None): - """Accumulates the given confusion matrix condition statistics. + """Accumulates the metric statistics. Args: y_true: The ground truth values. @@ -973,6 +966,15 @@ class FalsePositives(_ConfusionMatrixConditionCount): If `sample_weight` is `None`, weights default to 1. Use `sample_weight` of 0 to mask values. + Args: + thresholds: (Optional) Defaults to 0.5. A float value or a python + list/tuple of float threshold values in [0, 1]. A threshold is compared + with prediction values to determine the truth value of predictions + (i.e., above the threshold is `true`, below is `false`). One metric + value is generated for each threshold value. + name: (Optional) string name of the metric instance. + dtype: (Optional) data type of the metric result. + Usage: >>> m = tf.keras.metrics.FalsePositives() @@ -994,17 +996,6 @@ class FalsePositives(_ConfusionMatrixConditionCount): """ def __init__(self, thresholds=None, name=None, dtype=None): - """Creates a `FalsePositives` instance. - - Args: - thresholds: (Optional) Defaults to 0.5. A float value or a python - list/tuple of float threshold values in [0, 1]. A threshold is compared - with prediction values to determine the truth value of predictions - (i.e., above the threshold is `true`, below is `false`). One metric - value is generated for each threshold value. - name: (Optional) string name of the metric instance. - dtype: (Optional) data type of the metric result. - """ super(FalsePositives, self).__init__( confusion_matrix_cond=metrics_utils.ConfusionMatrix.FALSE_POSITIVES, thresholds=thresholds, @@ -1023,6 +1014,15 @@ class FalseNegatives(_ConfusionMatrixConditionCount): If `sample_weight` is `None`, weights default to 1. Use `sample_weight` of 0 to mask values. + Args: + thresholds: (Optional) Defaults to 0.5. A float value or a python + list/tuple of float threshold values in [0, 1]. A threshold is compared + with prediction values to determine the truth value of predictions + (i.e., above the threshold is `true`, below is `false`). One metric + value is generated for each threshold value. + name: (Optional) string name of the metric instance. + dtype: (Optional) data type of the metric result. + Usage: >>> m = tf.keras.metrics.FalseNegatives() @@ -1044,17 +1044,6 @@ class FalseNegatives(_ConfusionMatrixConditionCount): """ def __init__(self, thresholds=None, name=None, dtype=None): - """Creates a `FalseNegatives` instance. - - Args: - thresholds: (Optional) Defaults to 0.5. A float value or a python - list/tuple of float threshold values in [0, 1]. A threshold is compared - with prediction values to determine the truth value of predictions - (i.e., above the threshold is `true`, below is `false`). One metric - value is generated for each threshold value. - name: (Optional) string name of the metric instance. - dtype: (Optional) data type of the metric result. - """ super(FalseNegatives, self).__init__( confusion_matrix_cond=metrics_utils.ConfusionMatrix.FALSE_NEGATIVES, thresholds=thresholds, @@ -1073,6 +1062,15 @@ class TrueNegatives(_ConfusionMatrixConditionCount): If `sample_weight` is `None`, weights default to 1. Use `sample_weight` of 0 to mask values. + Args: + thresholds: (Optional) Defaults to 0.5. A float value or a python + list/tuple of float threshold values in [0, 1]. A threshold is compared + with prediction values to determine the truth value of predictions + (i.e., above the threshold is `true`, below is `false`). One metric + value is generated for each threshold value. + name: (Optional) string name of the metric instance. + dtype: (Optional) data type of the metric result. + Usage: >>> m = tf.keras.metrics.TrueNegatives() @@ -1094,17 +1092,6 @@ class TrueNegatives(_ConfusionMatrixConditionCount): """ def __init__(self, thresholds=None, name=None, dtype=None): - """Creates a `TrueNegatives` instance. - - Args: - thresholds: (Optional) Defaults to 0.5. A float value or a python - list/tuple of float threshold values in [0, 1]. A threshold is compared - with prediction values to determine the truth value of predictions - (i.e., above the threshold is `true`, below is `false`). One metric - value is generated for each threshold value. - name: (Optional) string name of the metric instance. - dtype: (Optional) data type of the metric result. - """ super(TrueNegatives, self).__init__( confusion_matrix_cond=metrics_utils.ConfusionMatrix.TRUE_NEGATIVES, thresholds=thresholds, @@ -1123,6 +1110,15 @@ class TruePositives(_ConfusionMatrixConditionCount): If `sample_weight` is `None`, weights default to 1. Use `sample_weight` of 0 to mask values. + Args: + thresholds: (Optional) Defaults to 0.5. A float value or a python + list/tuple of float threshold values in [0, 1]. A threshold is compared + with prediction values to determine the truth value of predictions + (i.e., above the threshold is `true`, below is `false`). One metric + value is generated for each threshold value. + name: (Optional) string name of the metric instance. + dtype: (Optional) data type of the metric result. + Usage: >>> m = tf.keras.metrics.TruePositives() @@ -1144,17 +1140,6 @@ class TruePositives(_ConfusionMatrixConditionCount): """ def __init__(self, thresholds=None, name=None, dtype=None): - """Creates a `TruePositives` instance. - - Args: - thresholds: (Optional) Defaults to 0.5. A float value or a python - list/tuple of float threshold values in [0, 1]. A threshold is compared - with prediction values to determine the truth value of predictions - (i.e., above the threshold is `true`, below is `false`). One metric - value is generated for each threshold value. - name: (Optional) string name of the metric instance. - dtype: (Optional) data type of the metric result. - """ super(TruePositives, self).__init__( confusion_matrix_cond=metrics_utils.ConfusionMatrix.TRUE_POSITIVES, thresholds=thresholds, @@ -1183,6 +1168,21 @@ class Precision(Metric): top-k highest predictions, and computing the fraction of them for which `class_id` is indeed a correct label. + Args: + thresholds: (Optional) A float value or a python list/tuple of float + threshold values in [0, 1]. A threshold is compared with prediction + values to determine the truth value of predictions (i.e., above the + threshold is `true`, below is `false`). One metric value is generated + for each threshold value. If neither thresholds nor top_k are set, the + default is to calculate precision with `thresholds=0.5`. + top_k: (Optional) Unset by default. An int value specifying the top-k + predictions to consider when calculating precision. + class_id: (Optional) Integer class ID for which we want binary metrics. + This must be in the half-open interval `[0, num_classes)`, where + `num_classes` is the last dimension of predictions. + name: (Optional) string name of the metric instance. + dtype: (Optional) data type of the metric result. + Usage: >>> m = tf.keras.metrics.Precision() @@ -1221,23 +1221,6 @@ class Precision(Metric): class_id=None, name=None, dtype=None): - """Creates a `Precision` instance. - - Args: - thresholds: (Optional) A float value or a python list/tuple of float - threshold values in [0, 1]. A threshold is compared with prediction - values to determine the truth value of predictions (i.e., above the - threshold is `true`, below is `false`). One metric value is generated - for each threshold value. If neither thresholds nor top_k are set, the - default is to calculate precision with `thresholds=0.5`. - top_k: (Optional) Unset by default. An int value specifying the top-k - predictions to consider when calculating precision. - class_id: (Optional) Integer class ID for which we want binary metrics. - This must be in the half-open interval `[0, num_classes)`, where - `num_classes` is the last dimension of predictions. - name: (Optional) string name of the metric instance. - dtype: (Optional) data type of the metric result. - """ super(Precision, self).__init__(name=name, dtype=dtype) self.init_thresholds = thresholds self.top_k = top_k @@ -1321,6 +1304,21 @@ class Recall(Metric): fraction of them for which `class_id` is above the threshold and/or in the top-k predictions. + Args: + thresholds: (Optional) A float value or a python list/tuple of float + threshold values in [0, 1]. A threshold is compared with prediction + values to determine the truth value of predictions (i.e., above the + threshold is `true`, below is `false`). One metric value is generated + for each threshold value. If neither thresholds nor top_k are set, the + default is to calculate recall with `thresholds=0.5`. + top_k: (Optional) Unset by default. An int value specifying the top-k + predictions to consider when calculating recall. + class_id: (Optional) Integer class ID for which we want binary metrics. + This must be in the half-open interval `[0, num_classes)`, where + `num_classes` is the last dimension of predictions. + name: (Optional) string name of the metric instance. + dtype: (Optional) data type of the metric result. + Usage: >>> m = tf.keras.metrics.Recall() @@ -1347,23 +1345,6 @@ class Recall(Metric): class_id=None, name=None, dtype=None): - """Creates a `Recall` instance. - - Args: - thresholds: (Optional) A float value or a python list/tuple of float - threshold values in [0, 1]. A threshold is compared with prediction - values to determine the truth value of predictions (i.e., above the - threshold is `true`, below is `false`). One metric value is generated - for each threshold value. If neither thresholds nor top_k are set, the - default is to calculate recall with `thresholds=0.5`. - top_k: (Optional) Unset by default. An int value specifying the top-k - predictions to consider when calculating recall. - class_id: (Optional) Integer class ID for which we want binary metrics. - This must be in the half-open interval `[0, num_classes)`, where - `num_classes` is the last dimension of predictions. - name: (Optional) string name of the metric instance. - dtype: (Optional) data type of the metric result. - """ super(Recall, self).__init__(name=name, dtype=dtype) self.init_thresholds = thresholds self.top_k = top_k @@ -1541,6 +1522,13 @@ class SensitivityAtSpecificity(SensitivitySpecificityBase): For additional information about specificity and sensitivity, see the following: https://en.wikipedia.org/wiki/Sensitivity_and_specificity + Args: + specificity: A scalar value in range `[0, 1]`. + num_thresholds: (Optional) Defaults to 200. The number of thresholds to + use for matching the given specificity. + name: (Optional) string name of the metric instance. + dtype: (Optional) data type of the metric result. + Usage: >>> m = tf.keras.metrics.SensitivityAtSpecificity(0.5) @@ -1566,15 +1554,6 @@ class SensitivityAtSpecificity(SensitivitySpecificityBase): """ def __init__(self, specificity, num_thresholds=200, name=None, dtype=None): - """Creates a `SensitivityAtSpecificity` instance. - - Args: - specificity: A scalar value in range `[0, 1]`. - num_thresholds: (Optional) Defaults to 200. The number of thresholds to - use for matching the given specificity. - name: (Optional) string name of the metric instance. - dtype: (Optional) data type of the metric result. - """ if specificity < 0 or specificity > 1: raise ValueError('`specificity` must be in the range [0, 1].') self.specificity = specificity @@ -1619,6 +1598,13 @@ class SpecificityAtSensitivity(SensitivitySpecificityBase): For additional information about specificity and sensitivity, see the following: https://en.wikipedia.org/wiki/Sensitivity_and_specificity + Args: + sensitivity: A scalar value in range `[0, 1]`. + num_thresholds: (Optional) Defaults to 200. The number of thresholds to + use for matching the given sensitivity. + name: (Optional) string name of the metric instance. + dtype: (Optional) data type of the metric result. + Usage: >>> m = tf.keras.metrics.SpecificityAtSensitivity(0.5) @@ -1644,15 +1630,6 @@ class SpecificityAtSensitivity(SensitivitySpecificityBase): """ def __init__(self, sensitivity, num_thresholds=200, name=None, dtype=None): - """Creates a `SpecificityAtSensitivity` instance. - - Args: - sensitivity: A scalar value in range `[0, 1]`. - num_thresholds: (Optional) Defaults to 200. The number of thresholds to - use for matching the given sensitivity. - name: (Optional) string name of the metric instance. - dtype: (Optional) data type of the metric result. - """ if sensitivity < 0 or sensitivity > 1: raise ValueError('`sensitivity` must be in the range [0, 1].') self.sensitivity = sensitivity @@ -1689,6 +1666,13 @@ class PrecisionAtRecall(SensitivitySpecificityBase): If `sample_weight` is `None`, weights default to 1. Use `sample_weight` of 0 to mask values. + Args: + recall: A scalar value in range `[0, 1]`. + num_thresholds: (Optional) Defaults to 200. The number of thresholds to + use for matching the given recall. + name: (Optional) string name of the metric instance. + dtype: (Optional) data type of the metric result. + Usage: >>> m = tf.keras.metrics.PrecisionAtRecall(0.5) @@ -1714,15 +1698,6 @@ class PrecisionAtRecall(SensitivitySpecificityBase): """ def __init__(self, recall, num_thresholds=200, name=None, dtype=None): - """Creates a `PrecisionAtRecall` instance. - - Args: - recall: A scalar value in range `[0, 1]`. - num_thresholds: (Optional) Defaults to 200. The number of thresholds to - use for matching the given recall. - name: (Optional) string name of the metric instance. - dtype: (Optional) data type of the metric result. - """ if recall < 0 or recall > 1: raise ValueError('`recall` must be in the range [0, 1].') self.recall = recall @@ -1762,6 +1737,13 @@ class RecallAtPrecision(SensitivitySpecificityBase): If `sample_weight` is `None`, weights default to 1. Use `sample_weight` of 0 to mask values. + Args: + precision: A scalar value in range `[0, 1]`. + num_thresholds: (Optional) Defaults to 200. The number of thresholds to + use for matching the given precision. + name: (Optional) string name of the metric instance. + dtype: (Optional) data type of the metric result. + Usage: >>> m = tf.keras.metrics.RecallAtPrecision(0.8) @@ -1787,15 +1769,6 @@ class RecallAtPrecision(SensitivitySpecificityBase): """ def __init__(self, precision, num_thresholds=200, name=None, dtype=None): - """Creates a `RecallAtPrecision` instance. - - Args: - precision: A scalar value in range `[0, 1]`. - num_thresholds: (Optional) Defaults to 200. The number of thresholds to - use for matching the given precision. - name: (Optional) string name of the metric instance. - dtype: (Optional) data type of the metric result. - """ if precision < 0 or precision > 1: raise ValueError('`precision` must be in the range [0, 1].') self.precision = precision @@ -1850,6 +1823,44 @@ class AUC(Metric): If `sample_weight` is `None`, weights default to 1. Use `sample_weight` of 0 to mask values. + Args: + num_thresholds: (Optional) Defaults to 200. The number of thresholds to + use when discretizing the roc curve. Values must be > 1. + curve: (Optional) Specifies the name of the curve to be computed, 'ROC' + [default] or 'PR' for the Precision-Recall-curve. + summation_method: (Optional) Specifies the Riemann summation method used + (https://en.wikipedia.org/wiki/Riemann_sum): 'interpolation' [default], + applies mid-point summation scheme for `ROC`. For PR-AUC, interpolates + (true/false) positives but not the ratio that is precision (see Davis + & Goadrich 2006 for details); 'minoring' that applies left summation + for increasing intervals and right summation for decreasing intervals; + 'majoring' that does the opposite. + name: (Optional) string name of the metric instance. + dtype: (Optional) data type of the metric result. + thresholds: (Optional) A list of floating point values to use as the + thresholds for discretizing the curve. If set, the `num_thresholds` + parameter is ignored. Values should be in [0, 1]. Endpoint thresholds + equal to {-epsilon, 1+epsilon} for a small positive epsilon value will + be automatically included with these to correctly handle predictions + equal to exactly 0 or 1. + multi_label: boolean indicating whether multilabel data should be + treated as such, wherein AUC is computed separately for each label and + then averaged across labels, or (when False) if the data should be + flattened into a single label before AUC computation. In the latter + case, when multilabel data is passed to AUC, each label-prediction pair + is treated as an individual data point. Should be set to False for + multi-class data. + label_weights: (optional) list, array, or tensor of non-negative weights + used to compute AUCs for multilabel data. When `multi_label` is True, + the weights are applied to the individual label AUCs when they are + averaged to produce the multi-label AUC. When it's False, they are used + to weight the individual label predictions in computing the confusion + matrix on the flattened data. Note that this is unlike class_weights in + that class_weights weights the example depending on the value of its + label, whereas label_weights depends only on the index of that label + before flattening; therefore `label_weights` should not be used for + multi-class data. + Usage: >>> m = tf.keras.metrics.AUC(num_thresholds=3) @@ -1884,46 +1895,6 @@ class AUC(Metric): thresholds=None, multi_label=False, label_weights=None): - """Creates an `AUC` instance. - - Args: - num_thresholds: (Optional) Defaults to 200. The number of thresholds to - use when discretizing the roc curve. Values must be > 1. - curve: (Optional) Specifies the name of the curve to be computed, 'ROC' - [default] or 'PR' for the Precision-Recall-curve. - summation_method: (Optional) Specifies the Riemann summation method used - (https://en.wikipedia.org/wiki/Riemann_sum): 'interpolation' [default], - applies mid-point summation scheme for `ROC`. For PR-AUC, interpolates - (true/false) positives but not the ratio that is precision (see Davis - & Goadrich 2006 for details); 'minoring' that applies left summation - for increasing intervals and right summation for decreasing intervals; - 'majoring' that does the opposite. - name: (Optional) string name of the metric instance. - dtype: (Optional) data type of the metric result. - thresholds: (Optional) A list of floating point values to use as the - thresholds for discretizing the curve. If set, the `num_thresholds` - parameter is ignored. Values should be in [0, 1]. Endpoint thresholds - equal to {-epsilon, 1+epsilon} for a small positive epsilon value will - be automatically included with these to correctly handle predictions - equal to exactly 0 or 1. - multi_label: boolean indicating whether multilabel data should be - treated as such, wherein AUC is computed separately for each label and - then averaged across labels, or (when False) if the data should be - flattened into a single label before AUC computation. In the latter - case, when multilabel data is passed to AUC, each label-prediction pair - is treated as an individual data point. Should be set to False for - multi-class data. - label_weights: (optional) list, array, or tensor of non-negative weights - used to compute AUCs for multilabel data. When `multi_label` is True, - the weights are applied to the individual label AUCs when they are - averaged to produce the multi-label AUC. When it's False, they are used - to weight the individual label predictions in computing the confusion - matrix on the flattened data. Note that this is unlike class_weights in - that class_weights weights the example depending on the value of its - label, whereas label_weights depends only on the index of that label - before flattening; therefore `label_weights` should not be used for - multi-class data. - """ # Validate configurations. if isinstance(curve, metrics_utils.AUCCurve) and curve not in list( metrics_utils.AUCCurve): @@ -2262,6 +2233,12 @@ class CosineSimilarity(MeanMetricWrapper): This metric keeps the average cosine similarity between `predictions` and `labels` over a stream of data. + Args: + name: (Optional) string name of the metric instance. + dtype: (Optional) data type of the metric result. + axis: (Optional) Defaults to -1. The dimension along which the cosine + similarity is computed. + Usage: >>> # l2_norm(y_true) = [[0., 1.], [1./1.414], 1./1.414]]] @@ -2292,14 +2269,6 @@ class CosineSimilarity(MeanMetricWrapper): """ def __init__(self, name='cosine_similarity', dtype=None, axis=-1): - """Creates a `CosineSimilarity` instance. - - Args: - name: (Optional) string name of the metric instance. - dtype: (Optional) data type of the metric result. - axis: (Optional) Defaults to -1. The dimension along which the cosine - similarity is computed. - """ super(CosineSimilarity, self).__init__( cosine_similarity, name, dtype=dtype, axis=axis) @@ -2308,6 +2277,10 @@ class CosineSimilarity(MeanMetricWrapper): class MeanAbsoluteError(MeanMetricWrapper): """Computes the mean absolute error between the labels and predictions. + Args: + name: (Optional) string name of the metric instance. + dtype: (Optional) data type of the metric result. + Usage: >>> m = tf.keras.metrics.MeanAbsoluteError() @@ -2339,6 +2312,10 @@ class MeanAbsoluteError(MeanMetricWrapper): class MeanAbsolutePercentageError(MeanMetricWrapper): """Computes the mean absolute percentage error between `y_true` and `y_pred`. + Args: + name: (Optional) string name of the metric instance. + dtype: (Optional) data type of the metric result. + Usage: >>> m = tf.keras.metrics.MeanAbsolutePercentageError() @@ -2372,6 +2349,10 @@ class MeanAbsolutePercentageError(MeanMetricWrapper): class MeanSquaredError(MeanMetricWrapper): """Computes the mean squared error between `y_true` and `y_pred`. + Args: + name: (Optional) string name of the metric instance. + dtype: (Optional) data type of the metric result. + Usage: >>> m = tf.keras.metrics.MeanSquaredError() @@ -2403,6 +2384,10 @@ class MeanSquaredError(MeanMetricWrapper): class MeanSquaredLogarithmicError(MeanMetricWrapper): """Computes the mean squared logarithmic error between `y_true` and `y_pred`. + Args: + name: (Optional) string name of the metric instance. + dtype: (Optional) data type of the metric result. + Usage: >>> m = tf.keras.metrics.MeanSquaredLogarithmicError() @@ -2439,6 +2424,10 @@ class Hinge(MeanMetricWrapper): `y_true` values are expected to be -1 or 1. If binary (0 or 1) labels are provided we will convert them to -1 or 1. + Args: + name: (Optional) string name of the metric instance. + dtype: (Optional) data type of the metric result. + Usage: >>> m = tf.keras.metrics.Hinge() @@ -2471,6 +2460,10 @@ class SquaredHinge(MeanMetricWrapper): `y_true` values are expected to be -1 or 1. If binary (0 or 1) labels are provided we will convert them to -1 or 1. + Args: + name: (Optional) string name of the metric instance. + dtype: (Optional) data type of the metric result. + Usage: >>> m = tf.keras.metrics.SquaredHinge() @@ -2503,6 +2496,10 @@ class SquaredHinge(MeanMetricWrapper): class CategoricalHinge(MeanMetricWrapper): """Computes the categorical hinge metric between `y_true` and `y_pred`. + Args: + name: (Optional) string name of the metric instance. + dtype: (Optional) data type of the metric result. + Usage: >>> m = tf.keras.metrics.CategoricalHinge() @@ -2593,6 +2590,10 @@ class LogCoshError(MeanMetricWrapper): `logcosh = log((exp(x) + exp(-x))/2)`, where x is the error (y_pred - y_true) + Args: + name: (Optional) string name of the metric instance. + dtype: (Optional) data type of the metric result. + Usage: >>> m = tf.keras.metrics.LogCoshError() @@ -2624,6 +2625,10 @@ class Poisson(MeanMetricWrapper): `metric = y_pred - y_true * log(y_pred)` + Args: + name: (Optional) string name of the metric instance. + dtype: (Optional) data type of the metric result. + Usage: >>> m = tf.keras.metrics.Poisson() @@ -2655,6 +2660,10 @@ class KLDivergence(MeanMetricWrapper): `metric = y_true * log(y_true / y_pred)` + Args: + name: (Optional) string name of the metric instance. + dtype: (Optional) data type of the metric result. + Usage: >>> m = tf.keras.metrics.KLDivergence() @@ -2695,6 +2704,13 @@ class MeanIoU(Metric): If `sample_weight` is `None`, weights default to 1. Use `sample_weight` of 0 to mask values. + Args: + num_classes: The possible number of labels the prediction task can have. + This value must be provided, since a confusion matrix of dimension = + [num_classes, num_classes] will be allocated. + name: (Optional) string name of the metric instance. + dtype: (Optional) data type of the metric result. + Usage: >>> # cm = [[1, 1], @@ -2725,15 +2741,6 @@ class MeanIoU(Metric): """ def __init__(self, num_classes, name=None, dtype=None): - """Creates a `MeanIoU` instance. - - Args: - num_classes: The possible number of labels the prediction task can have. - This value must be provided, since a confusion matrix of dimension = - [num_classes, num_classes] will be allocated. - name: (Optional) string name of the metric instance. - dtype: (Optional) data type of the metric result. - """ super(MeanIoU, self).__init__(name=name, dtype=dtype) self.num_classes = num_classes @@ -2825,6 +2832,10 @@ class MeanTensor(Metric): `total` tracks the sum of the weighted values, and `count` stores the sum of the weighted counts. + Args: + name: (Optional) string name of the metric instance. + dtype: (Optional) data type of the metric result. + Usage: >>> m = tf.keras.metrics.MeanTensor() @@ -2839,12 +2850,6 @@ class MeanTensor(Metric): """ def __init__(self, name='mean_tensor', dtype=None): - """Creates a `MeanTensor` instance. - - Args: - name: (Optional) string name of the metric instance. - dtype: (Optional) data type of the metric result. - """ super(MeanTensor, self).__init__(name=name, dtype=dtype) self._shape = None self._total = None @@ -2936,6 +2941,16 @@ class BinaryCrossentropy(MeanMetricWrapper): This is the crossentropy metric class to be used when there are only two label classes (0 and 1). + Args: + name: (Optional) string name of the metric instance. + dtype: (Optional) data type of the metric result. + from_logits: (Optional )Whether output is expected to be a logits tensor. + By default, we consider that output encodes a probability distribution. + label_smoothing: (Optional) Float in [0, 1]. When > 0, label values are + smoothed, meaning the confidence on label values are relaxed. + e.g. `label_smoothing=0.2` means that we will use a value of `0.1` for + label `0` and `0.9` for label `1`". + Usage: >>> m = tf.keras.metrics.BinaryCrossentropy() @@ -2965,19 +2980,6 @@ class BinaryCrossentropy(MeanMetricWrapper): dtype=None, from_logits=False, label_smoothing=0): - """Creates a `BinaryCrossentropy` instance. - - Args: - name: (Optional) string name of the metric instance. - dtype: (Optional) data type of the metric result. - from_logits: (Optional )Whether output is expected to be a logits tensor. - By default, we consider that output encodes a probability distribution. - label_smoothing: (Optional) Float in [0, 1]. When > 0, label values are - smoothed, meaning the confidence on label values are relaxed. - e.g. `label_smoothing=0.2` means that we will use a value of `0.1` for - label `0` and `0.9` for label `1`" - """ - super(BinaryCrossentropy, self).__init__( binary_crossentropy, name, @@ -2995,6 +2997,16 @@ class CategoricalCrossentropy(MeanMetricWrapper): representation. eg., When labels values are [2, 0, 1], `y_true` = [[0, 0, 1], [1, 0, 0], [0, 1, 0]]. + Args: + name: (Optional) string name of the metric instance. + dtype: (Optional) data type of the metric result. + from_logits: (Optional) Whether output is expected to be a logits tensor. + By default, we consider that output encodes a probability distribution. + label_smoothing: (Optional) Float in [0, 1]. When > 0, label values are + smoothed, meaning the confidence on label values are relaxed. e.g. + `label_smoothing=0.2` means that we will use a value of `0.1` for label + `0` and `0.9` for label `1`" + Usage: >>> # EPSILON = 1e-7, y = y_true, y` = y_pred @@ -3026,16 +3038,6 @@ class CategoricalCrossentropy(MeanMetricWrapper): loss='mse', metrics=[tf.keras.metrics.CategoricalCrossentropy()]) ``` - - Args: - name: (Optional) string name of the metric instance. - dtype: (Optional) data type of the metric result. - from_logits: (Optional ) Whether `y_pred` is expected to be a logits tensor. - By default, we assume that `y_pred` encodes a probability distribution. - label_smoothing: Float in [0, 1]. When > 0, label values are smoothed, - meaning the confidence on label values are relaxed. e.g. - `label_smoothing=0.2` means that we will use a value of `0.1` for label - `0` and `0.9` for label `1`" """ def __init__(self, @@ -3043,7 +3045,6 @@ class CategoricalCrossentropy(MeanMetricWrapper): dtype=None, from_logits=False, label_smoothing=0): - super(CategoricalCrossentropy, self).__init__( categorical_crossentropy, name, @@ -3067,6 +3068,14 @@ class SparseCategoricalCrossentropy(MeanMetricWrapper): The shape of `y_true` is `[batch_size]` and the shape of `y_pred` is `[batch_size, num_classes]`. + Args: + name: (Optional) string name of the metric instance. + dtype: (Optional) data type of the metric result. + from_logits: (Optional) Whether output is expected to be a logits tensor. + By default, we consider that output encodes a probability distribution. + axis: (Optional) Defaults to -1. The dimension along which the metric is + computed. + Usage: >>> # y_true = one_hot(y_true) = [[0, 1, 0], [0, 0, 1]] @@ -3101,14 +3110,6 @@ class SparseCategoricalCrossentropy(MeanMetricWrapper): loss='mse', metrics=[tf.keras.metrics.SparseCategoricalCrossentropy()]) ``` - - Args: - name: (Optional) string name of the metric instance. - dtype: (Optional) data type of the metric result. - from_logits: (Optional ) Whether `y_pred` is expected to be a logits tensor. - By default, we assume that `y_pred` encodes a probability distribution. - axis: (Optional) Defaults to -1. The dimension along which the metric is - computed. """ def __init__(self, @@ -3116,7 +3117,6 @@ class SparseCategoricalCrossentropy(MeanMetricWrapper): dtype=None, from_logits=False, axis=-1): - super(SparseCategoricalCrossentropy, self).__init__( sparse_categorical_crossentropy, name, @@ -3196,6 +3196,14 @@ def accuracy(y_true, y_pred): def binary_accuracy(y_true, y_pred, threshold=0.5): """Calculates how often predictions matches binary labels. + Usage: + >>> y_true = [[1], [1], [0], [0]] + >>> y_pred = [[1], [1], [0], [0]] + >>> m = tf.keras.metrics.binary_accuracy(y_true, y_pred) + >>> assert m.shape == (4,) + >>> m.numpy() + array([1., 1., 1., 1.], dtype=float32) + Args: y_true: Ground truth values. shape = `[batch_size, d0, .. dN]`. y_pred: The predicted values. shape = `[batch_size, d0, .. dN]`. @@ -3205,6 +3213,7 @@ def binary_accuracy(y_true, y_pred, threshold=0.5): Returns: Binary accuracy values. shape = `[batch_size, d0, .. dN-1]` """ + y_pred = ops.convert_to_tensor_v2(y_pred) threshold = math_ops.cast(threshold, y_pred.dtype) y_pred = math_ops.cast(y_pred > threshold, y_pred.dtype) return K.mean(math_ops.equal(y_true, y_pred), axis=-1) @@ -3214,6 +3223,14 @@ def binary_accuracy(y_true, y_pred, threshold=0.5): def categorical_accuracy(y_true, y_pred): """Calculates how often predictions matches one-hot labels. + Usage: + >>> y_true = [[0, 0, 1], [0, 1, 0]] + >>> y_pred = [[0.1, 0.9, 0.8], [0.05, 0.95, 0]] + >>> m = tf.keras.metrics.categorical_accuracy(y_true, y_pred) + >>> assert m.shape == (2,) + >>> m.numpy() + array([0., 1.], dtype=float32) + You can provide logits of classes as `y_pred`, since argmax of logits and probabilities are same. @@ -3234,6 +3251,14 @@ def categorical_accuracy(y_true, y_pred): def sparse_categorical_accuracy(y_true, y_pred): """Calculates how often predictions matches integer labels. + Usage: + >>> y_true = [2, 1] + >>> y_pred = [[0.1, 0.9, 0.8], [0.05, 0.95, 0]] + >>> m = tf.keras.metrics.sparse_categorical_accuracy(y_true, y_pred) + >>> assert m.shape == (2,) + >>> m.numpy() + array([0., 1.], dtype=float32) + You can provide logits of classes as `y_pred`, since argmax of logits and probabilities are same. @@ -3244,8 +3269,10 @@ def sparse_categorical_accuracy(y_true, y_pred): Returns: Sparse categorical accuracy values. """ - y_pred_rank = ops.convert_to_tensor_v2(y_pred).shape.ndims - y_true_rank = ops.convert_to_tensor_v2(y_true).shape.ndims + y_pred = ops.convert_to_tensor_v2(y_pred) + y_true = ops.convert_to_tensor_v2(y_true) + y_pred_rank = y_pred.shape.ndims + y_true_rank = y_true.shape.ndims # If the shape of y_true is (num_samples, 1), squeeze to (num_samples,) if (y_true_rank is not None) and (y_pred_rank is not None) and (len( K.int_shape(y_true)) == len(K.int_shape(y_pred))): @@ -3264,6 +3291,14 @@ def sparse_categorical_accuracy(y_true, y_pred): def top_k_categorical_accuracy(y_true, y_pred, k=5): """Computes how often targets are in the top `K` predictions. + Usage: + >>> y_true = [[0, 0, 1], [0, 1, 0]] + >>> y_pred = [[0.1, 0.9, 0.8], [0.05, 0.95, 0]] + >>> m = tf.keras.metrics.top_k_categorical_accuracy(y_true, y_pred, k=3) + >>> assert m.shape == (2,) + >>> m.numpy() + array([1., 1.], dtype=float32) + Args: y_true: The ground truth values. y_pred: The prediction values. @@ -3281,6 +3316,15 @@ def top_k_categorical_accuracy(y_true, y_pred, k=5): def sparse_top_k_categorical_accuracy(y_true, y_pred, k=5): """Computes how often integer targets are in the top `K` predictions. + Usage: + >>> y_true = [2, 1] + >>> y_pred = [[0.1, 0.9, 0.8], [0.05, 0.95, 0]] + >>> m = tf.keras.metrics.sparse_top_k_categorical_accuracy( + ... y_true, y_pred, k=3) + >>> assert m.shape == (2,) + >>> m.numpy() + array([1., 1.], dtype=float32) + Args: y_true: tensor of true targets. y_pred: tensor of predicted targets. @@ -3345,11 +3389,29 @@ def clone_metrics(metrics): @keras_export('keras.metrics.serialize') def serialize(metric): + """Serializes metric function or `Metric` instance. + + Arguments: + metric: A Keras `Metric` instance or a metric function. + + Returns: + Metric configuration dictionary. + """ return serialize_keras_object(metric) @keras_export('keras.metrics.deserialize') def deserialize(config, custom_objects=None): + """Deserializes a serialized metric class/function instance. + + Arguments: + config: Metric configuration. + custom_objects: Optional dictionary mapping names (strings) to custom + objects (classes and functions) to be considered during deserialization. + + Returns: + A Keras `Metric` instance or a metric function. + """ return deserialize_keras_object( config, module_objects=globals(), @@ -3359,7 +3421,38 @@ def deserialize(config, custom_objects=None): @keras_export('keras.metrics.get') def get(identifier): - """Return a metric given its identifer.""" + """Retrieves a Keras metric as a `function`/`Metric` class instance. + + The `identifier` may be the string name of a metric function or class. + + >>> metric = tf.keras.metrics.get("categorical_crossentropy") + >>> type(metric) + + >>> metric = tf.keras.metrics.get("CategoricalCrossentropy") + >>> type(metric) + + + You can also specify `config` of the metric to this function by passing dict + containing `class_name` and `config` as an identifier. Also note that the + `class_name` must map to a `Metric` class + + >>> identifier = {"class_name": "CategoricalCrossentropy", + ... "config": {"from_logits": True}} + >>> metric = tf.keras.metrics.get(identifier) + >>> type(metric) + + + Arguments: + identifier: A metric identifier. One of None or string name of a metric + function/class or metric configuration dictionary or a metric function or + a metric class instance + + Returns: + A Keras metric as a `function`/ `Metric` class instance. + + Raises: + ValueError: If `identifier` cannot be interpreted. + """ if isinstance(identifier, dict): return deserialize(identifier) elif isinstance(identifier, six.string_types): diff --git a/tensorflow/python/keras/mixed_precision/experimental/loss_scale_optimizer.py b/tensorflow/python/keras/mixed_precision/experimental/loss_scale_optimizer.py index 6b45ea1e95b..0adfad7d865 100644 --- a/tensorflow/python/keras/mixed_precision/experimental/loss_scale_optimizer.py +++ b/tensorflow/python/keras/mixed_precision/experimental/loss_scale_optimizer.py @@ -30,6 +30,7 @@ from tensorflow.python.keras.optimizer_v2 import optimizer_v2 from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.training.experimental import loss_scale as loss_scale_module +from tensorflow.python.training.experimental import mixed_precision from tensorflow.python.util.tf_export import keras_export @@ -107,6 +108,8 @@ class LossScaleOptimizer(optimizer_v2.OptimizerV2): 0.25 """ + _HAS_AGGREGATE_GRAD = True + def __init__(self, optimizer, loss_scale): """Initializes this loss scale optimizer. @@ -271,9 +274,12 @@ class LossScaleOptimizer(optimizer_v2.OptimizerV2): def _apply_gradients(self, grads, wrapped_vars, name, experimental_aggregate_gradients): + # TODO(reedwm): This will raise a fairly cryptic error message if + # self._optimizer.apply_gradients does not take + # experimental_aggregate_gradients. return self._optimizer.apply_gradients( list(zip(grads, wrapped_vars.value)), name, - experimental_aggregate_gradients) + experimental_aggregate_gradients=experimental_aggregate_gradients) def get_config(self): serialized_optimizer = optimizers.serialize(self._optimizer) @@ -327,6 +333,9 @@ class LossScaleOptimizer(optimizer_v2.OptimizerV2): def set_weights(self, weights): return self._optimizer.set_weights(weights) + def _aggregate_gradients(self, grads_and_vars): + return self._optimizer._aggregate_gradients(grads_and_vars) # pylint: disable=protected-access + # For the most part, we only expose methods in the base OptimizerV2, not # individual subclasses like Adam. However, although "learning_rate" and "lr" # properties are not part of the base OptimizerV2 class, they are part of most @@ -382,6 +391,11 @@ class LossScaleOptimizer(optimizer_v2.OptimizerV2): # optimizer being used. +# pylint: disable=protected-access +mixed_precision._register_wrapper_optimizer_cls(optimizer_v2.OptimizerV2, + LossScaleOptimizer) + + def _multiply_gradient(gradient, scale): """Multiply a (possibly sparse) gradient by the given scale factor.""" scale = math_ops.cast(scale, gradient.dtype) diff --git a/tensorflow/python/keras/mixed_precision/experimental/policy_test.py b/tensorflow/python/keras/mixed_precision/experimental/policy_test.py index ff809d061cb..81d461c304d 100644 --- a/tensorflow/python/keras/mixed_precision/experimental/policy_test.py +++ b/tensorflow/python/keras/mixed_precision/experimental/policy_test.py @@ -166,6 +166,8 @@ class PolicyTest(test.TestCase, parameterized.TestCase): 'not passing any loss_scale instead.') for policy_name in 'float16', 'mixed_float16': + # Trigger any other warnings that occur only once + mp_policy.Policy(policy_name, loss_scale=2.) with test.mock.patch.object(tf_logging, 'warn') as mock_warn: mp_policy.Policy(policy_name, loss_scale=2.) mock_warn.assert_not_called() diff --git a/tensorflow/python/keras/ops.py b/tensorflow/python/keras/ops.py deleted file mode 100644 index 23bdc8dcd56..00000000000 --- a/tensorflow/python/keras/ops.py +++ /dev/null @@ -1,110 +0,0 @@ -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Module for exporting TensorFlow ops under tf.keras.*.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.python.framework import ops -from tensorflow.python.ops import init_ops -from tensorflow.python.ops import init_ops_v2 -from tensorflow.python.util.tf_export import keras_export - - -# pylint: disable=bad-continuation -keras_export(v1=["keras.initializers.Initializer"])( - init_ops.Initializer) -keras_export(v1=["keras.initializers.Zeros", "keras.initializers.zeros"])( - init_ops.Zeros) -keras_export(v1=["keras.initializers.Ones", "keras.initializers.ones"])( - init_ops.Ones) -keras_export(v1=["keras.initializers.Constant", "keras.initializers.constant"])( - init_ops.Constant) -keras_export(v1=["keras.initializers.VarianceScaling"])( - init_ops.VarianceScaling) -keras_export(v1=["keras.initializers.Orthogonal", - "keras.initializers.orthogonal"])( - init_ops.Orthogonal) -keras_export(v1=["keras.initializers.Identity", - "keras.initializers.identity"])( - init_ops.Identity) -keras_export(v1=["keras.initializers.glorot_uniform"])( - init_ops.GlorotUniform) -keras_export(v1=["keras.initializers.glorot_normal"])( - init_ops.GlorotNormal) -keras_export(v1=["keras.initializers.lecun_normal"])( - init_ops.lecun_normal) -keras_export(v1=["keras.initializers.lecun_uniform"])( - init_ops.lecun_uniform) -keras_export(v1=["keras.initializers.he_normal"])( - init_ops.he_normal) -keras_export(v1=["keras.initializers.he_uniform"])( - init_ops.he_uniform) - -keras_export("keras.initializers.Initializer", v1=[])( - init_ops_v2.Initializer) -keras_export( - "keras.initializers.Zeros", "keras.initializers.zeros", v1=[])( - init_ops_v2.Zeros) -keras_export( - "keras.initializers.Ones", "keras.initializers.ones", v1=[])( - init_ops_v2.Ones) -keras_export( - "keras.initializers.Constant", "keras.initializers.constant", v1=[])( - init_ops_v2.Constant) -keras_export("keras.initializers.VarianceScaling", v1=[])( - init_ops_v2.VarianceScaling) -keras_export( - "keras.initializers.Orthogonal", "keras.initializers.orthogonal", v1=[])( - init_ops_v2.Orthogonal) -keras_export( - "keras.initializers.Identity", "keras.initializers.identity", v1=[])( - init_ops_v2.Identity) -keras_export( - "keras.initializers.GlorotUniform", - "keras.initializers.glorot_uniform", - v1=[])( - init_ops_v2.GlorotUniform) -keras_export( - "keras.initializers.GlorotNormal", - "keras.initializers.glorot_normal", - v1=[])( - init_ops_v2.GlorotNormal) -keras_export("keras.initializers.lecun_normal", v1=[])( - init_ops_v2.lecun_normal) -keras_export("keras.initializers.lecun_uniform", v1=[])( - init_ops_v2.lecun_uniform) -keras_export("keras.initializers.he_normal", v1=[])( - init_ops_v2.he_normal) -keras_export("keras.initializers.he_uniform", v1=[])( - init_ops_v2.he_uniform) -keras_export( - "keras.initializers.RandomNormal", - "keras.initializers.random_normal", - v1=[])( - init_ops_v2.RandomNormal) -keras_export( - "keras.initializers.RandomUniform", - "keras.initializers.random_uniform", - v1=[])( - init_ops_v2.RandomUniform) -keras_export( - "keras.initializers.TruncatedNormal", - "keras.initializers.truncated_normal", - v1=[])( - init_ops_v2.TruncatedNormal) -# pylint: enable=bad-continuation - -keras_export(v1=["keras.backend.name_scope"])(ops.name_scope_v1) diff --git a/tensorflow/python/keras/optimizer_v2/BUILD b/tensorflow/python/keras/optimizer_v2/BUILD index c5eab79f6c2..8636ffb237e 100644 --- a/tensorflow/python/keras/optimizer_v2/BUILD +++ b/tensorflow/python/keras/optimizer_v2/BUILD @@ -66,6 +66,20 @@ py_library( ], ) +py_library( + name = "legacy_learning_rate_decay", + srcs = ["legacy_learning_rate_decay.py"], + srcs_version = "PY2AND3", + deps = [ + ":learning_rate_schedule", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:tf_export", + "//tensorflow/python/eager:context", + ], +) + cuda_py_test( name = "adagrad_test", size = "medium", @@ -245,6 +259,21 @@ cuda_py_test( ], ) +cuda_py_test( + name = "legacy_learning_rate_decay_test", + size = "medium", + srcs = ["legacy_learning_rate_decay_test.py"], + deps = [ + ":legacy_learning_rate_decay", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", + "//tensorflow/python:resource_variable_ops", + "//tensorflow/python:training_lib", + "//tensorflow/python:variables", + "//tensorflow/python/eager:context", + ], +) + cuda_py_test( name = "rmsprop_test", size = "medium", diff --git a/tensorflow/python/keras/optimizer_v2/legacy_learning_rate_decay.py b/tensorflow/python/keras/optimizer_v2/legacy_learning_rate_decay.py new file mode 100644 index 00000000000..f86e68d188f --- /dev/null +++ b/tensorflow/python/keras/optimizer_v2/legacy_learning_rate_decay.py @@ -0,0 +1,771 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Various learning rate decay functions.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import functools + +from tensorflow.python.eager import context +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.keras.optimizer_v2 import learning_rate_schedule +from tensorflow.python.ops import math_ops +from tensorflow.python.util.tf_export import tf_export + + +@tf_export(v1=["train.exponential_decay"]) +def exponential_decay(learning_rate, + global_step, + decay_steps, + decay_rate, + staircase=False, + name=None): + """Applies exponential decay to the learning rate. + + When training a model, it is often recommended to lower the learning rate as + the training progresses. This function applies an exponential decay function + to a provided initial learning rate. It requires a `global_step` value to + compute the decayed learning rate. You can just pass a TensorFlow variable + that you increment at each training step. + + The function returns the decayed learning rate. It is computed as: + + ```python + decayed_learning_rate = learning_rate * + decay_rate ^ (global_step / decay_steps) + ``` + + If the argument `staircase` is `True`, then `global_step / decay_steps` is an + integer division and the decayed learning rate follows a staircase function. + + Example: decay every 100000 steps with a base of 0.96: + + ```python + ... + global_step = tf.Variable(0, trainable=False) + starter_learning_rate = 0.1 + learning_rate = tf.compat.v1.train.exponential_decay(starter_learning_rate, + global_step, + 100000, 0.96, staircase=True) + # Passing global_step to minimize() will increment it at each step. + learning_step = ( + tf.compat.v1.train.GradientDescentOptimizer(learning_rate) + .minimize(...my loss..., global_step=global_step) + ) + ``` + + Args: + learning_rate: A scalar `float32` or `float64` `Tensor` or a Python number. + The initial learning rate. + global_step: A scalar `int32` or `int64` `Tensor` or a Python number. Global + step to use for the decay computation. Must not be negative. + decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number. Must + be positive. See the decay computation above. + decay_rate: A scalar `float32` or `float64` `Tensor` or a Python number. + The decay rate. + staircase: Boolean. If `True` decay the learning rate at discrete intervals + name: String. Optional name of the operation. Defaults to + 'ExponentialDecay'. + + Returns: + A scalar `Tensor` of the same type as `learning_rate`. The decayed + learning rate. + + Raises: + ValueError: if `global_step` is not supplied. + + @compatibility(eager) + When eager execution is enabled, this function returns a function which in + turn returns the decayed learning rate Tensor. This can be useful for changing + the learning rate value across different invocations of optimizer functions. + @end_compatibility + """ + decayed_lr = learning_rate_schedule.ExponentialDecay( + learning_rate, decay_steps, decay_rate, staircase=staircase, name=name) + if not context.executing_eagerly(): + decayed_lr = decayed_lr(global_step) + else: + decayed_lr = functools.partial(decayed_lr, global_step) + return decayed_lr + + +@tf_export(v1=["train.piecewise_constant_decay", "train.piecewise_constant"]) +def piecewise_constant(x, boundaries, values, name=None): + """Piecewise constant from boundaries and interval values. + + Example: use a learning rate that's 1.0 for the first 100001 steps, 0.5 + for the next 10000 steps, and 0.1 for any additional steps. + + ```python + global_step = tf.Variable(0, trainable=False) + boundaries = [100000, 110000] + values = [1.0, 0.5, 0.1] + learning_rate = tf.compat.v1.train.piecewise_constant(global_step, boundaries, + values) + + # Later, whenever we perform an optimization step, we increment global_step. + ``` + + Args: + x: A 0-D scalar `Tensor`. Must be one of the following types: `float32`, + `float64`, `uint8`, `int8`, `int16`, `int32`, `int64`. + boundaries: A list of `Tensor`s or `int`s or `float`s with strictly + increasing entries, and with all elements having the same type as `x`. + values: A list of `Tensor`s or `float`s or `int`s that specifies the values + for the intervals defined by `boundaries`. It should have one more element + than `boundaries`, and all elements should have the same type. + name: A string. Optional name of the operation. Defaults to + 'PiecewiseConstant'. + + Returns: + A 0-D Tensor. Its value is `values[0]` when `x <= boundaries[0]`, + `values[1]` when `x > boundaries[0]` and `x <= boundaries[1]`, ..., + and values[-1] when `x > boundaries[-1]`. + + Raises: + ValueError: if types of `x` and `boundaries` do not match, or types of all + `values` do not match or + the number of elements in the lists does not match. + + @compatibility(eager) + When eager execution is enabled, this function returns a function which in + turn returns the decayed learning rate Tensor. This can be useful for changing + the learning rate value across different invocations of optimizer functions. + @end_compatibility + """ + boundaries = ops.convert_n_to_tensor(boundaries) + values = ops.convert_n_to_tensor(values) + x_recomp = ops.convert_to_tensor(x) + # Avoid explicit conversion to x's dtype. This could result in faulty + # comparisons, for example if floats are converted to integers. + for i, b in enumerate(boundaries): + if b.dtype.base_dtype != x_recomp.dtype.base_dtype: + # We can promote int32 boundaries to int64 without loss of precision. + # This covers the most common case where the user passes in boundaries + # as an array of Python integers. + if (b.dtype.base_dtype == dtypes.int32 and + x_recomp.dtype.base_dtype == dtypes.int64): + b = math_ops.cast(b, x_recomp.dtype.base_dtype) + boundaries[i] = b + else: + raise ValueError( + "Boundaries (%s) must have the same dtype as x (%s)." % + (b.dtype.base_dtype, x_recomp.dtype.base_dtype)) + for v in values[1:]: + if v.dtype.base_dtype != values[0].dtype.base_dtype: + raise ValueError( + "Values must have elements all with the same dtype (%s vs %s)." % + (values[0].dtype.base_dtype, v.dtype.base_dtype)) + decayed_lr = learning_rate_schedule.PiecewiseConstantDecay( + boundaries, values, name=name) + if not context.executing_eagerly(): + decayed_lr = decayed_lr(x) + else: + decayed_lr = functools.partial(decayed_lr, x) + return decayed_lr + + +@tf_export(v1=["train.polynomial_decay"]) +def polynomial_decay(learning_rate, + global_step, + decay_steps, + end_learning_rate=0.0001, + power=1.0, + cycle=False, + name=None): + """Applies a polynomial decay to the learning rate. + + It is commonly observed that a monotonically decreasing learning rate, whose + degree of change is carefully chosen, results in a better performing model. + This function applies a polynomial decay function to a provided initial + `learning_rate` to reach an `end_learning_rate` in the given `decay_steps`. + + It requires a `global_step` value to compute the decayed learning rate. You + can just pass a TensorFlow variable that you increment at each training step. + + The function returns the decayed learning rate. It is computed as: + + ```python + global_step = min(global_step, decay_steps) + decayed_learning_rate = (learning_rate - end_learning_rate) * + (1 - global_step / decay_steps) ^ (power) + + end_learning_rate + + ``` + + If `cycle` is True then a multiple of `decay_steps` is used, the first one + that is bigger than `global_steps`. + + ```python + decay_steps = decay_steps * ceil(global_step / decay_steps) + decayed_learning_rate = (learning_rate - end_learning_rate) * + (1 - global_step / decay_steps) ^ (power) + + end_learning_rate + + ``` + + Example: decay from 0.1 to 0.01 in 10000 steps using sqrt (i.e. power=0.5): + + ```python + ... + global_step = tf.Variable(0, trainable=False) + starter_learning_rate = 0.1 + end_learning_rate = 0.01 + decay_steps = 10000 + learning_rate = tf.compat.v1.train.polynomial_decay(starter_learning_rate, + global_step, + decay_steps, end_learning_rate, + power=0.5) + # Passing global_step to minimize() will increment it at each step. + learning_step = ( + tf.compat.v1.train.GradientDescentOptimizer(learning_rate) + .minimize(...my loss..., global_step=global_step) + ) + ``` + + Args: + learning_rate: A scalar `float32` or `float64` `Tensor` or a Python number. + The initial learning rate. + global_step: A scalar `int32` or `int64` `Tensor` or a Python number. Global + step to use for the decay computation. Must not be negative. + decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number. Must + be positive. See the decay computation above. + end_learning_rate: A scalar `float32` or `float64` `Tensor` or a Python + number. The minimal end learning rate. + power: A scalar `float32` or `float64` `Tensor` or a Python number. The + power of the polynomial. Defaults to linear, 1.0. + cycle: A boolean, whether or not it should cycle beyond decay_steps. + name: String. Optional name of the operation. Defaults to + 'PolynomialDecay'. + + Returns: + A scalar `Tensor` of the same type as `learning_rate`. The decayed + learning rate. + + Raises: + ValueError: if `global_step` is not supplied. + + @compatibility(eager) + When eager execution is enabled, this function returns a function which in + turn returns the decayed learning rate Tensor. This can be useful for changing + the learning rate value across different invocations of optimizer functions. + @end_compatibility + """ + decayed_lr = learning_rate_schedule.PolynomialDecay( + learning_rate, + decay_steps, + end_learning_rate=end_learning_rate, + power=power, + cycle=cycle, + name=name) + + if not context.executing_eagerly(): + decayed_lr = decayed_lr(global_step) + else: + decayed_lr = functools.partial(decayed_lr, global_step) + return decayed_lr + + +@tf_export(v1=["train.natural_exp_decay"]) +def natural_exp_decay(learning_rate, + global_step, + decay_steps, + decay_rate, + staircase=False, + name=None): + """Applies natural exponential decay to the initial learning rate. + + When training a model, it is often recommended to lower the learning rate as + the training progresses. This function applies an exponential decay function + to a provided initial learning rate. It requires an `global_step` value to + compute the decayed learning rate. You can just pass a TensorFlow variable + that you increment at each training step. + + The function returns the decayed learning rate. It is computed as: + + ```python + decayed_learning_rate = learning_rate * exp(-decay_rate * global_step / + decay_step) + ``` + + or, if `staircase` is `True`, as: + + ```python + decayed_learning_rate = learning_rate * exp(-decay_rate * floor(global_step / + decay_step)) + ``` + + Example: decay exponentially with a base of 0.96: + + ```python + ... + global_step = tf.Variable(0, trainable=False) + learning_rate = 0.1 + decay_steps = 5 + k = 0.5 + learning_rate = tf.compat.v1.train.natural_exp_decay(learning_rate, + global_step, + decay_steps, k) + + # Passing global_step to minimize() will increment it at each step. + learning_step = ( + tf.compat.v1.train.GradientDescentOptimizer(learning_rate) + .minimize(...my loss..., global_step=global_step) + ) + ``` + + Args: + learning_rate: A scalar `float32` or `float64` `Tensor` or a Python number. + The initial learning rate. + global_step: A Python number. Global step to use for the decay computation. + Must not be negative. + decay_steps: How often to apply decay. + decay_rate: A Python number. The decay rate. + staircase: Whether to apply decay in a discrete staircase, as opposed to + continuous, fashion. + name: String. Optional name of the operation. Defaults to + 'ExponentialTimeDecay'. + + Returns: + A scalar `Tensor` of the same type as `learning_rate`. The decayed + learning rate. + + Raises: + ValueError: if `global_step` is not supplied. + + @compatibility(eager) + When eager execution is enabled, this function returns a function which in + turn returns the decayed learning rate Tensor. This can be useful for changing + the learning rate value across different invocations of optimizer functions. + @end_compatibility + """ + natural_exp_rate = math_ops.exp(math_ops.negative(decay_rate)) + decayed_lr = learning_rate_schedule.ExponentialDecay( + learning_rate, + decay_steps, + natural_exp_rate, + staircase=staircase, + name=name) + + if not context.executing_eagerly(): + decayed_lr = decayed_lr(global_step) + else: + decayed_lr = functools.partial(decayed_lr, global_step) + return decayed_lr + + +@tf_export(v1=["train.inverse_time_decay"]) +def inverse_time_decay(learning_rate, + global_step, + decay_steps, + decay_rate, + staircase=False, + name=None): + """Applies inverse time decay to the initial learning rate. + + When training a model, it is often recommended to lower the learning rate as + the training progresses. This function applies an inverse decay function + to a provided initial learning rate. It requires an `global_step` value to + compute the decayed learning rate. You can just pass a TensorFlow variable + that you increment at each training step. + + The function returns the decayed learning rate. It is computed as: + + ```python + decayed_learning_rate = learning_rate / (1 + decay_rate * global_step / + decay_step) + ``` + + or, if `staircase` is `True`, as: + + ```python + decayed_learning_rate = learning_rate / (1 + decay_rate * floor(global_step / + decay_step)) + ``` + + Example: decay 1/t with a rate of 0.5: + + ```python + ... + global_step = tf.Variable(0, trainable=False) + learning_rate = 0.1 + decay_steps = 1.0 + decay_rate = 0.5 + learning_rate = tf.compat.v1.train.inverse_time_decay(learning_rate, + global_step, + decay_steps, decay_rate) + + # Passing global_step to minimize() will increment it at each step. + learning_step = ( + tf.compat.v1.train.GradientDescentOptimizer(learning_rate) + .minimize(...my loss..., global_step=global_step) + ) + ``` + + Args: + learning_rate: A scalar `float32` or `float64` `Tensor` or a Python number. + The initial learning rate. + global_step: A Python number. Global step to use for the decay computation. + Must not be negative. + decay_steps: How often to apply decay. + decay_rate: A Python number. The decay rate. + staircase: Whether to apply decay in a discrete staircase, as opposed to + continuous, fashion. + name: String. Optional name of the operation. Defaults to + 'InverseTimeDecay'. + + Returns: + A scalar `Tensor` of the same type as `learning_rate`. The decayed + learning rate. + + Raises: + ValueError: if `global_step` is not supplied. + + @compatibility(eager) + When eager execution is enabled, this function returns a function which in + turn returns the decayed learning rate Tensor. This can be useful for changing + the learning rate value across different invocations of optimizer functions. + @end_compatibility + """ + decayed_lr = learning_rate_schedule.InverseTimeDecay( + learning_rate, decay_steps, decay_rate, staircase=staircase, name=name) + + if not context.executing_eagerly(): + decayed_lr = decayed_lr(global_step) + else: + decayed_lr = functools.partial(decayed_lr, global_step) + return decayed_lr + + +@tf_export(v1=["train.cosine_decay"]) +def cosine_decay(learning_rate, global_step, decay_steps, alpha=0.0, name=None): + """Applies cosine decay to the learning rate. + + When training a model, it is often recommended to lower the learning rate as + the training progresses. This function applies a cosine decay function + to a provided initial learning rate. It requires a `global_step` value to + compute the decayed learning rate. You can just pass a TensorFlow variable + that you increment at each training step. + + The function returns the decayed learning rate. It is computed as: + ```python + global_step = min(global_step, decay_steps) + cosine_decay = 0.5 * (1 + cos(pi * global_step / decay_steps)) + decayed = (1 - alpha) * cosine_decay + alpha + decayed_learning_rate = learning_rate * decayed + ``` + + Example usage: + ```python + decay_steps = 1000 + lr_decayed = cosine_decay(learning_rate, global_step, decay_steps) + ``` + + Args: + learning_rate: A scalar `float32` or `float64` Tensor or a Python number. + The initial learning rate. + global_step: A scalar `int32` or `int64` `Tensor` or a Python number. Global + step to use for the decay computation. + decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number. Number + of steps to decay over. + alpha: A scalar `float32` or `float64` Tensor or a Python number. Minimum + learning rate value as a fraction of learning_rate. + name: String. Optional name of the operation. Defaults to 'CosineDecay'. + + Returns: + A scalar `Tensor` of the same type as `learning_rate`. The decayed + learning rate. + Raises: + ValueError: if `global_step` is not supplied. + + References: + Stochastic Gradient Descent with Warm Restarts: + [Loshchilov et al., 2017] + (https://openreview.net/forum?id=Skq89Scxx¬eId=Skq89Scxx) + ([pdf](https://openreview.net/pdf?id=Skq89Scxx)) + + @compatibility(eager) + When eager execution is enabled, this function returns a function which in + turn returns the decayed learning rate Tensor. This can be useful for changing + the learning rate value across different invocations of optimizer functions. + @end_compatibility + """ + decayed_lr = learning_rate_schedule.CosineDecay( + learning_rate, decay_steps, alpha=alpha, name=name) + + if not context.executing_eagerly(): + decayed_lr = decayed_lr(global_step) + else: + decayed_lr = functools.partial(decayed_lr, global_step) + return decayed_lr + + +@tf_export(v1=["train.cosine_decay_restarts"]) +def cosine_decay_restarts(learning_rate, + global_step, + first_decay_steps, + t_mul=2.0, + m_mul=1.0, + alpha=0.0, + name=None): + """Applies cosine decay with restarts to the learning rate. + + When training a model, it is often recommended to lower the learning rate as + the training progresses. This function applies a cosine decay function with + restarts to a provided initial learning rate. It requires a `global_step` + value to compute the decayed learning rate. You can just pass a TensorFlow + variable that you increment at each training step. + + The function returns the decayed learning rate while taking into account + possible warm restarts. The learning rate multiplier first decays + from 1 to `alpha` for `first_decay_steps` steps. Then, a warm + restart is performed. Each new warm restart runs for `t_mul` times more steps + and with `m_mul` times smaller initial learning rate. + + Example usage: + ```python + first_decay_steps = 1000 + lr_decayed = cosine_decay_restarts(learning_rate, global_step, + first_decay_steps) + ``` + + Args: + learning_rate: A scalar `float32` or `float64` Tensor or a Python number. + The initial learning rate. + global_step: A scalar `int32` or `int64` `Tensor` or a Python number. Global + step to use for the decay computation. + first_decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number. + Number of steps to decay over. + t_mul: A scalar `float32` or `float64` `Tensor` or a Python number. Used to + derive the number of iterations in the i-th period + m_mul: A scalar `float32` or `float64` `Tensor` or a Python number. + Used to derive the initial learning rate of the i-th period: + alpha: A scalar `float32` or `float64` Tensor or a Python number. Minimum + learning rate value as a fraction of the learning_rate. + name: String. Optional name of the operation. Defaults to 'SGDRDecay'. + + Returns: + A scalar `Tensor` of the same type as `learning_rate`. The decayed + learning rate. + Raises: + ValueError: if `global_step` is not supplied. + + References: + Stochastic Gradient Descent with Warm Restarts: + [Loshchilov et al., 2017] + (https://openreview.net/forum?id=Skq89Scxx¬eId=Skq89Scxx) + ([pdf](https://openreview.net/pdf?id=Skq89Scxx)) + + @compatibility(eager) + When eager execution is enabled, this function returns a function which in + turn returns the decayed learning rate Tensor. This can be useful for changing + the learning rate value across different invocations of optimizer functions. + @end_compatibility + """ + decayed_lr = learning_rate_schedule.CosineDecayRestarts( + learning_rate, + first_decay_steps, + t_mul=t_mul, + m_mul=m_mul, + alpha=alpha, + name=name) + + if not context.executing_eagerly(): + decayed_lr = decayed_lr(global_step) + else: + decayed_lr = functools.partial(decayed_lr, global_step) + return decayed_lr + + +@tf_export(v1=["train.linear_cosine_decay"]) +def linear_cosine_decay(learning_rate, + global_step, + decay_steps, + num_periods=0.5, + alpha=0.0, + beta=0.001, + name=None): + """Applies linear cosine decay to the learning rate. + + Note that linear cosine decay is more aggressive than cosine decay and + larger initial learning rates can typically be used. + + When training a model, it is often recommended to lower the learning rate as + the training progresses. This function applies a linear cosine decay function + to a provided initial learning rate. It requires a `global_step` value to + compute the decayed learning rate. You can just pass a TensorFlow variable + that you increment at each training step. + + The function returns the decayed learning rate. It is computed as: + ```python + global_step = min(global_step, decay_steps) + linear_decay = (decay_steps - global_step) / decay_steps) + cosine_decay = 0.5 * ( + 1 + cos(pi * 2 * num_periods * global_step / decay_steps)) + decayed = (alpha + linear_decay) * cosine_decay + beta + decayed_learning_rate = learning_rate * decayed + ``` + + Example usage: + ```python + decay_steps = 1000 + lr_decayed = linear_cosine_decay(learning_rate, global_step, decay_steps) + ``` + + Args: + learning_rate: A scalar `float32` or `float64` Tensor or a Python number. + The initial learning rate. + global_step: A scalar `int32` or `int64` `Tensor` or a Python number. Global + step to use for the decay computation. + decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number. Number + of steps to decay over. + num_periods: Number of periods in the cosine part of the decay. See + computation above. + alpha: See computation above. + beta: See computation above. + name: String. Optional name of the operation. Defaults to + 'LinearCosineDecay'. + + Returns: + A scalar `Tensor` of the same type as `learning_rate`. The decayed + learning rate. + Raises: + ValueError: if `global_step` is not supplied. + + References: + Neural Optimizer Search with Reinforcement Learning: + [Bello et al., 2017](http://proceedings.mlr.press/v70/bello17a.html) + ([pdf](http://proceedings.mlr.press/v70/bello17a/bello17a.pdf)) + Stochastic Gradient Descent with Warm Restarts: + [Loshchilov et al., 2017] + (https://openreview.net/forum?id=Skq89Scxx¬eId=Skq89Scxx) + ([pdf](https://openreview.net/pdf?id=Skq89Scxx)) + + @compatibility(eager) + When eager execution is enabled, this function returns a function which in + turn returns the decayed learning rate Tensor. This can be useful for changing + the learning rate value across different invocations of optimizer functions. + @end_compatibility + """ + decayed_lr = learning_rate_schedule.LinearCosineDecay( + learning_rate, + decay_steps, + num_periods=num_periods, + alpha=alpha, + beta=beta, + name=name) + + if not context.executing_eagerly(): + decayed_lr = decayed_lr(global_step) + else: + decayed_lr = functools.partial(decayed_lr, global_step) + return decayed_lr + + +@tf_export(v1=["train.noisy_linear_cosine_decay"]) +def noisy_linear_cosine_decay(learning_rate, + global_step, + decay_steps, + initial_variance=1.0, + variance_decay=0.55, + num_periods=0.5, + alpha=0.0, + beta=0.001, + name=None): + """Applies noisy linear cosine decay to the learning rate. + + Note that linear cosine decay is more aggressive than cosine decay and + larger initial learning rates can typically be used. + + When training a model, it is often recommended to lower the learning rate as + the training progresses. This function applies a noisy linear + cosine decay function to a provided initial learning rate. + It requires a `global_step` value to compute the decayed learning rate. + You can just pass a TensorFlow variable that you increment at each + training step. + + The function returns the decayed learning rate. It is computed as: + ```python + global_step = min(global_step, decay_steps) + linear_decay = (decay_steps - global_step) / decay_steps) + cosine_decay = 0.5 * ( + 1 + cos(pi * 2 * num_periods * global_step / decay_steps)) + decayed = (alpha + linear_decay + eps_t) * cosine_decay + beta + decayed_learning_rate = learning_rate * decayed + ``` + where eps_t is 0-centered gaussian noise with variance + initial_variance / (1 + global_step) ** variance_decay + + Example usage: + ```python + decay_steps = 1000 + lr_decayed = noisy_linear_cosine_decay( + learning_rate, global_step, decay_steps) + ``` + + Args: + learning_rate: A scalar `float32` or `float64` Tensor or a Python number. + The initial learning rate. + global_step: A scalar `int32` or `int64` `Tensor` or a Python number. Global + step to use for the decay computation. + decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number. Number + of steps to decay over. + initial_variance: initial variance for the noise. See computation above. + variance_decay: decay for the noise's variance. See computation above. + num_periods: Number of periods in the cosine part of the decay. See + computation above. + alpha: See computation above. + beta: See computation above. + name: String. Optional name of the operation. Defaults to + 'NoisyLinearCosineDecay'. + + Returns: + A scalar `Tensor` of the same type as `learning_rate`. The decayed + learning rate. + Raises: + ValueError: if `global_step` is not supplied. + + References: + Neural Optimizer Search with Reinforcement Learning: + [Bello et al., 2017](http://proceedings.mlr.press/v70/bello17a.html) + ([pdf](http://proceedings.mlr.press/v70/bello17a/bello17a.pdf)) + Stochastic Gradient Descent with Warm Restarts: + [Loshchilov et al., 2017] + (https://openreview.net/forum?id=Skq89Scxx¬eId=Skq89Scxx) + ([pdf](https://openreview.net/pdf?id=Skq89Scxx)) + + @compatibility(eager) + When eager execution is enabled, this function returns a function which in + turn returns the decayed learning rate Tensor. This can be useful for changing + the learning rate value across different invocations of optimizer functions. + @end_compatibility + """ + decayed_lr = learning_rate_schedule.NoisyLinearCosineDecay( + learning_rate, + decay_steps, + initial_variance=initial_variance, + variance_decay=variance_decay, + num_periods=num_periods, + alpha=alpha, + beta=beta, + name=name) + + if not context.executing_eagerly(): + decayed_lr = decayed_lr(global_step) + else: + decayed_lr = functools.partial(decayed_lr, global_step) + return decayed_lr diff --git a/tensorflow/python/training/learning_rate_decay_test.py b/tensorflow/python/keras/optimizer_v2/legacy_learning_rate_decay_test.py similarity index 99% rename from tensorflow/python/training/learning_rate_decay_test.py rename to tensorflow/python/keras/optimizer_v2/legacy_learning_rate_decay_test.py index 1029d4cea8f..b5a3197ca67 100644 --- a/tensorflow/python/training/learning_rate_decay_test.py +++ b/tensorflow/python/keras/optimizer_v2/legacy_learning_rate_decay_test.py @@ -22,11 +22,11 @@ import math from tensorflow.python.eager import context from tensorflow.python.framework import test_util +from tensorflow.python.keras.optimizer_v2 import legacy_learning_rate_decay as learning_rate_decay # Import resource_variable_ops for the variables-to-tensor implicit conversion. from tensorflow.python.ops import resource_variable_ops # pylint: disable=unused-import from tensorflow.python.ops import variables from tensorflow.python.platform import googletest -from tensorflow.python.training import learning_rate_decay class LRDecayTest(test_util.TensorFlowTestCase): diff --git a/tensorflow/python/keras/optimizer_v2/optimizer_v2.py b/tensorflow/python/keras/optimizer_v2/optimizer_v2.py index 20515beb0eb..37ec1e933ff 100644 --- a/tensorflow/python/keras/optimizer_v2/optimizer_v2.py +++ b/tensorflow/python/keras/optimizer_v2/optimizer_v2.py @@ -220,6 +220,36 @@ class OptimizerV2(trackable.Trackable): opt.minimize(loss, var_list=[var1, var2]) ``` + ### Callable learning rate. + Optimizer accepts a callable learning rate in two ways. The first way is + through built-in or customized + `tf.keras.optimizers.schedules.LearningRateSchedule`. The schedule will be + called on each iteration with `schedule(iteration)`, a `tf.Variable` + owned by the optimizer. + + Example: + + >>> var = tf.Variable(np.random.random(size=(1,))) + >>> learning_rate = tf.keras.optimizers.schedules.ExponentialDecay( + ... initial_learning_rate=.01, decay_steps=20, decay_rate=.1) + >>> opt = tf.keras.optimizers.SGD(learning_rate=learning_rate) + >>> loss = lambda: 3 * var + >>> opt.minimize(loss, var_list=[var]) + >> var = tf.Variable(np.random.random(size=(1,))) + >>> def lr_callable(): + ... return .1 + >>> opt = tf.keras.optimizers.SGD(learning_rate=lr_callable) + >>> loss = lambda: 3 * var + >>> opt.minimize(loss, var_list=[var]) + >> y = [0, 1, 2, 3] - >>> tf.keras.utils.to_categorical(y, num_classes=4) - array([[1., 0., 0., 0.], - [0., 1., 0., 0.], - [0., 0., 1., 0.], - [0., 0., 0., 1.]], dtype=float32) + >>> a = tf.keras.utils.to_categorical([0, 1, 2, 3], num_classes=4) + >>> a = tf.constant(a, shape=[4, 4]) + >>> print(a) + tf.Tensor( + [[1. 0. 0. 0.] + [0. 1. 0. 0.] + [0. 0. 1. 0.] + [0. 0. 0. 1.]], shape=(4, 4), dtype=float32) + + >>> b = tf.constant([.9, .04, .03, .03, + ... .3, .45, .15, .13, + ... .04, .01, .94, .05, + ... .12, .21, .5, .17], + ... shape=[4, 4]) + >>> loss = tf.keras.backend.categorical_crossentropy(a, b) + >>> print(np.around(loss, 5)) + [0.10536 0.82807 0.1011 1.77196] + + >>> loss = tf.keras.backend.categorical_crossentropy(a, a) + >>> print(np.around(loss, 5)) + [0. 0. 0. 0.] Arguments: y: class vector to be converted into a matrix (integers from 0 to num_classes). - num_classes: total number of classes. + num_classes: total number of classes. If `None`, this would be inferred + as the (largest number in `y`) + 1. dtype: The data type expected by the input. Default: `'float32'`. Returns: A binary matrix representation of the input. The classes axis is placed last. + + Usage example: + + >>> y = [0, 1, 2, 3, 3, 1, 0] + >>> tf.keras.utils.to_categorical(y, 4) + array([[1., 0., 0., 0.], + [0., 1., 0., 0.], + [0., 0., 1., 0.], + [0., 0., 0., 1.], + [0., 0., 0., 1.], + [0., 1., 0., 0.], + [1., 0., 0., 0.]], dtype=float32) + + Raises: + Value Error: If input contains string value + """ y = np.array(y, dtype='int') input_shape = y.shape diff --git a/tensorflow/python/keras/utils/vis_utils.py b/tensorflow/python/keras/utils/vis_utils.py index dab5243663b..9819cb831e2 100644 --- a/tensorflow/python/keras/utils/vis_utils.py +++ b/tensorflow/python/keras/utils/vis_utils.py @@ -261,6 +261,22 @@ def plot_model(model, dpi=96): """Converts a Keras model to dot format and save to a file. + Example: + + ```python + input = tf.keras.Input(shape=(100,), dtype='int32', name='input') + x = tf.keras.layers.Embedding( + output_dim=512, input_dim=10000, input_length=100)(input) + x = tf.keras.layers.LSTM(32)(x) + x = tf.keras.layers.Dense(64, activation='relu')(x) + x = tf.keras.layers.Dense(64, activation='relu')(x) + x = tf.keras.layers.Dense(64, activation='relu')(x) + output = tf.keras.layers.Dense(1, activation='sigmoid', name='output')(x) + model = tf.keras.Model(inputs=[input], outputs=[output]) + dot_img_file = '/tmp/model_1.png' + tf.keras.utils.plot_model(model, to_file=dot_img_file, show_shapes=True) + ``` + Arguments: model: A Keras model instance to_file: File name of the plot image. diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py index 2533cf0a645..eec7165d148 100644 --- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py +++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py @@ -3011,6 +3011,25 @@ class ControlFlowTest(test.TestCase, parameterized.TestCase): self.assertAllEqual(f(), 4. * 2.**3) # 4 * x_init ^ 3 + @test_util.run_deprecated_v1 + def testTfFunctionInV1WhileLoop(self): + + # This test specifically tests that creating a Const node inside a + # tf.function inside a v1 while_loop while inlining is turned on works. + config = opt_cfg() + assert config.graph_options.optimizer_options.do_function_inlining + with session.Session(config=config): + + @def_function.function + def loop_body(i): + # Here we create the const. + return i + 1. + + loop_cond = lambda i: True + x = control_flow_ops.while_loop( + loop_cond, loop_body, [0.], maximum_iterations=5) + self.assertAllEqual(x, 5.) + def _testNestedWhileCondWhileGrad(self, use_gpu): with self.cached_session(use_gpu=use_gpu): diff --git a/tensorflow/python/kernel_tests/conv_ops_test.py b/tensorflow/python/kernel_tests/conv_ops_test.py index b4abcfa3a45..5d58a325d3c 100644 --- a/tensorflow/python/kernel_tests/conv_ops_test.py +++ b/tensorflow/python/kernel_tests/conv_ops_test.py @@ -836,8 +836,9 @@ class Conv2DTest(test.TestCase): x2 = self._CreateNumpyTensor(output_sizes) dilations = list(dilations) with test_util.device(use_gpu): - if data_format == "NCHW": - input_sizes = test_util.NHWCToNCHW(input_sizes) + if len(input_sizes) == 4: + if data_format == "NCHW": + input_sizes = test_util.NHWCToNCHW(input_sizes) t0 = constant_op.constant(input_sizes, shape=[len(input_sizes)]) t1 = constant_op.constant(x1, shape=filter_sizes) t2 = constant_op.constant(x2, shape=output_sizes) @@ -1007,6 +1008,22 @@ class Conv2DTest(test.TestCase): use_gpu=use_gpu, err=1e-5) + @test_util.run_in_graph_and_eager_modes + @test_util.disable_xla("XLA requires input_sizes to be a 4D shape.") + def testConv2DInputSizesContainsOnlySpatialDimensionsBackpropInput(self): + expected_output = [5.0, 11.0, 17.0, 23.0] + for (data_format, use_gpu) in GetTestConfigs(): + self._RunAndVerifyBackpropInput( + input_sizes=[2, 2], + filter_sizes=[2, 2, 1, 2], + output_sizes=[1, 1, 1, 2], + strides=[1, 1], + padding="VALID", + expected=expected_output, + data_format=data_format, + use_gpu=use_gpu, + err=1e-5) + # Testing for backprops def _RunAndVerifyBackpropFilter(self, input_sizes, diff --git a/tensorflow/python/kernel_tests/in_topk_op_test.py b/tensorflow/python/kernel_tests/in_topk_op_test.py index 2ef233f8d68..c636cee0dd5 100644 --- a/tensorflow/python/kernel_tests/in_topk_op_test.py +++ b/tensorflow/python/kernel_tests/in_topk_op_test.py @@ -28,7 +28,7 @@ from tensorflow.python.platform import test class InTopKTest(test.TestCase): def _validateInTopK(self, predictions, target, k, expected): - np_ans = np.array(expected) + np_ans = np.array(expected, np.bool) with self.cached_session(use_gpu=True): precision = nn_ops.in_top_k(predictions, target, k) out = self.evaluate(precision) @@ -66,6 +66,11 @@ class InTopKTest(test.TestCase): target = [2, 4] # must return False for invalid target self._validateInTopK(predictions, target, 2, [True, False]) + def testEmpty(self): + predictions = np.empty([0, 5]) + target = np.empty([0], np.int32) + self._validateInTopK(predictions, target, 2, []) + def testTensorK(self): predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]] target = [0, 2] diff --git a/tensorflow/python/kernel_tests/init_ops_test.py b/tensorflow/python/kernel_tests/init_ops_test.py index ff8793c46ec..cba154cec4a 100644 --- a/tensorflow/python/kernel_tests/init_ops_test.py +++ b/tensorflow/python/kernel_tests/init_ops_test.py @@ -615,6 +615,150 @@ class LinSpaceTest(test.TestCase): np.array([0., .1], np.float64)) +class LinSpaceNdTest(test.TestCase): + + def _gpu_modes(self): + if test.is_gpu_available(): + return [False, True] + else: + return [False] + + def _LinSpace(self, start, stop, num, axis=0): + with ops.Graph().as_default() as graph: + with self.session(graph=graph, force_gpu=self.force_gpu): + tf_ans = math_ops.linspace_nd(start, stop, num, axis=axis) + return self.evaluate(tf_ans) + + def _LinSpaceNumConstant(self, start, stop, num, axis=0): + with ops.Graph().as_default() as graph: + num_constant = constant_op.constant(num) + with self.session(graph=graph, force_gpu=self.force_gpu): + tf_ans = math_ops.linspace_nd(start, stop, num_constant, axis=axis) + return self.evaluate(tf_ans) + + def _LinspaceNoneShape(self, start, stop, num, graph_shape=None, axis=0): + with ops.Graph().as_default() as graph: + num_tensor = array_ops.placeholder(dtypes.int32) + start_t = array_ops.placeholder(dtypes.float32, shape=graph_shape) + stop_t = array_ops.placeholder(dtypes.float32, shape=graph_shape) + ans_tensor = math_ops.linspace_nd(start_t, stop_t, num_tensor, axis=axis) + + with self.session(graph=graph, force_gpu=self.force_gpu) as sess: + feed_dict = {start_t: start, stop_t: stop, num_tensor: num} + return sess.run(ans_tensor, feed_dict=feed_dict) + + def testPositive(self): + for self.force_gpu in self._gpu_modes(): + self.assertArrayNear(self._LinSpace(1., 5., 1), np.array([1.]), 1e-5) + self.assertArrayNear(self._LinSpace(1., 5., 2), np.array([1., 5.]), 1e-5) + self.assertArrayNear( + self._LinSpace(1., 5., 3), np.array([1., 3., 5.]), 1e-5) + self.assertArrayNear( + self._LinSpace(1., 5., 4), np.array([1., 7. / 3., 11. / 3., 5.]), + 1e-5) + + def testNegative(self): + for self.force_gpu in self._gpu_modes(): + self.assertArrayNear(self._LinSpace(-1., -5., 1), np.array([-1.]), 1e-5) + self.assertArrayNear( + self._LinSpace(-1., -5., 2), np.array([-1., -5.]), 1e-5) + self.assertArrayNear( + self._LinSpace(-1., -5., 3), np.array([-1., -3., -5.]), 1e-5) + self.assertArrayNear( + self._LinSpace(-1., -5., 4), np.array([-1., -7. / 3., -11. / 3., + -5.]), 1e-5) + + def testNegativeToPositive(self): + for self.force_gpu in self._gpu_modes(): + self.assertArrayNear(self._LinSpace(-1., 5., 1), np.array([-1.]), 1e-5) + self.assertArrayNear( + self._LinSpace(-1., 5., 2), np.array([-1., 5.]), 1e-5) + self.assertArrayNear( + self._LinSpace(-1., 5., 3), np.array([-1., 2., 5.]), 1e-5) + self.assertArrayNear( + self._LinSpace(-1., 5., 4), np.array([-1., 1., 3., 5.]), 1e-5) + + def testPoint(self): + for self.force_gpu in self._gpu_modes(): + self.assertArrayNear(self._LinSpace(5., 5., 1), np.array([5.]), 1e-5) + self.assertArrayNear(self._LinSpace(5., 5., 2), np.array([5.] * 2), 1e-5) + self.assertArrayNear(self._LinSpace(5., 5., 3), np.array([5.] * 3), 1e-5) + self.assertArrayNear(self._LinSpace(5., 5., 4), np.array([5.] * 4), 1e-5) + + def testEndpointsAreExact(self): + for self.force_gpu in self._gpu_modes(): + # Test some cases that produce last values not equal to "stop" when + # computed via start + (num - 1) * ((stop - start) / (num - 1)), since + # float arithmetic will introduce error through precision loss. + self.assertAllEqual( + self._LinSpace(0., 1., 42)[[0, -1]], np.array([0., 1.], np.float32)) + self.assertAllEqual( + self._LinSpace(-1., 0., 42)[[0, -1]], np.array([-1., 0.], np.float32)) + self.assertAllEqual( + self._LinSpace(.1, .2, 4)[[0, -1]], np.array([.1, .2], np.float32)) + # Check a case for float64 error too. + self.assertAllEqual( + self._LinSpace(np.array(0., np.float64), .1, 12)[[0, -1]], + np.array([0., .1], np.float64)) + + def testScalarsCompareToNumpy(self): + for self.force_gpu in self._gpu_modes(): + actual = self._LinSpace(0., 1., 32) + expected = np.linspace(0., 1., 32) + self.assertArrayNear(expected, actual, 1e-5) + + def _baseNDArrayCompareToNumpy(self, axis): + for self.force_gpu in self._gpu_modes(): + a, b, expected, num = self.create_nd_inputs_and_expected_output(axis) + actual = self._LinSpace(a, b, num, axis=axis) + self.assert_close(actual, expected) + + def assert_close(self, actual, expected): + wrong_indices = np.where(~np.allclose(actual, expected)) + mess = "Wrong float answer. Wrong indices: {}".format(wrong_indices) + self.assertTrue(np.allclose(actual, expected), mess) + + def create_nd_inputs_and_expected_output(self, axis): + a = np.arange(2, dtype=np.float32) + b = a * 5 + num = 5 + + res = np.array([[0., 0., 0., 0., 0.], [1., 2., 3., 4., 5.]]) + expected = res if axis != 0 else res.T + return a, b, expected, num + + def testNDArrayCompareToNumpyDefaultAxis(self): + self._baseNDArrayCompareToNumpy(0) + + def testNDArrayAxisStrictlyPositive(self): + self._baseNDArrayCompareToNumpy(1) + + def testNDArrayAxisStrictlyNegative(self): + self._baseNDArrayCompareToNumpy(-1) + + def testNumConstant(self): + for self.force_gpu in self._gpu_modes(): + actual = self._LinSpaceNumConstant(0., 1., 32) + expected = np.linspace(0., 1., 32) + self.assertArrayNear(expected, actual, 1e-5) + + def testUnknownShapeAtGraphCreationTime(self): + self.base_test_unknown_shape((2)) + + def testNoneValuesInShapeAtGraphCreationTime(self): + self.base_test_unknown_shape((None)) + + def testNoneShapeAtGraphCreationTime(self): + self.base_test_unknown_shape(None) + + def base_test_unknown_shape(self, graph_shape): + for self.force_gpu in self._gpu_modes(): + axis = 1 + a, b, expected, num = self.create_nd_inputs_and_expected_output(axis) + actual = self._LinspaceNoneShape(a, b, num, graph_shape, axis) + self.assert_close(actual, expected) + + class DeviceTest(test.TestCase): def testNoDevice(self): diff --git a/tensorflow/python/kernel_tests/linalg/sparse/csr_sparse_matrix_dense_mat_mul_grad_test.py b/tensorflow/python/kernel_tests/linalg/sparse/csr_sparse_matrix_dense_mat_mul_grad_test.py index 5cd206ccbc1..4841c18a78c 100644 --- a/tensorflow/python/kernel_tests/linalg/sparse/csr_sparse_matrix_dense_mat_mul_grad_test.py +++ b/tensorflow/python/kernel_tests/linalg/sparse/csr_sparse_matrix_dense_mat_mul_grad_test.py @@ -106,10 +106,7 @@ class CSRSparseMatrixDenseMatMulGradTest(test.TestCase): # These tests are refactored from sparse_csr_matrix_grad_test to keep its size # "medium". -dtypes_to_test = [np.float32] -if not test.is_built_with_rocm: - # complex type is not supported on the ROCm platform - dtypes_to_test += [np.complex64] +dtypes_to_test = [np.float32, np.complex64] for dtype in dtypes_to_test: for (t_a, t_b, adj_a, adj_b, t_out, conj_out) in itertools.product(*(([False, True],) * 6)): diff --git a/tensorflow/python/kernel_tests/linalg/sparse/csr_sparse_matrix_ops_test.py b/tensorflow/python/kernel_tests/linalg/sparse/csr_sparse_matrix_ops_test.py index 51757802968..ac82f190db0 100644 --- a/tensorflow/python/kernel_tests/linalg/sparse/csr_sparse_matrix_ops_test.py +++ b/tensorflow/python/kernel_tests/linalg/sparse/csr_sparse_matrix_ops_test.py @@ -517,9 +517,6 @@ class CSRSparseMatrixOpsTest(test.TestCase): @test_util.run_in_graph_and_eager_modes def testSparseMatrixMatMulConjugateOutput(self): - if test.is_built_with_rocm(): - self.skipTest("complex type not supported on ROCm") - for shapes in [[(5, 6), (6, 1)], [(5, 6), (6, 2)]]: a_indices = np.array([[0, 0], [2, 3]]) a_values = np.array([1.0 + 1.j, 5.0 - 2.j]).astype(np.complex64) @@ -542,17 +539,7 @@ class CSRSparseMatrixOpsTest(test.TestCase): @test_util.run_in_graph_and_eager_modes def testLargeBatchSparseMatrixMatMul(self): - dtypes_to_test = [np.float32] - if not test.is_built_with_rocm(): - # complex types is not supported on the ROCm platform - dtypes_to_test += [np.complex64] - - if test.is_built_with_rocm(): - # TODO(rocm): fix this - # This test is currently failing on the ROCm platform - # Ren-enable it once the fix is available - self.skipTest("hipSPARSE all failure on the ROCm platform") - + dtypes_to_test = [np.float32, np.complex64] sparsify = lambda m: m * (m > 0) for dtype in dtypes_to_test: for (transpose_a, transpose_b) in ((False, False), (False, True), diff --git a/tensorflow/python/kernel_tests/linalg/sparse/csr_sparse_matrix_test.py b/tensorflow/python/kernel_tests/linalg/sparse/csr_sparse_matrix_test.py index 66077f5b2d2..35c706cb36a 100644 --- a/tensorflow/python/kernel_tests/linalg/sparse/csr_sparse_matrix_test.py +++ b/tensorflow/python/kernel_tests/linalg/sparse/csr_sparse_matrix_test.py @@ -154,10 +154,7 @@ class SparseMatrixMatmulTest(test.TestCase): sparsify = lambda m: m * (m > 0) dense_shape_a = [5, 13, 7] if transpose_a or adjoint_a else [5, 7, 13] dense_shape_b = [5, 15, 13] if transpose_b or adjoint_b else [5, 13, 15] - dtypes_to_test = [np.float32] - if not test.is_built_with_rocm(): - # complex type is not supported on the ROCm platform - dtypes_to_test += [np.complex64] + dtypes_to_test = [np.float32, np.complex64] for dtype in dtypes_to_test: a_mats = sparsify((np.random.randn(*dense_shape_a) + 1.j * np.random.randn(*dense_shape_a))).astype(dtype) @@ -198,10 +195,7 @@ class SparseMatrixMatmulTest(test.TestCase): sparsify = lambda m: m * (m > 0) dense_shape_a = [5, 13, 7] if transpose_a or adjoint_a else [5, 7, 13] dense_shape_b = [5, 15, 13] if transpose_b or adjoint_b else [5, 13, 15] - dtypes_to_test = [np.float32] - if not test.is_built_with_rocm(): - # complex type is not supported on the ROCm platform - dtypes_to_test += [np.complex64] + dtypes_to_test = [np.float32, np.complex64] for dtype in dtypes_to_test: a_mats = sparsify((np.random.randn(*dense_shape_a) + 1.j * np.random.randn(*dense_shape_a))).astype(dtype) @@ -239,10 +233,7 @@ class SparseMatrixMatmulTest(test.TestCase): sparsify = lambda m: m * (m > 0) dense_shape_a = [5, 13, 7] if transpose_a or adjoint_a else [5, 7, 13] dense_shape_b = [5, 15, 13] if transpose_b or adjoint_b else [5, 13, 15] - dtypes_to_test = [np.float32] - if not test.is_built_with_rocm(): - # complex type is not supported on the ROCm platform - dtypes_to_test += [np.complex64] + dtypes_to_test = [np.float32, np.complex64] for dtype in dtypes_to_test: a_mats = (np.random.randn(*dense_shape_a) + 1.j * np.random.randn(*dense_shape_a)).astype(dtype) diff --git a/tensorflow/python/kernel_tests/list_ops_test.py b/tensorflow/python/kernel_tests/list_ops_test.py index e618e21ed9d..53ebdd3ab88 100644 --- a/tensorflow/python/kernel_tests/list_ops_test.py +++ b/tensorflow/python/kernel_tests/list_ops_test.py @@ -1665,6 +1665,26 @@ class ListOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase): self.assertAllEqual(f(), [b"A", b"B", b"C"]) + def testPopBackGrad(self): + # https://github.com/tensorflow/tensorflow/issues/37230 + + @def_function.function + def g(x): + x_prod = constant_op.constant([1.]) + for unused_i in math_ops.range(3): + x_prod = x_prod * x + return x_prod + + x = constant_op.constant(1.) + with backprop.GradientTape() as t: + t.watch(x) + with backprop.GradientTape() as tt: + tt.watch(x) + loss = g(x) + jac = tt.gradient(loss, x) + hess = t.gradient(jac, x) + self.assertAllEqual(hess, 6.) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/lib/core/py_func.cc b/tensorflow/python/lib/core/py_func.cc index 5faa07baf94..6486f42156f 100644 --- a/tensorflow/python/lib/core/py_func.cc +++ b/tensorflow/python/lib/core/py_func.cc @@ -34,7 +34,6 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/threadpool.h" -#include "tensorflow/core/platform/casts.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/types.h" @@ -92,10 +91,9 @@ Status MakeArgTuple(const PyCall* call, EagerContext* ctx, PyObject** tuple) { for (int64 i = 0; i < n; ++i) { PyObject* arg = nullptr; if (call->eager) { - TensorHandle* handle; Tensor t = call->ins[i]; - TF_RETURN_IF_ERROR(TensorHandle::CreateLocalHandle( - std::move(t), ctx->CanonicalDevice(device), nullptr, ctx, &handle)); + TensorHandle* handle = TensorHandle::CreateLocalHandle( + std::move(t), ctx->CanonicalDevice(device), nullptr, ctx); arg = EagerTensorFromHandle(new TFE_TensorHandle{ std::make_unique(handle)}); if (arg == nullptr) { @@ -146,9 +144,8 @@ bool IsSingleNone(PyObject* obj) { tensorflow::Status ExtractTensorFromEagerTensor(const PyObject* eager_tensor, const Device* expected_device, const Tensor** output_tensor) { - auto handle = down_cast( - EagerTensor_Handle(eager_tensor)->handle.get()) - ->Handle(); + tensorflow::TensorHandle* handle = tensorflow::TensorHandleFromInterface( + EagerTensor_Handle(eager_tensor)->handle); if (VariantDeviceIsCustom(handle->device())) { return errors::Unimplemented( "Custom devices are currently not supported with PyFuncs."); @@ -191,18 +188,18 @@ Status DoCallPyFunc(PyCall* call, bool* out_log_on_error) { // Prepare the argument. PyObject* args = nullptr; - TFE_Context* ctx = nullptr; std::unique_ptr new_executor = nullptr; EagerExecutor* old_executor = nullptr; if (call->eager) { // See FuncRegistry._ctx. - ctx = reinterpret_cast(PyCapsule_GetPointer( + TFE_Context* ctx = reinterpret_cast(PyCapsule_GetPointer( PyObject_GetAttrString(trampoline, "_ctx"), nullptr)); CHECK_NE(ctx, nullptr); - TF_RETURN_IF_ERROR(MakeArgTuple(call, ctx->context, &args)); + EagerContext* context = ContextFromInterface(ctx->context); + TF_RETURN_IF_ERROR(MakeArgTuple(call, context, &args)); new_executor.reset(new EagerExecutor(call->eager_async)); - old_executor = &ctx->context->Executor(); - ctx->context->SetExecutorForThread(new_executor.get()); + old_executor = &context->Executor(); + context->SetExecutorForThread(new_executor.get()); } else { TF_RETURN_IF_ERROR(MakeArgTuple(call, nullptr, &args)); } @@ -236,8 +233,11 @@ Status DoCallPyFunc(PyCall* call, bool* out_log_on_error) { } if (new_executor != nullptr) { + TFE_Context* ctx = reinterpret_cast(PyCapsule_GetPointer( + PyObject_GetAttrString(trampoline, "_ctx"), nullptr)); + EagerContext* context = ContextFromInterface(ctx->context); s.Update(new_executor->WaitForAllPendingNodes()); - ctx->context->SetExecutorForThread(old_executor); + context->SetExecutorForThread(old_executor); } TF_RETURN_IF_ERROR(s); diff --git a/tensorflow/python/lib/core/py_seq_tensor.cc b/tensorflow/python/lib/core/py_seq_tensor.cc index e81102847ea..1cd6c68b671 100644 --- a/tensorflow/python/lib/core/py_seq_tensor.cc +++ b/tensorflow/python/lib/core/py_seq_tensor.cc @@ -17,12 +17,12 @@ limitations under the License. #include "tensorflow/c/eager/c_api_internal.h" #include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/tstring.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/python/lib/core/ndarray_tensor.h" #include "tensorflow/python/lib/core/ndarray_tensor_bridge.h" @@ -278,29 +278,23 @@ struct Converter { static Status Convert(TFE_Context* ctx, PyObject* obj, ConverterState* state, TFE_TensorHandle** h, const char** error) { // TODO(josh11b): Allocator & attributes - // TODO(gjn): Use optimized scalar constructors when possible. - Tensor result(ConverterTraits::kTypeEnum, - TensorShape(state->inferred_shape)); + std::unique_ptr t; if (state->inferred_shape.empty()) { /* Scalar case */ T value; auto scalar = ZeroDimArrayToScalar(obj, state); *error = ConverterTraits::ConvertScalar(scalar, &value); Py_DECREF(scalar); if (*error != nullptr) return errors::InvalidArgument(*error); - result.scalar()() = value; + t = ConverterTraits::CreateScalar(ctx, value); } else { - T* buf = result.flat().data(); - *error = Helper(obj, 0, state, &buf); - if (*error != nullptr) return errors::InvalidArgument(*error); + t = ConverterTraits::CreateTensor(ctx, state->inferred_shape); + if (t->NumElements() > 0) { + T* buf = static_cast(t->Data()); + *error = Helper(obj, 0, state, &buf); + if (*error != nullptr) return errors::InvalidArgument(*error); + } } - tensorflow::TensorHandle* handle = nullptr; - auto status = tensorflow::TensorHandle::CreateLocalHandle( - std::move(result), /*d=*/ctx->context->HostCPU(), /*op_device=*/nullptr, - ctx->context, &handle); - if (!status.ok()) { - return status; - } - *h = new TFE_TensorHandle{std::make_unique(handle)}; + *h = new TFE_TensorHandle{ctx->context->CreateLocalHandle(std::move(t))}; return Status::OK(); } }; @@ -309,7 +303,15 @@ struct Converter { template <> struct ConverterTraits { - static const tensorflow::DataType kTypeEnum = DT_INT64; + static std::unique_ptr CreateScalar(TFE_Context* ctx, + int64 value) { + return ctx->context->CreateInt64Scalar(value); + } + + static std::unique_ptr CreateTensor( + TFE_Context* ctx, absl::Span dim_sizes) { + return ctx->context->CreateInt64Tensor(dim_sizes); + } static const char* ConvertScalar(PyObject* v, int64* out) { #if PY_MAJOR_VERSION < 3 @@ -342,7 +344,15 @@ typedef Converter Int64Converter; template <> struct ConverterTraits { - static const tensorflow::DataType kTypeEnum = DT_UINT64; + static std::unique_ptr CreateScalar(TFE_Context* ctx, + uint64 value) { + return ctx->context->CreateUint64Scalar(value); + } + + static std::unique_ptr CreateTensor( + TFE_Context* ctx, absl::Span dim_sizes) { + return ctx->context->CreateUint64Tensor(dim_sizes); + } static const char* ConvertScalar(PyObject* v, uint64* out) { #if PY_MAJOR_VERSION < 3 @@ -372,7 +382,15 @@ typedef Converter UInt64Converter; template <> struct ConverterTraits { - static const tensorflow::DataType kTypeEnum = DT_INT32; + static std::unique_ptr CreateScalar(TFE_Context* ctx, + int32 value) { + return ctx->context->CreateInt32Scalar(value); + } + + static std::unique_ptr CreateTensor( + TFE_Context* ctx, absl::Span dim_sizes) { + return ctx->context->CreateInt32Tensor(dim_sizes); + } static const char* ConvertScalar(PyObject* v, int32* out) { int64 i; @@ -472,7 +490,16 @@ static const char* ConvertOneFloat(PyObject* v, T* out) { template <> struct ConverterTraits { - static const tensorflow::DataType kTypeEnum = DT_FLOAT; + static std::unique_ptr CreateScalar(TFE_Context* ctx, + float value) { + return ctx->context->CreateFloatScalar(value); + } + + static std::unique_ptr CreateTensor( + TFE_Context* ctx, absl::Span dim_sizes) { + return ctx->context->CreateFloatTensor(dim_sizes); + } + static const char* ConvertScalar(PyObject* v, float* out) { return ConvertOneFloat(v, out); } @@ -480,7 +507,16 @@ struct ConverterTraits { template <> struct ConverterTraits { - static const tensorflow::DataType kTypeEnum = DT_DOUBLE; + static std::unique_ptr CreateScalar(TFE_Context* ctx, + double value) { + return ctx->context->CreateDoubleScalar(value); + } + + static std::unique_ptr CreateTensor( + TFE_Context* ctx, absl::Span dim_sizes) { + return ctx->context->CreateDoubleTensor(dim_sizes); + } + static const char* ConvertScalar(PyObject* v, double* out) { return ConvertOneFloat(v, out); } @@ -491,7 +527,15 @@ typedef Converter FloatConverter; template <> struct ConverterTraits { - static const tensorflow::DataType kTypeEnum = DT_HALF; + static std::unique_ptr CreateScalar( + TFE_Context* ctx, Eigen::half value) { + return ctx->context->CreateHalfScalar(value); + } + + static std::unique_ptr CreateTensor( + TFE_Context* ctx, absl::Span dim_sizes) { + return ctx->context->CreateHalfTensor(dim_sizes); + } static const char* ConvertScalar(PyObject* v, Eigen::half* out) { return ConvertOneFloat(v, out); @@ -504,7 +548,15 @@ typedef Converter NumpyHalfConverter; template <> struct ConverterTraits { - static const tensorflow::DataType kTypeEnum = DT_STRING; + static std::unique_ptr CreateScalar(TFE_Context* ctx, + tstring value) { + return ctx->context->CreateStringScalar(value); + } + + static std::unique_ptr CreateTensor( + TFE_Context* ctx, absl::Span dim_sizes) { + return ctx->context->CreateStringTensor(dim_sizes); + } static const char* ConvertScalar(PyObject* v, tstring* out) { if (PyBytes_Check(v)) { @@ -563,7 +615,16 @@ bool IsPyDimension(PyObject* obj) { template <> struct ConverterTraits { - static const tensorflow::DataType kTypeEnum = DT_COMPLEX128; + static std::unique_ptr CreateScalar( + TFE_Context* ctx, complex128 value) { + return ctx->context->CreateComplex128Scalar(value); + } + + static std::unique_ptr CreateTensor( + TFE_Context* ctx, absl::Span dim_sizes) { + return ctx->context->CreateComplex128Tensor(dim_sizes); + } + static const char* ConvertScalar(PyObject* v, complex128* out) { if (PyComplex_Check(v)) { *out = complex128(PyComplex_RealAsDouble(v), PyComplex_ImagAsDouble(v)); @@ -583,8 +644,15 @@ typedef Converter Complex128Converter; template <> struct ConverterTraits { - typedef bool Type; - static const tensorflow::DataType kTypeEnum = DT_BOOL; + static std::unique_ptr CreateScalar(TFE_Context* ctx, + bool value) { + return ctx->context->CreateBoolScalar(value); + } + + static std::unique_ptr CreateTensor( + TFE_Context* ctx, absl::Span dim_sizes) { + return ctx->context->CreateBoolTensor(dim_sizes); + } static const char* ConvertScalar(PyObject* v, bool* out) { if (v == Py_True) { @@ -606,24 +674,20 @@ typedef Converter BoolConverter; // The two may share underlying storage so changes to one may reflect in the // other. TFE_TensorHandle* NumpyToTFE_TensorHandle(TFE_Context* ctx, PyObject* obj) { - tensorflow::TensorHandle* handle; + std::unique_ptr handle; tensorflow::Tensor t; - auto cppstatus = tensorflow::NdarrayToTensor(obj, &t); - if (cppstatus.ok()) { - cppstatus = tensorflow::TensorHandle::CreateLocalHandle( - std::move(t), /*d=*/ctx->context->HostCPU(), /*op_device=*/nullptr, - ctx->context, &handle); - } - if (!cppstatus.ok()) { + tensorflow::Status status = tensorflow::NdarrayToTensor(obj, &t); + if (!status.ok()) { PyErr_SetString(PyExc_ValueError, tensorflow::strings::StrCat( "Failed to convert a NumPy array to a Tensor (", - cppstatus.error_message(), ").") + status.error_message(), ").") .c_str()); return nullptr; } - return new TFE_TensorHandle{ - std::make_unique(handle)}; + + return new TFE_TensorHandle{ctx->context->CreateLocalHandle( + std::make_unique(std::move(t)))}; } } // namespace @@ -805,17 +869,10 @@ TFE_TensorHandle* PySeqToTFE_TensorHandle(TFE_Context* ctx, PyObject* obj, case DT_INVALID: // Only occurs for empty tensors. { - tensorflow::TensorHandle* h = nullptr; Tensor t(requested_dtype == DT_INVALID ? DT_FLOAT : requested_dtype, TensorShape(state.inferred_shape)); - status = tensorflow::TensorHandle::CreateLocalHandle( - std::move(t), /*d=*/ctx->context->HostCPU(), /*op_device=*/nullptr, - ctx->context, &h); - if (!status.ok()) { - PyErr_SetString(PyExc_ValueError, status.error_message().c_str()); - return nullptr; - } - return new TFE_TensorHandle{std::make_unique(h)}; + return new TFE_TensorHandle{ctx->context->CreateLocalHandle( + std::make_unique(std::move(t)))}; } default: diff --git a/tensorflow/python/lib/io/file_io_test.py b/tensorflow/python/lib/io/file_io_test.py index 2e42eb8fbe8..af4b2e9dd60 100644 --- a/tensorflow/python/lib/io/file_io_test.py +++ b/tensorflow/python/lib/io/file_io_test.py @@ -159,6 +159,18 @@ class FileIoTest(test.TestCase): file_io.delete_recursively(dir_path) self.assertFalse(file_io.file_exists(os.path.join(dir_path, "file3.txt"))) + def testGetMatchingFilesWhenParentDirContainsParantheses(self): + dir_path = os.path.join(self._base_dir, "dir_(special)") + file_io.create_dir(dir_path) + files = ["file1.txt", "file(2).txt"] + for name in files: + file_path = os.path.join(dir_path, name) + file_io.FileIO(file_path, mode="w").write("testing") + expected_match = [os.path.join(dir_path, name) for name in files] + glob_pattern = os.path.join(dir_path, "*") + self.assertItemsEqual( + file_io.get_matching_files(glob_pattern), expected_match) + def testCreateRecursiveDir(self): dir_path = os.path.join(self._base_dir, "temp_dir/temp_dir1/temp_dir2") file_io.recursive_create_dir(dir_path) diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py index 989b3a4c338..bbb4f917b12 100644 --- a/tensorflow/python/ops/control_flow_ops.py +++ b/tensorflow/python/ops/control_flow_ops.py @@ -117,16 +117,6 @@ def Assert(condition, data, summarize=None, name=None): If `condition` evaluates to false, print the list of tensors in `data`. `summarize` determines how many entries of the tensors to print. - NOTE: In graph mode, to ensure that Assert executes, one usually attaches - a dependency: - - ```python - # Ensure maximum element of x is smaller or equal to 1 - assert_op = tf.Assert(tf.less_equal(tf.reduce_max(x), 1.), [x]) - with tf.control_dependencies([assert_op]): - ... code using x ... - ``` - Args: condition: The condition to evaluate. data: The tensors to print out when condition is false. @@ -141,8 +131,17 @@ def Assert(condition, data, summarize=None, name=None): @end_compatibility Raises: - @compatibility(eager) - `tf.errors.InvalidArgumentError` if `condition` is not true + @compatibility(TF1) + When in TF V1 mode (that is, outside `tf.function`) Assert needs a control + dependency on the output to ensure the assertion executes: + + ```python + # Ensure maximum element of x is smaller or equal to 1 + assert_op = tf.Assert(tf.less_equal(tf.reduce_max(x), 1.), [x]) + with tf.control_dependencies([assert_op]): + ... code using x ... + ``` + @end_compatibility """ if context.executing_eagerly(): @@ -1724,6 +1723,10 @@ class WhileContext(ControlFlowContext): We move any external control dependencies of the op to the loop pivot, to ensure they get executed. """ + # This is needed to prevent frame mismatch errors where there are Const + # nodes inside tf.function in v1 while_loop and inlining is turned on. + if op.type in ["PartitionedCall", "StatefulPartitionedCall"]: + op._add_control_input(self.GetControlPivot().op) # pylint: disable=protected-access if not op.inputs: # Remove any external control dependency on this op control_inputs, external_inputs = self._RemoveExternalControlEdges(op) diff --git a/tensorflow/python/ops/control_flow_util_v2.py b/tensorflow/python/ops/control_flow_util_v2.py index db03bd3e573..c65f67def30 100644 --- a/tensorflow/python/ops/control_flow_util_v2.py +++ b/tensorflow/python/ops/control_flow_util_v2.py @@ -26,13 +26,14 @@ from tensorflow.python.eager import function from tensorflow.python.framework import function_def_to_graph from tensorflow.python.framework import ops from tensorflow.python.framework.func_graph import FuncGraph -from tensorflow.python.keras.engine import base_layer_utils from tensorflow.python.ops import control_flow_util from tensorflow.python.ops import control_flow_v2_func_graphs from tensorflow.python.util import tf_contextlib _EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE = None +_KERAS_LAYER_CONTEXT_FUNCTION = None + CondBranchFuncGraph = control_flow_v2_func_graphs.CondBranchFuncGraph WhileCondFuncGraph = control_flow_v2_func_graphs.WhileCondFuncGraph @@ -224,8 +225,19 @@ def _is_tpu_strategy(strategy): strategy.__class__.__name__.startswith("TPUStrategy")) +def _register_keras_layer_context_function(func): + global _KERAS_LAYER_CONTEXT_FUNCTION + if _KERAS_LAYER_CONTEXT_FUNCTION is None: + _KERAS_LAYER_CONTEXT_FUNCTION = func + + def _is_building_keras_layer(): - return base_layer_utils.call_context().layer is not None + # TODO(srbs): Remove this function when we no long support session with Keras. + global _KERAS_LAYER_CONTEXT_FUNCTION + if _KERAS_LAYER_CONTEXT_FUNCTION is not None: + return _KERAS_LAYER_CONTEXT_FUNCTION().layer is not None + else: + return False def output_all_intermediates(): diff --git a/tensorflow/python/ops/image_ops_impl.py b/tensorflow/python/ops/image_ops_impl.py index f3e5c7cc1bc..85670182a87 100644 --- a/tensorflow/python/ops/image_ops_impl.py +++ b/tensorflow/python/ops/image_ops_impl.py @@ -20,6 +20,7 @@ from __future__ import print_function import numpy as np +from tensorflow.python.compat import compat from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -1269,26 +1270,27 @@ def resize_images(images, The `method` can be one of: - * `ResizeMethod.BILINEAR`: [Bilinear interpolation.]( + * `tf.image.ResizeMethod.BILINEAR`: [Bilinear interpolation.]( https://en.wikipedia.org/wiki/Bilinear_interpolation) - * `ResizeMethod.NEAREST_NEIGHBOR`: [Nearest neighbor interpolation.]( + * `tf.image.ResizeMethod.NEAREST_NEIGHBOR`: [ + Nearest neighbor interpolation.]( https://en.wikipedia.org/wiki/Nearest-neighbor_interpolation) - * `ResizeMethod.BICUBIC`: [Bicubic interpolation.]( + * `tf.image.ResizeMethod.BICUBIC`: [Bicubic interpolation.]( https://en.wikipedia.org/wiki/Bicubic_interpolation) - * `ResizeMethod.AREA`: Area interpolation. + * `tf.image.ResizeMethod.AREA`: Area interpolation. The return value has the same type as `images` if `method` is - `ResizeMethod.NEAREST_NEIGHBOR`. It will also have the same type as `images` - if the size of `images` can be statically determined to be the same as `size`, - because `images` is returned in this case. Otherwise, the return value has - type `float32`. + `tf.image.ResizeMethod.NEAREST_NEIGHBOR`. It will also have the same type + as `images` if the size of `images` can be statically determined to be the + same as `size`, because `images` is returned in this case. Otherwise, the + return value has type `float32`. Args: images: 4-D Tensor of shape `[batch, height, width, channels]` or 3-D Tensor of shape `[height, width, channels]`. size: A 1-D int32 Tensor of 2 elements: `new_height, new_width`. The new size for the images. - method: ResizeMethod. Defaults to `ResizeMethod.BILINEAR`. + method: ResizeMethod. Defaults to `tf.image.ResizeMethod.BILINEAR`. align_corners: bool. If True, the centers of the 4 corner pixels of the input and output tensors are aligned, preserving the values at the corner pixels. Defaults to `False`. @@ -4448,6 +4450,83 @@ def non_max_suppression_padded(boxes, sorted_input=False, canonicalized_coordinates=False, tile_size=512): + """Greedily selects a subset of bounding boxes in descending order of score. + + Performs algorithmically equivalent operation to tf.image.non_max_suppression, + with the addition of an optional parameter which zero-pads the output to + be of size `max_output_size`. + The output of this operation is a tuple containing the set of integers + indexing into the input collection of bounding boxes representing the selected + boxes and the number of valid indices in the index set. The bounding box + coordinates corresponding to the selected indices can then be obtained using + the `tf.slice` and `tf.gather` operations. For example: + ```python + selected_indices_padded, num_valid = tf.image.non_max_suppression_padded( + boxes, scores, max_output_size, iou_threshold, + score_threshold, pad_to_max_output_size=True) + selected_indices = tf.slice( + selected_indices_padded, tf.constant([0]), num_valid) + selected_boxes = tf.gather(boxes, selected_indices) + + Args: + boxes: a tensor of rank 2 or higher with a shape of [..., num_boxes, 4]. + Dimensions except the last two are batch dimensions. + scores: a tensor of rank 1 or higher with a shape of [..., num_boxes]. + max_output_size: a scalar integer `Tensor` representing the maximum number + of boxes to be selected by non max suppression. + iou_threshold: a float representing the threshold for deciding whether boxes + overlap too much with respect to IoU (intersection over union). + score_threshold: a float representing the threshold for box scores. Boxes + with a score that is not larger than this threshold will be suppressed. + pad_to_max_output_size: whether to pad the output idx to max_output_size. + Must be set to True when the input is a batch of images. + name: name of operation. + sorted_input: a boolean indicating whether the input boxes and scores + are sorted in descending order by the score. + canonicalized_coordinates: if box coordinates are given as + `[y_min, x_min, y_max, x_max]`, settign to True eliminate redundant + computation to canonicalize box coordinates. + tile_size: an integer representing the number of boxes in a tile, i.e., + the maximum number of boxes per image that can be used to suppress other + boxes in parallel; larger tile_size means larger parallelism and + potentially more redundant work. + Returns: + idx: a tensor with a shape of [..., num_boxes] representing the + indices selected by non-max suppression. The leadign dimensions + are the batch dimensions of the input boxes. All numbers are are within + [0, num_boxes). For each image (i.e., idx[i]), only the first num_valid[i] + indices (i.e., idx[i][:num_valid[i]]) are valid. + num_valid: a tensor of rank 0 or higher with a shape of [...] + representing the number of valid indices in idx. Its dimensions are the + batch dimensions of the input boxes. + Raises: + ValueError: When set pad_to_max_output_size to False for batched input. + """ + # if no new arguments are used and no later than 2020/4/20, use the old + # version to give us time to fix TFLite conversion + if (not sorted_input) and \ + (not canonicalized_coordinates) and \ + tile_size == 512 and compat.forward_compatible(2020, 4, 20): + return non_max_suppression_padded_v1( + boxes, scores, max_output_size, iou_threshold, score_threshold, + pad_to_max_output_size, name) + else: + return non_max_suppression_padded_v2( + boxes, scores, max_output_size, iou_threshold, score_threshold, + pad_to_max_output_size, name, sorted_input, canonicalized_coordinates, + tile_size) + + +def non_max_suppression_padded_v2(boxes, + scores, + max_output_size, + iou_threshold=0.5, + score_threshold=float('-inf'), + pad_to_max_output_size=False, + name=None, + sorted_input=False, + canonicalized_coordinates=False, + tile_size=512): """Non-maximum suppression. Prunes away boxes that have high intersection-over-union (IOU) overlap @@ -4465,9 +4544,9 @@ def non_max_suppression_padded(boxes, system; thus translating or reflections of the coordinate system result in the same boxes being selected by the algorithm. - Similar to tf.image.non_max_suppression, batched_non_max_suppression + Similar to tf.image.non_max_suppression, non_max_suppression_padded implements hard NMS but can operate on a batch of images and improves - performance by titling the bounding boxes. Batched_non_max_suppression should + performance by titling the bounding boxes. Non_max_suppression_padded should be preferred over tf.image_non_max_suppression when running on devices with abundant parallelsim for higher computation speed. For soft NMS, refer to tf.image.non_max_suppression_with_scores. @@ -4523,7 +4602,7 @@ def non_max_suppression_padded(boxes, iou_threshold: a float representing the threshold for deciding whether boxes overlap too much with respect to IoU (intersection over union). score_threshold: a float representing the threshold for box scores. Boxes - with a score that is lower than this threshold will be suppressed. + with a score that is not larger than this threshold will be suppressed. pad_to_max_output_size: whether to pad the output idx to max_output_size. Must be set to True when the input is a batch of images. name: name of operation. @@ -4579,19 +4658,24 @@ def non_max_suppression_padded(boxes, [batch_size, -1, 4]) return sorted_scores, sorted_boxes, sorted_scores_indices - with ops.name_scope(name, 'batched_non_max_suppression'): - if boxes.get_shape().ndims > 2 and not pad_to_max_output_size: - raise ValueError("'pad_to_max_output_size' (value {}) must be " - "True for batched input".format(pad_to_max_output_size)) + with ops.name_scope(name, 'non_max_suppression_padded'): + if not pad_to_max_output_size: + # pad_to_max_output_size may be set to False only when the shape of boxes` + # is [num_boxes, 4], i.e., a single image. We make best effort to detect + # violations at compile time. If `boxes` does not have a static rank, + # the check allows computation to proceed. + if boxes.get_shape().rank is not None and boxes.get_shape().rank > 2: + raise ValueError("'pad_to_max_output_size' (value {}) must be True for " + "batched input".format(pad_to_max_output_size)) - batch_dims = boxes.get_shape().as_list()[:-2] + batch_dims = array_ops.shape(boxes)[:-2] num_boxes = array_ops.shape(boxes)[-2] boxes = array_ops.reshape(boxes, [-1, num_boxes, 4]) scores = array_ops.reshape(scores, [-1, num_boxes]) batch_size = array_ops.shape(boxes)[0] if score_threshold != float('-inf'): with ops.name_scope('filter_by_score'): - score_mask = math_ops.cast(scores >= score_threshold, scores.dtype) + score_mask = math_ops.cast(scores > score_threshold, scores.dtype) scores *= score_mask box_mask = array_ops.expand_dims( math_ops.cast(score_mask, boxes.dtype), 2) @@ -4601,10 +4685,12 @@ def non_max_suppression_padded(boxes, with ops.name_scope('canonicalize_coordinates'): y_1, x_1, y_2, x_2 = array_ops.split( value=boxes, num_or_size_splits=4, axis=2) - y_1_is_min = math_ops.less(y_1[0, 0, 0], y_2[0, 0, 0]) + y_1_is_min = math_ops.reduce_all( + math_ops.less_equal(y_1[0, 0, 0], y_2[0, 0, 0])) y_min, y_max = control_flow_ops.cond( y_1_is_min, lambda: (y_1, y_2), lambda: (y_2, y_1)) - x_1_is_min = math_ops.less(x_1[0, 0, 0], x_2[0, 0, 0]) + x_1_is_min = math_ops.reduce_all( + math_ops.less_equal(x_1[0, 0, 0], x_2[0, 0, 0])) x_min, x_max = control_flow_ops.cond( x_1_is_min, lambda: (x_1, x_2), lambda: (x_2, x_1)) boxes = array_ops.concat([y_min, x_min, y_max, x_max], axis=2) @@ -4614,18 +4700,20 @@ def non_max_suppression_padded(boxes, pad = math_ops.cast( math_ops.ceil( - math_ops.cast(num_boxes, dtypes.float32) / tile_size), + math_ops.cast(math_ops.maximum(num_boxes, max_output_size), + dtypes.float32) / tile_size + ), dtypes.int32) * tile_size - num_boxes boxes = array_ops.pad( math_ops.cast(boxes, dtypes.float32), [[0, 0], [0, pad], [0, 0]]) scores = array_ops.pad( math_ops.cast(scores, dtypes.float32), [[0, 0], [0, pad]]) num_boxes_after_padding = num_boxes + pad - + num_iterations = num_boxes_after_padding // tile_size def _loop_cond(unused_boxes, unused_threshold, output_size, idx): return math_ops.logical_and( math_ops.reduce_min(output_size) < max_output_size, - idx < num_boxes_after_padding // tile_size) + idx < num_iterations) def suppression_loop_body(boxes, iou_threshold, output_size, idx): return _suppression_loop_body( @@ -4645,6 +4733,7 @@ def non_max_suppression_padded(boxes, math_ops.range(num_boxes_after_padding, 0, -1), 0), max_output_size)[0], dtypes.int32) idx = math_ops.minimum(idx, num_boxes - 1) + if not sorted_input: index_offsets = math_ops.range(batch_size) * num_boxes gather_idx = array_ops.reshape( @@ -4653,15 +4742,77 @@ def non_max_suppression_padded(boxes, array_ops.gather(array_ops.reshape(sorted_indices, [-1]), gather_idx), [batch_size, -1]) + invalid_index = array_ops.fill([batch_size, max_output_size], 0) + idx_index = array_ops.expand_dims(math_ops.range(max_output_size), 0) + num_valid_expanded = array_ops.expand_dims(num_valid, 1) + idx = array_ops.where(idx_index < num_valid_expanded, + idx, invalid_index) num_valid = array_ops.reshape(num_valid, batch_dims) if not pad_to_max_output_size: idx = idx[0, :num_valid] - batch_dims.append(-1) + return idx, num_valid + last_dim = constant_op.constant(-1, shape=[1]) + batch_dims = array_ops.concat([batch_dims, last_dim], 0) idx = array_ops.reshape(idx, batch_dims) return idx, num_valid +def non_max_suppression_padded_v1(boxes, + scores, + max_output_size, + iou_threshold=0.5, + score_threshold=float('-inf'), + pad_to_max_output_size=False, + name=None): + """Greedily selects a subset of bounding boxes in descending order of score. + + Performs algorithmically equivalent operation to tf.image.non_max_suppression, + with the addition of an optional parameter which zero-pads the output to + be of size `max_output_size`. + The output of this operation is a tuple containing the set of integers + indexing into the input collection of bounding boxes representing the selected + boxes and the number of valid indices in the index set. The bounding box + coordinates corresponding to the selected indices can then be obtained using + the `tf.slice` and `tf.gather` operations. For example: + ```python + selected_indices_padded, num_valid = tf.image.non_max_suppression_padded( + boxes, scores, max_output_size, iou_threshold, + score_threshold, pad_to_max_output_size=True) + selected_indices = tf.slice( + selected_indices_padded, tf.constant([0]), num_valid) + selected_boxes = tf.gather(boxes, selected_indices) + ``` + + Args: + boxes: A 2-D float `Tensor` of shape `[num_boxes, 4]`. + scores: A 1-D float `Tensor` of shape `[num_boxes]` representing a single + score corresponding to each box (each row of boxes). + max_output_size: A scalar integer `Tensor` representing the maximum number + of boxes to be selected by non-max suppression. + iou_threshold: A float representing the threshold for deciding whether boxes + overlap too much with respect to IOU. + score_threshold: A float representing the threshold for deciding when to + remove boxes based on score. + pad_to_max_output_size: bool. If True, size of `selected_indices` output is + padded to `max_output_size`. + name: A name for the operation (optional). + + Returns: + selected_indices: A 1-D integer `Tensor` of shape `[M]` representing the + selected indices from the boxes tensor, where `M <= max_output_size`. + valid_outputs: A scalar integer `Tensor` denoting how many elements in + `selected_indices` are valid. Valid elements occur first, then padding. + """ + with ops.name_scope(name, 'non_max_suppression_padded'): + iou_threshold = ops.convert_to_tensor(iou_threshold, name='iou_threshold') + score_threshold = ops.convert_to_tensor( + score_threshold, name='score_threshold') + return gen_image_ops.non_max_suppression_v4(boxes, scores, max_output_size, + iou_threshold, score_threshold, + pad_to_max_output_size) + + @tf_export('image.draw_bounding_boxes', v1=[]) def draw_bounding_boxes_v2(images, boxes, colors, name=None): """Draw bounding boxes on a batch of images. diff --git a/tensorflow/python/ops/image_ops_test.py b/tensorflow/python/ops/image_ops_test.py index 8d64672d41f..2ed077a862c 100644 --- a/tensorflow/python/ops/image_ops_test.py +++ b/tensorflow/python/ops/image_ops_test.py @@ -4615,9 +4615,8 @@ class NonMaxSuppressionPaddedTest(test_util.TensorFlowTestCase): self.assertEqual(selected_indices_padded.shape.is_fully_defined(), True) self.assertEqual(selected_indices.shape.is_fully_defined(), False) with self.cached_session(): - invalid_index = len(boxes_np) - 1 self.assertAllClose(selected_indices_padded.eval(), - [3, 0, 5, invalid_index, invalid_index]) + [3, 0, 5, 0, 0]) self.assertEqual(num_valid_padded.eval(), 3) self.assertAllClose(selected_indices.eval(), [3, 0, 5]) self.assertEqual(num_valid.eval(), 3) diff --git a/tensorflow/python/ops/init_ops_v2.py b/tensorflow/python/ops/init_ops_v2.py index 4999c4d8aac..3c110fe9cf9 100644 --- a/tensorflow/python/ops/init_ops_v2.py +++ b/tensorflow/python/ops/init_ops_v2.py @@ -41,6 +41,7 @@ from tensorflow.python.ops import linalg_ops_impl from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import stateless_random_ops +from tensorflow.python.ops.init_ops import _compute_fans from tensorflow.python.util.tf_export import tf_export @@ -126,6 +127,8 @@ class Zeros(Initializer): ValuesError: If the dtype is not numeric or boolean. """ dtype = dtypes.as_dtype(dtype) + if not dtype.is_numpy_compatible or dtype == dtypes.string: + raise ValueError("Expected numeric or boolean dtype, got %s." % dtype) return array_ops.zeros(shape, dtype) @@ -991,33 +994,6 @@ def he_uniform(seed=None): # Utility functions. -def _compute_fans(shape): - """Computes the number of input and output units for a weight shape. - - Args: - shape: Integer shape tuple or TF tensor shape. - - Returns: - A tuple of scalars (fan_in, fan_out). - """ - if len(shape) < 1: # Just to avoid errors for constants. - fan_in = fan_out = 1 - elif len(shape) == 1: - fan_in = fan_out = shape[0] - elif len(shape) == 2: - fan_in = shape[0] - fan_out = shape[1] - else: - # Assuming convolution kernels (2D, 3D, or more). - # kernel shape: (..., input_depth, depth) - receptive_field_size = 1. - for dim in shape[:-2]: - receptive_field_size *= dim - fan_in = shape[-2] * receptive_field_size - fan_out = shape[-1] * receptive_field_size - return fan_in, fan_out - - def _assert_float_dtype(dtype): """Validate and return floating point type based on `dtype`. diff --git a/tensorflow/python/ops/linalg/linear_operator_householder.py b/tensorflow/python/ops/linalg/linear_operator_householder.py index ae112bc3ea0..142d48c5331 100644 --- a/tensorflow/python/ops/linalg/linear_operator_householder.py +++ b/tensorflow/python/ops/linalg/linear_operator_householder.py @@ -64,6 +64,7 @@ class LinearOperatorHouseholder(linear_operator.LinearOperator): x = ... Shape [2, 4] Tensor operator.matmul(x) ==> Shape [2, 4] Tensor + ``` #### Shape compatibility diff --git a/tensorflow/python/ops/linalg/linear_operator_permutation.py b/tensorflow/python/ops/linalg/linear_operator_permutation.py index b705f0a077d..3a44cd5ef1b 100644 --- a/tensorflow/python/ops/linalg/linear_operator_permutation.py +++ b/tensorflow/python/ops/linalg/linear_operator_permutation.py @@ -75,6 +75,7 @@ class LinearOperatorPermutation(linear_operator.LinearOperator): x = ... Shape [3, 4] Tensor operator.matmul(x) ==> Shape [3, 4] Tensor + ``` #### Shape compatibility diff --git a/tensorflow/python/ops/list_ops.py b/tensorflow/python/ops/list_ops.py index ee01ff7cf68..3e7c116ec97 100644 --- a/tensorflow/python/ops/list_ops.py +++ b/tensorflow/python/ops/list_ops.py @@ -186,6 +186,8 @@ def _PopBackGrad(op, dlist, delement): element_dtype=delement.dtype, element_shape=gen_list_ops.tensor_list_element_shape( op.outputs[0], shape_type=dtypes.int32)) + if delement is None: + delement = array_ops.zeros_like(op.outputs[1]) return gen_list_ops.tensor_list_push_back(dlist, delement), None diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py index bf725b34e0b..062b571ff4e 100644 --- a/tensorflow/python/ops/math_ops.py +++ b/tensorflow/python/ops/math_ops.py @@ -100,9 +100,118 @@ from tensorflow.python.util import nest from tensorflow.python.util.tf_export import tf_export # Aliases for some automatically-generated names. -linspace = gen_math_ops.lin_space nextafter = gen_math_ops.next_after + +@tf_export("linspace", v1=["lin_space", "linspace"]) +@deprecation.deprecated_endpoints("lin_space") +def linspace_nd(start, stop, num, name=None, axis=0): + r"""Generates evenly-spaced values in an interval along a given axis. + + A sequence of `num` evenly-spaced values are generated beginning at `start` + along a given `axis`. + If `num > 1`, the values in the sequence increase by `stop - start / num - 1`, + so that the last one is exactly `stop`. If `num <= 0`, `ValueError` is raised. + + Matches + [np.linspace](https://docs.scipy.org/doc/numpy/reference/generated/numpy.linspace.html)'s + behaviour + except when `num == 0`. + + For example: + + ``` + tf.linspace(10.0, 12.0, 3, name="linspace") => [ 10.0 11.0 12.0] + ``` + + `Start` and `stop` can be tensors of arbitrary size: + + >>> tf.linspace([0., 5.], [10., 40.], 5, axis=0) + + + `Axis` is where the values will be generated (the dimension in the + returned tensor which corresponds to the axis will be equal to `num`) + + >>> tf.linspace([0., 5.], [10., 40.], 5, axis=-1) + + + + + Args: + start: A `Tensor`. Must be one of the following types: `bfloat16`, + `float32`, `float64`. N-D tensor. First entry in the range. + stop: A `Tensor`. Must have the same type and shape as `start`. N-D tensor. + Last entry in the range. + num: A `Tensor`. Must be one of the following types: `int32`, `int64`. 0-D + tensor. Number of values to generate. + name: A name for the operation (optional). + axis: Axis along which the operation is performed (used only when N-D + tensors are provided). + + Returns: + A `Tensor`. Has the same type as `start`. + """ + + with ops.name_scope(name, "linspace", [start, stop]): + start = ops.convert_to_tensor(start, name="start") + # stop must be convertible to the same dtype as start + stop = ops.convert_to_tensor(stop, name="stop", dtype=start.dtype) + num_int = array_ops.convert_to_int_tensor(num, name="num") + num = cast(num_int, dtype=start.dtype) + + expanded_start = array_ops.expand_dims(start, axis=axis) + expanded_stop = array_ops.expand_dims(stop, axis=axis) + + shape = array_ops.shape(expanded_start) + ndims = array_ops.shape(shape)[0] + + axis = array_ops.where_v2(axis >= 0, axis, ndims + axis) + + # to avoid having negative values in the range or zero division + # The result is sliced in the end so a correct result is returned for + # num == 1. + n_steps = gen_math_ops.maximum(num - 1., 1.) + delta = (expanded_stop - expanded_start) / n_steps + # If num < 0, we will throw exception in the range + # otherwise use the same div for delta + range_end = array_ops.where_v2(num_int > 0, n_steps, -1) + num_range = range(1., range_end, dtype=start.dtype) + shape_range = range(ndims) + ones_like_shape_range = array_ops.ones_like(shape_range) + axis_tiled = ones_like_shape_range * axis + # the purpose is to avoid having negative values when repeating + num_fill = gen_math_ops.maximum(num_int - 2, 0) + num_tiled = array_ops.ones_like(shape_range) * num_fill + ones = array_ops.ones_like(num_tiled) + mask = gen_math_ops.equal(axis_tiled, shape_range) + # reshape_target is [1. 1. 1. ... 1. num 1. 1. ... 1.], where the index + # of num is equal to axis + reshape_target = array_ops.where_v2(mask, num_fill, shape) + delta_expanded = array_ops.reshape(delta, shape) + delta_repeated = array_ops.broadcast_to(delta_expanded, reshape_target) + start_repeated = array_ops.broadcast_to(expanded_start, reshape_target) + + expanded_shape = array_ops.where_v2(mask, num_fill, ones) + range_indices = array_ops.reshape(num_range, expanded_shape) + tiled_range_indices = array_ops.tile(range_indices, shape) + res = start_repeated + delta_repeated * tiled_range_indices + all_tensors = (expanded_start, res, expanded_stop) + concatenated = array_ops.concat(all_tensors, axis=axis) + begin = array_ops.zeros_like(shape) + num_slice = ones_like_shape_range * num_int + size = array_ops.where_v2(mask, num_slice, shape) + return array_ops.slice(concatenated, begin, size) + + +linspace = linspace_nd + arg_max = deprecation.deprecated(None, "Use `tf.math.argmax` instead")(arg_max) # pylint: disable=used-before-assignment arg_min = deprecation.deprecated(None, "Use `tf.math.argmin` instead")(arg_min) # pylint: disable=used-before-assignment tf_export(v1=["arg_max"])(arg_max) diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py index 7e159747ff9..8df06a8e861 100644 --- a/tensorflow/python/ops/nn_ops.py +++ b/tensorflow/python/ops/nn_ops.py @@ -780,7 +780,7 @@ def convolution( name=None, data_format=None, filters=None, - dilations=None): + dilations=None): # pylint: disable=g-doc-args """Computes sums of N-D convolutions (actually cross-correlation). This also supports either output striding via the optional `strides` parameter @@ -865,8 +865,6 @@ def convolution( starts with "NC"). For N=1, the valid values are "NWC" (default) and "NCW". For N=2, the valid values are "NHWC" (default) and "NCHW". For N=3, the valid values are "NDHWC" (default) and "NCDHW". - filters: Alias of filter. - dilations: Alias of dilation_rate. Returns: A `Tensor` with the same type as `input` of shape diff --git a/tensorflow/python/ops/ragged/BUILD b/tensorflow/python/ops/ragged/BUILD index c7157f6ac1d..4d2161a93b8 100644 --- a/tensorflow/python/ops/ragged/BUILD +++ b/tensorflow/python/ops/ragged/BUILD @@ -1093,6 +1093,7 @@ py_test( srcs = ["ragged_map_fn_op_test.py"], python_version = "PY3", srcs_version = "PY2AND3", + tags = ["no_rocm"], deps = [ ":ragged", # fixdeps: keep ":ragged_factory_ops", diff --git a/tensorflow/python/ops/ragged/ragged_from_tensor_op_test.py b/tensorflow/python/ops/ragged/ragged_from_tensor_op_test.py index d54bc904874..110caa28b59 100644 --- a/tensorflow/python/ops/ragged/ragged_from_tensor_op_test.py +++ b/tensorflow/python/ops/ragged/ragged_from_tensor_op_test.py @@ -20,6 +20,9 @@ from __future__ import print_function from absl.testing import parameterized +import numpy as np + +from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops @@ -29,8 +32,8 @@ from tensorflow.python.platform import googletest @test_util.run_all_in_graph_and_eager_modes -class RaggedTensorToSparseOpTest(test_util.TensorFlowTestCase, - parameterized.TestCase): +class RaggedTensorFromTensorOpTest(test_util.TensorFlowTestCase, + parameterized.TestCase): def testDocStringExamples(self): # The examples from RaggedTensor.from_tensor.__doc__. @@ -366,6 +369,8 @@ class RaggedTensorToSparseOpTest(test_util.TensorFlowTestCase, if expected_shape is not None: self.assertEqual(rt.shape.as_list(), expected_shape) self.assertAllEqual(rt, expected) + self.assertAllEqual(rt, RaggedTensor.from_nested_row_splits( + rt.flat_values, rt.nested_row_splits, validate=True)) def testHighDimensions(self): # Use distinct prime numbers for all dimension shapes in this test, so @@ -380,6 +385,8 @@ class RaggedTensorToSparseOpTest(test_util.TensorFlowTestCase, dt.shape.is_compatible_with(rt.shape), '%s is incompatible with %s' % (dt.shape, rt.shape)) self.assertAllEqual(rt, self.evaluate(dt).tolist()) + self.assertAllEqual(rt, RaggedTensor.from_nested_row_splits( + rt.flat_values, rt.nested_row_splits, validate=True)) @parameterized.parameters( # With no padding or lengths @@ -399,6 +406,10 @@ class RaggedTensorToSparseOpTest(test_util.TensorFlowTestCase, 'dt_shape': [0, 2, 3], 'expected': [] }, + { + 'dt_shape': [1, 0, 0], + 'expected': [[]] + }, { 'dt_shape': [2, 0, 3], 'expected': [[], []] @@ -485,11 +496,74 @@ class RaggedTensorToSparseOpTest(test_util.TensorFlowTestCase, ) def testEmpty(self, dt_shape, expected, lengths=None, padding=None): dt = array_ops.zeros(dt_shape) - rt = RaggedTensor.from_tensor(dt, lengths, padding) - self.assertEqual(type(rt), RaggedTensor) - self.assertEqual(rt.ragged_rank, 1) - self.assertTrue(dt.shape.is_compatible_with(rt.shape)) - self.assertAllEqual(rt, expected) + for ragged_rank in range(1, len(dt_shape) - 1): + rt = RaggedTensor.from_tensor(dt, lengths, padding, ragged_rank) + self.assertEqual(type(rt), RaggedTensor) + self.assertEqual(rt.ragged_rank, ragged_rank) + self.assertTrue(dt.shape.is_compatible_with(rt.shape)) + self.assertAllEqual(rt, expected) + self.assertAllEqual(rt, RaggedTensor.from_nested_row_splits( + rt.flat_values, rt.nested_row_splits, validate=True)) + + @parameterized.named_parameters([ + { + 'testcase_name': '2D_UnknownRank', + 'tensor': [[1, 2], [3, 4]], + 'tensor_shape': None, + }, + { + 'testcase_name': '2D_Shape_None_None', + 'tensor': [[1, 2], [3, 4]], + 'tensor_shape': [None, None], + }, + { + 'testcase_name': '2D_Shape_2_None', + 'tensor': [[1, 2], [3, 4]], + 'tensor_shape': [2, None], + }, + { + 'testcase_name': '2D_Shape_None_2', + 'tensor': [[1, 2], [3, 4]], + 'tensor_shape': [None, 2], + }, + { + 'testcase_name': '4D_UnknownRank', + 'tensor': np.ones([4, 3, 2, 1]), + 'tensor_shape': None, + }, + { + 'testcase_name': '4D_Shape_None_None_None_None', + 'tensor': np.ones([4, 3, 2, 1]), + 'tensor_shape': [None, None, None, None], + }, + { + 'tensor': np.ones([4, 3, 2, 1]), + 'tensor_shape': [4, None, None, 1], + 'testcase_name': '4D_Shape_4_None_None_1', + }, + ]) + def testPartialShapes(self, tensor, tensor_shape, shape=None, + expected=None): + if expected is None: + expected = tensor + + if context.executing_eagerly(): + return # static shapes are always fully defined in eager mode. + + dt = constant_op.constant(tensor) + for ragged_rank in range(1, len(dt.shape) - 1): + dt_placeholder = array_ops.placeholder_with_default(tensor, tensor_shape) + rt = RaggedTensor.from_tensor(dt_placeholder, ragged_rank=ragged_rank) + self.assertIsInstance(rt, RaggedTensor) + self.assertEqual(rt.ragged_rank, ragged_rank) + self.assertTrue( + dt.shape.is_compatible_with(rt.shape), + '%s is incompatible with %s' % (dt.shape, rt.shape)) + if shape is not None: + self.assertEqual(rt.shape.as_list(), shape) + self.assertAllEqual(rt, expected.tolist()) + self.assertAllEqual(rt, RaggedTensor.from_nested_row_splits( + rt.flat_values, rt.nested_row_splits, validate=True)) @parameterized.parameters( { diff --git a/tensorflow/python/ops/ragged/ragged_tensor.py b/tensorflow/python/ops/ragged/ragged_tensor.py index 6d365210308..afb631ed0f2 100644 --- a/tensorflow/python/ops/ragged/ragged_tensor.py +++ b/tensorflow/python/ops/ragged/ragged_tensor.py @@ -30,8 +30,10 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor_like from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_spec +from tensorflow.python.framework import tensor_util from tensorflow.python.framework import type_spec from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops @@ -54,7 +56,7 @@ _convert_row_partition = RowPartition._convert_row_partition @tf_export("RaggedTensor") -class RaggedTensor(composite_tensor.CompositeTensor): +class RaggedTensor(composite_tensor.CompositeTensor, tensor_like.TensorLike): """Represents a ragged tensor. A `RaggedTensor` is a tensor with one or more *ragged dimensions*, which are @@ -1228,13 +1230,18 @@ class RaggedTensor(composite_tensor.CompositeTensor): splits_shape = array_ops.shape(self.row_splits, out_type=out_type) flat_values_shape = array_ops.shape(rt_flat_values, out_type=out_type) - ragged_dimensions = array_ops.stack([splits_shape[0] - 1] + [ + ragged_dimensions = [splits_shape[0] - 1] + [ math_ops.maximum(math_ops.reduce_max(splits[1:] - splits[:-1]), 0) for splits in nested_splits - ]) + ] inner_dimensions = flat_values_shape[1:] - bbox = array_ops.concat([ragged_dimensions, inner_dimensions], axis=0) + if out_type != self._row_partition.dtype: + ragged_dimensions = [ + math_ops.cast(d, out_type) for d in ragged_dimensions + ] + bbox = array_ops.concat( + [array_ops.stack(ragged_dimensions), inner_dimensions], axis=0) return bbox if axis is None else array_ops.gather(bbox, axis) #============================================================================= @@ -1368,6 +1375,59 @@ class RaggedTensor(composite_tensor.CompositeTensor): "inner_axis (%d)" % (outer_axis, inner_axis)) return _merge_dims(self, outer_axis, inner_axis) + def _set_shape(self, shape): + """Updates the static shape of `self` to be `shape`. + + * If a dimension of `shape` has known rank, and is encoded via + partitioning, then this will update the corresponding partition to + define `_uniform_row_length` and `nrows`. + * If a dimension of `shape` has a known rank, and is encoded as one + of the `flat_values` dimensions, then `flat_values.set_shape()` will + be used to update its shape. + + Warning: Using this method to assert an incorrect shape for a RaggedTensor + (i.e., one that's not consistent with its actual shape) can cause + segmentation faults and very difficult-to-diagnose behavior. Only use this + method if you are certain that the shape is correct. + + Args: + shape: `tf.TensorShape` specifying the shape for this `RaggedTensor`. + """ + # TODO(edloper): Refactor this to not directly access private members + # of RowPartition. + # pylint: disable=protected-access + + shape = tensor_shape.as_shape(shape) + if shape.rank is None: + return # Nothing to do. + + shape = shape.as_list() + + # Outermost dimension + if shape[0] is not None: + self._row_partition._row_splits.set_shape(shape[0] + 1) + + # Partitioned dimensions + dtype = self._row_partition.dtype + for i, partition in enumerate(self._nested_row_partitions): + size = shape[i + 1] + if size is not None: + if partition._uniform_row_length is not None: + old_row_length = tensor_util.constant_value( + partition._uniform_row_length) + if old_row_length is not None: + if size == old_row_length: + continue # already have shape info for this axis. + else: + raise ValueError("Inconsistent size for axis %s: %s vs %s" % + ((i + 1), old_row_length, size)) + partition._uniform_row_length = ops.convert_to_tensor(size, dtype) + if partition._nrows is None: + partition._nrows = array_ops.size(partition._row_splits) - 1 + + # Inner dimensions + flat_shape = tensor_shape.as_shape([None] + shape[self.ragged_rank + 1:]) + self.flat_values.set_shape(flat_shape) #============================================================================= # Tensor Type Conversions @@ -1481,15 +1541,15 @@ class RaggedTensor(composite_tensor.CompositeTensor): if ragged_rank > 1: if tensor.shape.is_fully_defined(): input_shape = tensor.shape.as_list() - new_shape = [-1] + input_shape[ragged_rank:] # The total number of elements in each dimension. E.g., if # input_shape=[3, 4, 5, 6], then dim[2] has 3*4*5 elements in total. dim_size = np.cumprod(input_shape) + new_shape = [dim_size[ragged_rank - 1]] + input_shape[ragged_rank:] else: - neg_one = constant_op.constant([-1], row_splits_dtype) - new_shape = array_ops.concat([neg_one, input_shape[ragged_rank:]], - axis=0) dim_size = math_ops.cumprod(input_shape) + new_shape = array_ops.concat([[dim_size[ragged_rank - 1]], + input_shape[ragged_rank:]], + axis=0) flattened = array_ops.reshape(tensor, new_shape) result = cls.from_tensor( flattened, lengths, padding, row_splits_dtype=row_splits_dtype) @@ -1563,7 +1623,8 @@ class RaggedTensor(composite_tensor.CompositeTensor): # If neither padding nor lengths were specified, then create a splits # vector that contains no default values, and reshape the input tensor # to form the values for the RaggedTensor. - values_shape = array_ops.concat([[-1], input_shape[2:]], axis=0) + values_shape = array_ops.concat([[input_shape[0] * input_shape[1]], + input_shape[2:]], axis=0) values = array_ops.reshape(tensor, values_shape) const_nrows = tensor_shape.dimension_at_index(tensor.shape, 0).value const_ncols = tensor_shape.dimension_at_index(tensor.shape, 1).value @@ -1620,13 +1681,30 @@ class RaggedTensor(composite_tensor.CompositeTensor): default_value = array_ops.zeros((), self.dtype) shape_tensor = _shape_as_tensor(shape, row_partition_tensors[0].dtype) - return gen_ragged_conversion_ops.ragged_tensor_to_tensor( + tensor = gen_ragged_conversion_ops.ragged_tensor_to_tensor( shape=shape_tensor, values=self.flat_values, default_value=default_value, row_partition_types=row_partition_types, row_partition_tensors=row_partition_tensors) + ragged_shape = self.shape + + if ragged_shape.rank is not None and not isinstance(shape, ops.Tensor): + # Merged self.shape and shape, favoring the second one as it takes + # into account potential padding added to the output. + shape = tensor_shape.as_shape(shape) + if shape.rank is None: + output_shape = ragged_shape + else: + # At this point we can assume that hshape.rank == ragged_shape.rank + # because otherwise it would have failed earlier. + output_shape = [s1 if s1 is not None else s2 for (s1, s2) + in zip(shape.as_list(), ragged_shape.as_list())] + tensor.set_shape(output_shape) + + return tensor + @classmethod def from_sparse(cls, st_input, name=None, row_splits_dtype=dtypes.int64): """Converts a 2D `tf.SparseTensor` to a `RaggedTensor`. @@ -2085,11 +2163,6 @@ class RaggedTensorSpec(type_spec.BatchableTypeSpec): else: return super(RaggedTensorSpec, self).is_compatible_with(spec_or_value) - @property - def dtype(self): - """The `tf.dtypes.DType` specified by this type for the RaggedTensor.""" - return self._dtype - def _serialize(self): return (self._shape, self._dtype, self._ragged_rank, self._row_splits_dtype) @@ -2145,6 +2218,8 @@ class RaggedTensorSpec(type_spec.BatchableTypeSpec): return [tensor_spec.TensorSpec(None, dtypes.variant)] def _to_tensor_list(self, value): + # TODO(edloper): Update gen_ragged_conversion_ops that convert to and + # from variant to include all of the row-partitioning tensors. ragged_rank = value.ragged_rank if isinstance(value, RaggedTensor) else 0 if ragged_rank != self._ragged_rank: raise ValueError("Ragged rank of value (%d) does not match ragged " @@ -2183,9 +2258,7 @@ class RaggedTensorSpec(type_spec.BatchableTypeSpec): outer_dim = tensor_shape.dimension_value(self._shape[0]) if outer_dim is not None: result.row_splits.set_shape([outer_dim + 1]) - result.flat_values.set_shape( - tensor_shape.TensorShape([None]).concatenate( - self._shape[1 + self._ragged_rank:])) + result._set_shape(self._shape) # pylint: disable=protected-access else: result.set_shape(self._shape) return result diff --git a/tensorflow/python/ops/ragged/ragged_tensor_bounding_shape_op_test.py b/tensorflow/python/ops/ragged/ragged_tensor_bounding_shape_op_test.py index 3f0ec1d12f3..faf44d36b11 100644 --- a/tensorflow/python/ops/ragged/ragged_tensor_bounding_shape_op_test.py +++ b/tensorflow/python/ops/ragged/ragged_tensor_bounding_shape_op_test.py @@ -18,45 +18,119 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from absl.testing import parameterized + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes from tensorflow.python.framework import test_util from tensorflow.python.ops.ragged import ragged_factory_ops -from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.platform import googletest @test_util.run_all_in_graph_and_eager_modes -class RaggedTensorBoundingShapeOp(test_util.TensorFlowTestCase): +class RaggedTensorBoundingShapeOp(test_util.TensorFlowTestCase, + parameterized.TestCase): - def testDocStringExample(self): - # This is the example from ragged.bounding_shape.__doc__. - rt = ragged_factory_ops.constant([[1, 2, 3, 4], [5], [], [6, 7, 8, 9], - [10]]) - self.assertAllEqual(rt.bounding_shape(), [5, 4]) + @parameterized.named_parameters([ + # rank = 2 + dict(testcase_name='docstring_example', + rt=[[1, 2, 3, 4], [5], [], [6, 7, 8, 9], [10]], + expected=[5, 4]), + dict(testcase_name='shape_5_3', + rt=[['a', 'b'], ['c', 'd', 'e'], ['f'], [], ['g']], + expected=[5, 3]), + dict(testcase_name='shape_1_7', + rt=[['a', 'b', 'c', 'd', 'e', 'f', 'g']], + expected=[1, 7]), + dict(testcase_name='shape_3_7', + rt=[[], ['a', 'b', 'c', 'd', 'e', 'f', 'g'], []], + expected=[3, 7]), + dict(testcase_name='shape_5_3_row_splits_int32', + rt=[['a', 'b'], ['c', 'd', 'e'], ['f'], [], ['g']], + rt_row_splits_dtype=dtypes.int32, + expected=[5, 3]), + dict(testcase_name='shape_0_0', + rt=[], + rt_ragged_rank=1, + expected=[0, 0]), + dict(testcase_name='shape_3_0', + rt=[[], [], []], + expected=[3, 0]), + # rank = 3 + dict(testcase_name='shape_5_3_2', + rt=[[[0, 1], [2]], [[3, 4], [], [5, 6]], [[7]], [], [[8, 9]]], + expected=[5, 3, 2]), + dict(testcase_name='shape_1_7_2', + rt=[[[0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10, 11], [12, 13]]], + expected=[1, 7, 2]), + dict(testcase_name='shape_3_7_4', + rt=[[], [[0, 1], [2], [], [3], [4], [5, 6, 7, 8], [9]], []], + expected=[3, 7, 4]), + dict(testcase_name='shape_1_7_2_ragged_rank_1', + rt=[[[0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10, 11], [12, 13]]], + rt_ragged_rank=1, + expected=[1, 7, 2]), + # axis != None + dict(testcase_name='shape_5_3_axis_0', + rt=[['a', 'b'], ['c', 'd', 'e'], ['f'], [], ['g']], + axis=0, + expected=5), + dict(testcase_name='shape_5_3_axis_1', + rt=[['a', 'b'], ['c', 'd', 'e'], ['f'], [], ['g']], + axis=1, + expected=3), + dict(testcase_name='shape_5_3_axis_1_0', + rt=[['a', 'b'], ['c', 'd', 'e'], ['f'], [], ['g']], + axis=[1, 0], + expected=[3, 5]), + # out_type != None + dict(testcase_name='shape_5_3_row_splits_int64_out_type_int64', + rt=[['a', 'b'], ['c', 'd', 'e'], ['f'], [], ['g']], + rt_row_splits_dtype=dtypes.int64, + out_type=dtypes.int64, + expected=[5, 3]), + dict(testcase_name='shape_5_3_row_splits_int32_out_type_int32', + rt=[['a', 'b'], ['c', 'd', 'e'], ['f'], [], ['g']], + rt_row_splits_dtype=dtypes.int32, + out_type=dtypes.int32, + expected=[5, 3]), + dict(testcase_name='shape_5_3_row_splits_int64_out_type_int32', + rt=[['a', 'b'], ['c', 'd', 'e'], ['f'], [], ['g']], + rt_row_splits_dtype=dtypes.int64, + out_type=dtypes.int32, + expected=[5, 3]), + dict(testcase_name='shape_5_3_row_splits_int32_out_type_int64', + rt=[['a', 'b'], ['c', 'd', 'e'], ['f'], [], ['g']], + rt_row_splits_dtype=dtypes.int32, + out_type=dtypes.int64, + expected=[5, 3]), + ]) # pyformat: disable + def testBoundingShape(self, + rt, + expected, + axis=None, + out_type=None, + rt_row_splits_dtype=dtypes.int64, + rt_ragged_rank=None): + rt = ragged_factory_ops.constant( + rt, ragged_rank=rt_ragged_rank, row_splits_dtype=rt_row_splits_dtype) + bounding_shape = rt.bounding_shape(axis=axis, out_type=out_type) + self.assertAllEqual(bounding_shape, expected) + if out_type is not None: + self.assertEqual(bounding_shape.dtype, out_type) + else: + self.assertEqual(bounding_shape.dtype, rt_row_splits_dtype) - def test2DRaggedTensorWithOneRaggedDimension(self): - values = ['a', 'b', 'c', 'd', 'e', 'f', 'g'] - rt1 = ragged_tensor.RaggedTensor.from_row_splits(values, [0, 2, 5, 6, 6, 7]) - rt2 = ragged_tensor.RaggedTensor.from_row_splits(values, [0, 7]) - rt3 = ragged_tensor.RaggedTensor.from_row_splits(values, [0, 0, 7, 7]) - self.assertAllEqual(rt1.bounding_shape(), [5, 3]) - self.assertAllEqual(rt2.bounding_shape(), [1, 7]) - self.assertAllEqual(rt3.bounding_shape(), [3, 7]) - - def test3DRaggedTensorWithOneRaggedDimension(self): - values = [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10, 11], [12, 13]] - rt1 = ragged_tensor.RaggedTensor.from_row_splits(values, [0, 2, 5, 6, 6, 7]) - rt2 = ragged_tensor.RaggedTensor.from_row_splits(values, [0, 7]) - rt3 = ragged_tensor.RaggedTensor.from_row_splits(values, [0, 0, 7, 7]) - self.assertAllEqual(rt1.bounding_shape(), [5, 3, 2]) - self.assertAllEqual(rt2.bounding_shape(), [1, 7, 2]) - self.assertAllEqual(rt3.bounding_shape(), [3, 7, 2]) - - def testExplicitAxisOptimizations(self): - rt = ragged_tensor.RaggedTensor.from_row_splits(b'a b c d e f g'.split(), - [0, 2, 5, 6, 6, 7]) - self.assertAllEqual(rt.bounding_shape(0), 5) - self.assertAllEqual(rt.bounding_shape(1), 3) - self.assertAllEqual(rt.bounding_shape([1, 0]), [3, 5]) + # If we're testing a configuration that uses `axis`, then make sure + # that it also works if `axis` is a tensor. + if axis is not None: + bounding_shape = rt.bounding_shape( + axis=constant_op.constant(axis), out_type=out_type) + self.assertAllEqual(bounding_shape, expected) + if out_type is not None: + self.assertEqual(bounding_shape.dtype, out_type) + else: + self.assertEqual(bounding_shape.dtype, rt_row_splits_dtype) if __name__ == '__main__': diff --git a/tensorflow/python/ops/ragged/ragged_tensor_test.py b/tensorflow/python/ops/ragged/ragged_tensor_test.py index 20c21bd5947..5b6521b5aa5 100644 --- a/tensorflow/python/ops/ragged/ragged_tensor_test.py +++ b/tensorflow/python/ops/ragged/ragged_tensor_test.py @@ -41,6 +41,7 @@ from tensorflow.python.ops.ragged.ragged_tensor import RaggedTensorSpec from tensorflow.python.ops.ragged.row_partition import RowPartition from tensorflow.python.platform import googletest +from tensorflow.python.util import nest def int32array(values): @@ -1483,6 +1484,73 @@ class RaggedTensorTest(test_util.TensorFlowTestCase, parameterized.TestCase): with self.assertRaisesRegexp(ValueError, 'only supported in eager mode'): rt.numpy() + @parameterized.parameters([ + ([[[1, 2], [3, 4, 5]], [[6]]], 2, None), + ([[[1, 2], [3, 4, 5]], [[6]]], 2, [None, None, None]), + ([[[1, 2], [3, 4, 5]], [[6]]], 2, [2, None, None]), + ([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9]]], 1, None), + ([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9]]], 1, [None, None, None]), + ([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9]]], 1, [2, None, None]), + ([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9]]], 1, [2, None, 3]), + ([[[1, 2, 3]]], 1, [1, 1, None]), + ([[[1, 2, 3]]], 1, [1, 1, 3]), + ]) + def testRaggedTensorSetShape(self, rt, rt_ragged_rank, shape): + rt1 = ragged_factory_ops.constant(rt, ragged_rank=rt_ragged_rank) + rt1._set_shape(shape) + rt1.shape.assert_is_compatible_with(shape) + if shape is not None: + self.assertIsNot(rt1.shape.rank, None) + for a, b in zip(rt1.shape, shape): + if b is not None: + self.assertEqual(a, b) + + @parameterized.parameters([ + ([[[1, 2], [3, 4, 5]], [[6]]], 2, None), + ([[[1, 2], [3, 4, 5]], [[6]]], 2, [None, None, None]), + ([[[1, 2], [3, 4, 5]], [[6]]], 2, [2, None, None]), + ([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9]]], 1, None), + ([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9]]], 1, [None, None, None]), + ([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9]]], 1, [2, None, None]), + ([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9]]], 1, [2, None, 3]), + ([[[1, 2, 3]]], 1, [1, 1, None]), + ([[[1, 2, 3]]], 1, [1, 1, 3]), + ]) + def testRaggedTensorSetShapeWithPlaceholders(self, rt, rt_ragged_rank, shape): + rt2 = nest.map_structure( + lambda x: array_ops.placeholder_with_default(x, None), + ragged_factory_ops.constant(rt, ragged_rank=rt_ragged_rank), + expand_composites=True) + rt2._set_shape(shape) + rt2.shape.assert_is_compatible_with(shape) + if shape is not None: + self.assertIsNot(rt2.shape.rank, None) + for a, b in zip(rt2.shape, shape): + if b is not None: + self.assertEqual(a, b) + + def testRaggedTensorSetShapeUniformRowLength(self): + rt = [[[1], [2], [3]], [[4], [5], [6]]] + + rt1 = RaggedTensor.from_tensor(rt, ragged_rank=1) + rt1._set_shape([2, 3, 1]) + + rt2 = nest.map_structure( + lambda x: array_ops.placeholder_with_default(x, None), + rt1, expand_composites=True) + rt2._set_shape([2, 3, 1]) + + def testRaggedTensorSetShapeInconsistentShapeError(self): + rt = RaggedTensor.from_tensor([[[1], [2], [3]], [[4], [5], [6]]], + ragged_rank=1) + self.assertEqual(rt.shape.as_list(), [2, 3, 1]) + with self.assertRaises(ValueError): + rt._set_shape([None, None, 5]) + with self.assertRaisesRegex(ValueError, 'Inconsistent size'): + rt._set_shape([None, 5, None]) + with self.assertRaises(ValueError): + rt._set_shape([5, None, None]) + @test_util.run_all_in_graph_and_eager_modes class RaggedTensorSpecTest(test_util.TensorFlowTestCase, @@ -1665,6 +1733,17 @@ class RaggedTensorSpecTest(test_util.TensorFlowTestCase, [t[0] for t in tensor_list]) self.assertAllEqual(rt[0], first_row) + def testToFromBatchedTensorListPreservesUniformRowLengths(self): + rt = RaggedTensor.from_tensor(array_ops.zeros([3, 4, 5]), + ragged_rank=2) + rt_spec = rt._type_spec + tensor_list = rt_spec._to_batched_tensor_list(rt) + rt_reconstructed = rt_spec._from_tensor_list(tensor_list) + self.assertAllEqual(rt, rt_reconstructed) + self.assertTrue(rt.shape.is_fully_defined()) + self.assertTrue(rt_reconstructed.shape.is_fully_defined()) + self.assertEqual(rt.shape.as_list(), rt_reconstructed.shape.as_list()) + @parameterized.parameters([ (RaggedTensorSpec([2, None], dtypes.float32, 1), 32, RaggedTensorSpec([32, 2, None], dtypes.float32, 2)), diff --git a/tensorflow/python/ops/ragged/ragged_to_tensor_op_test.py b/tensorflow/python/ops/ragged/ragged_to_tensor_op_test.py index fc2047de954..83b36394cc4 100644 --- a/tensorflow/python/ops/ragged/ragged_to_tensor_op_test.py +++ b/tensorflow/python/ops/ragged/ragged_to_tensor_op_test.py @@ -475,6 +475,22 @@ class RaggedTensorToTensorOpTest(test_util.TensorFlowTestCase, actual = input_data.to_tensor(shape=[3, 4]) self.assertAllEqual(actual, [[0, 1, 2, 0], [0, 0, 0, 0], [3, 0, 0, 0]]) + @parameterized.parameters( + ([2, 3, 4], None, [2, 3, 4]), + ([2, 3, 4], [None, None, None], [2, 3, 4]), + ([2, 3, 4], [None, 3, None], [2, 3, 4]), + ([2, 3, 4], [None, 3, 4], [2, 3, 4]), + ([2, 3, 4], [2, 3, 4], [2, 3, 4]), + ) + def test_preserve_shape_roundtrip( + self, input_shape, to_tensor_shape, expected_shape): + tensor = array_ops.zeros(input_shape) + ragged_from_tensor = RaggedTensor.from_tensor(tensor, ragged_rank=2) + recovered_tensor = ragged_from_tensor.to_tensor(shape=to_tensor_shape) + self.assertAllEqual(tensor.shape.as_list(), expected_shape) + self.assertAllEqual(ragged_from_tensor.shape.as_list(), expected_shape) + self.assertAllEqual(recovered_tensor.shape.as_list(), expected_shape) + def test_empty_tensor_with_shape(self): input_data = RaggedTensor.from_value_rowids( values=constant_op.constant([], dtype=dtypes.int64), diff --git a/tensorflow/python/ops/script_ops.py b/tensorflow/python/ops/script_ops.py index dd53b388bd4..bee85dc4a5b 100644 --- a/tensorflow/python/ops/script_ops.py +++ b/tensorflow/python/ops/script_ops.py @@ -70,7 +70,7 @@ def _maybe_copy_to_context_device(tensor, device_name): class EagerFunc(object): """A wrapper for a function owned by an EagerPyFunc.""" - def __init__(self, func, Tout, is_grad_func, use_tape_cache=True): + def __init__(self, func, Tout, is_grad_func): """Constructs an EagerFunc. Args: @@ -79,12 +79,10 @@ class EagerFunc(object): None. is_grad_func: Whether this EagerFunc is the gradient of another EagerPyFunc. - use_tape_cache: (Optional.) Whether to cache `func` in the `tape_cache`. """ self._func = func self._out_dtypes = Tout self._is_grad_func = is_grad_func - self._use_tape_cache = use_tape_cache def _convert(self, value, dtype): """Converts `value` to a tensor of type `dtype`, with error checking. @@ -148,8 +146,7 @@ class EagerFunc(object): else: outputs = _maybe_copy_to_context_device( self._convert(ret, dtype=self._out_dtypes[0]), device_name) - if self._use_tape_cache: - tape_cache[compat.as_bytes(token)] = (tape, args, outputs) + tape_cache[compat.as_bytes(token)] = (tape, args, outputs) return outputs @@ -279,8 +276,7 @@ def _internal_py_func(func, stateful=None, eager=False, is_grad_func=False, - name=None, - use_tape_cache=True): + name=None): """See documentation for py_func and eager_py_func.""" if not callable(func): raise ValueError("Expected func to be callable, got func of type {}".format( @@ -296,7 +292,7 @@ def _internal_py_func(func, Tout = [Tout] if eager: - func = EagerFunc(func, Tout, is_grad_func, use_tape_cache=use_tape_cache) + func = EagerFunc(func, Tout, is_grad_func) # Tying the registered function's lifetime with the current default graph is # not reliable. For example, Estimator-based binaries may switch graphs in @@ -373,35 +369,6 @@ def _EagerPyFuncGrad(op, *dy): is_grad_func=True) -# NOTE(lithuak): this function as a layer of indirection was added with one -# specific purpose: as a workaround for github issue #35084. -# It does all the same as `eager_py_func` used to do with one difference: -# it can be used to instruct underlying EagerFunc not to use `tape_cache` -# to avoid memory leak. When the issue #35084 is fixed - this function should -# be removed, its body should be moved back to become the body of -# `eager_py_func` and all the call sites should be reverted to -# using `eager_py_func` without `use_tape_cache` argument of any value. -def _eager_py_func(func, inp, Tout, name=None, use_tape_cache=True): - """Wraps a python function into a TensorFlow op that executes it eagerly.""" - if ops.executing_eagerly_outside_functions(): - with ops.device(context.context().host_address_space()): - return _internal_py_func( - func=func, - inp=inp, - Tout=Tout, - eager=True, - name=name, - use_tape_cache=use_tape_cache) - - return _internal_py_func( - func=func, - inp=inp, - Tout=Tout, - eager=True, - name=name, - use_tape_cache=use_tape_cache) - - @tf_export("py_function") def eager_py_func(func, inp, Tout, name=None): """Wraps a python function into a TensorFlow op that executes it eagerly. @@ -482,8 +449,12 @@ def eager_py_func(func, inp, Tout, name=None): A list of `Tensor` or a single `Tensor` which `func` computes; an empty list if `func` returns None. """ - return _eager_py_func( - func=func, inp=inp, Tout=Tout, name=name, use_tape_cache=True) + if ops.executing_eagerly_outside_functions(): + with ops.device(context.context().host_address_space()): + return _internal_py_func( + func=func, inp=inp, Tout=Tout, eager=True, name=name) + + return _internal_py_func(func=func, inp=inp, Tout=Tout, eager=True, name=name) def py_func_common(func, inp, Tout, stateful=True, name=None): diff --git a/tensorflow/python/ops/tensor_array_grad.py b/tensorflow/python/ops/tensor_array_grad.py index b0549041466..0beae1e55da 100644 --- a/tensorflow/python/ops/tensor_array_grad.py +++ b/tensorflow/python/ops/tensor_array_grad.py @@ -18,6 +18,7 @@ from __future__ import division from __future__ import print_function from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops from tensorflow.python.ops import tensor_array_ops # TODO(b/31222613): These ops may be differentiable, and there may be @@ -130,6 +131,12 @@ def _TensorArrayWriteGrad(op, flow): index = op.inputs[1] dtype = op.get_attr("T") grad_source = _GetGradSource(flow) + flow_out = array_ops.identity(op.outputs[0], "flow_out") + # Avoid a race condition where the TensorArrayGrad op is executed before the + # final TensorArrayWrite by adding a control dependency on the output flow of + # the write to the input flow to the TensorArrayGrad. + with ops.control_dependencies([flow_out]): + flow = array_ops.identity(flow, "write_barrier") g = (tensor_array_ops.TensorArray(dtype=dtype, handle=handle, flow=flow, colocate_with_first_write_call=False) .grad(source=grad_source, flow=flow)) @@ -185,6 +192,12 @@ def _TensorArrayScatterGrad(op, flow): indices = op.inputs[1] dtype = op.get_attr("T") grad_source = _GetGradSource(flow) + flow_out = array_ops.identity(op.outputs[0], "flow_out") + # Avoid a race condition where the TensorArrayGrad op is executed before the + # TensorArrayScatter by adding a control dependency on the output flow of + # the scatter to the input flow to the TensorArrayGrad. + with ops.control_dependencies([flow_out]): + flow = array_ops.identity(flow, "write_barrier") g = (tensor_array_ops.TensorArray(dtype=dtype, handle=handle, flow=flow, colocate_with_first_write_call=False) .grad(source=grad_source, flow=flow)) @@ -240,6 +253,12 @@ def _TensorArraySplitGrad(op, flow): handle = op.inputs[0] dtype = op.get_attr("T") grad_source = _GetGradSource(flow) + flow_out = array_ops.identity(op.outputs[0], "flow_out") + # Avoid a race condition where the TensorArrayGrad op is executed before the + # TensorArraySplit by adding a control dependency on the output flow of + # the split to the input flow to the TensorArrayGrad. + with ops.control_dependencies([flow_out]): + flow = array_ops.identity(flow, "write_barrier") g = (tensor_array_ops.TensorArray(dtype=dtype, handle=handle, flow=flow, colocate_with_first_write_call=False) .grad(source=grad_source, flow=flow)) diff --git a/tensorflow/python/profiler/internal/BUILD b/tensorflow/python/profiler/internal/BUILD index b2f25339176..a13fa26e005 100644 --- a/tensorflow/python/profiler/internal/BUILD +++ b/tensorflow/python/profiler/internal/BUILD @@ -119,6 +119,7 @@ tf_python_pybind_extension( deps = [ "//tensorflow/core:lib", "//tensorflow/core/profiler/convert:xplane_to_profile_response", + "//tensorflow/core/profiler/convert:xplane_to_trace_events", "//tensorflow/core/profiler/lib:profiler_session_headers", "//tensorflow/core/profiler/rpc:profiler_server", "//tensorflow/core/profiler/rpc/client:capture_profile", diff --git a/tensorflow/python/profiler/internal/profiler_wrapper.cc b/tensorflow/python/profiler/internal/profiler_wrapper.cc index c7780b7dc01..eaf7d09105b 100644 --- a/tensorflow/python/profiler/internal/profiler_wrapper.cc +++ b/tensorflow/python/profiler/internal/profiler_wrapper.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/core/platform/host_info.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/profiler/convert/xplane_to_profile_response.h" +#include "tensorflow/core/profiler/convert/xplane_to_trace_events.h" #include "tensorflow/core/profiler/lib/profiler_session.h" #include "tensorflow/core/profiler/rpc/client/capture_profile.h" #include "tensorflow/core/profiler/rpc/client/save_profile.h" @@ -56,8 +57,10 @@ class ProfilerSessionWrapper { py::bytes Stop() { tensorflow::string content; if (session_ != nullptr) { - tensorflow::Status status = session_->SerializeToString(&content); + tensorflow::profiler::XSpace xspace; + tensorflow::Status status = session_->CollectData(&xspace); session_.reset(); + tensorflow::profiler::ConvertXSpaceToTraceEventsString(xspace, &content); tensorflow::MaybeRaiseRegisteredFromStatus(status); } // The content is not valid UTF-8, so it must be converted to bytes. diff --git a/tensorflow/python/saved_model/nested_structure_coder.py b/tensorflow/python/saved_model/nested_structure_coder.py index 2cf01515181..9c71b853675 100644 --- a/tensorflow/python/saved_model/nested_structure_coder.py +++ b/tensorflow/python/saved_model/nested_structure_coder.py @@ -45,6 +45,7 @@ from tensorflow.python.framework import indexed_slices from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_spec +from tensorflow.python.framework import tensor_util from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import tensor_array_ops from tensorflow.python.ops.ragged import ragged_tensor @@ -420,7 +421,9 @@ class _TensorSpecCodec(object): """Codec for `TensorSpec`.""" def can_encode(self, pyobj): - return isinstance(pyobj, tensor_spec.TensorSpec) + # BoundedTensorSpec has its own decoder. + return (isinstance(pyobj, tensor_spec.TensorSpec) and + not isinstance(pyobj, tensor_spec.BoundedTensorSpec)) def do_encode(self, tensor_spec_value, encode_fn): encoded_tensor_spec = struct_pb2.StructuredValue() @@ -449,6 +452,45 @@ class _TensorSpecCodec(object): StructureCoder.register_codec(_TensorSpecCodec()) +class _BoundedTensorSpecCodec(object): + """Codec for `BoundedTensorSpec`.""" + + def can_encode(self, pyobj): + return isinstance(pyobj, tensor_spec.BoundedTensorSpec) + + def do_encode(self, bounded_tensor_spec_value, encode_fn): + """Returns an encoded proto for the given `tf.BoundedTensorSpec`.""" + encoded_bounded_tensor_spec = struct_pb2.StructuredValue() + encoded_bounded_tensor_spec.bounded_tensor_spec_value.CopyFrom( + struct_pb2.BoundedTensorSpecProto( + shape=encode_fn(bounded_tensor_spec_value.shape).tensor_shape_value, + dtype=encode_fn(bounded_tensor_spec_value.dtype).tensor_dtype_value, + name=bounded_tensor_spec_value.name, + minimum=tensor_util.make_tensor_proto( + bounded_tensor_spec_value.minimum), + maximum=tensor_util.make_tensor_proto( + bounded_tensor_spec_value.maximum))) + return encoded_bounded_tensor_spec + + def can_decode(self, value): + return value.HasField("bounded_tensor_spec_value") + + def do_decode(self, value, decode_fn): + btsv = value.bounded_tensor_spec_value + name = btsv.name + return tensor_spec.BoundedTensorSpec( + shape=decode_fn( + struct_pb2.StructuredValue(tensor_shape_value=btsv.shape)), + dtype=decode_fn( + struct_pb2.StructuredValue(tensor_dtype_value=btsv.dtype)), + minimum=tensor_util.MakeNdarray(btsv.minimum), + maximum=tensor_util.MakeNdarray(btsv.maximum), + name=(name if name else None)) + + +StructureCoder.register_codec(_BoundedTensorSpecCodec()) + + class _TypeSpecCodec(object): """Codec for `tf.TypeSpec`.""" diff --git a/tensorflow/python/saved_model/nested_structure_coder_test.py b/tensorflow/python/saved_model/nested_structure_coder_test.py index 23c305d0708..c68bc1017ee 100644 --- a/tensorflow/python/saved_model/nested_structure_coder_test.py +++ b/tensorflow/python/saved_model/nested_structure_coder_test.py @@ -27,6 +27,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_spec +from tensorflow.python.framework import tensor_util from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.platform import test from tensorflow.python.saved_model import nested_structure_coder @@ -35,6 +36,7 @@ from tensorflow.python.saved_model import nested_structure_coder class NestedStructureTest(test.TestCase): def setUp(self): + super(NestedStructureTest, self).setUp() self._coder = nested_structure_coder.StructureCoder() def testEncodeDecodeList(self): @@ -271,6 +273,54 @@ class NestedStructureTest(test.TestCase): ValueError, "The type 'FutureTensorSpec' is not supported"): self._coder.decode_proto(encoded) + def testEncodeDecodeBoundedTensorSpec(self): + structure = [ + tensor_spec.BoundedTensorSpec([1, 2, 3], dtypes.int64, 0, 10, + "hello-0-10") + ] + self.assertTrue(self._coder.can_encode(structure)) + encoded = self._coder.encode_structure(structure) + expected = struct_pb2.StructuredValue() + expected_list = expected.list_value + expected_tensor_spec = expected_list.values.add().bounded_tensor_spec_value + expected_tensor_spec.shape.dim.add().size = 1 + expected_tensor_spec.shape.dim.add().size = 2 + expected_tensor_spec.shape.dim.add().size = 3 + expected_tensor_spec.name = "hello-0-10" + expected_tensor_spec.dtype = dtypes.int64.as_datatype_enum + expected_tensor_spec.minimum.CopyFrom( + tensor_util.make_tensor_proto([0], dtype=dtypes.int64, shape=[])) + expected_tensor_spec.maximum.CopyFrom( + tensor_util.make_tensor_proto([10], dtype=dtypes.int64, shape=[])) + self.assertEqual(expected, encoded) + decoded = self._coder.decode_proto(encoded) + self.assertEqual(structure, decoded) + + def testEncodeDecodeBoundedTensorSpecNoName(self): + structure = [ + tensor_spec.BoundedTensorSpec((28, 28, 3), dtypes.float64, -2, + (1, 1, 20)) + ] + self.assertTrue(self._coder.can_encode(structure)) + encoded = self._coder.encode_structure(structure) + expected = struct_pb2.StructuredValue() + expected_list = expected.list_value + expected_tensor_spec = expected_list.values.add().bounded_tensor_spec_value + expected_tensor_spec.shape.dim.add().size = 28 + expected_tensor_spec.shape.dim.add().size = 28 + expected_tensor_spec.shape.dim.add().size = 3 + expected_tensor_spec.name = "" + expected_tensor_spec.dtype = dtypes.float64.as_datatype_enum + expected_tensor_spec.minimum.CopyFrom( + tensor_util.make_tensor_proto([-2], dtype=dtypes.float64, shape=[])) + expected_tensor_spec.maximum.CopyFrom( + tensor_util.make_tensor_proto([1, 1, 20], + dtype=dtypes.float64, + shape=[3])) + self.assertEqual(expected, encoded) + decoded = self._coder.decode_proto(encoded) + self.assertEqual(structure, decoded) + def testEncodeDataSetSpec(self): structure = [dataset_ops.DatasetSpec( {"rt": ragged_tensor.RaggedTensorSpec([10, None], dtypes.int32), diff --git a/tensorflow/python/tf2.py b/tensorflow/python/tf2.py index fd1c8c1757a..bc713d6e28b 100644 --- a/tensorflow/python/tf2.py +++ b/tensorflow/python/tf2.py @@ -24,7 +24,6 @@ from __future__ import print_function import os - _force_enable = None @@ -44,5 +43,5 @@ def enabled(): """Returns True iff TensorFlow 2.0 behavior should be enabled.""" if _force_enable is None: return os.getenv("TF2_BEHAVIOR", "0") != "0" - else: - return _force_enable + + return _force_enable diff --git a/tensorflow/python/tfe_wrapper.cc b/tensorflow/python/tfe_wrapper.cc index 09221d8b0a2..ee43e6e1d43 100644 --- a/tensorflow/python/tfe_wrapper.cc +++ b/tensorflow/python/tfe_wrapper.cc @@ -713,6 +713,7 @@ PYBIND11_MODULE(_pywrap_tfe, m) { &TFE_ContextOptionsSetDevicePlacementPolicy); m.def("TFE_ContextOptionsSetLazyRemoteInputsCopy", &TFE_ContextOptionsSetLazyRemoteInputsCopy); + m.def("TFE_ContextOptionsSetTfrt", &TFE_ContextOptionsSetTfrt); m.def("TFE_ContextOptionsSetMirroringPolicy", &TFE_ContextOptionsSetMirroringPolicy); m.def("TFE_ContextOptionsSetAsync", &TFE_ContextOptionsSetAsync); @@ -1073,7 +1074,8 @@ PYBIND11_MODULE(_pywrap_tfe, m) { return capsule; }); - m.def("TFE_FromDlpackCapsule", [](const py::capsule& pycapsule) { + m.def("TFE_FromDlpackCapsule", [](const py::capsule& pycapsule, + const py::handle& context) { tensorflow::Safe_TF_StatusPtr status = tensorflow::make_safe(TF_NewStatus()); if (absl::string_view(pycapsule.name()) != @@ -1084,8 +1086,9 @@ PYBIND11_MODULE(_pywrap_tfe, m) { absl::string_view(pycapsule.name())); tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); } - TFE_TensorHandle* thandle = - tensorflow::TFE_HandleFromDLPack(pycapsule, status.get()); + + TFE_TensorHandle* thandle = tensorflow::TFE_HandleFromDLPack( + pycapsule, status.get(), tensorflow::InputTFE_Context(context)); tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); diff --git a/tensorflow/python/tools/import_pb_to_tensorboard.py b/tensorflow/python/tools/import_pb_to_tensorboard.py index 5e9400d6d25..806f1aa9096 100644 --- a/tensorflow/python/tools/import_pb_to_tensorboard.py +++ b/tensorflow/python/tools/import_pb_to_tensorboard.py @@ -41,19 +41,17 @@ except ImportError: def import_to_tensorboard(model_dir, log_dir, tag_set): - """View an imported protobuf model (`.pb` file) as a graph in Tensorboard. + """View an SavedModel as a graph in Tensorboard. Args: - model_dir: The location of the protobuf (`pb`) model to visualize + model_dir: The directory containing the SavedModel to import. log_dir: The location for the Tensorboard log to begin visualization from. tag_set: Group of tag(s) of the MetaGraphDef to load, in string format, - separated by ','. For tag-set contains multiple tags, all tags must be - passed in. - - Usage: - Call this function with your model location and desired log directory. - Launch Tensorboard by pointing it to the log directory. - View your imported `.pb` model as a graph. + separated by ','. For tag-set contains multiple tags, all tags must be + passed in. + Usage: Call this function with your SavedModel location and desired log + directory. Launch Tensorboard by pointing it to the log directory. View your + imported SavedModel as a graph. """ with session.Session(graph=ops.Graph()) as sess: input_graph_def = saved_model_utils.get_meta_graph_def(model_dir, diff --git a/tensorflow/python/tools/module_util.py b/tensorflow/python/tools/module_util.py index 7a91eaeae92..66baafd552e 100644 --- a/tensorflow/python/tools/module_util.py +++ b/tensorflow/python/tools/module_util.py @@ -59,7 +59,7 @@ def get_parent_dir_for_name(module_name): spec = importlib.util.find_spec(name_split[0]) except ValueError: return None - if not spec.origin: + if not spec or not spec.origin: return None base_path = os.path.dirname(spec.origin) return os.path.join(base_path, *name_split[1:-1]) diff --git a/tensorflow/python/tools/saved_model_cli.py b/tensorflow/python/tools/saved_model_cli.py index 6e60e58b345..261ee1b9e9d 100644 --- a/tensorflow/python/tools/saved_model_cli.py +++ b/tensorflow/python/tools/saved_model_cli.py @@ -50,6 +50,7 @@ from tensorflow.python.saved_model import save from tensorflow.python.saved_model import signature_constants from tensorflow.python.tools import saved_model_aot_compile from tensorflow.python.tools import saved_model_utils +from tensorflow.python.tpu import tpu _XLA_DEBUG_OPTIONS_URL = ( @@ -438,7 +439,7 @@ def run_saved_model_with_feed_dict(saved_model_dir, tag_set, signature_def_key, print('Initializing TPU System ...') # This is needed for freshly started worker, or if the job # restarts after a preemption. - sess.run(tf.contrib.tpu.initialize_system()) + sess.run(tpu.initialize_system()) loader.load(sess, tag_set.split(','), saved_model_dir) diff --git a/tensorflow/python/tools/tools.bzl b/tensorflow/python/tools/tools.bzl index 65288a92a8a..c6853e1fc63 100644 --- a/tensorflow/python/tools/tools.bzl +++ b/tensorflow/python/tools/tools.bzl @@ -1,6 +1,6 @@ """Definitions for using tools like saved_model_cli.""" -load("//tensorflow:tensorflow.bzl", "if_xla_available") +load("//tensorflow:tensorflow.bzl", "clean_dep", "if_xla_available") load("//tensorflow/compiler/aot:tfcompile.bzl", "target_llvm_triple") def _maybe_force_compile(args, force_compile): @@ -121,12 +121,15 @@ def saved_model_compile_aot( "{}_makefile.inc".format(name), ], cmd = ( - "$(location :saved_model_cli) aot_compile_cpu " + + "$(location {}) aot_compile_cpu ".format( + clean_dep("//tensorflow/python/tools:saved_model_cli"), + ) + "--dir \"$$(dirname $(location {}))\" ".format(saved_model) + checkpoint_cmd_args + "--output_prefix $(@D)/{} ".format(name) + "--cpp_class {} ".format(cpp_class) + "--variables_to_feed {} ".format(variables_to_feed) + + "--signature_def_key {} ".format(signature_def) + "--target_triple " + target_triple + " " + "--tag_set {} ".format(tag_set) ), diff --git a/tensorflow/python/tpu/BUILD b/tensorflow/python/tpu/BUILD index cf32d933e0c..ebf0a4ffc57 100644 --- a/tensorflow/python/tpu/BUILD +++ b/tensorflow/python/tpu/BUILD @@ -32,6 +32,7 @@ py_test( "no_oss_py2", "no_oss_py35", "no_pip", + "no_rocm", ], deps = [ "//tensorflow/python:client_testlib", diff --git a/tensorflow/python/tpu/profiler/BUILD b/tensorflow/python/tpu/profiler/BUILD index eb77e5de742..b505262c6a2 100644 --- a/tensorflow/python/tpu/profiler/BUILD +++ b/tensorflow/python/tpu/profiler/BUILD @@ -11,8 +11,8 @@ py_library( srcs_version = "PY2AND3", deps = [ ":profiler_analysis_pb2_grpc", - "//tensorflow/core:protos_all_py", "//tensorflow/core/profiler:profiler_analysis_proto_py", + "//tensorflow/core/profiler/protobuf:trace_events_proto_py", "//tensorflow/python:util", ], ) diff --git a/tensorflow/python/tpu/profiler/__init__.py b/tensorflow/python/tpu/profiler/__init__.py index c8db3eafe6d..f021019b208 100644 --- a/tensorflow/python/tpu/profiler/__init__.py +++ b/tensorflow/python/tpu/profiler/__init__.py @@ -20,7 +20,7 @@ from __future__ import division from __future__ import print_function # pylint: disable=wildcard-import,unused-import -from tensorflow.core.protobuf.trace_events_pb2 import * +from tensorflow.core.profiler.protobuf.trace_events_pb2 import * from tensorflow.core.profiler.profiler_analysis_pb2 import * # pylint: enable=wildcard-import,unused-import diff --git a/tensorflow/python/tpu/tensor_tracer.py b/tensorflow/python/tpu/tensor_tracer.py index 87a2309cedf..d355bd6205a 100644 --- a/tensorflow/python/tpu/tensor_tracer.py +++ b/tensorflow/python/tpu/tensor_tracer.py @@ -101,6 +101,117 @@ _TT_EVENT_FILE_SUFFIX = '.tensor_tracer' _TT_SUMMARY_MAX_QUEUE = 100 +def set_parameters(tensor_tracer_params=None): + """Enables tensor tracer and sets its parameters. + + Example usage: + tensor_tracer_parameters = {'trace_dir': '/usr/tmp/trace_dir', + 'trace_mode': 'norm', + 'report_file': '/usr/tmp/trace_dir/report.all'} + tensor_tracer.set_parameters(tensor_tracer_parameters) + + This sets up the parameters for tensor tracer. A call to tensor tracer as + below is necessary to enable debugging on CPUs and GPUs. On TPUs below can be + skipped as this call is hooked into tpu.rewrite. + tt = tensor_tracer.TensorTracer() + loss = tt.trace_cpu(tf.get_default_graph(), tensor_fetches=loss) + + Args: + tensor_tracer_params: Tensor tracer parameter dictionary. Below gives + examples of these parameters: See tensor_tracer_report.py for all + parameters. + - enable: If set, tensor tracer will be enabled. Calling + enable_tensor_tracer automatically adds this parameters. + - trace_mode: The trace_mode to be used by tensor tracer. These include: + - summary: Collects multiple statistics for traced tensors, and writes + them a summary file that can be visualized using tensorboard. This + mode currently only works for TPUEstimator. It can be also be used + for other models, but outfeed must be handled by the user. + - norm: Collects norm of each traced tensor and writes them into a + text file pointed by 'trace_dir' flag. (Default mode). + - nan-inf: Checks the existince of NaNs and Infs in the tensor, and + writes a boolean value to a text file pointed by 'trace_dir' flag. + Note that 'norm' mode can also capture this information with more + numerical info. + - max-abs: Collects the absolute max for each traced tensors and + writes it into a text file pointed by 'trace_dir' flag. + - full-tensor: Writes the full tensor content of the traced tensors + into a text file pointed by 'trace_dir' flag. + - part-tensor: Writes a part of the tensor content of the traced + tensors into a text file pointed by 'trace_dir' flag. + - full_tensor_summary: Writes the full tensors as binary event files. + The outputs can be read using: trace = + tensor_tracer.read_tensor_tracer_event_file(event_file_path) + - trace-back-if-nan: This mode will write the full tensor content only + when the tensor has a NaN or Inf in it. It is possible to also print + the inputs coming to this op using 'trace_stack_size' parameter. + E.g., if trace_stack_size=2, then the tensor with NaN/Inf, its + inputs, and its inputs' inputs will also be printed. + - report_file: Path to the metadata file that is written during graph + construction. If not set, metadata will be printed to stdout during + graph construction. + - trace_dir: Path where the execution traces will be written during the + graph execution. If not set, trace will be printed to stderr. + - trace_level: Tensor tracer aims to trace everything it can. This + introduces some overhead on graph execution and graph compilation + times. Using trace_level parameter, it is possible to trace operation + based on their priorities. For example, - trace_level=7 is the highest + trace_level, in which every op is traced. - trace_level=6 will skip + constant operations such as tf.constant. - trace_level=5 will skip + less important ops such as tf.identities. - The default trace_level=3, + that will skip concat ops, or random number generators. - To reduce + the graph compile time overhead, trace_level can be set to 0, that + will skip additions, and substractions, and multiplications as well. + - excluded_opnames: If set, any matching op name will not be traced. + excluded_opnames can be set as a regular expression. E.g, + excluded_opnames=.* will exclude everything. + - excluded_optypes: If set, any matching op type will not be traced. + excluded_optypes can be set as a regular expression. E.g, + excluded_optypes=.* will exclude everything. excluded_optypes=MatMul + will exclude all MatMul ops from tracing. + - included_opnames: If set, any matching op name will be forced to be + traced. included_opnames can be set as a regular expression. E.g, + '--included_opnames=some_op --excluded_opname=*.' will only trace + some_op. + - included_optypes: If set, any matching op type will be forced to be + traced. included_optypes can be set as a regular expression. E.g, + '--included_optypes=some_op_type --excluded_optypes=*.' will trace + only the ops with type 'some_op_type' + Advanced Flags: + - compact_trace: If not set, statistics per tensor is written as soon as + they are executed. If set, then statistics for all traced tensors will + be stored in a cache and will be written only once per step. This flag + is ignored for full-tensor and part-tensor trace modes. If the + trace_dir is a remote directory, compact_trace will be forced. + - trace_scalar: Scalar values are not traced by default. If this flag is + set, scalar values will also be traced. + - included_cores: Accepts a list string. Tracing will only be dumped for + these cores. E.g, setting it to '[0,2,4,6]' will result in a trace + only for those cores. + - op_range: In the form of '%d:%d' that limits the tracing to the ops + within this limit. --op_range='5:10' will trace only the ops that have + topological order between 5-10. + - trace_before_included_ops: If set to a number-k, it will also trace + distance-k inputs of each traced tensor. E.g., k=1, then in addition + to each traced_tensor, their input tensors will also be traced. + - trace_after_included_ops: Same as trace_before_included_ops, where it + will also trace distance-k outputs of each traced tensor. + - submode: 'brief' or 'detailed'. If the trace mode is not compact, + brief mode will print only the id of each traced tensor to save some + space. 'detailed' mode prints the full tensor name. + - trace_stack_size: Used only for trace_mode=trace-back-if-nan mode. It + determines how many ops to print back from a nan op. E.g, op4 -> op3 + -> op2 -> op1 -> op0, if op0 has a NaN and trace_stack_size is 1, the + result of op1 will also be printed. trace_stack_size is 2, the result + of op1 and op2 will be printed. + """ + flags = '--%s=1' % tensor_tracer_flags.FLAG_NAME_ENABLE + if tensor_tracer_params: + for key, value in tensor_tracer_params.items(): + flags += ' --%s=%s' % (key, value) + os.environ[tensor_tracer_flags.FLAGS_ENV_VAR] = flags + + def op_priority(op_type): """Returns the priority of the op. @@ -147,6 +258,15 @@ def op_priority(op_type): def read_tensor_tracer_event_file(event_file): """Reads the event file written by tensor tracer. + This can be used to read the full tensors written into binary event files by + by TensorTracer with trace_mode=full_tensor_summary. + + Example usage: + result_dict = tensor_tracer.read_tensor_tracer_event_file(event_file_path) + for step, tensor_dict in result_dict.items(): + for tensor_name, full_tensor_content in tensor_dict.items(): + logging.info(tensor_name, full_tensor_content) + Args: event_file: Path to the event file that contains only tensor tracer events. Returns: @@ -179,31 +299,40 @@ def read_tensor_tracer_event_file(event_file): return event_dict -def tensor_tracepoint(tensor, checkpoint_name): - """Adds a checkpoint with the given checkpoint name for the given tensor. +def trace_tensor(tensor, tracepoint_name=None): + """Programmatic interface to trace a tensor with Tensor Tracer. - The tensor will be added to the list of tensors that will be traced by the - tensor tracer. + Tensor Tracer, by default, traces all tensors in the execution. This function + can be used to limit traced tensors. If this function is called for a subset + of the tensors, only those will be traced. + For example, Tensor Traacer will only trace c below. + c = tf.MatMul(a, b) + tensor_tracer.trace_tensor(c) + d = tf.add(c, 1) Args: tensor: the tensor object for which the tracing is requested. - checkpoint_name: a string name for the checkpoint. This name has to be a - unique name if used within model comparison. The tensors that have the same - checkpoint identifier is compared in model comparison. + tracepoint_name: an optional tensor tracepoint name string. A tracepoint + name is an Tensor Tracer internal name for the tensor. It is useful when + comparing equivalent traces from different models that have different + tensor namings. Equivalent tensors (with different names) can be mapped + to each other by assigning a common tracepoint_name. + Returns: The provided tensor. """ - + if tracepoint_name is None: + tracepoint_name = tensor.name tensor.graph.get_collection(_TENSOR_TRACER_COLLECTION) tensor.graph.add_to_collection(_TENSOR_TRACER_COLLECTION, - (tensor, checkpoint_name)) + (tensor, tracepoint_name)) return tensor def keras_layer_tracepoint(layer, checkpoint_name): """An interface for adding the tensor outputs of a keras layer. - Encapsulates tensor_tracepoint. + Encapsulates trace_tensor. Args: layer: A keras layer. @@ -217,12 +346,12 @@ def keras_layer_tracepoint(layer, checkpoint_name): try: outputs = layer.output if tensor_util.is_tensor(outputs): - tensor_tracepoint(outputs, '%s' % (checkpoint_name)) + trace_tensor(outputs, '%s' % (checkpoint_name)) else: idx = 0 for output_tensor in outputs: if tensor_util.is_tensor(outputs): - tensor_tracepoint(output_tensor, '%s_%d' % (checkpoint_name, idx)) + trace_tensor(output_tensor, '%s_%d' % (checkpoint_name, idx)) idx += 1 except AttributeError: pass @@ -250,20 +379,40 @@ def _trace_files_need_precreated(output_dir): class TensorTracer(object): - """A software construct for tracing tensor values in a TF graph on TPU. + """A software construct for tracing tensor values in a TF graph. - This utility is disabled by default. It can be enabled by setting - the TENSOR_TRACER_FLAGS env variable as: + This utility is disabled by default. It is hooked into tpu.rewrite, so it can + easily be enabled on TPUs by setting the TENSOR_TRACER_FLAGS env variable as + below without a code change. export TENSOR_TRACER_FLAGS="--enable=1" + + Below is the use example to enable it on CPUs or GPUs, or for more advance use + cases on TPUs. + + a = x + 1 + b = a * 2 + rs = tf.reduce_sum(b) + tensor_tracer.set_parameters({'trace_dir': 'path/to/trace_dir', + 'report_file: 'path/to/report/file'}) + tt = tensor_tracer.TensorTracer() + if on_tpu: + rs = tt.trace_tpu(tf.get_default_graph(), + tensor_fetches=rs) + else: + rs = tt.trace_cpu(tf.get_default_graph(), + tensor_fetches=rs) + session.run(rs) + If it is enabled, it will trace the output tensor values of selected Ops in the graph. It has two outputs: (1) the traces and (2) - a report. The traces are dumped to a specified local file on the TPU - host. The report is printed to the log.info of the TPU job. + a report. The traces are dumped to a specified directory during the graph + execution, while the report is dumped during the graph construction. By passing options via the env variable, users can change: (1) the trace mode (e.g., detecting NaN/Inf, printing partial or full tensor values) (2) which Ops to be traced (via op.name or op.type) (3) output trace file path. + """ # The set of graphs that are rewritten by tensor tracer. _traced_graphs = set() diff --git a/tensorflow/python/tpu/tensor_tracer_flags.py b/tensorflow/python/tpu/tensor_tracer_flags.py index 37f8ce408b1..badc44f263d 100644 --- a/tensorflow/python/tpu/tensor_tracer_flags.py +++ b/tensorflow/python/tpu/tensor_tracer_flags.py @@ -39,39 +39,43 @@ TRACE_MODE_SUMMARY = 'summary' TRACE_MODE_FULL_TENSOR_SUMMARY = 'full_tensor_summary' # Full tensor mode dumps the whole tensor values for the traced tensors without # any processing on them; using tb summaries. -_FLAG_NAME_TRACE_STACK_SIZE = 'trace_stack_size' + _SUBMODE_BRIEF = 'brief' _SUBMODE_DETAILED = 'detailed' -_FLAGS_ENV_VAR = 'TENSOR_TRACER_FLAGS' + _FLAG_SINGLE_QUOTE_PAT = re.compile(r"\s*--([^=]+)='([^']*)'") _FLAG_DOUBLE_QUOTE_PAT = re.compile(r'\s*--([^=]+)="([^"]*)"') _FLAG_NO_QUOTE_PAT = re.compile(r'\s*--([^=]+)=(\S*)') _FLAG_NO_EQUAL_PAT = re.compile(r'\s*--([^=]+)\s*') -_FLAG_NAME_ENABLE = 'enable' -_FLAG_NAME_TRACE_MODE = 'trace_mode' -_FLAG_NAME_USE_COMPACT_TRACE = 'compact_trace' -_FLAG_NAME_TRACE_SCALAR_OPS = 'trace_scalar' -_FLAG_NAME_TRACE_BEFORE_OPS = 'trace_before_included_ops' -_FLAG_NAME_TRACE_AFTER_OPS = 'trace_after_included_ops' -_FLAG_NAME_SUBMODE = 'submode' -_FLAG_NAME_INCLUDE_LESS_INTERESTING_OPS = 'include_less_interesting_ops' -_FLAG_NAME_EXCLUDED_OPNAMES = 'excluded_opnames' -_FLAG_NAME_EXCLUDED_OPTYPES = 'excluded_optypes' -_FLAG_NAME_INCLUDED_OPNAMES = 'included_opnames' -_FLAG_NAME_INCLUDED_OPTYPES = 'included_optypes' -_FLAG_NAME_INCLUDED_CORES = 'included_cores' -_FLAG_NAME_TRACE_LEVEL = 'trace_level' -_FLAG_NAME_TRACE_DIR = 'trace_dir' -_FLAG_NAME_REPORT_FILE = 'report_file' -_FLAG_NAME_USE_TEST_UNDECLARED_OUTPUTS_DIR = 'use_test_undeclared_outputs_dir' -_FLAG_NAME_OP_RANGE = 'op_range' + +FLAGS_ENV_VAR = 'TENSOR_TRACER_FLAGS' +FLAG_NAME_TRACE_STACK_SIZE = 'trace_stack_size' +FLAG_NAME_ENABLE = 'enable' +FLAG_NAME_TRACE_MODE = 'trace_mode' +FLAG_NAME_USE_COMPACT_TRACE = 'compact_trace' +FLAG_NAME_TRACE_SCALAR_OPS = 'trace_scalar' +FLAG_NAME_TRACE_BEFORE_OPS = 'trace_before_included_ops' +FLAG_NAME_TRACE_AFTER_OPS = 'trace_after_included_ops' +FLAG_NAME_SUBMODE = 'submode' +FLAG_NAME_INCLUDE_LESS_INTERESTING_OPS = 'include_less_interesting_ops' +FLAG_NAME_EXCLUDED_OPNAMES = 'excluded_opnames' +FLAG_NAME_EXCLUDED_OPTYPES = 'excluded_optypes' +FLAG_NAME_INCLUDED_OPNAMES = 'included_opnames' +FLAG_NAME_INCLUDED_OPTYPES = 'included_optypes' +FLAG_NAME_INCLUDED_CORES = 'included_cores' +FLAG_NAME_TRACE_LEVEL = 'trace_level' +FLAG_NAME_TRACE_DIR = 'trace_dir' +FLAG_NAME_REPORT_FILE = 'report_file' +FLAG_NAME_USE_TEST_UNDECLARED_OUTPUTS_DIR = 'use_test_undeclared_outputs_dir' +FLAG_NAME_OP_RANGE = 'op_range' # Folder to dump the pre (before tensor tracer updates) and post graphs (after # tensor tracer updates). -_FLAG_DUMP_BEFORE_AFTER_GRAPHS = 'dump_graphs' +FLAG_NAME_DUMP_BEFORE_AFTER_GRAPHS = 'dump_graphs' +FLAG_NAME_SUMMARY_SIGNATURES = 'signatures' +FLAG_NAME_SUMMARY_PER_CORE = 'collect_summary_per_core' + _OP_RANGE_PAT = re.compile(r'(\d+):(\d+)') _TEST_UNDECLARED_OUTPUTS_DIR_ENV_VAR = 'TEST_UNDECLARED_OUTPUTS_DIR' -_FLAG_SUMMARY_SIGNATURES = 'signatures' -_FLAG_NAME_SUMMARY_PER_CORE = 'collect_summary_per_core' _TT_DEFAULT_TRACE_LEVEL = 3 _TT_PREFIX = 'tensor_tracer' @@ -93,8 +97,6 @@ TT_SUMMARY_SIZE = '%s_%s' % (_TT_PREFIX, _TT_SIZE) TT_SUMMARY_SIGNATURES = (TT_SUMMARY_NORM, TT_SUMMARY_MAX, TT_SUMMARY_MIN, TT_SUMMARY_MEAN, TT_SUMMARY_VAR, TT_SUMMARY_SIZE) -_TT_DEFAULT_TRACE_LEVEL = 3 - class TTParameters(object): """A class that handles the parameters of Tensor Tracer.""" @@ -111,18 +113,18 @@ class TTParameters(object): self.report_file_path = self._get_report_filepath() self.op_range = self._get_op_range() self.excluded_opname_re_list = self._flag_value_to_re_list( - _FLAG_NAME_EXCLUDED_OPNAMES) + FLAG_NAME_EXCLUDED_OPNAMES) self.excluded_optype_re_list = self._flag_value_to_re_list( - _FLAG_NAME_EXCLUDED_OPTYPES) + FLAG_NAME_EXCLUDED_OPTYPES) self.included_opname_re_list = self._flag_value_to_re_list( - _FLAG_NAME_INCLUDED_OPNAMES) + FLAG_NAME_INCLUDED_OPNAMES) self.included_optype_re_list = self._flag_value_to_re_list( - _FLAG_NAME_INCLUDED_OPTYPES) + FLAG_NAME_INCLUDED_OPTYPES) self.is_conditional_trace = self._is_conditional_trace_mode() - self.trace_scalar_ops = self.is_flag_on(_FLAG_NAME_TRACE_SCALAR_OPS) - self.use_compact_trace = self.is_flag_on(_FLAG_NAME_USE_COMPACT_TRACE) + self.trace_scalar_ops = self.is_flag_on(FLAG_NAME_TRACE_SCALAR_OPS) + self.use_compact_trace = self.is_flag_on(FLAG_NAME_USE_COMPACT_TRACE) # _trace_ops_before_included and _trace_ops_after_included denotes to depth # of tracing relative to the ops given in --included_opnames or @@ -135,21 +137,20 @@ class TTParameters(object): # included op. Similarly, if --trace_after_included_ops=2, then op4 and op5 # will also be traced. self.trace_ops_before_included = self._get_flag_int_value( - _FLAG_NAME_TRACE_BEFORE_OPS, 0) + FLAG_NAME_TRACE_BEFORE_OPS, 0) self.trace_ops_after_included = self._get_flag_int_value( - _FLAG_NAME_TRACE_AFTER_OPS, 0) - self.trace_stack_size = self._get_flag_int_value( - _FLAG_NAME_TRACE_STACK_SIZE, 1) + FLAG_NAME_TRACE_AFTER_OPS, 0) + self.trace_stack_size = self._get_flag_int_value(FLAG_NAME_TRACE_STACK_SIZE, + 1) _, self.graph_dump_path = self.get_flag_value( - _FLAG_DUMP_BEFORE_AFTER_GRAPHS) - self.included_cores = self._flag_value_as_int_list( - _FLAG_NAME_INCLUDED_CORES) + FLAG_NAME_DUMP_BEFORE_AFTER_GRAPHS) + self.included_cores = self._flag_value_as_int_list(FLAG_NAME_INCLUDED_CORES) self.include_less_interesting_ops = self.is_flag_on( - _FLAG_NAME_INCLUDE_LESS_INTERESTING_OPS) - self.trace_level = self._get_flag_int_value( - _FLAG_NAME_TRACE_LEVEL, _TT_DEFAULT_TRACE_LEVEL) + FLAG_NAME_INCLUDE_LESS_INTERESTING_OPS) + self.trace_level = self._get_flag_int_value(FLAG_NAME_TRACE_LEVEL, + _TT_DEFAULT_TRACE_LEVEL) self.summary_signatures = self._get_summary_signatures() - self.collect_summary_per_core = self.is_flag_on(_FLAG_NAME_SUMMARY_PER_CORE) + self.collect_summary_per_core = self.is_flag_on(FLAG_NAME_SUMMARY_PER_CORE) def _is_conditional_trace_mode(self): return self.trace_mode == TRACE_MODE_FULL_IF_NAN @@ -157,8 +158,7 @@ class TTParameters(object): def _get_report_filepath(self): """Sets the path of the output report file.""" - found, report_file_path = self.get_flag_value( - _FLAG_NAME_REPORT_FILE) + found, report_file_path = self.get_flag_value(FLAG_NAME_REPORT_FILE) if found and report_file_path \ and self.use_test_undeclared_outputs_dir(): if os.path.isabs(report_file_path): @@ -171,7 +171,7 @@ class TTParameters(object): def _get_op_range(self): """Sets the index range of the Ops that we will consider tracing.""" - found, op_range = self.get_flag_value(_FLAG_NAME_OP_RANGE) + found, op_range = self.get_flag_value(FLAG_NAME_OP_RANGE) if not found or not op_range: op_range = (-1, -1) # this means including all ops. return op_range @@ -183,12 +183,12 @@ class TTParameters(object): return op_range def _get_trace_dir(self): - found, trace_dir = self.get_flag_value(_FLAG_NAME_TRACE_DIR) + found, trace_dir = self.get_flag_value(FLAG_NAME_TRACE_DIR) if found and trace_dir \ and self.use_test_undeclared_outputs_dir(): - raise ValueError('Cannot not use --%s and --%s at the same time' - %(_FLAG_NAME_TRACE_DIR, - _FLAG_NAME_USE_TEST_UNDECLARED_OUTPUTS_DIR)) + raise ValueError( + 'Cannot not use --%s and --%s at the same time' % + (FLAG_NAME_TRACE_DIR, FLAG_NAME_USE_TEST_UNDECLARED_OUTPUTS_DIR)) if self.use_test_undeclared_outputs_dir(): trace_dir = self._env.get(_TEST_UNDECLARED_OUTPUTS_DIR_ENV_VAR) return trace_dir @@ -196,7 +196,7 @@ class TTParameters(object): def _get_trace_mode(self): """Checks if the given trace mode is valid.""" - found, trace_mode = self.get_flag_value(_FLAG_NAME_TRACE_MODE) + found, trace_mode = self.get_flag_value(FLAG_NAME_TRACE_MODE) if not found or not trace_mode: trace_mode = TRACE_MODE_NORM valid_trace_modes = [ @@ -216,7 +216,7 @@ class TTParameters(object): def _get_submode(self): """Checks if the given submode is valid.""" - found, submode = self.get_flag_value(_FLAG_NAME_SUBMODE) + found, submode = self.get_flag_value(FLAG_NAME_SUBMODE) if not found or not submode: submode = _SUBMODE_DETAILED if not submode: @@ -261,19 +261,19 @@ class TTParameters(object): def _validate_flag_names(self): """Validates if the TensorTrace flags passed are valid.""" valid_flag_names = [ - _FLAG_NAME_ENABLE, _FLAG_NAME_TRACE_MODE, _FLAG_NAME_USE_COMPACT_TRACE, - _FLAG_NAME_TRACE_SCALAR_OPS, _FLAG_NAME_TRACE_BEFORE_OPS, - _FLAG_NAME_TRACE_AFTER_OPS, _FLAG_NAME_TRACE_STACK_SIZE, - _FLAG_NAME_SUBMODE, _FLAG_NAME_EXCLUDED_OPNAMES, - _FLAG_NAME_EXCLUDED_OPTYPES, _FLAG_NAME_INCLUDED_OPNAMES, - _FLAG_NAME_INCLUDED_OPTYPES, _FLAG_NAME_TRACE_DIR, - _FLAG_NAME_INCLUDED_CORES, _FLAG_NAME_REPORT_FILE, - _FLAG_NAME_USE_TEST_UNDECLARED_OUTPUTS_DIR, - _FLAG_NAME_INCLUDE_LESS_INTERESTING_OPS, _FLAG_NAME_OP_RANGE, - _FLAG_DUMP_BEFORE_AFTER_GRAPHS, _FLAG_NAME_TRACE_LEVEL, - _FLAG_SUMMARY_SIGNATURES, _FLAG_NAME_SUMMARY_PER_CORE + FLAG_NAME_ENABLE, FLAG_NAME_TRACE_MODE, FLAG_NAME_USE_COMPACT_TRACE, + FLAG_NAME_TRACE_SCALAR_OPS, FLAG_NAME_TRACE_BEFORE_OPS, + FLAG_NAME_TRACE_AFTER_OPS, FLAG_NAME_TRACE_STACK_SIZE, + FLAG_NAME_SUBMODE, FLAG_NAME_EXCLUDED_OPNAMES, + FLAG_NAME_EXCLUDED_OPTYPES, FLAG_NAME_INCLUDED_OPNAMES, + FLAG_NAME_INCLUDED_OPTYPES, FLAG_NAME_TRACE_DIR, + FLAG_NAME_INCLUDED_CORES, FLAG_NAME_REPORT_FILE, + FLAG_NAME_USE_TEST_UNDECLARED_OUTPUTS_DIR, + FLAG_NAME_INCLUDE_LESS_INTERESTING_OPS, FLAG_NAME_OP_RANGE, + FLAG_NAME_DUMP_BEFORE_AFTER_GRAPHS, FLAG_NAME_TRACE_LEVEL, + FLAG_NAME_SUMMARY_SIGNATURES, FLAG_NAME_SUMMARY_PER_CORE ] - tensor_tracer_flags = self._env.get(_FLAGS_ENV_VAR) + tensor_tracer_flags = self._env.get(FLAGS_ENV_VAR) if not tensor_tracer_flags: return pos = 0 @@ -286,7 +286,7 @@ class TTParameters(object): raise ValueError( 'The flag name "%s" passed via the environment variable "%s" ' 'is invalid. Valid flag names are:' - '\n%s'%(flag_name, _FLAGS_ENV_VAR, valid_flag_names)) + '\n%s' % (flag_name, FLAGS_ENV_VAR, valid_flag_names)) pos = match.end() def _get_summary_signatures(self): @@ -296,7 +296,7 @@ class TTParameters(object): A dictionary of the signature identifiers {signature: index} that will be computed when trace_mode is summary. """ - signatures = self._flag_value_as_list(_FLAG_SUMMARY_SIGNATURES) + signatures = self._flag_value_as_list(FLAG_NAME_SUMMARY_SIGNATURES) tt_signatures = [] for signature in signatures: @@ -398,7 +398,7 @@ class TTParameters(object): RuntimeError: If supposedly deadcode is reached. """ - tensor_tracer_flags = self._env.get(_FLAGS_ENV_VAR) + tensor_tracer_flags = self._env.get(FLAGS_ENV_VAR) if not tensor_tracer_flags: return False, None pos = 0 @@ -446,9 +446,9 @@ class TTParameters(object): def is_enabled(self): """Returns True if TensorTracer is enabled.""" - if self.is_flag_on(_FLAG_NAME_ENABLE): + if self.is_flag_on(FLAG_NAME_ENABLE): logging.info('Tensor Tracer is enabled with flags %s.' % - self._env.get(_FLAGS_ENV_VAR)) + self._env.get(FLAGS_ENV_VAR)) return True else: return False @@ -465,4 +465,4 @@ class TTParameters(object): env variable. """ - return self.is_flag_on(_FLAG_NAME_USE_TEST_UNDECLARED_OUTPUTS_DIR) + return self.is_flag_on(FLAG_NAME_USE_TEST_UNDECLARED_OUTPUTS_DIR) diff --git a/tensorflow/python/tpu/tpu.bzl b/tensorflow/python/tpu/tpu.bzl index 767f27ded54..ba58e57e90d 100644 --- a/tensorflow/python/tpu/tpu.bzl +++ b/tensorflow/python/tpu/tpu.bzl @@ -43,6 +43,7 @@ def tpu_py_test( "no_pip", "no_gpu", "nomac", + "local", ] + tags test_main = kwargs.get("srcs") diff --git a/tensorflow/python/tpu/tpu.py b/tensorflow/python/tpu/tpu.py index 768ef072052..7f65f623f3f 100644 --- a/tensorflow/python/tpu/tpu.py +++ b/tensorflow/python/tpu/tpu.py @@ -1218,6 +1218,7 @@ def split_compile_and_replicate(computation, for i in inputs[0]]), infeed_queue.number_of_tuple_elements, arg_error)) + dynamic_shape_inputs = False if maximum_shapes: if infeed_queue: raise ValueError( @@ -1248,6 +1249,8 @@ def split_compile_and_replicate(computation, flat_inputs, padding_maps = _pad_all_input(flat_inputs, flat_maximum_shapes, padding_spec) + if padding_maps: + dynamic_shape_inputs = True serialized_padding_maps = [] for padding_map in padding_maps: @@ -1304,7 +1307,7 @@ def split_compile_and_replicate(computation, # inputs when dynamic padding is enabled. # TODO(rxsang): Use other ways except argument index in padding_map so # outside compilation can work with dynamic padding correctly. - if maximum_shapes is None or composite: + if not dynamic_shape_inputs or composite: i.op._set_attr("_tpu_input_identity", attr_value_pb2.AttrValue(b=True)) # pylint: enable=protected-access @@ -1339,9 +1342,8 @@ def split_compile_and_replicate(computation, kwargs["partitioner"] = None logging.warning( "Partitioned variables are not supported on TPU. Got " - "`partitioner` that is {} for variable {}. " - "Setting `partitioner` to `None`." - .format(partitioner, name)) + "`partitioner` that is %s for variable %s. " + "Setting `partitioner` to `None`.", partitioner, name) if saved_custom_getter is None: return getter(name, *args, **kwargs) else: diff --git a/tensorflow/python/training/experimental/mixed_precision.py b/tensorflow/python/training/experimental/mixed_precision.py index 8e1bf42ddb8..38377dd0600 100644 --- a/tensorflow/python/training/experimental/mixed_precision.py +++ b/tensorflow/python/training/experimental/mixed_precision.py @@ -23,49 +23,35 @@ from tensorflow.python.platform import tf_logging from tensorflow.python.training import optimizer from tensorflow.python.training.experimental import loss_scale_optimizer as loss_scale_optimizer_v1 from tensorflow.python.training.experimental import mixed_precision_global_state -from tensorflow.python.util import tf_inspect from tensorflow.python.util.tf_export import tf_export +# A mapping between optimizers and the corresponding wrapper class that will be +# used for mixed precision. +_REGISTERED_WRAPPER_OPTIMIZER_CLS = { + optimizer.Optimizer: + loss_scale_optimizer_v1.MixedPrecisionLossScaleOptimizer, +} + + +def _register_wrapper_optimizer_cls(optimizer_cls, wrapper_optimizer_cls): + _REGISTERED_WRAPPER_OPTIMIZER_CLS[optimizer_cls] = wrapper_optimizer_cls + + def _wrap_optimizer(opt, loss_scale, use_v1_behavior): """Wraps an optimizer with a LossScaleOptimizer.""" - if isinstance(opt, loss_scale_optimizer_v1.MixedPrecisionLossScaleOptimizer): - raise ValueError('"opt" must not already be an instance of a ' - 'MixedPrecisionLossScaleOptimizer. ' - '`enable_mixed_precision_graph_rewrite` will ' - 'automatically wrap the optimizer with a ' - 'MixedPrecisionLossScaleOptimizer.') - # To avoid a circular dependency, we cannot depend on tf.keras. Because - # LossScaleOptimizer is in Keras, we cannot use isinstance, so instead check - # the class name. - if opt.__class__.__name__ == 'LossScaleOptimizer': - raise ValueError('"opt" must not already be an instance of a ' - 'LossScaleOptimizer. ' - '`enable_mixed_precision_graph_rewrite` will ' - 'automatically wrap the optimizer with a ' - 'LossScaleOptimizer.') + for wrapper_optimizer in _REGISTERED_WRAPPER_OPTIMIZER_CLS.values(): + if isinstance(opt, wrapper_optimizer): + raise ValueError('"opt" must not already be an instance of a {cls}. ' + '`enable_mixed_precision_graph_rewrite` will ' + 'automatically wrap the optimizer with a ' + '{cls}.' + .format(cls=wrapper_optimizer.__name__)) - if isinstance(opt, optimizer.Optimizer): - # For convenience, we allow the V2 version of this function to wrap the V1 - # optimizer, even though we do not document this. - return loss_scale_optimizer_v1.MixedPrecisionLossScaleOptimizer(opt, - loss_scale) - - # Because we cannot depend on tf.keras, we see if `opt` is an instance of the - # Keras OptimizerV2 class by checking the subclass names. - base_classes = tf_inspect.getmro(opt.__class__) - base_class_names = [cls.__name__ for cls in base_classes] - is_loss_scale_optimizer_v2 = 'OptimizerV2' in base_class_names - - if is_loss_scale_optimizer_v2: - # Because we cannot depend on tf.keras, we cannot unconditionally do this - # import. But since `opt` is a Keras OptimizerV2, we know keras is - # importable, so it is safe to do this import. (Technically, it's possible - # to have a dependency on OptimizerV2 and not LossScaleOptimizer, but this - # is not done in practice). - from tensorflow.python.keras.mixed_precision.experimental import loss_scale_optimizer as loss_scale_optimizer_v2 # pylint: disable=g-import-not-at-top - return loss_scale_optimizer_v2.LossScaleOptimizer(opt, loss_scale) + for optimizer_cls, wrapper_cls in _REGISTERED_WRAPPER_OPTIMIZER_CLS.items(): + if isinstance(opt, optimizer_cls): + return wrapper_cls(opt, loss_scale) if use_v1_behavior: raise ValueError('"opt" must be an instance of a tf.train.Optimizer or a ' diff --git a/tensorflow/python/training/learning_rate_decay.py b/tensorflow/python/training/learning_rate_decay.py index f86e68d188f..86a718f8c5b 100644 --- a/tensorflow/python/training/learning_rate_decay.py +++ b/tensorflow/python/training/learning_rate_decay.py @@ -17,755 +17,16 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import functools -from tensorflow.python.eager import context -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.keras.optimizer_v2 import learning_rate_schedule -from tensorflow.python.ops import math_ops -from tensorflow.python.util.tf_export import tf_export +from tensorflow.python.keras.optimizer_v2 import legacy_learning_rate_decay as learning_rate_decay -@tf_export(v1=["train.exponential_decay"]) -def exponential_decay(learning_rate, - global_step, - decay_steps, - decay_rate, - staircase=False, - name=None): - """Applies exponential decay to the learning rate. - - When training a model, it is often recommended to lower the learning rate as - the training progresses. This function applies an exponential decay function - to a provided initial learning rate. It requires a `global_step` value to - compute the decayed learning rate. You can just pass a TensorFlow variable - that you increment at each training step. - - The function returns the decayed learning rate. It is computed as: - - ```python - decayed_learning_rate = learning_rate * - decay_rate ^ (global_step / decay_steps) - ``` - - If the argument `staircase` is `True`, then `global_step / decay_steps` is an - integer division and the decayed learning rate follows a staircase function. - - Example: decay every 100000 steps with a base of 0.96: - - ```python - ... - global_step = tf.Variable(0, trainable=False) - starter_learning_rate = 0.1 - learning_rate = tf.compat.v1.train.exponential_decay(starter_learning_rate, - global_step, - 100000, 0.96, staircase=True) - # Passing global_step to minimize() will increment it at each step. - learning_step = ( - tf.compat.v1.train.GradientDescentOptimizer(learning_rate) - .minimize(...my loss..., global_step=global_step) - ) - ``` - - Args: - learning_rate: A scalar `float32` or `float64` `Tensor` or a Python number. - The initial learning rate. - global_step: A scalar `int32` or `int64` `Tensor` or a Python number. Global - step to use for the decay computation. Must not be negative. - decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number. Must - be positive. See the decay computation above. - decay_rate: A scalar `float32` or `float64` `Tensor` or a Python number. - The decay rate. - staircase: Boolean. If `True` decay the learning rate at discrete intervals - name: String. Optional name of the operation. Defaults to - 'ExponentialDecay'. - - Returns: - A scalar `Tensor` of the same type as `learning_rate`. The decayed - learning rate. - - Raises: - ValueError: if `global_step` is not supplied. - - @compatibility(eager) - When eager execution is enabled, this function returns a function which in - turn returns the decayed learning rate Tensor. This can be useful for changing - the learning rate value across different invocations of optimizer functions. - @end_compatibility - """ - decayed_lr = learning_rate_schedule.ExponentialDecay( - learning_rate, decay_steps, decay_rate, staircase=staircase, name=name) - if not context.executing_eagerly(): - decayed_lr = decayed_lr(global_step) - else: - decayed_lr = functools.partial(decayed_lr, global_step) - return decayed_lr - - -@tf_export(v1=["train.piecewise_constant_decay", "train.piecewise_constant"]) -def piecewise_constant(x, boundaries, values, name=None): - """Piecewise constant from boundaries and interval values. - - Example: use a learning rate that's 1.0 for the first 100001 steps, 0.5 - for the next 10000 steps, and 0.1 for any additional steps. - - ```python - global_step = tf.Variable(0, trainable=False) - boundaries = [100000, 110000] - values = [1.0, 0.5, 0.1] - learning_rate = tf.compat.v1.train.piecewise_constant(global_step, boundaries, - values) - - # Later, whenever we perform an optimization step, we increment global_step. - ``` - - Args: - x: A 0-D scalar `Tensor`. Must be one of the following types: `float32`, - `float64`, `uint8`, `int8`, `int16`, `int32`, `int64`. - boundaries: A list of `Tensor`s or `int`s or `float`s with strictly - increasing entries, and with all elements having the same type as `x`. - values: A list of `Tensor`s or `float`s or `int`s that specifies the values - for the intervals defined by `boundaries`. It should have one more element - than `boundaries`, and all elements should have the same type. - name: A string. Optional name of the operation. Defaults to - 'PiecewiseConstant'. - - Returns: - A 0-D Tensor. Its value is `values[0]` when `x <= boundaries[0]`, - `values[1]` when `x > boundaries[0]` and `x <= boundaries[1]`, ..., - and values[-1] when `x > boundaries[-1]`. - - Raises: - ValueError: if types of `x` and `boundaries` do not match, or types of all - `values` do not match or - the number of elements in the lists does not match. - - @compatibility(eager) - When eager execution is enabled, this function returns a function which in - turn returns the decayed learning rate Tensor. This can be useful for changing - the learning rate value across different invocations of optimizer functions. - @end_compatibility - """ - boundaries = ops.convert_n_to_tensor(boundaries) - values = ops.convert_n_to_tensor(values) - x_recomp = ops.convert_to_tensor(x) - # Avoid explicit conversion to x's dtype. This could result in faulty - # comparisons, for example if floats are converted to integers. - for i, b in enumerate(boundaries): - if b.dtype.base_dtype != x_recomp.dtype.base_dtype: - # We can promote int32 boundaries to int64 without loss of precision. - # This covers the most common case where the user passes in boundaries - # as an array of Python integers. - if (b.dtype.base_dtype == dtypes.int32 and - x_recomp.dtype.base_dtype == dtypes.int64): - b = math_ops.cast(b, x_recomp.dtype.base_dtype) - boundaries[i] = b - else: - raise ValueError( - "Boundaries (%s) must have the same dtype as x (%s)." % - (b.dtype.base_dtype, x_recomp.dtype.base_dtype)) - for v in values[1:]: - if v.dtype.base_dtype != values[0].dtype.base_dtype: - raise ValueError( - "Values must have elements all with the same dtype (%s vs %s)." % - (values[0].dtype.base_dtype, v.dtype.base_dtype)) - decayed_lr = learning_rate_schedule.PiecewiseConstantDecay( - boundaries, values, name=name) - if not context.executing_eagerly(): - decayed_lr = decayed_lr(x) - else: - decayed_lr = functools.partial(decayed_lr, x) - return decayed_lr - - -@tf_export(v1=["train.polynomial_decay"]) -def polynomial_decay(learning_rate, - global_step, - decay_steps, - end_learning_rate=0.0001, - power=1.0, - cycle=False, - name=None): - """Applies a polynomial decay to the learning rate. - - It is commonly observed that a monotonically decreasing learning rate, whose - degree of change is carefully chosen, results in a better performing model. - This function applies a polynomial decay function to a provided initial - `learning_rate` to reach an `end_learning_rate` in the given `decay_steps`. - - It requires a `global_step` value to compute the decayed learning rate. You - can just pass a TensorFlow variable that you increment at each training step. - - The function returns the decayed learning rate. It is computed as: - - ```python - global_step = min(global_step, decay_steps) - decayed_learning_rate = (learning_rate - end_learning_rate) * - (1 - global_step / decay_steps) ^ (power) + - end_learning_rate - - ``` - - If `cycle` is True then a multiple of `decay_steps` is used, the first one - that is bigger than `global_steps`. - - ```python - decay_steps = decay_steps * ceil(global_step / decay_steps) - decayed_learning_rate = (learning_rate - end_learning_rate) * - (1 - global_step / decay_steps) ^ (power) + - end_learning_rate - - ``` - - Example: decay from 0.1 to 0.01 in 10000 steps using sqrt (i.e. power=0.5): - - ```python - ... - global_step = tf.Variable(0, trainable=False) - starter_learning_rate = 0.1 - end_learning_rate = 0.01 - decay_steps = 10000 - learning_rate = tf.compat.v1.train.polynomial_decay(starter_learning_rate, - global_step, - decay_steps, end_learning_rate, - power=0.5) - # Passing global_step to minimize() will increment it at each step. - learning_step = ( - tf.compat.v1.train.GradientDescentOptimizer(learning_rate) - .minimize(...my loss..., global_step=global_step) - ) - ``` - - Args: - learning_rate: A scalar `float32` or `float64` `Tensor` or a Python number. - The initial learning rate. - global_step: A scalar `int32` or `int64` `Tensor` or a Python number. Global - step to use for the decay computation. Must not be negative. - decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number. Must - be positive. See the decay computation above. - end_learning_rate: A scalar `float32` or `float64` `Tensor` or a Python - number. The minimal end learning rate. - power: A scalar `float32` or `float64` `Tensor` or a Python number. The - power of the polynomial. Defaults to linear, 1.0. - cycle: A boolean, whether or not it should cycle beyond decay_steps. - name: String. Optional name of the operation. Defaults to - 'PolynomialDecay'. - - Returns: - A scalar `Tensor` of the same type as `learning_rate`. The decayed - learning rate. - - Raises: - ValueError: if `global_step` is not supplied. - - @compatibility(eager) - When eager execution is enabled, this function returns a function which in - turn returns the decayed learning rate Tensor. This can be useful for changing - the learning rate value across different invocations of optimizer functions. - @end_compatibility - """ - decayed_lr = learning_rate_schedule.PolynomialDecay( - learning_rate, - decay_steps, - end_learning_rate=end_learning_rate, - power=power, - cycle=cycle, - name=name) - - if not context.executing_eagerly(): - decayed_lr = decayed_lr(global_step) - else: - decayed_lr = functools.partial(decayed_lr, global_step) - return decayed_lr - - -@tf_export(v1=["train.natural_exp_decay"]) -def natural_exp_decay(learning_rate, - global_step, - decay_steps, - decay_rate, - staircase=False, - name=None): - """Applies natural exponential decay to the initial learning rate. - - When training a model, it is often recommended to lower the learning rate as - the training progresses. This function applies an exponential decay function - to a provided initial learning rate. It requires an `global_step` value to - compute the decayed learning rate. You can just pass a TensorFlow variable - that you increment at each training step. - - The function returns the decayed learning rate. It is computed as: - - ```python - decayed_learning_rate = learning_rate * exp(-decay_rate * global_step / - decay_step) - ``` - - or, if `staircase` is `True`, as: - - ```python - decayed_learning_rate = learning_rate * exp(-decay_rate * floor(global_step / - decay_step)) - ``` - - Example: decay exponentially with a base of 0.96: - - ```python - ... - global_step = tf.Variable(0, trainable=False) - learning_rate = 0.1 - decay_steps = 5 - k = 0.5 - learning_rate = tf.compat.v1.train.natural_exp_decay(learning_rate, - global_step, - decay_steps, k) - - # Passing global_step to minimize() will increment it at each step. - learning_step = ( - tf.compat.v1.train.GradientDescentOptimizer(learning_rate) - .minimize(...my loss..., global_step=global_step) - ) - ``` - - Args: - learning_rate: A scalar `float32` or `float64` `Tensor` or a Python number. - The initial learning rate. - global_step: A Python number. Global step to use for the decay computation. - Must not be negative. - decay_steps: How often to apply decay. - decay_rate: A Python number. The decay rate. - staircase: Whether to apply decay in a discrete staircase, as opposed to - continuous, fashion. - name: String. Optional name of the operation. Defaults to - 'ExponentialTimeDecay'. - - Returns: - A scalar `Tensor` of the same type as `learning_rate`. The decayed - learning rate. - - Raises: - ValueError: if `global_step` is not supplied. - - @compatibility(eager) - When eager execution is enabled, this function returns a function which in - turn returns the decayed learning rate Tensor. This can be useful for changing - the learning rate value across different invocations of optimizer functions. - @end_compatibility - """ - natural_exp_rate = math_ops.exp(math_ops.negative(decay_rate)) - decayed_lr = learning_rate_schedule.ExponentialDecay( - learning_rate, - decay_steps, - natural_exp_rate, - staircase=staircase, - name=name) - - if not context.executing_eagerly(): - decayed_lr = decayed_lr(global_step) - else: - decayed_lr = functools.partial(decayed_lr, global_step) - return decayed_lr - - -@tf_export(v1=["train.inverse_time_decay"]) -def inverse_time_decay(learning_rate, - global_step, - decay_steps, - decay_rate, - staircase=False, - name=None): - """Applies inverse time decay to the initial learning rate. - - When training a model, it is often recommended to lower the learning rate as - the training progresses. This function applies an inverse decay function - to a provided initial learning rate. It requires an `global_step` value to - compute the decayed learning rate. You can just pass a TensorFlow variable - that you increment at each training step. - - The function returns the decayed learning rate. It is computed as: - - ```python - decayed_learning_rate = learning_rate / (1 + decay_rate * global_step / - decay_step) - ``` - - or, if `staircase` is `True`, as: - - ```python - decayed_learning_rate = learning_rate / (1 + decay_rate * floor(global_step / - decay_step)) - ``` - - Example: decay 1/t with a rate of 0.5: - - ```python - ... - global_step = tf.Variable(0, trainable=False) - learning_rate = 0.1 - decay_steps = 1.0 - decay_rate = 0.5 - learning_rate = tf.compat.v1.train.inverse_time_decay(learning_rate, - global_step, - decay_steps, decay_rate) - - # Passing global_step to minimize() will increment it at each step. - learning_step = ( - tf.compat.v1.train.GradientDescentOptimizer(learning_rate) - .minimize(...my loss..., global_step=global_step) - ) - ``` - - Args: - learning_rate: A scalar `float32` or `float64` `Tensor` or a Python number. - The initial learning rate. - global_step: A Python number. Global step to use for the decay computation. - Must not be negative. - decay_steps: How often to apply decay. - decay_rate: A Python number. The decay rate. - staircase: Whether to apply decay in a discrete staircase, as opposed to - continuous, fashion. - name: String. Optional name of the operation. Defaults to - 'InverseTimeDecay'. - - Returns: - A scalar `Tensor` of the same type as `learning_rate`. The decayed - learning rate. - - Raises: - ValueError: if `global_step` is not supplied. - - @compatibility(eager) - When eager execution is enabled, this function returns a function which in - turn returns the decayed learning rate Tensor. This can be useful for changing - the learning rate value across different invocations of optimizer functions. - @end_compatibility - """ - decayed_lr = learning_rate_schedule.InverseTimeDecay( - learning_rate, decay_steps, decay_rate, staircase=staircase, name=name) - - if not context.executing_eagerly(): - decayed_lr = decayed_lr(global_step) - else: - decayed_lr = functools.partial(decayed_lr, global_step) - return decayed_lr - - -@tf_export(v1=["train.cosine_decay"]) -def cosine_decay(learning_rate, global_step, decay_steps, alpha=0.0, name=None): - """Applies cosine decay to the learning rate. - - When training a model, it is often recommended to lower the learning rate as - the training progresses. This function applies a cosine decay function - to a provided initial learning rate. It requires a `global_step` value to - compute the decayed learning rate. You can just pass a TensorFlow variable - that you increment at each training step. - - The function returns the decayed learning rate. It is computed as: - ```python - global_step = min(global_step, decay_steps) - cosine_decay = 0.5 * (1 + cos(pi * global_step / decay_steps)) - decayed = (1 - alpha) * cosine_decay + alpha - decayed_learning_rate = learning_rate * decayed - ``` - - Example usage: - ```python - decay_steps = 1000 - lr_decayed = cosine_decay(learning_rate, global_step, decay_steps) - ``` - - Args: - learning_rate: A scalar `float32` or `float64` Tensor or a Python number. - The initial learning rate. - global_step: A scalar `int32` or `int64` `Tensor` or a Python number. Global - step to use for the decay computation. - decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number. Number - of steps to decay over. - alpha: A scalar `float32` or `float64` Tensor or a Python number. Minimum - learning rate value as a fraction of learning_rate. - name: String. Optional name of the operation. Defaults to 'CosineDecay'. - - Returns: - A scalar `Tensor` of the same type as `learning_rate`. The decayed - learning rate. - Raises: - ValueError: if `global_step` is not supplied. - - References: - Stochastic Gradient Descent with Warm Restarts: - [Loshchilov et al., 2017] - (https://openreview.net/forum?id=Skq89Scxx¬eId=Skq89Scxx) - ([pdf](https://openreview.net/pdf?id=Skq89Scxx)) - - @compatibility(eager) - When eager execution is enabled, this function returns a function which in - turn returns the decayed learning rate Tensor. This can be useful for changing - the learning rate value across different invocations of optimizer functions. - @end_compatibility - """ - decayed_lr = learning_rate_schedule.CosineDecay( - learning_rate, decay_steps, alpha=alpha, name=name) - - if not context.executing_eagerly(): - decayed_lr = decayed_lr(global_step) - else: - decayed_lr = functools.partial(decayed_lr, global_step) - return decayed_lr - - -@tf_export(v1=["train.cosine_decay_restarts"]) -def cosine_decay_restarts(learning_rate, - global_step, - first_decay_steps, - t_mul=2.0, - m_mul=1.0, - alpha=0.0, - name=None): - """Applies cosine decay with restarts to the learning rate. - - When training a model, it is often recommended to lower the learning rate as - the training progresses. This function applies a cosine decay function with - restarts to a provided initial learning rate. It requires a `global_step` - value to compute the decayed learning rate. You can just pass a TensorFlow - variable that you increment at each training step. - - The function returns the decayed learning rate while taking into account - possible warm restarts. The learning rate multiplier first decays - from 1 to `alpha` for `first_decay_steps` steps. Then, a warm - restart is performed. Each new warm restart runs for `t_mul` times more steps - and with `m_mul` times smaller initial learning rate. - - Example usage: - ```python - first_decay_steps = 1000 - lr_decayed = cosine_decay_restarts(learning_rate, global_step, - first_decay_steps) - ``` - - Args: - learning_rate: A scalar `float32` or `float64` Tensor or a Python number. - The initial learning rate. - global_step: A scalar `int32` or `int64` `Tensor` or a Python number. Global - step to use for the decay computation. - first_decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number. - Number of steps to decay over. - t_mul: A scalar `float32` or `float64` `Tensor` or a Python number. Used to - derive the number of iterations in the i-th period - m_mul: A scalar `float32` or `float64` `Tensor` or a Python number. - Used to derive the initial learning rate of the i-th period: - alpha: A scalar `float32` or `float64` Tensor or a Python number. Minimum - learning rate value as a fraction of the learning_rate. - name: String. Optional name of the operation. Defaults to 'SGDRDecay'. - - Returns: - A scalar `Tensor` of the same type as `learning_rate`. The decayed - learning rate. - Raises: - ValueError: if `global_step` is not supplied. - - References: - Stochastic Gradient Descent with Warm Restarts: - [Loshchilov et al., 2017] - (https://openreview.net/forum?id=Skq89Scxx¬eId=Skq89Scxx) - ([pdf](https://openreview.net/pdf?id=Skq89Scxx)) - - @compatibility(eager) - When eager execution is enabled, this function returns a function which in - turn returns the decayed learning rate Tensor. This can be useful for changing - the learning rate value across different invocations of optimizer functions. - @end_compatibility - """ - decayed_lr = learning_rate_schedule.CosineDecayRestarts( - learning_rate, - first_decay_steps, - t_mul=t_mul, - m_mul=m_mul, - alpha=alpha, - name=name) - - if not context.executing_eagerly(): - decayed_lr = decayed_lr(global_step) - else: - decayed_lr = functools.partial(decayed_lr, global_step) - return decayed_lr - - -@tf_export(v1=["train.linear_cosine_decay"]) -def linear_cosine_decay(learning_rate, - global_step, - decay_steps, - num_periods=0.5, - alpha=0.0, - beta=0.001, - name=None): - """Applies linear cosine decay to the learning rate. - - Note that linear cosine decay is more aggressive than cosine decay and - larger initial learning rates can typically be used. - - When training a model, it is often recommended to lower the learning rate as - the training progresses. This function applies a linear cosine decay function - to a provided initial learning rate. It requires a `global_step` value to - compute the decayed learning rate. You can just pass a TensorFlow variable - that you increment at each training step. - - The function returns the decayed learning rate. It is computed as: - ```python - global_step = min(global_step, decay_steps) - linear_decay = (decay_steps - global_step) / decay_steps) - cosine_decay = 0.5 * ( - 1 + cos(pi * 2 * num_periods * global_step / decay_steps)) - decayed = (alpha + linear_decay) * cosine_decay + beta - decayed_learning_rate = learning_rate * decayed - ``` - - Example usage: - ```python - decay_steps = 1000 - lr_decayed = linear_cosine_decay(learning_rate, global_step, decay_steps) - ``` - - Args: - learning_rate: A scalar `float32` or `float64` Tensor or a Python number. - The initial learning rate. - global_step: A scalar `int32` or `int64` `Tensor` or a Python number. Global - step to use for the decay computation. - decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number. Number - of steps to decay over. - num_periods: Number of periods in the cosine part of the decay. See - computation above. - alpha: See computation above. - beta: See computation above. - name: String. Optional name of the operation. Defaults to - 'LinearCosineDecay'. - - Returns: - A scalar `Tensor` of the same type as `learning_rate`. The decayed - learning rate. - Raises: - ValueError: if `global_step` is not supplied. - - References: - Neural Optimizer Search with Reinforcement Learning: - [Bello et al., 2017](http://proceedings.mlr.press/v70/bello17a.html) - ([pdf](http://proceedings.mlr.press/v70/bello17a/bello17a.pdf)) - Stochastic Gradient Descent with Warm Restarts: - [Loshchilov et al., 2017] - (https://openreview.net/forum?id=Skq89Scxx¬eId=Skq89Scxx) - ([pdf](https://openreview.net/pdf?id=Skq89Scxx)) - - @compatibility(eager) - When eager execution is enabled, this function returns a function which in - turn returns the decayed learning rate Tensor. This can be useful for changing - the learning rate value across different invocations of optimizer functions. - @end_compatibility - """ - decayed_lr = learning_rate_schedule.LinearCosineDecay( - learning_rate, - decay_steps, - num_periods=num_periods, - alpha=alpha, - beta=beta, - name=name) - - if not context.executing_eagerly(): - decayed_lr = decayed_lr(global_step) - else: - decayed_lr = functools.partial(decayed_lr, global_step) - return decayed_lr - - -@tf_export(v1=["train.noisy_linear_cosine_decay"]) -def noisy_linear_cosine_decay(learning_rate, - global_step, - decay_steps, - initial_variance=1.0, - variance_decay=0.55, - num_periods=0.5, - alpha=0.0, - beta=0.001, - name=None): - """Applies noisy linear cosine decay to the learning rate. - - Note that linear cosine decay is more aggressive than cosine decay and - larger initial learning rates can typically be used. - - When training a model, it is often recommended to lower the learning rate as - the training progresses. This function applies a noisy linear - cosine decay function to a provided initial learning rate. - It requires a `global_step` value to compute the decayed learning rate. - You can just pass a TensorFlow variable that you increment at each - training step. - - The function returns the decayed learning rate. It is computed as: - ```python - global_step = min(global_step, decay_steps) - linear_decay = (decay_steps - global_step) / decay_steps) - cosine_decay = 0.5 * ( - 1 + cos(pi * 2 * num_periods * global_step / decay_steps)) - decayed = (alpha + linear_decay + eps_t) * cosine_decay + beta - decayed_learning_rate = learning_rate * decayed - ``` - where eps_t is 0-centered gaussian noise with variance - initial_variance / (1 + global_step) ** variance_decay - - Example usage: - ```python - decay_steps = 1000 - lr_decayed = noisy_linear_cosine_decay( - learning_rate, global_step, decay_steps) - ``` - - Args: - learning_rate: A scalar `float32` or `float64` Tensor or a Python number. - The initial learning rate. - global_step: A scalar `int32` or `int64` `Tensor` or a Python number. Global - step to use for the decay computation. - decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number. Number - of steps to decay over. - initial_variance: initial variance for the noise. See computation above. - variance_decay: decay for the noise's variance. See computation above. - num_periods: Number of periods in the cosine part of the decay. See - computation above. - alpha: See computation above. - beta: See computation above. - name: String. Optional name of the operation. Defaults to - 'NoisyLinearCosineDecay'. - - Returns: - A scalar `Tensor` of the same type as `learning_rate`. The decayed - learning rate. - Raises: - ValueError: if `global_step` is not supplied. - - References: - Neural Optimizer Search with Reinforcement Learning: - [Bello et al., 2017](http://proceedings.mlr.press/v70/bello17a.html) - ([pdf](http://proceedings.mlr.press/v70/bello17a/bello17a.pdf)) - Stochastic Gradient Descent with Warm Restarts: - [Loshchilov et al., 2017] - (https://openreview.net/forum?id=Skq89Scxx¬eId=Skq89Scxx) - ([pdf](https://openreview.net/pdf?id=Skq89Scxx)) - - @compatibility(eager) - When eager execution is enabled, this function returns a function which in - turn returns the decayed learning rate Tensor. This can be useful for changing - the learning rate value across different invocations of optimizer functions. - @end_compatibility - """ - decayed_lr = learning_rate_schedule.NoisyLinearCosineDecay( - learning_rate, - decay_steps, - initial_variance=initial_variance, - variance_decay=variance_decay, - num_periods=num_periods, - alpha=alpha, - beta=beta, - name=name) - - if not context.executing_eagerly(): - decayed_lr = decayed_lr(global_step) - else: - decayed_lr = functools.partial(decayed_lr, global_step) - return decayed_lr +exponential_decay = learning_rate_decay.exponential_decay +piecewise_constant = learning_rate_decay.piecewise_constant +polynomial_decay = learning_rate_decay.polynomial_decay +natural_exp_decay = learning_rate_decay.natural_exp_decay +inverse_time_decay = learning_rate_decay.inverse_time_decay +cosine_decay = learning_rate_decay.cosine_decay +cosine_decay_restarts = learning_rate_decay.cosine_decay_restarts +linear_cosine_decay = learning_rate_decay.linear_cosine_decay +noisy_linear_cosine_decay = learning_rate_decay.noisy_linear_cosine_decay diff --git a/tensorflow/python/training/tracking/graph_view.py b/tensorflow/python/training/tracking/graph_view.py index 54b22fa07f9..041ff38eedd 100644 --- a/tensorflow/python/training/tracking/graph_view.py +++ b/tensorflow/python/training/tracking/graph_view.py @@ -93,7 +93,9 @@ def _serialize_slot_variables(trackable_objects, node_ids, object_names): for trackable in non_slot_objects: if (isinstance(trackable, optimizer_v1.Optimizer) # TODO(b/110718070): Fix Keras imports. - or hasattr(trackable, "_create_or_restore_slot_variable")): + # Note: dir() is used rather than hasattr() here to avoid triggering + # custom __getattr__ code, see b/152031870 for context. + or "_create_or_restore_slot_variable" in dir(trackable)): naming_scheme = _slot_variable_naming_for_optimizer( optimizer_path=object_names[trackable]) slot_names = trackable.get_slot_names() diff --git a/tensorflow/python/training/tracking/util.py b/tensorflow/python/training/tracking/util.py index eeaa2a541c5..144ec068551 100644 --- a/tensorflow/python/training/tracking/util.py +++ b/tensorflow/python/training/tracking/util.py @@ -52,23 +52,28 @@ from tensorflow.python.training.tracking import graph_view as graph_view_lib from tensorflow.python.training.tracking import tracking from tensorflow.python.util import compat from tensorflow.python.util import deprecation -from tensorflow.python.util import lazy_loader from tensorflow.python.util import object_identity from tensorflow.python.util import tf_contextlib from tensorflow.python.util.tf_export import tf_export -# Loaded lazily due to a circular dependency. -keras_backend = lazy_loader.LazyLoader( - "keras_backend", globals(), - "tensorflow.python.keras.backend") +# The callable that provide Keras default session that is needed for saving. +_SESSION_PROVIDER = None + + +def register_session_provider(session_provider): + global _SESSION_PROVIDER + if _SESSION_PROVIDER is None: + _SESSION_PROVIDER = session_provider def get_session(): # Prefer TF's default session since get_session from Keras has side-effects. session = ops.get_default_session() if session is None: - session = keras_backend.get_session() + global _SESSION_PROVIDER + if _SESSION_PROVIDER is not None: + session = _SESSION_PROVIDER() # pylint: disable=not-callable return session diff --git a/tensorflow/python/types/BUILD b/tensorflow/python/types/BUILD new file mode 100644 index 00000000000..040555b910f --- /dev/null +++ b/tensorflow/python/types/BUILD @@ -0,0 +1,31 @@ +load("//tensorflow:tensorflow.bzl", "py_strict_library") + +package( + licenses = ["notice"], # Apache 2.0 +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) + +# Important: this is a leaf library. It may not have any new dependencies inside TF proper. +# The sole exception is tf_export, to allow exporting symbols into the public namespace. +py_strict_library( + name = "types", + srcs = [ + "__init__.py", + "core.py", + ], + srcs_version = "PY2AND3", + visibility = ["//tensorflow:__subpackages__"], + deps = [ + ], +) diff --git a/tensorflow/python/types/__init__.py b/tensorflow/python/types/__init__.py new file mode 100644 index 00000000000..72a749cff46 --- /dev/null +++ b/tensorflow/python/types/__init__.py @@ -0,0 +1,25 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Public TensorFlow type definitions. + +For details, see +https://github.com/tensorflow/community/blob/master/rfcs/20200211-tf-types.md. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# Note: this module should contain **type definitions only**. diff --git a/tensorflow/python/types/core.py b/tensorflow/python/types/core.py new file mode 100644 index 00000000000..20da83e562d --- /dev/null +++ b/tensorflow/python/types/core.py @@ -0,0 +1,60 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Core TensorFlow types.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +# TODO(mdan): Consider adding ABC once the dependence on isinstance is reduced. +# TODO(mdan): Add type annotations. + + +class Tensor(object): + """The base class of all dense Tensor objects. + + A dense tensor has a static data type (dtype), and may have a static rank and + shape. Tensor objects are immutable. Mutable objects may be backed by a Tensor + which holds the unique handle that identifies the mutable object. + """ + + @property + def dtype(self): + pass + + @property + def shape(self): + pass + + +class Symbol(Tensor): + """Symbolic "graph" Tensor. + + These objects represent the output of an op definition and do not carry a + value. + """ + pass + + +class Value(Tensor): + """Tensor that can be associated with a value (aka "eager tensor"). + + These objects represent the (usually future) output of executing an op + immediately. + """ + + def numpy(self): + pass diff --git a/tensorflow/python/util/serialization.py b/tensorflow/python/util/serialization.py index c34383c5f2e..1e5de4cb280 100644 --- a/tensorflow/python/util/serialization.py +++ b/tensorflow/python/util/serialization.py @@ -23,7 +23,6 @@ import wrapt from tensorflow.python.framework import tensor_shape from tensorflow.python.util.compat import collections_abc -from tensorflow.python.keras.utils import generic_utils def get_json_type(obj): @@ -41,10 +40,7 @@ def get_json_type(obj): # if obj is a serializable Keras class instance # e.g. optimizer, layer if hasattr(obj, 'get_config'): - return { - 'class_name': generic_utils.get_registered_name(obj.__class__), - 'config': obj.get_config() - } + return {'class_name': obj.__class__.__name__, 'config': obj.get_config()} # if obj is any numpy type if type(obj).__module__ == np.__name__: diff --git a/tensorflow/python/util/serialization_test.py b/tensorflow/python/util/serialization_test.py index 2f4b5d55b03..6df7533831b 100644 --- a/tensorflow/python/util/serialization_test.py +++ b/tensorflow/python/util/serialization_test.py @@ -23,12 +23,10 @@ import json from tensorflow.python.framework import constant_op from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import test_util -from tensorflow.python.keras import losses from tensorflow.python.keras.engine import input_layer from tensorflow.python.keras.engine import sequential from tensorflow.python.keras.engine import training from tensorflow.python.keras.layers import core -from tensorflow.python.keras.utils import losses_utils, generic_utils from tensorflow.python.platform import test from tensorflow.python.util import serialization @@ -71,41 +69,5 @@ class SerializationTests(test.TestCase): self.assertEqual( 10, model_round_trip["config"]["layers"][1]["config"]["units"]) - @test_util.run_in_graph_and_eager_modes - def test_serialize_custom_model_compile(self): - with generic_utils.custom_object_scope(): - - @generic_utils.register_keras_serializable(package="dummy-package") - class DummySparseCategoricalCrossentropyLoss(losses.LossFunctionWrapper): - # This loss is identical equal to tf.keras.losses.SparseCategoricalCrossentropy - def __init__( - self, - from_logits=False, - reduction=losses_utils.ReductionV2.AUTO, - name="dummy_sparse_categorical_crossentropy_loss", - ): - super(DummySparseCategoricalCrossentropyLoss, self).__init__( - losses.sparse_categorical_crossentropy, - name=name, - reduction=reduction, - from_logits=from_logits, - ) - - x = input_layer.Input(shape=[3]) - y = core.Dense(10)(x) - model = training.Model(x, y) - model.compile( - loss=DummySparseCategoricalCrossentropyLoss(from_logits=True)) - model_round_trip = json.loads( - json.dumps(model.loss, default=serialization.get_json_type)) - - # check if class name with package scope - self.assertEqual("dummy-package>DummySparseCategoricalCrossentropyLoss", - model_round_trip["class_name"]) - - # check if configure is correctly - self.assertEqual(True, model_round_trip["config"]["from_logits"]) - - if __name__ == "__main__": test.main() diff --git a/tensorflow/stream_executor/cuda/cublas_stub.cc b/tensorflow/stream_executor/cuda/cublas_stub.cc index 5c1b666bcef..dd13ad0960b 100644 --- a/tensorflow/stream_executor/cuda/cublas_stub.cc +++ b/tensorflow/stream_executor/cuda/cublas_stub.cc @@ -57,7 +57,7 @@ cublasStatus_t GetSymbolNotFoundError() { return CUBLAS_STATUS_INTERNAL_ERROR; } typedef enum {} cublasMath_t; #endif -#if CUDA_VERSION < 9020 +#if CUDA_VERSION < 10000 #include "tensorflow/stream_executor/cuda/cublas_9_0.inc" #elif CUDA_VERSION == 10000 #include "tensorflow/stream_executor/cuda/cublas_10_0.inc" diff --git a/tensorflow/stream_executor/cuda/cuda_stub.cc b/tensorflow/stream_executor/cuda/cuda_stub.cc index ebdc4a33db6..ce02be89c22 100644 --- a/tensorflow/stream_executor/cuda/cuda_stub.cc +++ b/tensorflow/stream_executor/cuda/cuda_stub.cc @@ -93,7 +93,7 @@ typedef struct CUDA_EXTERNAL_SEMAPHORE_WAIT_PARAMS_st CUDA_EXTERNAL_SEMAPHORE_WAIT_PARAMS; typedef void(CUDA_CB* CUhostFn)(void* userData); -#if CUDA_VERSION <= 9000 +#if CUDA_VERSION < 10000 #include "tensorflow/stream_executor/cuda/cuda_9_0.inc" #elif CUDA_VERSION == 10000 #include "tensorflow/stream_executor/cuda/cuda_10_0.inc" diff --git a/tensorflow/stream_executor/cuda/cudart_stub.cc b/tensorflow/stream_executor/cuda/cudart_stub.cc index 3afe6780402..5ee106a65fd 100644 --- a/tensorflow/stream_executor/cuda/cudart_stub.cc +++ b/tensorflow/stream_executor/cuda/cudart_stub.cc @@ -51,7 +51,7 @@ cudaError_t GetSymbolNotFoundError() { #define __CUDA_DEPRECATED // A bunch of new symbols were introduced in version 10 -#if CUDART_VERSION <= 9020 +#if CUDART_VERSION < 10000 #include "tensorflow/stream_executor/cuda/cuda_runtime_9_0.inc" #elif CUDART_VERSION == 10000 #include "tensorflow/stream_executor/cuda/cuda_runtime_10_0.inc" diff --git a/tensorflow/stream_executor/cuda/cusparse_stub.cc b/tensorflow/stream_executor/cuda/cusparse_stub.cc index b2f76fe6d5c..783b034d7b6 100644 --- a/tensorflow/stream_executor/cuda/cusparse_stub.cc +++ b/tensorflow/stream_executor/cuda/cusparse_stub.cc @@ -51,7 +51,7 @@ cusparseStatus_t GetSymbolNotFoundError() { } } // namespace -#if CUDA_VERSION < 9020 +#if CUDA_VERSION < 10000 #include "tensorflow/stream_executor/cuda/cusparse_9_0.inc" #elif CUDA_VERSION == 10000 #include "tensorflow/stream_executor/cuda/cusparse_10_0.inc" diff --git a/tensorflow/stream_executor/host/host_gpu_executor.h b/tensorflow/stream_executor/host/host_gpu_executor.h index 50873f4491f..d40a7a88015 100644 --- a/tensorflow/stream_executor/host/host_gpu_executor.h +++ b/tensorflow/stream_executor/host/host_gpu_executor.h @@ -157,7 +157,7 @@ class HostExecutor : public internal::StreamExecutorInterface { } port::Status SetDeviceSharedMemoryConfig(SharedMemoryConfig config) override { - string error_msg{ + std::string error_msg{ "Shared memory configuration is unsupported for host " "executors."}; LOG(INFO) << error_msg; diff --git a/tensorflow/stream_executor/host/host_platform.cc b/tensorflow/stream_executor/host/host_platform.cc index ab3167d8686..a829ffc96fb 100644 --- a/tensorflow/stream_executor/host/host_platform.cc +++ b/tensorflow/stream_executor/host/host_platform.cc @@ -38,7 +38,7 @@ int HostPlatform::VisibleDeviceCount() const { return std::thread::hardware_concurrency(); } -const string& HostPlatform::Name() const { return name_; } +const std::string& HostPlatform::Name() const { return name_; } port::StatusOr> HostPlatform::DescriptionForDevice(int ordinal) const { diff --git a/tensorflow/stream_executor/host/host_platform.h b/tensorflow/stream_executor/host/host_platform.h index bf1d6c79589..a7eaa2b60cc 100644 --- a/tensorflow/stream_executor/host/host_platform.h +++ b/tensorflow/stream_executor/host/host_platform.h @@ -49,7 +49,7 @@ class HostPlatform : public Platform { // base::NumCPUs(). int VisibleDeviceCount() const override; - const string& Name() const override; + const std::string& Name() const override; port::StatusOr> DescriptionForDevice( int ordinal) const override; @@ -71,7 +71,7 @@ class HostPlatform : public Platform { private: // This platform's name. - string name_; + std::string name_; // Cache of created StreamExecutors. ExecutorCache executor_cache_; diff --git a/tensorflow/tools/api/golden/v1/tensorflow.-indexed-slices.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-indexed-slices.pbtxt index 780f5e5cb6a..e9e805426a1 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.-indexed-slices.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.-indexed-slices.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.IndexedSlices" tf_class { is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/v1/tensorflow.-ragged-tensor-spec.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-ragged-tensor-spec.pbtxt index 029d04fee9b..2ec5bb46ed1 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.-ragged-tensor-spec.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.-ragged-tensor-spec.pbtxt @@ -4,10 +4,6 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" - member { - name: "dtype" - mtype: "" - } member { name: "value_type" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.-ragged-tensor.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-ragged-tensor.pbtxt index 920a48a2294..44a66874e70 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.-ragged-tensor.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.-ragged-tensor.pbtxt @@ -2,6 +2,7 @@ path: "tensorflow.RaggedTensor" tf_class { is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "dtype" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.-sparse-tensor.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-sparse-tensor.pbtxt index 261a6ae35a4..d71812ce83a 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.-sparse-tensor.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.-sparse-tensor.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.SparseTensor" tf_class { is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/v1/tensorflow.-tensor.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-tensor.pbtxt index fd80af51d2a..33742e3b867 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.-tensor.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.-tensor.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.Tensor" tf_class { is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" member { name: "OVERLOADABLE_OPERATORS" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.-dataset.pbtxt index 841b142c082..872d03770ed 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.data.-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.data.-dataset.pbtxt @@ -63,7 +63,7 @@ tf_class { } member_method { name: "from_generator" - argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\', \'output_signature\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " } member_method { name: "from_sparse_tensor_slices" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.-fixed-length-record-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.-fixed-length-record-dataset.pbtxt index 42225d3f566..a84c5aa3caf 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.data.-fixed-length-record-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.data.-fixed-length-record-dataset.pbtxt @@ -65,7 +65,7 @@ tf_class { } member_method { name: "from_generator" - argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\', \'output_signature\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " } member_method { name: "from_sparse_tensor_slices" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.-t-f-record-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.-t-f-record-dataset.pbtxt index 81a1c7fbd9c..a3862ae2a19 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.data.-t-f-record-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.data.-t-f-record-dataset.pbtxt @@ -65,7 +65,7 @@ tf_class { } member_method { name: "from_generator" - argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\', \'output_signature\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " } member_method { name: "from_sparse_tensor_slices" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.-text-line-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.-text-line-dataset.pbtxt index e9e3962a498..baaaf7ea7be 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.data.-text-line-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.data.-text-line-dataset.pbtxt @@ -65,7 +65,7 @@ tf_class { } member_method { name: "from_generator" - argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\', \'output_signature\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " } member_method { name: "from_sparse_tensor_slices" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-csv-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-csv-dataset.pbtxt index 20712fb14a7..afdeea5d018 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-csv-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-csv-dataset.pbtxt @@ -65,7 +65,7 @@ tf_class { } member_method { name: "from_generator" - argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\', \'output_signature\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " } member_method { name: "from_sparse_tensor_slices" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-random-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-random-dataset.pbtxt index c139c6b9cc8..76113c5e01d 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-random-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-random-dataset.pbtxt @@ -65,7 +65,7 @@ tf_class { } member_method { name: "from_generator" - argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\', \'output_signature\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " } member_method { name: "from_sparse_tensor_slices" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-sql-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-sql-dataset.pbtxt index 41a67db62dc..1a11026fd19 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-sql-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-sql-dataset.pbtxt @@ -65,7 +65,7 @@ tf_class { } member_method { name: "from_generator" - argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\', \'output_signature\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " } member_method { name: "from_sparse_tensor_slices" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.-initializer.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.-initializer.pbtxt index 8f10d1698e7..74597c8c7cf 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.-initializer.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.-initializer.pbtxt @@ -1,6 +1,6 @@ path: "tensorflow.keras.initializers.Initializer" tf_class { - is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.-random-normal.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.-random-normal.pbtxt index 0dc36c98210..c395b6c3fda 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.-random-normal.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.-random-normal.pbtxt @@ -1,6 +1,6 @@ path: "tensorflow.keras.initializers.RandomNormal" tf_class { - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.-random-uniform.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.-random-uniform.pbtxt index fd6647cb381..c43acee1881 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.-random-uniform.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.-random-uniform.pbtxt @@ -1,6 +1,6 @@ path: "tensorflow.keras.initializers.RandomUniform" tf_class { - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.-truncated-normal.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.-truncated-normal.pbtxt index 6fe47498271..0d6464c0b56 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.-truncated-normal.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.-truncated-normal.pbtxt @@ -1,6 +1,6 @@ path: "tensorflow.keras.initializers.TruncatedNormal" tf_class { - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.he_normal.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.he_normal.pbtxt new file mode 100644 index 00000000000..3d2141b445c --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.he_normal.pbtxt @@ -0,0 +1,19 @@ +path: "tensorflow.keras.initializers.he_normal" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.he_uniform.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.he_uniform.pbtxt new file mode 100644 index 00000000000..5fa07a71315 --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.he_uniform.pbtxt @@ -0,0 +1,19 @@ +path: "tensorflow.keras.initializers.he_uniform" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.lecun_normal.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.lecun_normal.pbtxt new file mode 100644 index 00000000000..e2a0ee1e548 --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.lecun_normal.pbtxt @@ -0,0 +1,19 @@ +path: "tensorflow.keras.initializers.lecun_normal" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.lecun_uniform.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.lecun_uniform.pbtxt new file mode 100644 index 00000000000..070d2ff8be9 --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.lecun_uniform.pbtxt @@ -0,0 +1,19 @@ +path: "tensorflow.keras.initializers.lecun_uniform" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.normal.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.normal.pbtxt index 145200c0a59..f1f9a7f0fba 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.normal.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.normal.pbtxt @@ -1,6 +1,6 @@ path: "tensorflow.keras.initializers.normal" tf_class { - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.pbtxt index 1540c2915bf..11794d5005a 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.pbtxt @@ -52,10 +52,26 @@ tf_module { name: "glorot_uniform" mtype: "" } + member { + name: "he_normal" + mtype: "" + } + member { + name: "he_uniform" + mtype: "" + } member { name: "identity" mtype: "" } + member { + name: "lecun_normal" + mtype: "" + } + member { + name: "lecun_uniform" + mtype: "" + } member { name: "normal" mtype: "" @@ -96,22 +112,6 @@ tf_module { name: "get" argspec: "args=[\'identifier\'], varargs=None, keywords=None, defaults=None" } - member_method { - name: "he_normal" - argspec: "args=[\'seed\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "he_uniform" - argspec: "args=[\'seed\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "lecun_normal" - argspec: "args=[\'seed\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "lecun_uniform" - argspec: "args=[\'seed\'], varargs=None, keywords=None, defaults=[\'None\'], " - } member_method { name: "serialize" argspec: "args=[\'initializer\'], varargs=None, keywords=None, defaults=None" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.random_normal.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.random_normal.pbtxt index 51ff41cb283..fc24a55cb9b 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.random_normal.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.random_normal.pbtxt @@ -1,6 +1,6 @@ path: "tensorflow.keras.initializers.random_normal" tf_class { - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.random_uniform.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.random_uniform.pbtxt index bdb6724d063..1d3fd598799 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.random_uniform.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.random_uniform.pbtxt @@ -1,6 +1,6 @@ path: "tensorflow.keras.initializers.random_uniform" tf_class { - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.truncated_normal.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.truncated_normal.pbtxt index 7647e170786..63232343dcd 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.truncated_normal.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.truncated_normal.pbtxt @@ -1,6 +1,6 @@ path: "tensorflow.keras.initializers.truncated_normal" tf_class { - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.uniform.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.uniform.pbtxt index 0b85fbea918..ac43b6e992c 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.uniform.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.uniform.pbtxt @@ -1,6 +1,6 @@ path: "tensorflow.keras.initializers.uniform" tf_class { - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv3-d-transpose.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv3-d-transpose.pbtxt index 68e37cc8475..a2998f59114 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv3-d-transpose.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv3-d-transpose.pbtxt @@ -115,7 +115,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'output_padding\', \'data_format\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1, 1)\', \'valid\', \'None\', \'None\', \'None\', \'True\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'output_padding\', \'data_format\', \'dilation_rate\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1, 1)\', \'valid\', \'None\', \'None\', \'(1, 1, 1)\', \'None\', \'True\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], " } member_method { name: "add_loss" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution3-d-transpose.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution3-d-transpose.pbtxt index 115afd297ba..3f750a6200b 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution3-d-transpose.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution3-d-transpose.pbtxt @@ -115,7 +115,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'output_padding\', \'data_format\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1, 1)\', \'valid\', \'None\', \'None\', \'None\', \'True\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'output_padding\', \'data_format\', \'dilation_rate\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1, 1)\', \'valid\', \'None\', \'None\', \'(1, 1, 1)\', \'None\', \'True\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], " } member_method { name: "add_loss" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.losses.-k-l-divergence.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.losses.-k-l-divergence.pbtxt index 30d68f097be..b15ba6f2d6d 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.losses.-k-l-divergence.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.losses.-k-l-divergence.pbtxt @@ -6,7 +6,7 @@ tf_class { is_instance: "" member_method { name: "__init__" - argspec: "args=[\'self\', \'reduction\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'kullback_leibler_divergence\'], " + argspec: "args=[\'self\', \'reduction\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'kl_divergence\'], " } member_method { name: "call" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.losses.-log-cosh.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.losses.-log-cosh.pbtxt index 9310f07f509..1bdc6751a4a 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.losses.-log-cosh.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.losses.-log-cosh.pbtxt @@ -6,7 +6,7 @@ tf_class { is_instance: "" member_method { name: "__init__" - argspec: "args=[\'self\', \'reduction\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'logcosh\'], " + argspec: "args=[\'self\', \'reduction\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'log_cosh\'], " } member_method { name: "call" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.losses.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.losses.pbtxt index e24947ad19a..8f4c6a78b26 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.losses.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.losses.pbtxt @@ -120,6 +120,10 @@ tf_module { name: "hinge" argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "kl_divergence" + argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "kld" argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" @@ -128,6 +132,10 @@ tf_module { name: "kullback_leibler_divergence" argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "log_cosh" + argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "logcosh" argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.pbtxt index 2b09ccc48a2..57876312213 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.pbtxt @@ -208,6 +208,10 @@ tf_module { name: "hinge" argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "kl_divergence" + argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "kld" argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt index 54c0c8f9b5b..1247b3615f4 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt @@ -1554,11 +1554,11 @@ tf_module { } member_method { name: "lin_space" - argspec: "args=[\'start\', \'stop\', \'num\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'start\', \'stop\', \'num\', \'name\', \'axis\'], varargs=None, keywords=None, defaults=[\'None\', \'0\'], " } member_method { name: "linspace" - argspec: "args=[\'start\', \'stop\', \'num\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'start\', \'stop\', \'num\', \'name\', \'axis\'], varargs=None, keywords=None, defaults=[\'None\', \'0\'], " } member_method { name: "load_file_system_library" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.sparse.-sparse-tensor.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.sparse.-sparse-tensor.pbtxt index a563015afe3..a3ea216468e 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.sparse.-sparse-tensor.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.sparse.-sparse-tensor.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.sparse.SparseTensor" tf_class { is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-indexed-slices.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-indexed-slices.pbtxt index 780f5e5cb6a..e9e805426a1 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.-indexed-slices.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.-indexed-slices.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.IndexedSlices" tf_class { is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-ragged-tensor-spec.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-ragged-tensor-spec.pbtxt index 029d04fee9b..2ec5bb46ed1 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.-ragged-tensor-spec.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.-ragged-tensor-spec.pbtxt @@ -4,10 +4,6 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" - member { - name: "dtype" - mtype: "" - } member { name: "value_type" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-ragged-tensor.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-ragged-tensor.pbtxt index 920a48a2294..44a66874e70 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.-ragged-tensor.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.-ragged-tensor.pbtxt @@ -2,6 +2,7 @@ path: "tensorflow.RaggedTensor" tf_class { is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "dtype" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-sparse-tensor.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-sparse-tensor.pbtxt index 261a6ae35a4..d71812ce83a 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.-sparse-tensor.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.-sparse-tensor.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.SparseTensor" tf_class { is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-tensor.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-tensor.pbtxt index fd80af51d2a..33742e3b867 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.-tensor.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.-tensor.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.Tensor" tf_class { is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" member { name: "OVERLOADABLE_OPERATORS" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt index 3cb50feac2d..d9414c31e7d 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt @@ -46,7 +46,7 @@ tf_class { } member_method { name: "from_generator" - argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\', \'output_signature\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " } member_method { name: "from_tensor_slices" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.pbtxt index 9e2fa7255fd..28efdb6e855 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.pbtxt @@ -48,7 +48,7 @@ tf_class { } member_method { name: "from_generator" - argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\', \'output_signature\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " } member_method { name: "from_tensor_slices" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.pbtxt index 1bd43d28bc4..c9553efb58c 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.pbtxt @@ -47,7 +47,7 @@ tf_class { } member_method { name: "from_generator" - argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\', \'output_signature\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " } member_method { name: "from_tensor_slices" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.pbtxt index 2e295c44b5f..16a878144ae 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.pbtxt @@ -48,7 +48,7 @@ tf_class { } member_method { name: "from_generator" - argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\', \'output_signature\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " } member_method { name: "from_tensor_slices" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-csv-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-csv-dataset.pbtxt index 91175909f77..d1d2db041e0 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-csv-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-csv-dataset.pbtxt @@ -48,7 +48,7 @@ tf_class { } member_method { name: "from_generator" - argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\', \'output_signature\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " } member_method { name: "from_tensor_slices" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-random-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-random-dataset.pbtxt index 09ed74d3460..18a6b8cbd1b 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-random-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-random-dataset.pbtxt @@ -48,7 +48,7 @@ tf_class { } member_method { name: "from_generator" - argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\', \'output_signature\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " } member_method { name: "from_tensor_slices" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-sql-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-sql-dataset.pbtxt index c245d563e9e..0cf3d94ba68 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-sql-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-sql-dataset.pbtxt @@ -48,7 +48,7 @@ tf_class { } member_method { name: "from_generator" - argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\', \'output_signature\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " } member_method { name: "from_tensor_slices" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.initializers.-constant.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.initializers.-constant.pbtxt index 502fda18fd8..175a5a9637c 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.initializers.-constant.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.initializers.-constant.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.initializers.Constant" tf_class { - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.initializers.-glorot-normal.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.initializers.-glorot-normal.pbtxt index 06beddc818b..0df534b7567 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.initializers.-glorot-normal.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.initializers.-glorot-normal.pbtxt @@ -1,8 +1,10 @@ path: "tensorflow.initializers.GlorotNormal" tf_class { - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.initializers.-glorot-uniform.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.initializers.-glorot-uniform.pbtxt index f37448f6346..15db2f10eb5 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.initializers.-glorot-uniform.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.initializers.-glorot-uniform.pbtxt @@ -1,8 +1,10 @@ path: "tensorflow.initializers.GlorotUniform" tf_class { - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.initializers.-he-normal.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.initializers.-he-normal.pbtxt new file mode 100644 index 00000000000..c23aa7827c3 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.initializers.-he-normal.pbtxt @@ -0,0 +1,21 @@ +path: "tensorflow.initializers.HeNormal" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.initializers.-he-uniform.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.initializers.-he-uniform.pbtxt new file mode 100644 index 00000000000..70412ed4252 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.initializers.-he-uniform.pbtxt @@ -0,0 +1,21 @@ +path: "tensorflow.initializers.HeUniform" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.initializers.-identity.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.initializers.-identity.pbtxt index fc6f16b04fa..b6f3b9fe45d 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.initializers.-identity.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.initializers.-identity.pbtxt @@ -1,7 +1,9 @@ path: "tensorflow.initializers.Identity" tf_class { + is_instance: "" is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.initializers.-initializer.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.initializers.-initializer.pbtxt index 03a69732c6c..8b032efce8b 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.initializers.-initializer.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.initializers.-initializer.pbtxt @@ -1,6 +1,6 @@ path: "tensorflow.initializers.Initializer" tf_class { - is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.initializers.-lecun-normal.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.initializers.-lecun-normal.pbtxt new file mode 100644 index 00000000000..a392394e113 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.initializers.-lecun-normal.pbtxt @@ -0,0 +1,21 @@ +path: "tensorflow.initializers.LecunNormal" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.initializers.-lecun-uniform.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.initializers.-lecun-uniform.pbtxt new file mode 100644 index 00000000000..d863752d047 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.initializers.-lecun-uniform.pbtxt @@ -0,0 +1,21 @@ +path: "tensorflow.initializers.LecunUniform" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.initializers.-ones.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.initializers.-ones.pbtxt index c2ff715c649..ad86af8e06b 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.initializers.-ones.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.initializers.-ones.pbtxt @@ -1,7 +1,9 @@ path: "tensorflow.initializers.Ones" tf_class { + is_instance: "" is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.initializers.-orthogonal.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.initializers.-orthogonal.pbtxt index 303752f934f..c918524bf17 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.initializers.-orthogonal.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.initializers.-orthogonal.pbtxt @@ -1,7 +1,9 @@ path: "tensorflow.initializers.Orthogonal" tf_class { + is_instance: "" is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.initializers.-random-normal.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.initializers.-random-normal.pbtxt index c1b1dfb4d23..aab2d0c5916 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.initializers.-random-normal.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.initializers.-random-normal.pbtxt @@ -1,7 +1,9 @@ path: "tensorflow.initializers.RandomNormal" tf_class { + is_instance: "" is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.initializers.-random-uniform.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.initializers.-random-uniform.pbtxt index eac13a13246..3952a353150 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.initializers.-random-uniform.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.initializers.-random-uniform.pbtxt @@ -1,7 +1,9 @@ path: "tensorflow.initializers.RandomUniform" tf_class { + is_instance: "" is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.initializers.-truncated-normal.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.initializers.-truncated-normal.pbtxt index b796faef0cf..53b9f2039e5 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.initializers.-truncated-normal.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.initializers.-truncated-normal.pbtxt @@ -1,7 +1,9 @@ path: "tensorflow.initializers.TruncatedNormal" tf_class { + is_instance: "" is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.initializers.-variance-scaling.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.initializers.-variance-scaling.pbtxt index ea46406d3a6..bb9a8470467 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.initializers.-variance-scaling.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.initializers.-variance-scaling.pbtxt @@ -1,7 +1,9 @@ path: "tensorflow.initializers.VarianceScaling" tf_class { + is_instance: "" is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.initializers.-zeros.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.initializers.-zeros.pbtxt index 2596186705f..45a1535e052 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.initializers.-zeros.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.initializers.-zeros.pbtxt @@ -1,7 +1,9 @@ path: "tensorflow.initializers.Zeros" tf_class { + is_instance: "" is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.initializers.constant.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.initializers.constant.pbtxt index f26775f4c84..1fa3c47ade8 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.initializers.constant.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.initializers.constant.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.initializers.constant" tf_class { - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.initializers.glorot_normal.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.initializers.glorot_normal.pbtxt index 0af16f3ae89..30e92a31858 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.initializers.glorot_normal.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.initializers.glorot_normal.pbtxt @@ -1,8 +1,10 @@ path: "tensorflow.initializers.glorot_normal" tf_class { - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.initializers.glorot_uniform.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.initializers.glorot_uniform.pbtxt index d8272f2c64f..fc43bef7141 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.initializers.glorot_uniform.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.initializers.glorot_uniform.pbtxt @@ -1,8 +1,10 @@ path: "tensorflow.initializers.glorot_uniform" tf_class { - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.initializers.he_normal.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.initializers.he_normal.pbtxt new file mode 100644 index 00000000000..0cade59f8de --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.initializers.he_normal.pbtxt @@ -0,0 +1,21 @@ +path: "tensorflow.initializers.he_normal" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.initializers.he_uniform.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.initializers.he_uniform.pbtxt new file mode 100644 index 00000000000..3b43fd2c763 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.initializers.he_uniform.pbtxt @@ -0,0 +1,21 @@ +path: "tensorflow.initializers.he_uniform" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.initializers.identity.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.initializers.identity.pbtxt index c11593fe312..e857e75a84c 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.initializers.identity.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.initializers.identity.pbtxt @@ -1,7 +1,9 @@ path: "tensorflow.initializers.identity" tf_class { + is_instance: "" is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.initializers.lecun_normal.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.initializers.lecun_normal.pbtxt new file mode 100644 index 00000000000..8dfe4da3ea4 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.initializers.lecun_normal.pbtxt @@ -0,0 +1,21 @@ +path: "tensorflow.initializers.lecun_normal" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.initializers.lecun_uniform.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.initializers.lecun_uniform.pbtxt new file mode 100644 index 00000000000..df8dfefc69c --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.initializers.lecun_uniform.pbtxt @@ -0,0 +1,21 @@ +path: "tensorflow.initializers.lecun_uniform" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.initializers.ones.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.initializers.ones.pbtxt index 436465c10a6..9c9024bd203 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.initializers.ones.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.initializers.ones.pbtxt @@ -1,7 +1,9 @@ path: "tensorflow.initializers.ones" tf_class { + is_instance: "" is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.initializers.orthogonal.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.initializers.orthogonal.pbtxt index 6603f66736e..fa90188c798 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.initializers.orthogonal.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.initializers.orthogonal.pbtxt @@ -1,7 +1,9 @@ path: "tensorflow.initializers.orthogonal" tf_class { + is_instance: "" is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.initializers.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.initializers.pbtxt index bd61603c835..baed423d368 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.initializers.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.initializers.pbtxt @@ -12,6 +12,14 @@ tf_module { name: "GlorotUniform" mtype: "" } + member { + name: "HeNormal" + mtype: "" + } + member { + name: "HeUniform" + mtype: "" + } member { name: "Identity" mtype: "" @@ -20,6 +28,14 @@ tf_module { name: "Initializer" mtype: "" } + member { + name: "LecunNormal" + mtype: "" + } + member { + name: "LecunUniform" + mtype: "" + } member { name: "Ones" mtype: "" @@ -60,10 +76,26 @@ tf_module { name: "glorot_uniform" mtype: "" } + member { + name: "he_normal" + mtype: "" + } + member { + name: "he_uniform" + mtype: "" + } member { name: "identity" mtype: "" } + member { + name: "lecun_normal" + mtype: "" + } + member { + name: "lecun_uniform" + mtype: "" + } member { name: "ones" mtype: "" @@ -84,6 +116,10 @@ tf_module { name: "truncated_normal" mtype: "" } + member { + name: "variance_scaling" + mtype: "" + } member { name: "zeros" mtype: "" @@ -96,22 +132,6 @@ tf_module { name: "get" argspec: "args=[\'identifier\'], varargs=None, keywords=None, defaults=None" } - member_method { - name: "he_normal" - argspec: "args=[\'seed\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "he_uniform" - argspec: "args=[\'seed\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "lecun_normal" - argspec: "args=[\'seed\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "lecun_uniform" - argspec: "args=[\'seed\'], varargs=None, keywords=None, defaults=[\'None\'], " - } member_method { name: "serialize" argspec: "args=[\'initializer\'], varargs=None, keywords=None, defaults=None" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.initializers.random_normal.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.initializers.random_normal.pbtxt index a2a197b3541..3cc237e5c91 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.initializers.random_normal.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.initializers.random_normal.pbtxt @@ -1,7 +1,9 @@ path: "tensorflow.initializers.random_normal" tf_class { + is_instance: "" is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.initializers.random_uniform.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.initializers.random_uniform.pbtxt index 60621bfe2f0..ab6f955984e 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.initializers.random_uniform.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.initializers.random_uniform.pbtxt @@ -1,7 +1,9 @@ path: "tensorflow.initializers.random_uniform" tf_class { + is_instance: "" is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.initializers.truncated_normal.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.initializers.truncated_normal.pbtxt index 0504d7e2d48..9750914aeb0 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.initializers.truncated_normal.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.initializers.truncated_normal.pbtxt @@ -1,7 +1,9 @@ path: "tensorflow.initializers.truncated_normal" tf_class { + is_instance: "" is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.initializers.variance_scaling.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.initializers.variance_scaling.pbtxt new file mode 100644 index 00000000000..5cff80eba00 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.initializers.variance_scaling.pbtxt @@ -0,0 +1,20 @@ +path: "tensorflow.initializers.variance_scaling" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'scale\', \'mode\', \'distribution\', \'seed\'], varargs=None, keywords=None, defaults=[\'1.0\', \'fan_in\', \'truncated_normal\', \'None\'], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.initializers.zeros.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.initializers.zeros.pbtxt index 27774af69a2..dc75f07fca9 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.initializers.zeros.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.initializers.zeros.pbtxt @@ -1,7 +1,9 @@ path: "tensorflow.initializers.zeros" tf_class { + is_instance: "" is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-constant.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-constant.pbtxt index 71b5acc38fd..68135c36c92 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-constant.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-constant.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.keras.initializers.Constant" tf_class { - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-glorot-normal.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-glorot-normal.pbtxt index edff37e3a15..a9f559347bd 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-glorot-normal.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-glorot-normal.pbtxt @@ -1,8 +1,10 @@ path: "tensorflow.keras.initializers.GlorotNormal" tf_class { - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-glorot-uniform.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-glorot-uniform.pbtxt index bc685ce0d58..255b1c14f13 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-glorot-uniform.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-glorot-uniform.pbtxt @@ -1,8 +1,10 @@ path: "tensorflow.keras.initializers.GlorotUniform" tf_class { - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-he-normal.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-he-normal.pbtxt new file mode 100644 index 00000000000..5b53b41efd6 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-he-normal.pbtxt @@ -0,0 +1,21 @@ +path: "tensorflow.keras.initializers.HeNormal" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-he-uniform.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-he-uniform.pbtxt new file mode 100644 index 00000000000..41fd8a2e135 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-he-uniform.pbtxt @@ -0,0 +1,21 @@ +path: "tensorflow.keras.initializers.HeUniform" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-identity.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-identity.pbtxt index e0f0f3a93da..1a02232371b 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-identity.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-identity.pbtxt @@ -1,7 +1,9 @@ path: "tensorflow.keras.initializers.Identity" tf_class { + is_instance: "" is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-initializer.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-initializer.pbtxt index ae5ea9e48c9..74597c8c7cf 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-initializer.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-initializer.pbtxt @@ -1,6 +1,6 @@ path: "tensorflow.keras.initializers.Initializer" tf_class { - is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-lecun-normal.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-lecun-normal.pbtxt new file mode 100644 index 00000000000..6ef45b229ac --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-lecun-normal.pbtxt @@ -0,0 +1,21 @@ +path: "tensorflow.keras.initializers.LecunNormal" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-lecun-uniform.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-lecun-uniform.pbtxt new file mode 100644 index 00000000000..d2e590a8855 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-lecun-uniform.pbtxt @@ -0,0 +1,21 @@ +path: "tensorflow.keras.initializers.LecunUniform" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-ones.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-ones.pbtxt index 57c0b0917d1..43dee054425 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-ones.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-ones.pbtxt @@ -1,7 +1,9 @@ path: "tensorflow.keras.initializers.Ones" tf_class { + is_instance: "" is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-orthogonal.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-orthogonal.pbtxt index b24844fa35c..e1d23edfd09 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-orthogonal.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-orthogonal.pbtxt @@ -1,7 +1,9 @@ path: "tensorflow.keras.initializers.Orthogonal" tf_class { + is_instance: "" is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-random-normal.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-random-normal.pbtxt index 0753827aa67..8d165faa6c8 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-random-normal.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-random-normal.pbtxt @@ -1,7 +1,9 @@ path: "tensorflow.keras.initializers.RandomNormal" tf_class { + is_instance: "" is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-random-uniform.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-random-uniform.pbtxt index 280b0a0243d..a843a1e3cfc 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-random-uniform.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-random-uniform.pbtxt @@ -1,7 +1,9 @@ path: "tensorflow.keras.initializers.RandomUniform" tf_class { + is_instance: "" is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-truncated-normal.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-truncated-normal.pbtxt index 4076aa595fe..14fe9547976 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-truncated-normal.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-truncated-normal.pbtxt @@ -1,7 +1,9 @@ path: "tensorflow.keras.initializers.TruncatedNormal" tf_class { + is_instance: "" is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-variance-scaling.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-variance-scaling.pbtxt index a68219def66..c0e3d3585b8 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-variance-scaling.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-variance-scaling.pbtxt @@ -1,7 +1,9 @@ path: "tensorflow.keras.initializers.VarianceScaling" tf_class { + is_instance: "" is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-zeros.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-zeros.pbtxt index 129fa18c617..1afb930ec53 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-zeros.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-zeros.pbtxt @@ -1,7 +1,9 @@ path: "tensorflow.keras.initializers.Zeros" tf_class { + is_instance: "" is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.constant.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.constant.pbtxt index b03cbb8eb80..7324655c00a 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.constant.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.constant.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.keras.initializers.constant" tf_class { - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.glorot_normal.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.glorot_normal.pbtxt index 02f8c252bda..5bca6a37ee1 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.glorot_normal.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.glorot_normal.pbtxt @@ -1,8 +1,10 @@ path: "tensorflow.keras.initializers.glorot_normal" tf_class { - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.glorot_uniform.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.glorot_uniform.pbtxt index 6d18a3b6e7e..3a6cbe15e18 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.glorot_uniform.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.glorot_uniform.pbtxt @@ -1,8 +1,10 @@ path: "tensorflow.keras.initializers.glorot_uniform" tf_class { - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.he_normal.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.he_normal.pbtxt new file mode 100644 index 00000000000..5ece8aee902 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.he_normal.pbtxt @@ -0,0 +1,21 @@ +path: "tensorflow.keras.initializers.he_normal" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.he_uniform.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.he_uniform.pbtxt new file mode 100644 index 00000000000..0d2dc7ed5b2 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.he_uniform.pbtxt @@ -0,0 +1,21 @@ +path: "tensorflow.keras.initializers.he_uniform" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.identity.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.identity.pbtxt index dcdb6ddf5f0..647864a25fb 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.identity.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.identity.pbtxt @@ -1,7 +1,9 @@ path: "tensorflow.keras.initializers.identity" tf_class { + is_instance: "" is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.lecun_normal.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.lecun_normal.pbtxt new file mode 100644 index 00000000000..4eb04c91864 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.lecun_normal.pbtxt @@ -0,0 +1,21 @@ +path: "tensorflow.keras.initializers.lecun_normal" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.lecun_uniform.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.lecun_uniform.pbtxt new file mode 100644 index 00000000000..d1f8e8abc4c --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.lecun_uniform.pbtxt @@ -0,0 +1,21 @@ +path: "tensorflow.keras.initializers.lecun_uniform" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.ones.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.ones.pbtxt index cc2dd171dfc..ade249eedbe 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.ones.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.ones.pbtxt @@ -1,7 +1,9 @@ path: "tensorflow.keras.initializers.ones" tf_class { + is_instance: "" is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.orthogonal.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.orthogonal.pbtxt index 855065c1634..227f8957954 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.orthogonal.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.orthogonal.pbtxt @@ -1,7 +1,9 @@ path: "tensorflow.keras.initializers.orthogonal" tf_class { + is_instance: "" is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.pbtxt index de4e798305c..f39b701806a 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.pbtxt @@ -12,6 +12,14 @@ tf_module { name: "GlorotUniform" mtype: "" } + member { + name: "HeNormal" + mtype: "" + } + member { + name: "HeUniform" + mtype: "" + } member { name: "Identity" mtype: "" @@ -20,6 +28,14 @@ tf_module { name: "Initializer" mtype: "" } + member { + name: "LecunNormal" + mtype: "" + } + member { + name: "LecunUniform" + mtype: "" + } member { name: "Ones" mtype: "" @@ -60,10 +76,26 @@ tf_module { name: "glorot_uniform" mtype: "" } + member { + name: "he_normal" + mtype: "" + } + member { + name: "he_uniform" + mtype: "" + } member { name: "identity" mtype: "" } + member { + name: "lecun_normal" + mtype: "" + } + member { + name: "lecun_uniform" + mtype: "" + } member { name: "ones" mtype: "" @@ -84,6 +116,10 @@ tf_module { name: "truncated_normal" mtype: "" } + member { + name: "variance_scaling" + mtype: "" + } member { name: "zeros" mtype: "" @@ -96,22 +132,6 @@ tf_module { name: "get" argspec: "args=[\'identifier\'], varargs=None, keywords=None, defaults=None" } - member_method { - name: "he_normal" - argspec: "args=[\'seed\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "he_uniform" - argspec: "args=[\'seed\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "lecun_normal" - argspec: "args=[\'seed\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "lecun_uniform" - argspec: "args=[\'seed\'], varargs=None, keywords=None, defaults=[\'None\'], " - } member_method { name: "serialize" argspec: "args=[\'initializer\'], varargs=None, keywords=None, defaults=None" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.random_normal.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.random_normal.pbtxt index 55b70918542..75f754cc00c 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.random_normal.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.random_normal.pbtxt @@ -1,7 +1,9 @@ path: "tensorflow.keras.initializers.random_normal" tf_class { + is_instance: "" is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.random_uniform.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.random_uniform.pbtxt index 2cbe270b4ef..7541b5eddc3 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.random_uniform.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.random_uniform.pbtxt @@ -1,7 +1,9 @@ path: "tensorflow.keras.initializers.random_uniform" tf_class { + is_instance: "" is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.truncated_normal.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.truncated_normal.pbtxt index f276d11d9a3..e5ebf905d08 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.truncated_normal.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.truncated_normal.pbtxt @@ -1,7 +1,9 @@ path: "tensorflow.keras.initializers.truncated_normal" tf_class { + is_instance: "" is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.variance_scaling.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.variance_scaling.pbtxt new file mode 100644 index 00000000000..4ec96caa16f --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.variance_scaling.pbtxt @@ -0,0 +1,20 @@ +path: "tensorflow.keras.initializers.variance_scaling" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'scale\', \'mode\', \'distribution\', \'seed\'], varargs=None, keywords=None, defaults=[\'1.0\', \'fan_in\', \'truncated_normal\', \'None\'], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.zeros.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.zeros.pbtxt index f9b3359d7a9..de923f98977 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.zeros.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.zeros.pbtxt @@ -1,7 +1,9 @@ path: "tensorflow.keras.initializers.zeros" tf_class { + is_instance: "" is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv3-d-transpose.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv3-d-transpose.pbtxt index 68e37cc8475..a2998f59114 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv3-d-transpose.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv3-d-transpose.pbtxt @@ -115,7 +115,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'output_padding\', \'data_format\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1, 1)\', \'valid\', \'None\', \'None\', \'None\', \'True\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'output_padding\', \'data_format\', \'dilation_rate\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1, 1)\', \'valid\', \'None\', \'None\', \'(1, 1, 1)\', \'None\', \'True\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], " } member_method { name: "add_loss" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution3-d-transpose.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution3-d-transpose.pbtxt index 115afd297ba..3f750a6200b 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution3-d-transpose.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution3-d-transpose.pbtxt @@ -115,7 +115,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'output_padding\', \'data_format\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1, 1)\', \'valid\', \'None\', \'None\', \'None\', \'True\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'output_padding\', \'data_format\', \'dilation_rate\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1, 1)\', \'valid\', \'None\', \'None\', \'(1, 1, 1)\', \'None\', \'True\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], " } member_method { name: "add_loss" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.losses.-k-l-divergence.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.losses.-k-l-divergence.pbtxt index 30d68f097be..b15ba6f2d6d 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.losses.-k-l-divergence.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.losses.-k-l-divergence.pbtxt @@ -6,7 +6,7 @@ tf_class { is_instance: "" member_method { name: "__init__" - argspec: "args=[\'self\', \'reduction\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'kullback_leibler_divergence\'], " + argspec: "args=[\'self\', \'reduction\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'kl_divergence\'], " } member_method { name: "call" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.losses.-log-cosh.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.losses.-log-cosh.pbtxt index 9310f07f509..1bdc6751a4a 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.losses.-log-cosh.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.losses.-log-cosh.pbtxt @@ -6,7 +6,7 @@ tf_class { is_instance: "" member_method { name: "__init__" - argspec: "args=[\'self\', \'reduction\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'logcosh\'], " + argspec: "args=[\'self\', \'reduction\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'log_cosh\'], " } member_method { name: "call" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.losses.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.losses.pbtxt index a8a4134df5e..dc97f818309 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.losses.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.losses.pbtxt @@ -116,6 +116,14 @@ tf_module { name: "hinge" argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "huber" + argspec: "args=[\'y_true\', \'y_pred\', \'delta\'], varargs=None, keywords=None, defaults=[\'1.0\'], " + } + member_method { + name: "kl_divergence" + argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "kld" argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" @@ -124,6 +132,10 @@ tf_module { name: "kullback_leibler_divergence" argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "log_cosh" + argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "logcosh" argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.pbtxt index 1b4976294ed..17768aeafbe 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.pbtxt @@ -200,6 +200,10 @@ tf_module { name: "hinge" argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "kl_divergence" + argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "kld" argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.losses.-k-l-divergence.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.losses.-k-l-divergence.pbtxt index 21930e36fd9..d6f19bf3144 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.losses.-k-l-divergence.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.losses.-k-l-divergence.pbtxt @@ -6,7 +6,7 @@ tf_class { is_instance: "" member_method { name: "__init__" - argspec: "args=[\'self\', \'reduction\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'kullback_leibler_divergence\'], " + argspec: "args=[\'self\', \'reduction\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'kl_divergence\'], " } member_method { name: "call" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.losses.-log-cosh.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.losses.-log-cosh.pbtxt index 44d1f898717..0fea0e6712f 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.losses.-log-cosh.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.losses.-log-cosh.pbtxt @@ -6,7 +6,7 @@ tf_class { is_instance: "" member_method { name: "__init__" - argspec: "args=[\'self\', \'reduction\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'logcosh\'], " + argspec: "args=[\'self\', \'reduction\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'log_cosh\'], " } member_method { name: "call" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.losses.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.losses.pbtxt index e681f29b99c..88a473e7372 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.losses.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.losses.pbtxt @@ -116,6 +116,14 @@ tf_module { name: "hinge" argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "huber" + argspec: "args=[\'y_true\', \'y_pred\', \'delta\'], varargs=None, keywords=None, defaults=[\'1.0\'], " + } + member_method { + name: "kl_divergence" + argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "kld" argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" @@ -124,6 +132,10 @@ tf_module { name: "kullback_leibler_divergence" argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "log_cosh" + argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "logcosh" argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.metrics.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.metrics.pbtxt index eb0fc81133a..b3c87d67d2b 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.metrics.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.metrics.pbtxt @@ -200,6 +200,10 @@ tf_module { name: "hinge" argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "kl_divergence" + argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "kld" argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt index a5200f86bfa..11e144bf13b 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt @@ -754,7 +754,7 @@ tf_module { } member_method { name: "linspace" - argspec: "args=[\'start\', \'stop\', \'num\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'start\', \'stop\', \'num\', \'name\', \'axis\'], varargs=None, keywords=None, defaults=[\'None\', \'0\'], " } member_method { name: "load_library" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.sparse.-sparse-tensor.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.sparse.-sparse-tensor.pbtxt index a563015afe3..a3ea216468e 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.sparse.-sparse-tensor.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.sparse.-sparse-tensor.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.sparse.SparseTensor" tf_class { is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/tests/module_test.py b/tensorflow/tools/api/tests/module_test.py index 5397278f5f3..aa8a224d00b 100644 --- a/tensorflow/tools/api/tests/module_test.py +++ b/tensorflow/tools/api/tests/module_test.py @@ -79,6 +79,18 @@ class ModuleTest(test.TestCase): tf.compat.v1.summary.FileWriter # pylint: enable=pointless-statement + def testInternalKerasImport(self): + # pylint: disable=g-import-not-at-top + from tensorflow.python.keras import layers + normalization_parent = layers.Normalization.__module__.split('.')[-1] + if tf._major_api_version == 2: + self.assertEqual('normalization', normalization_parent) + self.assertTrue(layers.BatchNormalization._USE_V2_BEHAVIOR) + else: + self.assertEqual('normalization_v1', normalization_parent) + self.assertFalse(layers.BatchNormalization._USE_V2_BEHAVIOR) + # pylint: enable=g-import-not-at-top + if __name__ == '__main__': test.main() diff --git a/tensorflow/tools/ci_build/README.md b/tensorflow/tools/ci_build/README.md index 988c6706c11..bf1993e8c7a 100644 --- a/tensorflow/tools/ci_build/README.md +++ b/tensorflow/tools/ci_build/README.md @@ -83,7 +83,7 @@ this UI, to see the logs for a failed build: * Submit special pull request (PR) comment to trigger CI: **bot:mlx:test** * Test session is run automatically. -* Test results and artefacts (log files) are reported via PR comments +* Test results and artifacts (log files) are reported via PR comments ##### CI Steps diff --git a/tensorflow/tools/ci_build/build_scripts/PRESUBMIT_BUILD_TARGETS.sh b/tensorflow/tools/ci_build/build_scripts/PRESUBMIT_BUILD_TARGETS.sh index 16e7fb9b1da..1893db7802c 100755 --- a/tensorflow/tools/ci_build/build_scripts/PRESUBMIT_BUILD_TARGETS.sh +++ b/tensorflow/tools/ci_build/build_scripts/PRESUBMIT_BUILD_TARGETS.sh @@ -15,4 +15,4 @@ #!/bin/bash set -x -DEFAULT_BAZEL_TARGETS="//tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/compiler/xrt/... //tensorflow/compiler/mlir/lite/..." +DEFAULT_BAZEL_TARGETS="//tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/compiler/xrt/... //tensorflow/compiler/mlir/lite/..." diff --git a/tensorflow/tools/ci_build/install/install_bootstrap_deb_packages.sh b/tensorflow/tools/ci_build/install/install_bootstrap_deb_packages.sh index a8be075e3cf..19f93a0d916 100755 --- a/tensorflow/tools/ci_build/install/install_bootstrap_deb_packages.sh +++ b/tensorflow/tools/ci_build/install/install_bootstrap_deb_packages.sh @@ -16,6 +16,9 @@ set -e +# Retry on connection timeout. +bash -c "echo 'APT::Acquire::Retries \"3\";' > /etc/apt/apt.conf.d/80-retries" + # Install bootstrap dependencies from ubuntu deb repository. apt-get update apt-get install -y --no-install-recommends \ diff --git a/tensorflow/tools/ci_build/presubmit/ubuntu_16/gpu_py36_full/build.sh b/tensorflow/tools/ci_build/presubmit/ubuntu_16/gpu_py36_full/build.sh index 1498063630a..6a7e4c74576 100644 --- a/tensorflow/tools/ci_build/presubmit/ubuntu_16/gpu_py36_full/build.sh +++ b/tensorflow/tools/ci_build/presubmit/ubuntu_16/gpu_py36_full/build.sh @@ -45,12 +45,14 @@ function run_build () { export ACTION_PATH="/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin" export PYTHON_BIN_PATH="/usr/bin/python3" export TF2_BEHAVIOR=1 - tag_filters="gpu,-no_gpu,-nogpu,-benchmark-test,-no_oss,-oss_serial,-no_gpu_presubmit""$(maybe_skip_v1)" + # TODO(b/152356894): + # Remove -gpu_cupti once RBE supports cupti tests. + tag_filters="gpu,-no_gpu,-nogpu,-benchmark-test,-no_oss,-oss_serial,-no_gpu_presubmit,-gpu_cupti""$(maybe_skip_v1)" # Get the default test targets for bazel. source tensorflow/tools/ci_build/build_scripts/PRESUBMIT_BUILD_TARGETS.sh - RBE_CONFIG="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.0-cudnn7-tensorrt5.1" + RBE_CONFIG="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0" TF_CUDA_CONFIG_REPO="${RBE_CONFIG}_config_cuda" TF_TENSORRT_CONFIG_REPO="${RBE_CONFIG}_config_tensorrt" TF_PYTHON_CONFIG_REPO="${RBE_CONFIG}_config_python" diff --git a/tensorflow/tools/def_file_filter/symbols_pybind.txt b/tensorflow/tools/def_file_filter/symbols_pybind.txt index 1298479009b..f08ef720902 100644 --- a/tensorflow/tools/def_file_filter/symbols_pybind.txt +++ b/tensorflow/tools/def_file_filter/symbols_pybind.txt @@ -84,6 +84,13 @@ tensorflow::DoQuantizeTrainingOnSerializedGraphDef tensorflow::DeviceFactory::ListAllPhysicalDevices tensorflow::SessionState::kTensorHandleResourceTypeName +[server_lib] # server_lib +tensorflow::data::GrpcDataServer::Start +tensorflow::data::GrpcDataServer::Stop +tensorflow::data::GrpcDataServer::Target +tensorflow::data::NewMasterServer +tensorflow::data::NewWorkerServer + [protos_all] # device_lib, dtypes tensorflow::DataType_IsValid tensorflow::ConfigProto::ConfigProto diff --git a/tensorflow/tools/docs/BUILD b/tensorflow/tools/docs/BUILD index 5d349da84bf..97aedfd9fa2 100644 --- a/tensorflow/tools/docs/BUILD +++ b/tensorflow/tools/docs/BUILD @@ -26,6 +26,7 @@ py_test( tags = [ "no_oss_py2", "no_pip", + "no_rocm", "no_windows", # numpy prints differently on windows. "noasan", "nomsan", diff --git a/tensorflow/tools/gcs_test/python/gcs_smoke.py b/tensorflow/tools/gcs_test/python/gcs_smoke.py index 087577e7a98..3c642ec13f6 100644 --- a/tensorflow/tools/gcs_test/python/gcs_smoke.py +++ b/tensorflow/tools/gcs_test/python/gcs_smoke.py @@ -44,7 +44,7 @@ def create_examples(num_examples, input_mean): examples = [] for row in range(num_examples): ex = example_pb2.Example() - ex.features.feature["id"].bytes_list.value.append(str(ids[row, 0])) + ex.features.feature["id"].bytes_list.value.append(bytes(ids[row, 0])) ex.features.feature["target"].float_list.value.append(target[row, 0]) ex.features.feature["inputs"].float_list.value.append(inputs[row, 0]) examples.append(ex) diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD index c38d7b84a74..991a5742579 100644 --- a/tensorflow/tools/pip_package/BUILD +++ b/tensorflow/tools/pip_package/BUILD @@ -95,6 +95,7 @@ COMMON_PIP_DEPS = [ "//tensorflow:tensorflow_py", "//tensorflow/compiler/tf2xla:xla_compiled_cpu_runtime_hdrs", "//tensorflow/compiler/tf2xla:xla_compiled_cpu_runtime_srcs", + "//tensorflow/core/data/service/python:server_lib", "//tensorflow/core:protos_all_proto_srcs", "//tensorflow/examples/saved_model/integration_tests:mnist_util", "//tensorflow/lite/python/testdata:interpreter_test_data", diff --git a/tensorflow/tools/pip_package/setup.py b/tensorflow/tools/pip_package/setup.py index 37c11c1b0e1..888504256b1 100644 --- a/tensorflow/tools/pip_package/setup.py +++ b/tensorflow/tools/pip_package/setup.py @@ -96,7 +96,7 @@ if '--project_name' in sys.argv: if 'tf_nightly' in project_name: for i, pkg in enumerate(REQUIRED_PACKAGES): if 'tensorboard' in pkg: - REQUIRED_PACKAGES[i] = 'tb-nightly >= 2.2.0a0, < 2.3.0a0' + REQUIRED_PACKAGES[i] = 'tb-nightly >= 2.3.0a0, < 2.4.0a0' elif 'tensorflow_estimator' in pkg and '2.0' in project_name: REQUIRED_PACKAGES[i] = 'tensorflow-estimator-2.0-preview' elif 'tensorflow_estimator' in pkg: diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index b90b9be67c7..142652f3e23 100755 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -148,11 +148,11 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""): tf_http_archive( name = "XNNPACK", - sha256 = "77a4cea07169b4d67df456d50deffaa100e587192657c68ee4f2b7c12ba133d1", - strip_prefix = "XNNPACK-479e78c7f93a5764ffb221bdead3f290c7fd8ea3", + sha256 = "2afaaf5f866ec714358985b123c3115043b9e099638100937743997f02bbd8cb", + strip_prefix = "XNNPACK-05702cf4099ad019ad1abb8ba656bfe04304f32a", urls = [ - "https://storage.googleapis.com/mirror.tensorflow.org/github.com/google/XNNPACK/archive/479e78c7f93a5764ffb221bdead3f290c7fd8ea3.zip", - "https://github.com/google/XNNPACK/archive/479e78c7f93a5764ffb221bdead3f290c7fd8ea3.zip", + "https://storage.googleapis.com/mirror.tensorflow.org/github.com/google/XNNPACK/archive/05702cf4099ad019ad1abb8ba656bfe04304f32a.zip", + "https://github.com/google/XNNPACK/archive/05702cf4099ad019ad1abb8ba656bfe04304f32a.zip", ], ) @@ -201,11 +201,11 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""): name = "eigen_archive", build_file = clean_dep("//third_party:eigen.BUILD"), patch_file = clean_dep("//third_party/eigen3:gpu_packet_math.patch"), - sha256 = "eb50646c27d32791d6b09b0422f29f52b8ff0385354abd117f68aa66da1e2e92", # SHARED_EIGEN_SHA - strip_prefix = "eigen-4da2c6b1974827b1999bab652a3d4703e1992d26", + sha256 = "2f046557f4093becf51b44c6339873c18e2f1ea55c4b3f3a08b7d15a1d9c6e5b", # SHARED_EIGEN_SHA + strip_prefix = "eigen-4fd5d1477b221fc7daf2b7f1c7e4ee4f04ceaced", urls = [ - "https://storage.googleapis.com/mirror.tensorflow.org/gitlab.com/libeigen/eigen/-/archive/4da2c6b1974827b1999bab652a3d4703e1992d26/eigen-4da2c6b1974827b1999bab652a3d4703e1992d26.tar.gz", - "https://gitlab.com/libeigen/eigen/-/archive/4da2c6b1974827b1999bab652a3d4703e1992d26/eigen-4da2c6b1974827b1999bab652a3d4703e1992d26.tar.gz", + "https://storage.googleapis.com/mirror.tensorflow.org/gitlab.com/libeigen/eigen/-/archive/4fd5d1477b221fc7daf2b7f1c7e4ee4f04ceaced/eigen-4fd5d1477b221fc7daf2b7f1c7e4ee4f04ceaced.tar.gz", + "https://gitlab.com/libeigen/eigen/-/archive/4fd5d1477b221fc7daf2b7f1c7e4ee4f04ceaced/eigen-4fd5d1477b221fc7daf2b7f1c7e4ee4f04ceaced.tar.gz", ], ) @@ -589,8 +589,8 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""): ) # Check out LLVM and MLIR from llvm-project. - LLVM_COMMIT = "a711a3a46039154c38eade8bef1138b77fdb05ee" - LLVM_SHA256 = "b070be6653ac61e42649afcda0a02dee027cd610c1e2929663ca67fdb1301679" + LLVM_COMMIT = "7a4a98a9c4f39d9c395f5ce587dbbcb5450a9655" + LLVM_SHA256 = "d11b4b7e4522e86d9525f1ad1f840f2f871164ab0b0f848e9a1f314af63cf3d7" LLVM_URLS = [ "https://storage.googleapis.com/mirror.tensorflow.org/github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT), "https://github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT), diff --git a/third_party/cpuinfo/BUILD.bazel b/third_party/cpuinfo/BUILD.bazel index b886e2f07b4..70cfa6c0359 100644 --- a/third_party/cpuinfo/BUILD.bazel +++ b/third_party/cpuinfo/BUILD.bazel @@ -16,8 +16,8 @@ C99OPTS = [ # Source code common to all platforms. COMMON_SRCS = [ "src/api.c", - "src/cache.c", "src/init.c", + "src/cache.c", ] # Architecture-specific sources and headers. @@ -59,10 +59,6 @@ EMSCRIPTEN_SRCS = [ "src/emscripten/init.c", ] -PNACL_SRCS = [ - "src/pnacl/init.c", -] - LINUX_X86_SRCS = [ "src/x86/linux/cpuinfo.c", "src/x86/linux/init.c", @@ -103,6 +99,7 @@ cc_library( ":linux_aarch64": COMMON_SRCS + ARM_SRCS + LINUX_SRCS + LINUX_ARM64_SRCS, ":linux_arm": COMMON_SRCS + ARM_SRCS + LINUX_SRCS + LINUX_ARM32_SRCS, ":linux_x86_64": COMMON_SRCS + X86_SRCS + LINUX_SRCS + LINUX_X86_SRCS, + ":linux_aarch64": COMMON_SRCS + ARM_SRCS + LINUX_SRCS + LINUX_ARM64_SRCS, ":macos_x86_64": COMMON_SRCS + X86_SRCS + MACH_SRCS + MACH_X86_SRCS, ":android_armv7": COMMON_SRCS + ARM_SRCS + LINUX_SRCS + LINUX_ARM32_SRCS + ANDROID_ARM_SRCS, ":android_arm64": COMMON_SRCS + ARM_SRCS + LINUX_SRCS + LINUX_ARM64_SRCS + ANDROID_ARM_SRCS, @@ -117,6 +114,8 @@ cc_library( ":watchos_x86": COMMON_SRCS + X86_SRCS + MACH_SRCS + MACH_X86_SRCS, ":watchos_armv7k": COMMON_SRCS + MACH_SRCS + MACH_ARM_SRCS, ":watchos_arm64_32": COMMON_SRCS + MACH_SRCS + MACH_ARM_SRCS, + ":tvos_x86_64": COMMON_SRCS + X86_SRCS + MACH_SRCS + MACH_X86_SRCS, + ":tvos_arm64": COMMON_SRCS + MACH_SRCS + MACH_ARM_SRCS, ":emscripten_wasm": COMMON_SRCS + EMSCRIPTEN_SRCS, }), copts = C99OPTS + [ @@ -175,7 +174,11 @@ config_setting( config_setting( name = "linux_x86_64", values = {"cpu": "k8"}, - visibility = ["//visibility:public"], +) + +config_setting( + name = "linux_aarch64", + values = {"cpu": "aarch64"}, ) config_setting( diff --git a/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl b/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl index 8a94afbfde1..f5ac7b39dfd 100755 --- a/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl +++ b/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl @@ -173,6 +173,13 @@ def InvokeHipcc(argv, log=False): out = ' -o ' + out_file[0] hipccopts = ' ' + # In hip-clang environment, we need to make sure that hip header is included + # before some standard math header like is included in any source. + # Otherwise, we get build error. + # Also we need to retain warning about uninitialised shared variable as + # warning only, even when -Werror option is specified. + if HIPCC_IS_HIPCLANG: + hipccopts += ' --include=hip/hip_runtime.h -Wno-error=cuda-shared-init ' hipccopts += ' ' + hipcc_compiler_options # Use -fno-gpu-rdc by default for early GPU kernel finalization # This flag would trigger GPU kernels be generated at compile time, instead diff --git a/third_party/gpus/cuda_configure.bzl b/third_party/gpus/cuda_configure.bzl index 8fa64f264dc..203630802e4 100644 --- a/third_party/gpus/cuda_configure.bzl +++ b/third_party/gpus/cuda_configure.bzl @@ -1035,18 +1035,7 @@ def _create_local_cuda_repository(repository_ctx): cuda_defines["%{host_compiler_prefix}"] = host_compiler_prefix - # Bazel sets '-B/usr/bin' flag to workaround build errors on RHEL (see - # https://github.com/bazelbuild/bazel/issues/760). - # However, this stops our custom clang toolchain from picking the provided - # LLD linker, so we're only adding '-B/usr/bin' when using non-downloaded - # toolchain. - # TODO: when bazel stops adding '-B/usr/bin' by default, remove this - # flag from the CROSSTOOL completely (see - # https://github.com/bazelbuild/bazel/issues/5634) - if should_download_clang: - cuda_defines["%{linker_bin_path}"] = "" - else: - cuda_defines["%{linker_bin_path}"] = host_compiler_prefix + cuda_defines["%{linker_bin_path}"] = "" cuda_defines["%{extra_no_canonical_prefixes_flags}"] = "" cuda_defines["%{unfiltered_compile_flags}"] = "" diff --git a/third_party/gpus/rocm_configure.bzl b/third_party/gpus/rocm_configure.bzl index 20ff2a4aafa..3c345e6724b 100644 --- a/third_party/gpus/rocm_configure.bzl +++ b/third_party/gpus/rocm_configure.bzl @@ -250,6 +250,10 @@ def _rocm_include_path(repository_ctx, rocm_config): inc_dirs.append(rocm_config.rocm_toolkit_path + "/hcc/compiler/lib/clang/10.0.0/include/") inc_dirs.append(rocm_config.rocm_toolkit_path + "/hcc/lib/clang/10.0.0/include") + # Support hcc based off clang 11.0.0, included in ROCm3.1 + inc_dirs.append(rocm_config.rocm_toolkit_path + "/hcc/compiler/lib/clang/11.0.0/include/") + inc_dirs.append(rocm_config.rocm_toolkit_path + "/hcc/lib/clang/11.0.0/include") + return inc_dirs def _enable_rocm(repository_ctx): diff --git a/third_party/mlir/BUILD b/third_party/mlir/BUILD index 819aed65efe..f9b7b4bfa52 100644 --- a/third_party/mlir/BUILD +++ b/third_party/mlir/BUILD @@ -542,6 +542,7 @@ gentbl( td_file = "include/mlir/Dialect/Shape/IR/ShapeOps.td", td_srcs = [ ":StdOpsTdFiles", + "include/mlir/Interfaces/InferTypeOpInterface.td", ], ) @@ -560,6 +561,7 @@ cc_library( ":CallOpInterfaces", ":CommonFolders", ":IR", + ":InferTypeOpInterface", ":ShapeOpsIncGen", ":SideEffects", ":Support", @@ -846,6 +848,7 @@ cc_library( filegroup( name = "GPUOpsTdFiles", srcs = [ + "include/mlir/Dialect/GPU/GPUBase.td", "include/mlir/Dialect/GPU/GPUOps.td", "include/mlir/Dialect/LLVMIR/LLVMOpBase.td", "include/mlir/Interfaces/SideEffects.td", @@ -853,6 +856,35 @@ filegroup( ], ) +gentbl( + name = "ParallelLoopMapperAttrGen", + strip_include_prefix = "include", + tbl_outs = [ + ( + "-gen-struct-attr-decls", + "include/mlir/Dialect/GPU/ParallelLoopMapperAttr.h.inc", + ), + ( + "-gen-struct-attr-defs", + "include/mlir/Dialect/GPU/ParallelLoopMapperAttr.cpp.inc", + ), + ( + "-gen-enum-decls", + "include/mlir/Dialect/GPU/ParallelLoopMapperEnums.h.inc", + ), + ( + "-gen-enum-defs", + "include/mlir/Dialect/GPU/ParallelLoopMapperEnums.cpp.inc", + ), + ], + tblgen = ":mlir-tblgen", + td_file = "include/mlir/Dialect/GPU/ParallelLoopMapperAttr.td", + td_srcs = [ + ":GPUOpsTdFiles", + ":AffineOpsTdFiles", + ], +) + gentbl( name = "GPUOpsIncGen", strip_include_prefix = "include", @@ -886,7 +918,7 @@ cc_library( ], ), hdrs = glob([ - "include/mlir/Dialect/GPU/*.h", + "include/mlir/Dialect/GPU/GPUDialect.h", ]), includes = ["include"], deps = [ @@ -908,6 +940,8 @@ cc_library( ], ), hdrs = [ + "include/mlir/Dialect/GPU/MemoryPromotion.h", + "include/mlir/Dialect/GPU/ParallelLoopMapper.h", "include/mlir/Dialect/GPU/Passes.h", "include/mlir/Dialect/GPU/Utils.h", ], @@ -917,10 +951,12 @@ cc_library( ":GPUDialect", ":IR", ":LoopOps", + ":ParallelLoopMapperAttrGen", ":Pass", ":StandardOps", ":Support", ":Transforms", + "@llvm-project//llvm:support", ], ) @@ -982,6 +1018,7 @@ cc_library( ":GPUCommonTransforms", ":GPUDialect", ":GPUToNVVMGen", + ":GPUTransforms", ":IR", ":LLVMTransforms", ":NVVMDialect", @@ -1733,17 +1770,14 @@ cc_library( cc_library( name = "LoopsToGPU", - srcs = [ - "lib/Conversion/LoopsToGPU/LoopsToGPU.cpp", - ], - hdrs = [ - "include/mlir/Conversion/LoopsToGPU/LoopsToGPU.h", - ], + srcs = ["lib/Conversion/LoopsToGPU/LoopsToGPU.cpp"], + hdrs = ["include/mlir/Conversion/LoopsToGPU/LoopsToGPU.h"], includes = ["include"], deps = [ ":Affine", ":AffineToStandardTransforms", ":GPUDialect", + ":GPUTransforms", ":IR", ":LoopOps", ":Pass", diff --git a/third_party/mlir/test.BUILD b/third_party/mlir/test.BUILD index 5d569827860..476fd8b77df 100644 --- a/third_party/mlir/test.BUILD +++ b/third_party/mlir/test.BUILD @@ -48,35 +48,35 @@ gentbl( gentbl( name = "TestOpsIncGen", - strip_include_prefix = "lib/TestDialect", + strip_include_prefix = "lib/Dialect/Test", tbl_outs = [ ( "-gen-op-decls", - "lib/TestDialect/TestOps.h.inc", + "lib/Dialect/Test/TestOps.h.inc", ), ( "-gen-op-defs", - "lib/TestDialect/TestOps.cpp.inc", + "lib/Dialect/Test/TestOps.cpp.inc", ), ( "-gen-dialect-decls", - "lib/TestDialect/TestOpsDialect.h.inc", + "lib/Dialect/Test/TestOpsDialect.h.inc", ), ( "-gen-enum-decls", - "lib/TestDialect/TestOpEnums.h.inc", + "lib/Dialect/Test/TestOpEnums.h.inc", ), ( "-gen-enum-defs", - "lib/TestDialect/TestOpEnums.cpp.inc", + "lib/Dialect/Test/TestOpEnums.cpp.inc", ), ( "-gen-rewriters", - "lib/TestDialect/TestPatterns.inc", + "lib/Dialect/Test/TestPatterns.inc", ), ], tblgen = "@llvm-project//mlir:mlir-tblgen", - td_file = "lib/TestDialect/TestOps.td", + td_file = "lib/Dialect/Test/TestOps.td", td_srcs = [ "@llvm-project//mlir:OpBaseTdFiles", "@llvm-project//mlir:include/mlir/IR/OpAsmInterface.td", @@ -91,15 +91,15 @@ gentbl( cc_library( name = "TestDialect", srcs = [ - "lib/TestDialect/TestDialect.cpp", - "lib/TestDialect/TestPatterns.cpp", + "lib/Dialect/Test/TestDialect.cpp", + "lib/Dialect/Test/TestPatterns.cpp", ], hdrs = [ - "lib/TestDialect/TestDialect.h", + "lib/Dialect/Test/TestDialect.h", ], includes = [ "lib/DeclarativeTransforms", - "lib/TestDialect", + "lib/Dialect/Test", ], deps = [ ":TestOpsIncGen", @@ -154,7 +154,7 @@ cc_library( "lib/Transforms/*.cpp", ]), defines = ["MLIR_CUDA_CONVERSIONS_ENABLED"], - includes = ["lib/TestDialect"], + includes = ["lib/Dialect/Test"], deps = [ ":TestDialect", ":TestLinalgTransformPatternsIncGen", diff --git a/third_party/pthreadpool/BUILD.bazel b/third_party/pthreadpool/BUILD.bazel index 1267e4f3736..59bf52d8cbd 100644 --- a/third_party/pthreadpool/BUILD.bazel +++ b/third_party/pthreadpool/BUILD.bazel @@ -10,6 +10,7 @@ exports_files(["LICENSE"]) cc_library( name = "pthreadpool", srcs = [ + "src/threadpool-atomics.h", "src/threadpool-pthreads.c", "src/threadpool-utils.h", ], @@ -18,7 +19,15 @@ cc_library( ], copts = [ "-O2", - ], + ] + select({ + ":optimized_build": ["-O2"], + "//conditions:default": [], + }) + select({ + ":linux_aarch64": ["-DPTHREADPOOL_USE_CPUINFO=1"], + ":android_arm64": ["-DPTHREADPOOL_USE_CPUINFO=1"], + ":android_armv7": ["-DPTHREADPOOL_USE_CPUINFO=1"], + "//conditions:default": ["-DPTHREADPOOL_USE_CPUINFO=0"], + }), defines = [ "PTHREADPOOL_NO_DEPRECATED_API", ], @@ -28,5 +37,40 @@ cc_library( strip_include_prefix = "include", deps = [ "@FXdiv", - ], + ] + select({ + ":linux_aarch64": ["@cpuinfo"], + ":android_arm64": ["@cpuinfo"], + ":android_armv7": ["@cpuinfo"], + "//conditions:default": [], + }), +) + +############################# Build configurations ############################# + +config_setting( + name = "optimized_build", + values = { + "compilation_mode": "opt", + }, +) + +config_setting( + name = "linux_aarch64", + values = {"cpu": "aarch64"}, +) + +config_setting( + name = "android_armv7", + values = { + "crosstool_top": "//external:android/crosstool", + "cpu": "armeabi-v7a", + }, +) + +config_setting( + name = "android_arm64", + values = { + "crosstool_top": "//external:android/crosstool", + "cpu": "arm64-v8a", + }, ) diff --git a/third_party/pthreadpool/workspace.bzl b/third_party/pthreadpool/workspace.bzl index 63eeac2c5a6..d13e7803408 100644 --- a/third_party/pthreadpool/workspace.bzl +++ b/third_party/pthreadpool/workspace.bzl @@ -5,11 +5,11 @@ load("//third_party:repo.bzl", "third_party_http_archive") def repo(): third_party_http_archive( name = "pthreadpool", - strip_prefix = "pthreadpool-ebd50d0cfa3664d454ffdf246fcd228c3b370a11", - sha256 = "ca4fc774cf2339cb739bba827de8ed4ccbd450c4608e05329e974153448aaf56", + strip_prefix = "pthreadpool-76042155a8b1e189c8f141429fd72219472c32e1", + sha256 = "91c7b00c16c60c96f23d1966d524879c0f6044caf4bc5e9fc06518dda643e07e", urls = [ - "https://storage.googleapis.com/mirror.tensorflow.org/github.com/Maratyszcza/pthreadpool/archive/ebd50d0cfa3664d454ffdf246fcd228c3b370a11.tar.gz", - "https://github.com/Maratyszcza/pthreadpool/archive/ebd50d0cfa3664d454ffdf246fcd228c3b370a11.tar.gz", + "https://storage.googleapis.com/mirror.tensorflow.org/github.com/Maratyszcza/pthreadpool/archive/76042155a8b1e189c8f141429fd72219472c32e1.tar.gz", + "https://github.com/Maratyszcza/pthreadpool/archive/76042155a8b1e189c8f141429fd72219472c32e1.tar.gz", ], build_file = "//third_party/pthreadpool:BUILD.bazel", )