diff --git a/.github/bot_config.yml b/.github/bot_config.yml index d0e7256aec0..952ff91fef7 100644 --- a/.github/bot_config.yml +++ b/.github/bot_config.yml @@ -40,6 +40,22 @@ segfault_memory: # assignees filesystem_security_assignee: - mihaimaruseac + +tflite_micro_path: + - tensorflow/lite/micro + +tflite_micro_comment: > + Thanks for contributing to TensorFlow Lite Micro. + + + To keep this process moving along, we'd like to make sure that you have completed the items on this list: + * Read the [contributing guidelines for TensorFlow Lite Micro](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/micro/CONTRIBUTING.md) + * Created a [TF Lite Micro Github issue](https://github.com/tensorflow/tensorflow/issues/new?labels=comp%3Amicro&template=70-tflite-micro-issue.md) + * Linked to the issue from the PR description + + + We would like to have a discussion on the Github issue first to determine the best path forward, and then proceed to the PR review. + # Cuda Comment cuda_comment: > From the template it looks like you are installing **TensorFlow** (TF) prebuilt binaries: diff --git a/RELEASE.md b/RELEASE.md index 7057657c340..6890352cf8a 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -37,6 +37,9 @@ * XLA:CPU and XLA:GPU devices are no longer registered by default. Use `TF_XLA_FLAGS=--tf_xla_enable_xla_devices` if you really need them (to be removed). +* `tf.raw_ops.Max` and `tf.raw_ops.Min` no longer accept inputs of type + `tf.complex64` or `tf.complex128`, because the behavior of these ops is not + well defined for complex types. ## Known Caveats @@ -120,6 +123,13 @@ customization of how gradients are aggregated across devices, as well as `gradients_transformers` to allow for custom gradient transformations (such as gradient clipping). + * The `steps_per_execution` argument in `compile()` is no longer + experimental; if you were passing `experimental_steps_per_execution`, + rename it to `steps_per_execution` in your code. This argument controls + the number of batches to run during each `tf.function` call when calling + `fit()`. Running multiple batches inside a single `tf.function` call can + greatly improve performance on TPUs or small models with a large Python + overhead. * `tf.function` / AutoGraph: * Added `experimental_follow_type_hints` argument for `tf.function`. When True, the function may use type annotations to optimize the tracing @@ -147,6 +157,8 @@ * Deprecate `Interpreter::UseNNAPI(bool)` C++ API * Prefer using `NnApiDelegate()` and related delegate configuration methods directly. * Add NNAPI Delegation support for requantization use cases by converting the operation into a dequantize-quantize pair. + * TFLite Profiler for Android is available. See the detailed + [guide](https://www.tensorflow.org/lite/performance/measurement#trace_tensorflow_lite_internals_in_android). * * `tf.random`: * diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD index 9d8032aca52..0cace4d102f 100644 --- a/tensorflow/c/BUILD +++ b/tensorflow/c/BUILD @@ -387,6 +387,7 @@ tf_cuda_library( "//tensorflow/core/common_runtime/eager:eager_operation", "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib", "//tensorflow/core/platform", + "//tensorflow/core/platform:blocking_counter", "@com_google_absl//absl/strings", ], alwayslink = 1, diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc index b4297033b6d..81fb9d1a2b8 100644 --- a/tensorflow/c/c_api_experimental.cc +++ b/tensorflow/c/c_api_experimental.cc @@ -35,6 +35,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/platform/blocking_counter.h" #include "tensorflow/core/platform/casts.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/platform/net.h" @@ -560,6 +561,21 @@ TF_CAPI_EXPORT extern void TFE_AbortCollectiveOps(TFE_Context* ctx, collective_executor_handle->get()->StartAbort(status->status); } +TF_CAPI_EXPORT extern void TFE_CollectiveOpsCheckPeerHealth(TFE_Context* ctx, + const char* task, + TF_Status* status) { + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); + auto collective_executor_handle = context->GetCollectiveExecutorHandle(); + tensorflow::Notification done; + collective_executor_handle->get()->remote_access()->CheckPeerHealth( + task, [&done, status](const Status& s) { + status->status = s; + done.Notify(); + }); + done.WaitForNotification(); +} + TF_ShapeAndTypeList* TF_NewShapeAndTypeList(int num_items) { TF_ShapeAndTypeList* result = new TF_ShapeAndTypeList; result->num_items = num_items; diff --git a/tensorflow/c/c_api_experimental.h b/tensorflow/c/c_api_experimental.h index ebd14b4b571..a08d4f29fcc 100644 --- a/tensorflow/c/c_api_experimental.h +++ b/tensorflow/c/c_api_experimental.h @@ -238,6 +238,13 @@ TF_CAPI_EXPORT extern void TFE_EnableCollectiveOps(TFE_Context* ctx, TF_CAPI_EXPORT extern void TFE_AbortCollectiveOps(TFE_Context* ctx, TF_Status* status); +// Checks the health of collective ops peers. Explicit health check is needed in +// multi worker collective ops to detect failures in the cluster. If a peer is +// down, collective ops may hang. +TF_CAPI_EXPORT extern void TFE_CollectiveOpsCheckPeerHealth(TFE_Context* ctx, + const char* task, + TF_Status* status); + // Information about the shape of a Tensor and its type. struct TF_ShapeAndType { // Number of dimensions. -1 indicates unknown rank. diff --git a/tensorflow/c/c_api_function_test.cc b/tensorflow/c/c_api_function_test.cc index 3fff9bcd371..ec8cfe4a31a 100644 --- a/tensorflow/c/c_api_function_test.cc +++ b/tensorflow/c/c_api_function_test.cc @@ -1704,66 +1704,5 @@ TEST_F(CApiFunctionTest, GetFunctionsFromGraph) { TF_DeleteFunction(func1); } -// This test only works when the TF build includes XLA compiler. One way to set -// this up is via bazel build option "--define with_xla_support=true". -// -// FIXME: generalize the macro name TENSORFLOW_EAGER_USE_XLA to -// something like TENSORFLOW_CAPI_USE_XLA. -#ifdef TENSORFLOW_EAGER_USE_XLA -TEST_F(CApiFunctionTest, StatelessIf_XLA) { - TF_Function* func; - const std::string funcName = "BranchFunc"; - DefineFunction(funcName.c_str(), &func); - TF_GraphCopyFunction(host_graph_, func, nullptr, s_); - ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); - - TF_Operation* feed = Placeholder(host_graph_, s_); - ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); - - TF_Operation* true_cond = ScalarConst(true, host_graph_, s_); - ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); - - TF_OperationDescription* desc = - TF_NewOperation(host_graph_, "StatelessIf", "IfNode"); - TF_AddInput(desc, {true_cond, 0}); - TF_Output inputs[] = {{feed, 0}}; - TF_AddInputList(desc, inputs, TF_ARRAYSIZE(inputs)); - ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); - TF_SetAttrType(desc, "Tcond", TF_BOOL); - TF_DataType inputType = TF_INT32; - TF_SetAttrTypeList(desc, "Tin", &inputType, 1); - TF_SetAttrTypeList(desc, "Tout", &inputType, 1); - TF_SetAttrFuncName(desc, "then_branch", funcName.data(), funcName.size()); - TF_SetAttrFuncName(desc, "else_branch", funcName.data(), funcName.size()); - TF_SetDevice(desc, "/device:XLA_CPU:0"); - auto op = TF_FinishOperation(desc, s_); - ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); - ASSERT_NE(op, nullptr); - - // Create a session for this graph. - CSession csession(host_graph_, s_, /*use_XLA*/ true); - ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); - - // Run the graph. - csession.SetInputs({{feed, Int32Tensor(17)}}); - csession.SetOutputs({op}); - csession.Run(s_); - ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); - TF_Tensor* out = csession.output_tensor(0); - ASSERT_TRUE(out != nullptr); - EXPECT_EQ(TF_INT32, TF_TensorType(out)); - EXPECT_EQ(0, TF_NumDims(out)); // scalar - ASSERT_EQ(sizeof(int32), TF_TensorByteSize(out)); - int32* output_contents = static_cast(TF_TensorData(out)); - EXPECT_EQ(-17, *output_contents); - - // Clean up - csession.CloseAndDelete(s_); - ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); - - TF_DeleteFunction(func); -} -#endif // TENSORFLOW_EAGER_USE_XLA - } // namespace } // namespace tensorflow diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index cc02d83fe01..1a3b348e8f9 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -6,7 +6,6 @@ load( "tf_copts", "tf_cuda_cc_test", "tf_cuda_library", - "tfe_xla_copts", ) load( "//tensorflow/core/platform:build_config.bzl", @@ -31,7 +30,7 @@ tf_cuda_library( "c_api_unified_experimental.h", ], hdrs = ["c_api.h"], - copts = tf_copts() + tfe_xla_copts(), + copts = tf_copts(), visibility = ["//visibility:public"], deps = select({ "//tensorflow:android": [ @@ -72,13 +71,6 @@ tf_cuda_library( "//tensorflow/core:protos_all_cc", "//tensorflow/core/profiler/lib:traceme", ], - }) + select({ - "//tensorflow:with_xla_support": [ - "//tensorflow/compiler/tf2xla:xla_compiler", - "//tensorflow/compiler/jit", - "//tensorflow/compiler/jit:xla_device", - ], - "//conditions:default": [], }) + [ "@com_google_absl//absl/memory", "//tensorflow/core/common_runtime/eager:eager_operation", @@ -228,7 +220,6 @@ tf_cuda_cc_test( "gradients_test.cc", ], args = ["--heap_check=local"], - extra_copts = tfe_xla_copts(), linkstatic = tf_kernel_tests_linkstatic(), tags = tf_cuda_tests_tags() + ["nomac"], deps = [ @@ -278,6 +269,7 @@ cc_library( "//tensorflow/c/experimental/ops:math_ops", "//tensorflow/c/experimental/ops:nn_ops", "//tensorflow/core/lib/llvm_rtti", + "//tensorflow/core/platform:status", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/types:span", ], @@ -290,12 +282,9 @@ tf_cuda_cc_test( "mnist_gradients_test.cc", ], args = ["--heap_check=local"], - extra_copts = tfe_xla_copts(), linkstatic = tf_kernel_tests_linkstatic(), tags = tf_cuda_tests_tags() + [ "nomac", - "notap", # TODO(b/166150182): Enable - "no_oss", # TODO(b/166150182): Enable ], deps = [ ":abstract_tensor_handle", @@ -553,7 +542,6 @@ tf_cuda_cc_test( "c_api_debug_test.cc", "c_api_test.cc", ], - extra_copts = tfe_xla_copts(), tags = [ "noguitar", # TODO(b/155445984): flaky #"guitar", @@ -608,7 +596,6 @@ tf_cuda_cc_test( ], # TODO(b/136478427): Figure out how to correctly shut the server down args = ["--heap_check=local"], - extra_copts = tfe_xla_copts(), tags = [ "no_windows", ], @@ -641,7 +628,6 @@ tf_cuda_cc_test( ], # TODO(b/136478427): Figure out how to correctly shut the server down args = ["--heap_check=local"], - extra_copts = tfe_xla_copts(), tags = [ "no_windows", ], @@ -660,7 +646,6 @@ tf_cuda_cc_test( ], # TODO(b/136478427): Figure out how to correctly shut the server down args = ["--heap_check=local"], - extra_copts = tfe_xla_copts(), tags = [ "no_windows", "noasan", # leaks gRPC server instances @@ -694,7 +679,6 @@ tf_cuda_cc_test( ], # TODO(b/136478427): Figure out how to correctly shut the server down args = ["--heap_check=local"], - extra_copts = tfe_xla_copts(), tags = [ "no_windows", ], @@ -729,7 +713,7 @@ tf_cuda_library( "c_api_experimental.h", "c_api_unified_experimental.h", ], - copts = tf_copts() + tfe_xla_copts(), + copts = tf_copts(), visibility = ["//visibility:public"], deps = select({ "//tensorflow:android": [ @@ -801,7 +785,6 @@ tf_cuda_cc_test( "c_api_experimental_test.cc", ], args = ["--heap_check=local"], - extra_copts = tfe_xla_copts(), linkstatic = tf_kernel_tests_linkstatic(), tags = tf_cuda_tests_tags() + ["nomac"], deps = [ @@ -825,7 +808,6 @@ tf_cuda_cc_test( "c_api_unified_experimental_test.cc", ], args = ["--heap_check=local"], - extra_copts = tfe_xla_copts(), linkstatic = tf_kernel_tests_linkstatic(), tags = tf_cuda_tests_tags() + ["nomac"], deps = [ diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 2e4d230f5ee..2920c94b4c2 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -51,9 +51,6 @@ limitations under the License. #include "tensorflow/core/protobuf/device_filters.pb.h" #include "tensorflow/core/protobuf/error_codes.pb.h" #include "tensorflow/core/util/device_name_utils.h" -#ifdef TENSORFLOW_EAGER_USE_XLA -#include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#endif // TENSORFLOW_EAGER_USE_XLA #include "tensorflow/core/common_runtime/copy_tensor.h" #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/device_mgr.h" @@ -1148,26 +1145,23 @@ void TFE_DeleteOp(TFE_Op* op) { tensorflow::unwrap(op)->Release(); } +const char* TFE_OpGetName(const TFE_Op* op, TF_Status* status) { + return tensorflow::unwrap(op)->Name().c_str(); +} + +TFE_Context* TFE_OpGetContext(const TFE_Op* op, TF_Status* status) { + return tensorflow::wrap( + &(OperationFromInterface(tensorflow::unwrap(op))->EagerContext())); +} + void TFE_OpSetDevice(TFE_Op* op, const char* device_name, TF_Status* status) { status->status = tensorflow::unwrap(op)->SetDeviceName(device_name); } -const char* TFE_OpGetDevice(TFE_Op* op, TF_Status* status) { +const char* TFE_OpGetDevice(const TFE_Op* op, TF_Status* status) { return tensorflow::unwrap(op)->DeviceName().c_str(); } -void TFE_OpSetXLACompilation(TFE_Op* op, unsigned char enable) { -#ifdef TENSORFLOW_EAGER_USE_XLA - tensorflow::Status s = tensorflow::unwrap(op)->SetUseXla(enable); - if (!s.ok()) { - LOG(ERROR) << "Could not enable XLA compilation for op: " << s; - } -#else - LOG(WARNING) << "This call is a no-op, as the TensorFlow library is not " - "built with XLA support."; -#endif // TENSORFLOW_EAGER_USE_XLA -} - void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* input, TF_Status* status) { status->status = tensorflow::unwrap(op)->AddInput(tensorflow::unwrap(input)); } @@ -1180,6 +1174,15 @@ void TFE_OpAddInputList(TFE_Op* op, TFE_TensorHandle** inputs, int num_inputs, static_cast(num_inputs)}); } +extern int TFE_OpGetFlatInputCount(const TFE_Op* op, TF_Status* status) { + return tensorflow::unwrap(op)->GetInputs().size(); +} + +extern TFE_TensorHandle* TFE_OpGetFlatInput(const TFE_Op* op, int index, + TF_Status* status) { + return tensorflow::wrap(tensorflow::unwrap(op)->GetInputs()[index]); +} + TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name, unsigned char* is_list, TF_Status* status) { TF_AttrType ret = TF_ATTR_INT; @@ -1485,7 +1488,7 @@ void TFE_ContextEndStep(TFE_Context* ctx) { tensorflow::unwrap(ctx)->EndStep(); } -const TFE_OpAttrs* TFE_OpGetAttrs(TFE_Op* op) { +const TFE_OpAttrs* TFE_OpGetAttrs(const TFE_Op* op) { return tensorflow::wrap( &OperationFromInterface(tensorflow::unwrap(op))->Attrs()); } @@ -1611,19 +1614,12 @@ class CustomDeviceAPI : public tensorflow::CustomDevice { return status.status; } - tensorflow::Status Execute(tensorflow::EagerOperation* op, + tensorflow::Status Execute(const tensorflow::EagerOperation* op, tensorflow::TensorHandle** retvals, int* num_retvals) override { - std::vector inputs; - inputs.reserve(op->Inputs().size()); - for (int i = 0; i < op->Inputs().size(); ++i) { - op->Inputs()[i]->Ref(); - inputs.push_back(tensorflow::wrap(op->Inputs()[i])); - } std::vector outputs(*num_retvals); TF_Status status; - device_.execute(context_, inputs.size(), inputs.data(), op->Name().c_str(), - wrap(&op->Attrs()), num_retvals, outputs.data(), &status, + device_.execute(tensorflow::wrap(op), num_retvals, outputs.data(), &status, info_); if (status.status.ok()) { for (int i = 0; i < *num_retvals; ++i) { @@ -1633,10 +1629,6 @@ class CustomDeviceAPI : public tensorflow::CustomDevice { TFE_DeleteTensorHandle(outputs[i]); } } - - for (auto inp : inputs) { - TFE_DeleteTensorHandle(inp); - } return status.status; } diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h index 5afe3047dd7..a58c681e8fe 100644 --- a/tensorflow/c/eager/c_api.h +++ b/tensorflow/c/eager/c_api.h @@ -248,22 +248,22 @@ typedef struct TFE_Op TFE_Op; TF_CAPI_EXPORT extern TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name, TF_Status* status); - TF_CAPI_EXPORT extern void TFE_DeleteOp(TFE_Op* op); +// Returns the op or function name `op` will execute. +// +// The returned string remains valid throughout the lifetime of 'op'. +TF_CAPI_EXPORT extern const char* TFE_OpGetName(const TFE_Op* op, + TF_Status* status); +TF_CAPI_EXPORT extern TFE_Context* TFE_OpGetContext(const TFE_Op* op, + TF_Status* status); + TF_CAPI_EXPORT extern void TFE_OpSetDevice(TFE_Op* op, const char* device_name, TF_Status* status); // The returned string remains valid throughout the lifetime of 'op'. -TF_CAPI_EXPORT extern const char* TFE_OpGetDevice(TFE_Op* op, +TF_CAPI_EXPORT extern const char* TFE_OpGetDevice(const TFE_Op* op, TF_Status* status); -// When 'enable' is set to 1, and if TensorFlow library is built with XLA -// support, a subsequent TFE_Execute() call on `op` will run the op via XLA. -// -// If the library is not built with XLA support, this call would be a no-op. -TF_CAPI_EXPORT extern void TFE_OpSetXLACompilation(TFE_Op* op, - unsigned char enable); - TF_CAPI_EXPORT extern void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* input, TF_Status* status); @@ -272,6 +272,23 @@ TF_CAPI_EXPORT extern void TFE_OpAddInputList(TFE_Op* op, int num_inputs, TF_Status* status); +// Fetches the current number of inputs attached to `op`. +// +// Does not use the operation's definition to determine how many inputs should +// be attached. It is intended for use with TFE_OpGetFlatInput to inspect an +// already-finalized operation. +// +// Note that TFE_OpGetFlatInputCount and TFE_OpGetFlatInput operate on a flat +// sequence of inputs, unlike TFE_OpGetInputLength (for getting the length of a +// particular named input list, which may only be part of the op's inputs). +TF_CAPI_EXPORT extern int TFE_OpGetFlatInputCount(const TFE_Op* op, + TF_Status* status); +// Returns a borrowed reference to one of `op`'s inputs. Use +// `TFE_TensorHandleCopySharingTensor` to make a new reference. +TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_OpGetFlatInput(const TFE_Op* op, + int index, + TF_Status* status); + TF_CAPI_EXPORT extern TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name, unsigned char* is_list, diff --git a/tensorflow/c/eager/c_api_debug.cc b/tensorflow/c/eager/c_api_debug.cc index dd55f05283b..b5721cdab0a 100644 --- a/tensorflow/c/eager/c_api_debug.cc +++ b/tensorflow/c/eager/c_api_debug.cc @@ -22,9 +22,6 @@ limitations under the License. #include "tensorflow/c/tf_status_internal.h" #include "tensorflow/core/common_runtime/eager/tensor_handle.h" #include "tensorflow/core/platform/status.h" -#ifdef TENSORFLOW_EAGER_USE_XLA -#include "tensorflow/compiler/jit/xla_device.h" -#endif // TENSORFLOW_EAGER_USE_XLA using tensorflow::string; @@ -64,87 +61,6 @@ TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo( return nullptr; } -#ifdef TENSORFLOW_EAGER_USE_XLA - auto* device = absl::get(handle->device()); - - // If tensor resides on an XLA device, use XLA device's PaddedShapeFn. - auto* xla_device = dynamic_cast(device); - if (xla_device != nullptr) { - tensorflow::XlaDevice::PaddedShapeFn shape_fn = - xla_device->metadata().padded_shape_fn(); - xla::Shape padded_shape; - status->status = shape_fn(*tensor, &padded_shape); - if (!status->status.ok()) { - return nullptr; - } - if (VLOG_IS_ON(3)) { - std::vector shape_to_log = - TensorShapeAsVector(*handle, &status->status); - if (!status->status.ok()) { - // Ignore the status here as we are simply logging. - status->status = tensorflow::Status::OK(); - } else { - VLOG(3) << "Fully padded shape of [" - << absl::StrJoin(shape_to_log, ", ") << "] is " - << padded_shape.DebugString(); - } - } - - if (padded_shape.IsTuple()) { - if (xla::ShapeUtil::TupleElementCount(padded_shape) != 2) { - // Currently, the only case of XlaTensor containing a tuple shape is to - // represent 64 bit ints, doubles, and complex numbers (we don't support - // 64bit complex numbers). - status->status = tensorflow::errors::InvalidArgument( - "XlaTensors should only contain tuples of size 2. Shape: ", - padded_shape.DebugString()); - return nullptr; - } - - // shape0 is not a const& because we will assign it to padded_shape below. - // It is illegal to assign a part of a message to itself. - xla::Shape shape0 = xla::ShapeUtil::GetTupleElementShape(padded_shape, 0); - const xla::Shape& shape1 = - xla::ShapeUtil::GetTupleElementShape(padded_shape, 1); - if (shape0.IsTuple() || shape1.IsTuple()) { - status->status = tensorflow::errors::InvalidArgument( - "XlaTensors should not contain nested tuples. Shape: ", - padded_shape.DebugString()); - return nullptr; - } - if (!xla::ShapeUtil::Equal(shape0, shape1)) { - status->status = tensorflow::errors::InvalidArgument( - "Subshapes of XlaTensors should be the same. Shape: ", - padded_shape.DebugString()); - return nullptr; - } - - // Since the only case we handle here are two equal subshapes, we - // simply return one of them. The caller will interpret it as this - // shape directly storing the 64bit types. This approximation is good - // enough for this API's debugging use case. - padded_shape = shape0; - } - - int rank = padded_shape.dimensions_size(); - std::vector dev_dims; - dev_dims.reserve(rank); - if (rank == 1) { - // Rank 1 tensors might not have padded_shape.layout.minor_to_major set, - dev_dims.push_back(padded_shape.dimensions(0)); - } else { - for (int i = rank - 1; i >= 0; --i) { - tensorflow::int64 dim_index = padded_shape.layout().minor_to_major(i); - dev_dims.push_back(padded_shape.dimensions(dim_index)); - } - } - status->status = tensorflow::Status::OK(); - return new TFE_TensorDebugInfo(dev_dims); - } -#endif // TENSORFLOW_EAGER_USE_XLA - - // If the tensor is not an XLA tensor, the device shape is - // the same as regular tensor shape. std::vector dev_dims = TensorShapeAsVector(*handle, &status->status); if (!status->status.ok()) { diff --git a/tensorflow/c/eager/c_api_experimental.h b/tensorflow/c/eager/c_api_experimental.h index 37f868468e4..12546c6082a 100644 --- a/tensorflow/c/eager/c_api_experimental.h +++ b/tensorflow/c/eager/c_api_experimental.h @@ -414,7 +414,7 @@ typedef struct TFE_OpAttrs TFE_OpAttrs; // Fetch a reference to `op`'s attributes. The returned reference is only valid // while `op` is alive. -const TFE_OpAttrs* TFE_OpGetAttrs(TFE_Op* op); +TF_CAPI_EXPORT extern const TFE_OpAttrs* TFE_OpGetAttrs(const TFE_Op* op); // Add attributes in `attrs` to `op`. // // Does not overwrite or update existing attributes, but adds new ones. @@ -435,7 +435,11 @@ TF_CAPI_EXPORT extern void TFE_OpSetAttrValueProto(const TFE_Op* op, size_t proto_len, TF_Status* status); -#define TFE_CUSTOM_DEVICE_VERSION 2 +// TODO(b/166642410): It would be nice, for custom devices and for other users, +// to have a non-string representation of devices (TF_Device) extracted from +// tensors/ops/etc. and usable in APIs like OpSetDevice/ResetOp/etc. + +#define TFE_CUSTOM_DEVICE_VERSION 3 // Struct to be filled in typedef struct TFE_CustomDevice { @@ -454,9 +458,16 @@ typedef struct TFE_CustomDevice { void* device_info); // Method to execute an operation. - void (*execute)(TFE_Context* context, int num_inputs, - TFE_TensorHandle** inputs, const char* operation_name, - const TFE_OpAttrs* attributes, int* num_outputs, + // + // Arguments provide enough information to reconstruct the original `TFE_Op`, + // or construct a transformed version, by inspecting the passed `op`. + // + // TFE_OpGetDevice(op) records the original placement of the operation. It may + // be an empty string if no device was explicitly requested, but will + // otherwise be the name of this custom device. Ops are placed onto a custom + // device if any of their inputs are on that custom device, but custom devices + // are free to set a bad status in order to require explicit placement. + void (*execute)(const TFE_Op* op, int* num_outputs, TFE_TensorHandle** outputs, TF_Status* s, void* device_info); // Method to delete a device. diff --git a/tensorflow/c/eager/c_api_experimental_test.cc b/tensorflow/c/eager/c_api_experimental_test.cc index a4d31417073..4975d303375 100644 --- a/tensorflow/c/eager/c_api_experimental_test.cc +++ b/tensorflow/c/eager/c_api_experimental_test.cc @@ -316,86 +316,6 @@ TEST(CAPI, Function_ident_CPU) { TF_DeleteStatus(status); } -#ifdef TENSORFLOW_EAGER_USE_XLA -TEST(CAPI, Function_ident_XLA_CPU) { - // First create a simple identity function. - TF_Graph* function_graph = TF_NewGraph(); - TF_OperationDescription* arg_descr = - TF_NewOperation(function_graph, "Placeholder", "arg"); - TF_SetAttrType(arg_descr, "dtype", TF_INT32); - TF_Status* status = TF_NewStatus(); - TF_Operation* arg = TF_FinishOperation(arg_descr, status); - ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); - TF_OperationDescription* id_descr = - TF_NewOperation(function_graph, "Identity", "id"); - TF_SetAttrType(id_descr, "T", TF_INT32); - TF_AddInput(id_descr, {arg, 0}); - TF_Operation* id = TF_FinishOperation(id_descr, status); - ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); - TF_Output input{arg, 0}; - TF_Output output{id, 0}; - TF_Function* fn = - TF_GraphToFunction(function_graph, "ident", 0, 1, &id, 1, &input, 1, - &output, nullptr, nullptr, "test", status); - ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); - TF_DeleteGraph(function_graph); - TFE_ContextOptions* opts = TFE_NewContextOptions(); - TFE_Context* ctx = TFE_NewContext(opts, status); - ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); - TFE_DeleteContextOptions(opts); - TFE_ContextAddFunction(ctx, fn, status); - ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); - TF_DeleteFunction(fn); - - for (bool async : {false, true, false}) { - TFE_Executor* old_executor = TFE_ContextGetExecutorForThread(ctx); - TFE_Executor* executor = TFE_NewExecutor(async); - TFE_ContextSetExecutorForThread(ctx, executor); - CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - ASSERT_TRUE(TF_GetCode(status) == TF_OK); - TF_Tensor* t = - TF_AllocateTensor(TF_INT32, nullptr, 0, 1 * sizeof(tensorflow::int32)); - *reinterpret_cast(TF_TensorData(t)) = 42; - TFE_TensorHandle* h = TFE_NewTensorHandle(t, status); - ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); - TF_DeleteTensor(t); - - TFE_Op* op = TFE_NewOp(ctx, "ident", status); - ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); - TFE_OpAddInput(op, h, status); - ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); - - // Now run it via XLA. - TFE_OpSetXLACompilation(op, true); - - std::vector result; - result.push_back(nullptr); - int num_retvals = 1; - TFE_Execute(op, result.data(), &num_retvals, status); - TFE_DeleteOp(op); - ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); - ASSERT_EQ(num_retvals, 1); - - TF_Tensor* r = TFE_TensorHandleResolve(result[0], status); - ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); - EXPECT_EQ(*reinterpret_cast(TF_TensorData(r)), 42); - TFE_ContextSetExecutorForThread(ctx, old_executor); - TFE_ExecutorWaitForAllPendingNodes(executor, status); - ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TFE_DeleteExecutor(executor); - TFE_DeleteExecutor(old_executor); - TFE_DeleteTensorHandle(h); - TF_DeleteTensor(r); - TFE_DeleteTensorHandle(result[0]); - } - TFE_ContextRemoveFunction(ctx, "ident", status); - ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); - TFE_DeleteContext(ctx); - ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); - TF_DeleteStatus(status); -} -#endif // TENSORFLOW_EAGER_USE_XLA - void Executor_MatMul_CPU(bool async) { TF_Status* status = TF_NewStatus(); TFE_ContextOptions* opts = TFE_NewContextOptions(); diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc index 724176505ba..37bb9c5f16b 100644 --- a/tensorflow/c/eager/c_api_test.cc +++ b/tensorflow/c/eager/c_api_test.cc @@ -876,89 +876,6 @@ TEST(CAPI, Execute_Min_CPU) { TF_DeleteStatus(status); } -#ifdef TENSORFLOW_EAGER_USE_XLA -void Execute_MatMul_XLA_CPU(bool async) { - TF_Status* status = TF_NewStatus(); - TFE_ContextOptions* opts = TFE_NewContextOptions(); - TFE_ContextOptionsSetAsync(opts, static_cast(async)); - TFE_Context* ctx = TFE_NewContext(opts, status); - CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TFE_DeleteContextOptions(opts); - - TFE_TensorHandle* m = TestMatrixTensorHandle(ctx); - TFE_Op* matmul = MatMulOp(ctx, m, m); - - TFE_OpSetXLACompilation(matmul, true); - - TFE_TensorHandle* retvals[1] = {nullptr}; - int num_retvals = 1; - TFE_Execute(matmul, &retvals[0], &num_retvals, status); - // Running a primitive TF operator via XLA is not yet supported. - ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - - TFE_DeleteOp(matmul); - TFE_DeleteTensorHandle(m); - ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - - EXPECT_EQ(1, num_retvals); - - TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status); - TFE_DeleteTensorHandle(retvals[0]); - ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - float product[4] = {0}; - EXPECT_EQ(sizeof(product), TF_TensorByteSize(t)); - memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t)); - TF_DeleteTensor(t); - EXPECT_EQ(7, product[0]); - EXPECT_EQ(10, product[1]); - EXPECT_EQ(15, product[2]); - EXPECT_EQ(22, product[3]); - TFE_DeleteContext(ctx); - TF_DeleteStatus(status); -} -TEST(CAPI, Execute_MatMul_XLA_CPU) { Execute_MatMul_XLA_CPU(false); } -TEST(CAPI, Execute_MatMul_XLA_CPUAsync) { Execute_MatMul_XLA_CPU(true); } - -void Execute_Min_XLA_CPU(bool async) { - TF_Status* status = TF_NewStatus(); - TFE_ContextOptions* opts = TFE_NewContextOptions(); - TFE_ContextOptionsSetAsync(opts, static_cast(async)); - TFE_Context* ctx = TFE_NewContext(opts, status); - CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TFE_DeleteContextOptions(opts); - - TFE_TensorHandle* input = TestMatrixTensorHandle(ctx); - TFE_TensorHandle* axis = TestAxisTensorHandle(ctx); - TFE_Op* minOp = MinOp(ctx, input, axis); - - TFE_OpSetXLACompilation(minOp, true); - - TFE_TensorHandle* retvals[1] = {nullptr}; - int num_retvals = 1; - TFE_Execute(minOp, &retvals[0], &num_retvals, status); - EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TFE_DeleteOp(minOp); - TFE_DeleteTensorHandle(input); - TFE_DeleteTensorHandle(axis); - ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - ASSERT_EQ(1, num_retvals); - - TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status); - TFE_DeleteTensorHandle(retvals[0]); - ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - float output[2] = {0}; - EXPECT_EQ(sizeof(output), TF_TensorByteSize(t)); - memcpy(&output[0], TF_TensorData(t), TF_TensorByteSize(t)); - TF_DeleteTensor(t); - EXPECT_EQ(1, output[0]); - EXPECT_EQ(3, output[1]); - TFE_DeleteContext(ctx); - TF_DeleteStatus(status); -} -TEST(CAPI, Execute_Min_XLA_CPU) { Execute_Min_XLA_CPU(false); } -TEST(CAPI, Execute_Min_XLA_CPUAsync) { Execute_Min_XLA_CPU(true); } -#endif // TENSORFLOW_EAGER_USE_XLA - void ExecuteWithTracing(bool async) { TF_Status* status = TF_NewStatus(); TFE_ContextOptions* opts = TFE_NewContextOptions(); @@ -1620,4 +1537,91 @@ TEST(CAPI, TestTFE_OpAttrsSerialize) { TFE_DeleteContext(ctx); } +// Needs to work with a const TFE_Op since custom devices should not modify the +// op they are called with. +TFE_Op* CloneOp(const TFE_Op* other) { + TF_Status* status = TF_NewStatus(); + TFE_Context* context = TFE_OpGetContext(other, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + const char* op_name = TFE_OpGetName(other, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_Op* ret = TFE_NewOp(context, op_name, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + const char* device = TFE_OpGetDevice(other, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_OpSetDevice(ret, device, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_OpAddAttrs(ret, TFE_OpGetAttrs(other)); + int num_inputs = TFE_OpGetFlatInputCount(other, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + for (int input_index = 0; input_index < num_inputs; ++input_index) { + TFE_TensorHandle* input = TFE_OpGetFlatInput(other, input_index, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_OpAddInput(ret, input, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + } + TF_DeleteStatus(status); + return ret; +} + +TEST(CAPI, TestTFE_OpRecreation) { + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_Context* ctx = TFE_NewContext(opts, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + // Clone an op with attributes and a device set. + TFE_Op* original_var_op = TFE_NewOp(ctx, "VarHandleOp", status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_OpSetAttrType(original_var_op, "dtype", TF_INT64); + TFE_OpSetAttrShape(original_var_op, "shape", {}, 0, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + EXPECT_EQ("", std::string(TFE_OpGetDevice(original_var_op, status))); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_OpSetDevice(original_var_op, + "/job:localhost/replica:0/task:0/device:CPU:0", status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_Op* cloned = CloneOp(original_var_op); + + EXPECT_EQ("/job:localhost/replica:0/task:0/device:CPU:0", + std::string(TFE_OpGetDevice(cloned, status))); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + EXPECT_EQ("VarHandleOp", std::string(TFE_OpGetName(cloned, status))); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + int num_retvals = 1; + TFE_TensorHandle* ret; + TFE_Execute(cloned, &ret, &num_retvals, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteTensorHandle(ret); + + // Clone an op with inputs and no device set. + TFE_TensorHandle* input1 = TestMatrixTensorHandle(ctx); + TFE_TensorHandle* input2 = TestMatrixTensorHandle(ctx); + TFE_Op* original_identity = TFE_NewOp(ctx, "IdentityN", status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_TensorHandle* inputs[] = {input1, input2}; + TFE_OpAddInputList(original_identity, inputs, 2, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_Op* cloned_identity = CloneOp(original_identity); + EXPECT_EQ("", std::string(TFE_OpGetDevice(cloned_identity, status))); + TFE_TensorHandle* identity_ret[] = {nullptr, nullptr}; + num_retvals = 2; + TFE_Execute(cloned_identity, identity_ret, &num_retvals, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TFE_DeleteTensorHandle(input1); + TFE_DeleteTensorHandle(input2); + TFE_DeleteTensorHandle(identity_ret[0]); + TFE_DeleteTensorHandle(identity_ret[1]); + + TFE_DeleteOp(cloned_identity); + TFE_DeleteOp(original_identity); + TFE_DeleteOp(original_var_op); + TFE_DeleteOp(cloned); + TF_DeleteStatus(status); + TFE_DeleteContext(ctx); +} + } // namespace diff --git a/tensorflow/c/eager/custom_device_test.cc b/tensorflow/c/eager/custom_device_test.cc index 1c078d4f42c..b058c79a17b 100644 --- a/tensorflow/c/eager/custom_device_test.cc +++ b/tensorflow/c/eager/custom_device_test.cc @@ -36,7 +36,8 @@ TEST(CUSTOM_DEVICE, RegisterSimpleDevice) { bool arrived = false; bool executed = false; const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0"; - RegisterLoggingDevice(context, name, &arrived, &executed, status.get()); + RegisterLoggingDevice(context, name, /*strict_scope_placement=*/true, + &arrived, &executed, status.get()); ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); TFE_TensorHandle* hcpu = TestMatrixTensorHandle(context); ASSERT_FALSE(arrived); @@ -73,7 +74,8 @@ TEST(CUSTOM_DEVICE, ResetOperation) { bool executed = false; const char* custom_device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0"; - RegisterLoggingDevice(context.get(), custom_device_name, &arrived, &executed, + RegisterLoggingDevice(context.get(), custom_device_name, + /*strict_scope_placement=*/true, &arrived, &executed, status.get()); ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); @@ -103,7 +105,8 @@ TEST(CUSTOM_DEVICE, MakeVariable) { bool arrived = false; bool executed = false; const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0"; - RegisterLoggingDevice(context.get(), name, &arrived, &executed, status.get()); + RegisterLoggingDevice(context.get(), name, /*strict_scope_placement=*/true, + &arrived, &executed, status.get()); ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); // Create a variable handle placed on the custom device. @@ -187,7 +190,8 @@ TEST(CUSTOM_DEVICE, AccessVariableOnCustomDevice) { bool arrived = false; bool executed = false; const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0"; - RegisterLoggingDevice(context.get(), name, &arrived, &executed, status.get()); + RegisterLoggingDevice(context.get(), name, /*strict_scope_placement=*/false, + &arrived, &executed, status.get()); ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); // Create a variable handle placed on the custom device. @@ -264,10 +268,12 @@ TEST(CUSTOM_DEVICE, InputBasedPlacement) { const char* custom1 = "/job:localhost/replica:0/task:0/device:CUSTOM:1"; bool arrived = false; bool executed = false; - RegisterLoggingDevice(context.get(), custom0, &arrived, &executed, + RegisterLoggingDevice(context.get(), custom0, + /*strict_scope_placement=*/false, &arrived, &executed, status.get()); ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); - RegisterLoggingDevice(context.get(), custom1, &arrived, &executed, + RegisterLoggingDevice(context.get(), custom1, + /*strict_scope_placement=*/true, &arrived, &executed, status.get()); ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); @@ -314,14 +320,34 @@ TEST(CUSTOM_DEVICE, InputBasedPlacement) { ASSERT_TRUE(absl::StrContains(TF_Message(status.get()), custom0)); ASSERT_TRUE(absl::StrContains(TF_Message(status.get()), custom1)); - // Custom device: mix of custom/physical fails. + // Custom device: mix of custom/physical places the op on the custom device. matmul.reset(MatMulOp(context.get(), hcustom0.get(), hcpu.get())); num_retvals = 1; + executed = false; TFE_Execute(matmul.get(), &retval, &num_retvals, status.get()); - ASSERT_NE(TF_OK, TF_GetCode(status.get())); - ASSERT_TRUE(absl::StrContains(TF_Message(status.get()), custom0)); - ASSERT_TRUE( - absl::StrContains(TF_Message(status.get()), "[]")); // kVariantDeviceNull + EXPECT_TRUE(executed); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + TFE_DeleteTensorHandle(retval); + + // Explicit placement still forces the op onto the requested device + matmul.reset(MatMulOp(context.get(), hcustom0.get(), hcpu.get())); + TFE_OpSetDevice(matmul.get(), "/job:localhost/replica:0/task:0/device:CPU:0", + status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + num_retvals = 1; + executed = false; + TFE_Execute(matmul.get(), &retval, &num_retvals, status.get()); + EXPECT_FALSE(executed); + ASSERT_FALSE(TF_GetCode(status.get()) == TF_OK); + + // Custom devices can refuse to do type-based dispatch (as hcustom1 is + // configured to do) + matmul.reset(MatMulOp(context.get(), hcustom1.get(), hcpu.get())); + num_retvals = 1; + executed = false; + TFE_Execute(matmul.get(), &retval, &num_retvals, status.get()); + EXPECT_FALSE(executed); + ASSERT_FALSE(TF_GetCode(status.get()) == TF_OK); } TEST(CUSTOM_DEVICE, InvalidRegistrationError) { @@ -334,21 +360,24 @@ TEST(CUSTOM_DEVICE, InvalidRegistrationError) { ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); bool arrived = false; bool executed = false; - RegisterLoggingDevice(context.get(), "/device:CUSTOM:0", &arrived, &executed, + RegisterLoggingDevice(context.get(), "/device:CUSTOM:0", + /*strict_scope_placement=*/true, &arrived, &executed, status.get()); ASSERT_TRUE(TF_GetCode(status.get()) == TF_INVALID_ARGUMENT) << TF_Message(status.get()); const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0"; - RegisterLoggingDevice(context.get(), name, &arrived, &executed, status.get()); + RegisterLoggingDevice(context.get(), name, /*strict_scope_placement=*/true, + &arrived, &executed, status.get()); ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); - RegisterLoggingDevice(context.get(), name, &arrived, &executed, status.get()); - ASSERT_TRUE(TF_GetCode(status.get()) == TF_ALREADY_EXISTS) - << TF_Message(status.get()); - - RegisterLoggingDevice(context.get(), - "/job:localhost/replica:0/task:0/device:CPU:0", + RegisterLoggingDevice(context.get(), name, /*strict_scope_placement=*/true, &arrived, &executed, status.get()); ASSERT_TRUE(TF_GetCode(status.get()) == TF_ALREADY_EXISTS) << TF_Message(status.get()); + + RegisterLoggingDevice( + context.get(), "/job:localhost/replica:0/task:0/device:CPU:0", + /*strict_scope_placement=*/true, &arrived, &executed, status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_ALREADY_EXISTS) + << TF_Message(status.get()); } diff --git a/tensorflow/c/eager/custom_device_testutil.cc b/tensorflow/c/eager/custom_device_testutil.cc index 28de3665653..014abe38368 100644 --- a/tensorflow/c/eager/custom_device_testutil.cc +++ b/tensorflow/c/eager/custom_device_testutil.cc @@ -33,6 +33,9 @@ struct LoggingDevice { bool* arrived_flag; // Set to true whenever an operation is executed bool* executed_flag; + // If true, only explicit op placements are accepted. If false, uses + // type-based dispatch. + bool strict_scope_placement; }; struct LoggedTensor { @@ -84,18 +87,35 @@ TFE_TensorHandle* CopyTensorFromLoggingDevice(TFE_Context* context, return nullptr; } -void LoggingDeviceExecute(TFE_Context* context, int num_inputs, - TFE_TensorHandle** inputs, const char* operation_name, - const TFE_OpAttrs* attributes, int* num_outputs, +void LoggingDeviceExecute(const TFE_Op* original_op, int* num_outputs, TFE_TensorHandle** outputs, TF_Status* s, void* device_info) { + const char* requested_placement = TFE_OpGetDevice(original_op, s); + if (TF_GetCode(s) != TF_OK) return; + LoggingDevice* dev = reinterpret_cast(device_info); + if (dev->strict_scope_placement && *requested_placement == '\0') { + TF_SetStatus(s, TF_INTERNAL, + "Ops must be placed on the device explicitly, or their inputs " + "first copied to other devices."); + return; + } + TFE_Context* context = TFE_OpGetContext(original_op, s); + if (TF_GetCode(s) != TF_OK) return; + const char* operation_name = TFE_OpGetName(original_op, s); + if (TF_GetCode(s) != TF_OK) return; + const TFE_OpAttrs* attributes = TFE_OpGetAttrs(original_op); + TFE_Op* op(TFE_NewOp(context, operation_name, s)); if (TF_GetCode(s) != TF_OK) return; TFE_OpAddAttrs(op, attributes); TFE_OpSetDevice(op, dev->underlying_device.c_str(), s); + if (TF_GetCode(s) != TF_OK) return; + int num_inputs = TFE_OpGetFlatInputCount(original_op, s); + if (TF_GetCode(s) != TF_OK) return; for (int j = 0; j < num_inputs; ++j) { - TFE_TensorHandle* input = inputs[j]; + TFE_TensorHandle* input = TFE_OpGetFlatInput(original_op, j, s); + if (TF_GetCode(s) != TF_OK) return; const char* input_device = TFE_TensorHandleDeviceName(input, s); if (TF_GetCode(s) != TF_OK) return; if (dev->device_name == input_device) { @@ -131,8 +151,8 @@ void DeleteLoggingDevice(void* device_info) { } // namespace void RegisterLoggingDevice(TFE_Context* context, const char* name, - bool* arrived_flag, bool* executed_flag, - TF_Status* status) { + bool strict_scope_placement, bool* arrived_flag, + bool* executed_flag, TF_Status* status) { TFE_CustomDevice custom_device; custom_device.copy_tensor_to_device = &CopyToLoggingDevice; custom_device.copy_tensor_from_device = &CopyTensorFromLoggingDevice; @@ -143,6 +163,7 @@ void RegisterLoggingDevice(TFE_Context* context, const char* name, device->executed_flag = executed_flag; device->device_name = name; device->underlying_device = "/job:localhost/replica:0/task:0/device:CPU:0"; + device->strict_scope_placement = strict_scope_placement; TFE_RegisterCustomDevice(context, custom_device, name, device, status); } @@ -168,5 +189,6 @@ void AllocateLoggingDevice(const char* name, bool* arrived_flag, logging_device->device_name = name; logging_device->underlying_device = "/job:localhost/replica:0/task:0/device:CPU:0"; + logging_device->strict_scope_placement = true; *device_info = reinterpret_cast(logging_device); } diff --git a/tensorflow/c/eager/custom_device_testutil.h b/tensorflow/c/eager/custom_device_testutil.h index 509df7d3e3e..a7c60080adf 100644 --- a/tensorflow/c/eager/custom_device_testutil.h +++ b/tensorflow/c/eager/custom_device_testutil.h @@ -25,8 +25,8 @@ limitations under the License. #include "tensorflow/c/tf_status.h" void RegisterLoggingDevice(TFE_Context* context, const char* name, - bool* arrived_flag, bool* executed_flag, - TF_Status* status); + bool strict_scope_placement, bool* arrived_flag, + bool* executed_flag, TF_Status* status); void AllocateLoggingDevice(const char* name, bool* arrived_flag, bool* executed_flag, TFE_CustomDevice** device, void** device_info); diff --git a/tensorflow/c/eager/gradients.cc b/tensorflow/c/eager/gradients.cc index 9bcd0d0fea0..89ff140fa73 100644 --- a/tensorflow/c/eager/gradients.cc +++ b/tensorflow/c/eager/gradients.cc @@ -242,6 +242,7 @@ namespace internal { Status Reset(AbstractOperation* op_, const char* op, const char* raw_device_name, ForwardOperation* forward_op_) { forward_op_->op_name = op; + forward_op_->attrs.Reset(op); return op_->Reset(op, raw_device_name); } Status AddInput(AbstractOperation* op_, AbstractTensorHandle* input, @@ -418,6 +419,11 @@ Status Execute(AbstractOperation* op_, AbstractContext* ctx, // TODO(srbs): Manage refcount of ForwardOperation's inputs/outputs. forward_op_->outputs.push_back(retvals[i]); } + // TODO(b/166669239): This is needed to support AttrBuilder::Get for string + // attributes. Number type attrs and DataType attrs work fine without this. + // Consider getting rid of this and making the behavior between number types + // and string consistent. + forward_op_->attrs.BuildNodeDef(); std::vector tape_tensors; for (auto t : retvals) { tape_tensors.push_back(TapeTensor(t, ctx)); diff --git a/tensorflow/c/eager/gradients_test.cc b/tensorflow/c/eager/gradients_test.cc index 80b1f157074..56f0b847002 100644 --- a/tensorflow/c/eager/gradients_test.cc +++ b/tensorflow/c/eager/gradients_test.cc @@ -507,6 +507,57 @@ TEST_P(CppGradients, TestIdentityNGrad) { result_tensor = nullptr; } +TEST_P(CppGradients, TestSetAttrString) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + AbstractContextPtr ctx; + { + AbstractContext* ctx_raw = nullptr; + Status s = + BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + ctx.reset(ctx_raw); + } + + AbstractTensorHandlePtr t; + { + AbstractTensorHandle* x_raw = nullptr; + Status s = TestScalarTensorHandle(ctx.get(), 1.0f, &x_raw); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + t.reset(x_raw); + } + + AbstractOperationPtr check_numerics_op(ctx->CreateOperation()); + ForwardOperation forward_op; + forward_op.ctx = ctx.get(); + Status s = Reset(check_numerics_op.get(), "CheckNumerics", + /*raw_device_name=*/nullptr, &forward_op); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + if (isa(check_numerics_op.get())) { + s = dyn_cast(check_numerics_op.get()) + ->SetOpName("check_numerics"); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + } + s = AddInput(check_numerics_op.get(), t.get(), &forward_op); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + string message = "This is the way!"; + s = SetAttrString(check_numerics_op.get(), "message", message.data(), + message.length(), &forward_op); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + int num_retvals = 1; + std::vector outputs(1); + GradientRegistry registry; + std::unique_ptr tape(new Tape(/*persistent=*/false)); + s = Execute(check_numerics_op.get(), ctx.get(), absl::MakeSpan(outputs), + &num_retvals, &forward_op, tape.get(), registry); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + + string read_message; + s = forward_op.attrs.Get("message", &read_message); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + ASSERT_EQ(read_message, message); +} + // TODO(b/164171226): Enable this test with tfrt after AddInputList is // supported. It is needed for IdentityN. #ifdef PLATFORM_GOOGLE diff --git a/tensorflow/c/eager/immediate_execution_operation.h b/tensorflow/c/eager/immediate_execution_operation.h index ee212b21a96..7b68ec2c9f4 100644 --- a/tensorflow/c/eager/immediate_execution_operation.h +++ b/tensorflow/c/eager/immediate_execution_operation.h @@ -47,9 +47,6 @@ class ImmediateExecutionOperation : public AbstractOperation { virtual Status InputLength(const char* input_name, int* length) = 0; virtual Status OutputLength(const char* output_name, int* length) = 0; - // Experimental - virtual Status SetUseXla(bool enable) = 0; - // Set stack trace to be used for potential async error reporting. virtual void SetStackTrace(AbstractStackTrace stack_trace) = 0; diff --git a/tensorflow/c/eager/mnist_gradients_test.cc b/tensorflow/c/eager/mnist_gradients_test.cc index d6dd94806a7..1f04e25820e 100644 --- a/tensorflow/c/eager/mnist_gradients_test.cc +++ b/tensorflow/c/eager/mnist_gradients_test.cc @@ -765,13 +765,13 @@ TEST_P(CppGradients, TestMNIST_Training) { #ifdef PLATFORM_GOOGLE INSTANTIATE_TEST_SUITE_P( UnifiedCAPI, CppGradients, - ::testing::Combine(::testing::Values("graphdef"), + ::testing::Combine(::testing::Values("graphdef", "mlir"), /*tfrt*/ ::testing::Values(false), /*executing_eagerly*/ ::testing::Values(true, false))); #else INSTANTIATE_TEST_SUITE_P( UnifiedCAPI, CppGradients, - ::testing::Combine(::testing::Values("graphdef"), + ::testing::Combine(::testing::Values("graphdef", "mlir"), /*tfrt*/ ::testing::Values(false), /*executing_eagerly*/ ::testing::Values(true, false))); #endif diff --git a/tensorflow/c/eager/mnist_gradients_testutil.cc b/tensorflow/c/eager/mnist_gradients_testutil.cc index 4b2c87c678d..9f5d0d149d4 100644 --- a/tensorflow/c/eager/mnist_gradients_testutil.cc +++ b/tensorflow/c/eager/mnist_gradients_testutil.cc @@ -31,11 +31,15 @@ limitations under the License. #include "tensorflow/c/tf_tensor.h" #include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h" -using std::vector; -using tracing::TracingOperation; - // ========================== Tape Ops ============================== +namespace tensorflow { +namespace gradients { +namespace internal { + +using std::vector; +using tensorflow::tracing::TracingOperation; + // Computes `inputs[0] + inputs[1]` and records it on the tape. Status Add(AbstractContext* ctx, Tape* tape, absl::Span inputs, @@ -272,6 +276,7 @@ Status MNISTForwardModel(AbstractContext* ctx, AbstractTensorHandle* scores = temp_outputs[0]; + temp_outputs.resize(2); TF_RETURN_IF_ERROR(SparseSoftmaxCrossEntropyLoss( ctx, tape, {scores, y_labels}, absl::MakeSpan(temp_outputs), "softmax_loss", registry)); // Compute Softmax(Scores,labels) @@ -592,3 +597,7 @@ Status BuildImmediateExecutionContext(bool use_tfrt, AbstractContext** ctx) { TFE_DeleteContextOptions(opts); return Status::OK(); } + +} // namespace internal +} // namespace gradients +} // namespace tensorflow diff --git a/tensorflow/c/eager/mnist_gradients_testutil.h b/tensorflow/c/eager/mnist_gradients_testutil.h index b6de8ff6788..efe196e9ba3 100644 --- a/tensorflow/c/eager/mnist_gradients_testutil.h +++ b/tensorflow/c/eager/mnist_gradients_testutil.h @@ -27,13 +27,13 @@ limitations under the License. #include "tensorflow/c/tf_status_helper.h" #include "tensorflow/c/tf_tensor.h" #include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h" - -using namespace tensorflow; -using namespace tensorflow::gradients; -using namespace tensorflow::gradients::internal; +#include "tensorflow/core/platform/status.h" // ========================== Tape Ops ============================== +namespace tensorflow { +namespace gradients { +namespace internal { // Computes `inputs[0] + inputs[1]` and records it on the tape. Status Add(AbstractContext* ctx, Tape* tape, absl::Span inputs, @@ -144,3 +144,7 @@ Status RunModel(Model model, AbstractContext* ctx, const GradientRegistry& registry); Status BuildImmediateExecutionContext(bool use_tfrt, AbstractContext** ctx); + +} // namespace internal +} // namespace gradients +} // namespace tensorflow diff --git a/tensorflow/c/eager/parallel_device/parallel_device.cc b/tensorflow/c/eager/parallel_device/parallel_device.cc index d0e9f351478..b9d7be7f8ea 100644 --- a/tensorflow/c/eager/parallel_device/parallel_device.cc +++ b/tensorflow/c/eager/parallel_device/parallel_device.cc @@ -255,28 +255,44 @@ TFE_TensorHandle* CopyTensorFromParallelDevice(TFE_Context* context, // Since this function is used to satisfy the TFE_CustomDevice C API, // device_info is passed in using a C-style generic. It must always be a // ParallelDevice. -void ParallelDeviceExecute(TFE_Context* context, int num_inputs, - TFE_TensorHandle** inputs, - const char* operation_name, - const TFE_OpAttrs* attributes, int* num_outputs, +void ParallelDeviceExecute(const TFE_Op* original_op, int* num_outputs, TFE_TensorHandle** outputs, TF_Status* status, void* device_info) { + const char* requested_placement = TFE_OpGetDevice(original_op, status); + if (*requested_placement == '\0') { + TF_SetStatus( + status, TF_INTERNAL, + "Ops must be placed on the parallel device explicitly, or their inputs " + "first un-packed. Got an un-placed op with an input placed on the " + "parallel device."); + return; + } + TFE_Context* context = TFE_OpGetContext(original_op, status); + if (TF_GetCode(status) != TF_OK) return; + const char* operation_name = TFE_OpGetName(original_op, status); + if (TF_GetCode(status) != TF_OK) return; + const TFE_OpAttrs* attributes = TFE_OpGetAttrs(original_op); + NamedParallelDevice* named_device = reinterpret_cast(device_info); std::vector typed_inputs; + int num_inputs = TFE_OpGetFlatInputCount(original_op, status); + if (TF_GetCode(status) != TF_OK) return; typed_inputs.reserve(num_inputs); for (int i = 0; i < num_inputs; ++i) { + TFE_TensorHandle* input = TFE_OpGetFlatInput(original_op, i, status); + if (TF_GetCode(status) != TF_OK) return; const char* tensor_handle_device = - TFE_TensorHandleDeviceName(inputs[i], status); + TFE_TensorHandleDeviceName(input, status); if (TF_GetCode(status) != TF_OK) return; if (named_device->name() == tensor_handle_device) { // We assume that any tensors already placed on this device are // ParallelTensors. typed_inputs.emplace_back(reinterpret_cast( - TFE_TensorHandleDevicePointer(inputs[i], status))); + TFE_TensorHandleDevicePointer(input, status))); if (TF_GetCode(status) != TF_OK) return; } else { - typed_inputs.emplace_back(inputs[i]); + typed_inputs.emplace_back(input); } } diff --git a/tensorflow/c/experimental/filesystem/plugins/gcs/BUILD b/tensorflow/c/experimental/filesystem/plugins/gcs/BUILD index 68875d61e47..3965359d0f6 100644 --- a/tensorflow/c/experimental/filesystem/plugins/gcs/BUILD +++ b/tensorflow/c/experimental/filesystem/plugins/gcs/BUILD @@ -29,6 +29,7 @@ cc_library( ":gcs_helper", ":ram_file_block_cache", "//tensorflow/c:env", + "//tensorflow/c:logging", "//tensorflow/c:tf_status", "//tensorflow/c/experimental/filesystem:filesystem_interface", "@com_github_googlecloudplatform_google_cloud_cpp//:storage_client", diff --git a/tensorflow/c/experimental/filesystem/plugins/gcs/gcs_filesystem.cc b/tensorflow/c/experimental/filesystem/plugins/gcs/gcs_filesystem.cc index e01af918100..5285989d6f1 100644 --- a/tensorflow/c/experimental/filesystem/plugins/gcs/gcs_filesystem.cc +++ b/tensorflow/c/experimental/filesystem/plugins/gcs/gcs_filesystem.cc @@ -23,6 +23,7 @@ limitations under the License. #include "google/cloud/storage/client.h" #include "tensorflow/c/env.h" #include "tensorflow/c/experimental/filesystem/plugins/gcs/gcs_helper.h" +#include "tensorflow/c/logging.h" #include "tensorflow/c/tf_status.h" // Implementation of a filesystem for GCS environments. @@ -134,6 +135,8 @@ static int64_t LoadBufferFromGCS(const std::string& path, size_t offset, } // `TF_OUT_OF_RANGE` isn't considered as an error. So we clear it here. TF_SetStatus(status, TF_OK, ""); + TF_VLog(1, "Successful read of %s @ %u of size: %u", path.c_str(), offset, + read); stream.read(buffer, read); read = stream.gcount(); if (read < buffer_size) { @@ -146,6 +149,8 @@ static int64_t LoadBufferFromGCS(const std::string& path, size_t offset, path, " @ ", offset) .c_str()); } + TF_VLog(2, "Successful integrity check for: %s @ %u", path.c_str(), + offset); } } return read; @@ -284,6 +289,8 @@ static void SyncImpl(const std::string& bucket, const std::string& object, TF_SetStatusFromGCSStatus(metadata.status(), status); return; } + TF_VLog(3, "AppendObject: gs://%s/%s to gs://%s/%s", bucket.c_str(), + temporary_object.c_str(), bucket.c_str(), object.c_str()); const std::vector source_objects = { {object, {}, {}}, {temporary_object, {}, {}}}; metadata = gcs_client->ComposeObject(bucket, source_objects, object); @@ -321,6 +328,8 @@ void Append(const TF_WritableFile* file, const char* buffer, size_t n, "The internal temporary file is not writable."); return; } + TF_VLog(3, "Append: gs://%s/%s size %u", gcs_file->bucket.c_str(), + gcs_file->object.c_str(), n); gcs_file->sync_need = true; gcs_file->outfile.write(buffer, n); if (!gcs_file->outfile) @@ -346,6 +355,8 @@ int64_t Tell(const TF_WritableFile* file, TF_Status* status) { void Flush(const TF_WritableFile* file, TF_Status* status) { auto gcs_file = static_cast(file->plugin_file); if (gcs_file->sync_need) { + TF_VLog(3, "Flush started: gs://%s/%s", gcs_file->bucket.c_str(), + gcs_file->object.c_str()); if (!gcs_file->outfile) { TF_SetStatus(status, TF_INTERNAL, "Could not append to the internal temporary file."); @@ -353,6 +364,8 @@ void Flush(const TF_WritableFile* file, TF_Status* status) { } SyncImpl(gcs_file->bucket, gcs_file->object, &gcs_file->offset, &gcs_file->outfile, gcs_file->gcs_client, status); + TF_VLog(3, "Flush finished: gs://%s/%s", gcs_file->bucket.c_str(), + gcs_file->object.c_str()); if (TF_GetCode(status) != TF_OK) return; gcs_file->sync_need = false; } else { @@ -361,11 +374,16 @@ void Flush(const TF_WritableFile* file, TF_Status* status) { } void Sync(const TF_WritableFile* file, TF_Status* status) { + auto gcs_file = static_cast(file->plugin_file); + TF_VLog(3, "Sync: gs://%s/%s", gcs_file->bucket.c_str(), + gcs_file->object.c_str()); Flush(file, status); } void Close(const TF_WritableFile* file, TF_Status* status) { auto gcs_file = static_cast(file->plugin_file); + TF_VLog(3, "Close: gs://%s/%s", gcs_file->bucket.c_str(), + gcs_file->object.c_str()); if (gcs_file->sync_need) { Flush(file, status); } @@ -428,6 +446,8 @@ GCSFile::GCSFile(google::cloud::storage::Client&& gcs_client) if (absl::SimpleAtoi(std::getenv(kMaxStaleness), &value)) { max_staleness = value; } + TF_VLog(1, "GCS cache max size = %u ; block size = %u ; max staleness = %u", + max_bytes, block_size, max_staleness); file_block_cache = std::make_unique( block_size, max_bytes, max_staleness, @@ -511,6 +531,10 @@ static void UncachedStatForObject(const std::string& bucket, stat->base.mtime_nsec = metadata->time_storage_class_updated().time_since_epoch().count(); stat->base.is_directory = object.back() == '/'; + TF_VLog(1, + "Stat of: gs://%s/%s -- length: %u generation: %u; mtime_nsec: %u;", + bucket.c_str(), object.c_str(), stat->base.length, + stat->generation_number, stat->base.mtime_nsec); return TF_SetStatus(status, TF_OK, ""); } @@ -545,9 +569,10 @@ void NewRandomAccessFile(const TF_Filesystem* filesystem, const char* path, if (TF_GetCode(status) != TF_OK) return -1; if (!gcs_file->file_block_cache->ValidateAndUpdateFileSignature( path, stat.generation_number)) { - std::cout - << "File signature has been changed. Refreshing the cache. Path: " - << path; + TF_VLog( + 1, + "File signature has been changed. Refreshing the cache. Path: %s", + path.c_str()); } read = gcs_file->file_block_cache->Read(path, offset, n, buffer, status); } else { @@ -579,6 +604,7 @@ void NewWritableFile(const TF_Filesystem* filesystem, const char* path, (gcs_file->compose ? 0 : -1)}); // We are responsible for freeing the pointer returned by TF_GetTempFileName free(temp_file_name); + TF_VLog(3, "GcsWritableFile: %s", path); TF_SetStatus(status, TF_OK, ""); } @@ -624,7 +650,8 @@ void NewAppendableFile(const TF_Filesystem* filesystem, const char* path, return; } } - + TF_VLog(3, "GcsWritableFile: %s with existing file %s", path, + temp_file_name.c_str()); TF_SetStatus(status, TF_OK, ""); } @@ -812,6 +839,10 @@ void CreateDir(const TF_Filesystem* filesystem, const char* path, TF_Status* status) { std::string dir = path; MaybeAppendSlash(&dir); + TF_VLog(3, + "CreateDir: creating directory with path: %s and " + "path_with_slash: %s", + path, dir.c_str()); std::string bucket, object; ParseGCSPath(dir, true, &bucket, &object, status); if (TF_GetCode(status) != TF_OK) return; @@ -826,8 +857,11 @@ void CreateDir(const TF_Filesystem* filesystem, const char* path, } PathExists(filesystem, dir.c_str(), status); - if (TF_GetCode(status) == TF_OK) + if (TF_GetCode(status) == TF_OK) { + // Use the original name for a correct error here. + TF_VLog(3, "CreateDir: directory already exists, not uploading %s", path); return TF_SetStatus(status, TF_ALREADY_EXISTS, path); + } auto metadata = gcs_file->gcs_client.InsertObject( bucket, object, "", @@ -933,6 +967,7 @@ bool IsDirectory(const TF_Filesystem* filesystem, const char* path, static void RenameObject(const TF_Filesystem* filesystem, const std::string& src, const std::string& dst, TF_Status* status) { + TF_VLog(3, "RenameObject: started %s to %s", src.c_str(), dst.c_str()); std::string bucket_src, object_src; ParseGCSPath(src, false, &bucket_src, &object_src, status); if (TF_GetCode(status) != TF_OK) return; @@ -946,6 +981,7 @@ static void RenameObject(const TF_Filesystem* filesystem, bucket_src, object_src, bucket_dst, object_dst); TF_SetStatusFromGCSStatus(metadata.status(), status); if (TF_GetCode(status) != TF_OK) return; + TF_VLog(3, "RenameObject: finished %s to %s", src.c_str(), dst.c_str()); ClearFileCaches(gcs_file, dst); DeleteFile(filesystem, src.c_str(), status); diff --git a/tensorflow/c/experimental/saved_model/core/concrete_function.h b/tensorflow/c/experimental/saved_model/core/concrete_function.h index 934fa6d2bda..48a20ef7768 100644 --- a/tensorflow/c/experimental/saved_model/core/concrete_function.h +++ b/tensorflow/c/experimental/saved_model/core/concrete_function.h @@ -43,8 +43,8 @@ class ConcreteFunction { virtual ~ConcreteFunction() = default; // This method returns the "Call" Op used to execute the function. - virtual Status GetCallOp(absl::Span inputs, - ImmediateOpPtr* out) = 0; + virtual Status MakeCallOp(absl::Span inputs, + ImmediateOpPtr* out) const = 0; virtual const FunctionMetadata& GetFunctionMetadata() const = 0; }; diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/BUILD b/tensorflow/c/experimental/saved_model/core/revived_types/BUILD index 2b883618c87..25cac39daa0 100644 --- a/tensorflow/c/experimental/saved_model/core/revived_types/BUILD +++ b/tensorflow/c/experimental/saved_model/core/revived_types/BUILD @@ -28,6 +28,26 @@ cc_library( ], ) +cc_library( + name = "flat_tensor_function", + srcs = [ + "flat_tensor_function.cc", + ], + hdrs = [ + "flat_tensor_function.h", + ], + deps = [ + "//tensorflow/c/eager:abstract_tensor_handle", + "//tensorflow/c/eager:immediate_execution_context", + "//tensorflow/c/eager:immediate_execution_operation", + "//tensorflow/c/eager:immediate_execution_tensor_handle", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/common_runtime/eager:context", + "@com_google_absl//absl/types:span", + ], +) + cc_library( name = "variable", srcs = [ @@ -68,7 +88,7 @@ cc_library( "tf_concrete_function.h", ], deps = [ - ":tensorhandle_convertible", + ":flat_tensor_function", "//tensorflow/c/eager:abstract_tensor_handle", "//tensorflow/c/eager:immediate_execution_context", "//tensorflow/c/eager:immediate_execution_operation", @@ -81,3 +101,26 @@ cc_library( "@com_google_absl//absl/types:span", ], ) + +cc_library( + name = "tf_signature_def_function", + srcs = [ + "tf_signature_def_function.cc", + ], + hdrs = [ + "tf_signature_def_function.h", + ], + deps = [ + ":flat_tensor_function", + "//tensorflow/c/eager:abstract_tensor_handle", + "//tensorflow/c/eager:immediate_execution_context", + "//tensorflow/c/eager:immediate_execution_operation", + "//tensorflow/c/eager:immediate_execution_tensor_handle", + "//tensorflow/c/experimental/saved_model/core:signature_def_function", + "//tensorflow/c/experimental/saved_model/core:signature_def_function_metadata", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/common_runtime/eager:context", + "@com_google_absl//absl/types:span", + ], +) diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.cc b/tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.cc new file mode 100644 index 00000000000..ad9f896f43d --- /dev/null +++ b/tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.cc @@ -0,0 +1,85 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.h" + +#include +#include + +#include "absl/types/span.h" +#include "tensorflow/c/eager/abstract_tensor_handle.h" +#include "tensorflow/c/eager/immediate_execution_operation.h" +#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" +#include "tensorflow/core/common_runtime/eager/context.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/protobuf/saved_object_graph.pb.h" +#include "tensorflow/core/protobuf/struct.pb.h" + +namespace tensorflow { + +FlatTensorFunction::FlatTensorFunction( + const std::string& name, + std::vector captures, + ImmediateExecutionContext* ctx) + : name_(name), captures_(std::move(captures)), ctx_(ctx) {} + +FlatTensorFunction::~FlatTensorFunction() { + Status status = ctx_->RemoveFunction(name_); + if (!status.ok()) { + LOG(ERROR) << "Failed to remove functiondef " << name_ << ". " + << status.error_message(); + } +} + +Status FlatTensorFunction::Create( + const FunctionDef* function_def, + std::vector captures, + ImmediateExecutionContext* ctx, std::unique_ptr* out) { + TF_RETURN_IF_ERROR(ctx->AddFunctionDef(*function_def)); + out->reset(new FlatTensorFunction(function_def->signature().name(), + std::move(captures), ctx)); + return Status(); +} + +Status FlatTensorFunction::MakeCallOp( + absl::Span inputs, ImmediateOpPtr* out) const { + out->reset(ctx_->CreateOperation()); + // In eager mode, TF2 python executes functions by constructing an op with + // the name of the functiondef: + // https://github.com/tensorflow/tensorflow/blob/66668ec0ca432e2f38a575b814f45b6d299d01ed/tensorflow/python/eager/function.py#L545 + // In graph mode, we create a PartitionedCallOp instead: + // https://github.com/tensorflow/tensorflow/blob/66668ec0ca432e2f38a575b814f45b6d299d01ed/tensorflow/python/eager/function.py#L573 + + // TODO(bmzhao): After discussing with Allen, we should execute this via a + // PartitionedCallOp for compatibility with "tooling that assumes functions in + // graphs are PartitionedCallOps". + TF_RETURN_IF_ERROR((*out)->Reset(name_.c_str(), nullptr)); + + // Adding the user-provided inputs to the function. + TF_RETURN_IF_ERROR((*out)->AddInputList(inputs)); + + absl::Span captures( + reinterpret_cast(captures_.data()), + captures_.size()); + + // Adding the captures of the function. + TF_RETURN_IF_ERROR((*out)->AddInputList(captures)); + return Status(); +} + +} // namespace tensorflow diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.h b/tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.h new file mode 100644 index 00000000000..e6bcdec7e3a --- /dev/null +++ b/tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.h @@ -0,0 +1,84 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_FLAT_TENSOR_FUNCTION_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_FLAT_TENSOR_FUNCTION_H_ + +#include +#include +#include +#include + +#include "tensorflow/c/eager/immediate_execution_context.h" +#include "tensorflow/c/eager/immediate_execution_operation.h" +#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/protobuf/saved_object_graph.pb.h" + +namespace tensorflow { + +// FlatTensorFunction models a TF2 eager runtime view of a callable function, +// taking + returning flat lists of tensors, including any captures. +// Effectively, it is a thin wrapper around a FunctionDef owned by the +// EagerContext, and any TensorHandle captures associated with the function. The +// MakeCallOp method handles the logic of marshaling captures after the user +// provided inputs automatically. +// Note(bmzhao): This class is mainly intended to house low-level reusable +// function logic between SignatureDefFunction and ConcreteFunction, which +// present higher level interfaces. This type does *not* hold any "function +// metadata". +class FlatTensorFunction { + public: + // Factory for creating a FlatTensorFunction. + // + // Params: + // function_def - The function_def associated with the created + // FlatTensorFunction. FlatTensorFunction will register this + // function_def with `ctx` on creation, and de-register it on + // destruction. function_def must be non-null, but + // otherwise has no lifetime requirements. + // captures - The captured TensorHandles associated with this + // FlatTensorFunction. + // ctx - A handle to the Tensorflow runtime. This MUST be non-null and + // outlive TFConcreteFunction. + // out - The output FlatTensorFunction. + static Status Create(const FunctionDef* function_def, + std::vector captures, + ImmediateExecutionContext* ctx, + std::unique_ptr* out); + + // This method creates a "Call" Op used to execute the function. + Status MakeCallOp(absl::Span inputs, + ImmediateOpPtr* out) const; + + ~FlatTensorFunction(); + + private: + FlatTensorFunction(const std::string& name, + std::vector captures, + ImmediateExecutionContext* ctx); + + FlatTensorFunction(const FlatTensorFunction&) = delete; + FlatTensorFunction& operator=(const FlatTensorFunction&) = delete; + + // Name of the FunctionDef corresponding to this TFConcreteFunction + std::string name_; + std::vector captures_; + ImmediateExecutionContext* ctx_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_FLAT_TENSOR_FUNCTION_H_ diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.cc b/tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.cc index f734f9eca66..d9773a4520f 100644 --- a/tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.cc +++ b/tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/c/eager/abstract_tensor_handle.h" #include "tensorflow/c/eager/immediate_execution_operation.h" #include "tensorflow/c/eager/immediate_execution_tensor_handle.h" -#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.h" #include "tensorflow/core/common_runtime/eager/context.h" #include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/platform/errors.h" @@ -33,32 +33,20 @@ limitations under the License. namespace tensorflow { -TFConcreteFunction::TFConcreteFunction( - const std::string& name, - std::vector captures, - FunctionMetadata metadata, ImmediateExecutionContext* ctx) - : name_(name), - captures_(std::move(captures)), - metadata_(std::move(metadata)), - ctx_(ctx) {} - -TFConcreteFunction::~TFConcreteFunction() { - Status status = ctx_->RemoveFunction(name_); - if (!status.ok()) { - LOG(ERROR) << "Failed to remove functiondef " << name_ << ". " - << status.error_message(); - } -} +TFConcreteFunction::TFConcreteFunction(std::unique_ptr func, + FunctionMetadata metadata) + : func_(std::move(func)), metadata_(std::move(metadata)) {} Status TFConcreteFunction::Create( const FunctionDef* function_def, std::vector captures, FunctionMetadata metadata, ImmediateExecutionContext* ctx, std::unique_ptr* out) { - TF_RETURN_IF_ERROR(ctx->AddFunctionDef(*function_def)); - out->reset(new TFConcreteFunction(function_def->signature().name(), - std::move(captures), std::move(metadata), - ctx)); + std::unique_ptr func; + TF_RETURN_IF_ERROR(FlatTensorFunction::Create( + function_def, std::move(captures), ctx, &func)); + + out->reset(new TFConcreteFunction(std::move(func), std::move(metadata))); return Status(); } @@ -66,30 +54,9 @@ const FunctionMetadata& TFConcreteFunction::GetFunctionMetadata() const { return metadata_; } -Status TFConcreteFunction::GetCallOp( - absl::Span inputs, ImmediateOpPtr* out) { - out->reset(ctx_->CreateOperation()); - // In eager mode, TF2 python executes functions by constructing an op with - // the name of the functiondef: - // https://github.com/tensorflow/tensorflow/blob/66668ec0ca432e2f38a575b814f45b6d299d01ed/tensorflow/python/eager/function.py#L545 - // In graph mode, we create a PartitionedCallOp instead: - // https://github.com/tensorflow/tensorflow/blob/66668ec0ca432e2f38a575b814f45b6d299d01ed/tensorflow/python/eager/function.py#L573 - - // TODO(bmzhao): After discussing with Allen, we should execute this via a - // PartitionedCallOp for compatibility with "tooling that assumes functions in - // graphs are PartitionedCallOps". - TF_RETURN_IF_ERROR((*out)->Reset(name_.c_str(), nullptr)); - - // Adding the user-provided inputs to the function. - TF_RETURN_IF_ERROR((*out)->AddInputList(inputs)); - - absl::Span captures( - reinterpret_cast(captures_.data()), - captures_.size()); - - // Adding the captures of the function. - TF_RETURN_IF_ERROR((*out)->AddInputList(captures)); - return Status(); +Status TFConcreteFunction::MakeCallOp( + absl::Span inputs, ImmediateOpPtr* out) const { + return func_->MakeCallOp(inputs, out); } } // namespace tensorflow diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h b/tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h index d38f3546f91..edc26f4d5aa 100644 --- a/tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h +++ b/tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h @@ -27,7 +27,7 @@ limitations under the License. #include "tensorflow/c/eager/immediate_execution_tensor_handle.h" #include "tensorflow/c/experimental/saved_model/core/concrete_function.h" #include "tensorflow/c/experimental/saved_model/core/function_metadata.h" -#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.h" #include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/protobuf/saved_object_graph.pb.h" @@ -58,26 +58,22 @@ class TFConcreteFunction : public ConcreteFunction { std::unique_ptr* out); // This method returns the "Call" Op used to execute the function. - Status GetCallOp(absl::Span inputs, - ImmediateOpPtr* out) override; + Status MakeCallOp(absl::Span inputs, + ImmediateOpPtr* out) const override; const FunctionMetadata& GetFunctionMetadata() const override; - ~TFConcreteFunction() override; + ~TFConcreteFunction() override = default; private: - TFConcreteFunction(const std::string& name, - std::vector captures, - FunctionMetadata metadata, ImmediateExecutionContext* ctx); + TFConcreteFunction(std::unique_ptr func, + FunctionMetadata metadata); TFConcreteFunction(const TFConcreteFunction&) = delete; TFConcreteFunction& operator=(const TFConcreteFunction&) = delete; - // Name of the FunctionDef corresponding to this TFConcreteFunction - std::string name_; - std::vector captures_; + std::unique_ptr func_; FunctionMetadata metadata_; - ImmediateExecutionContext* ctx_; }; } // namespace tensorflow diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function.cc b/tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function.cc new file mode 100644 index 00000000000..ab1745dcd47 --- /dev/null +++ b/tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function.cc @@ -0,0 +1,64 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function.h" + +#include +#include + +#include "absl/types/span.h" +#include "tensorflow/c/eager/abstract_tensor_handle.h" +#include "tensorflow/c/eager/immediate_execution_operation.h" +#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.h" +#include "tensorflow/core/common_runtime/eager/context.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/protobuf/saved_object_graph.pb.h" +#include "tensorflow/core/protobuf/struct.pb.h" + +namespace tensorflow { + +TFSignatureDefFunction::TFSignatureDefFunction( + std::unique_ptr func, + SignatureDefFunctionMetadata metadata) + : func_(std::move(func)), metadata_(std::move(metadata)) {} + +Status TFSignatureDefFunction::Create( + const FunctionDef* function_def, + std::vector captures, + SignatureDefFunctionMetadata metadata, ImmediateExecutionContext* ctx, + std::unique_ptr* out) { + std::unique_ptr func; + TF_RETURN_IF_ERROR(FlatTensorFunction::Create( + function_def, std::move(captures), ctx, &func)); + + out->reset(new TFSignatureDefFunction(std::move(func), std::move(metadata))); + return Status(); +} + +const SignatureDefFunctionMetadata& +TFSignatureDefFunction::GetFunctionMetadata() const { + return metadata_; +} + +Status TFSignatureDefFunction::MakeCallOp( + absl::Span inputs, ImmediateOpPtr* out) const { + return func_->MakeCallOp(inputs, out); +} + +} // namespace tensorflow diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function.h b/tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function.h new file mode 100644 index 00000000000..7b564185b8b --- /dev/null +++ b/tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function.h @@ -0,0 +1,85 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TF_SIGNATURE_DEF_FUNCTION_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TF_SIGNATURE_DEF_FUNCTION_H_ + +#include +#include +#include +#include +#include + +#include "tensorflow/c/eager/immediate_execution_context.h" +#include "tensorflow/c/eager/immediate_execution_operation.h" +#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.h" +#include "tensorflow/c/experimental/saved_model/core/signature_def_function.h" +#include "tensorflow/c/experimental/saved_model/core/signature_def_function_metadata.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/protobuf/saved_object_graph.pb.h" + +namespace tensorflow { + +// This is the TF eager runtime implementation of SignatureDefFunction (separate +// from the TFRT implementation). The user-facing API of SignatureDefFunctions +// and their semantic differences from ConcreteFunction are described here: +// https://github.com/tensorflow/tensorflow/blob/e2db60c9d9598ebae0b7741587ce6f5d473584d9/tensorflow/cc/saved_model/experimental/public/signature_def_function.h#L30-L59 +// Additional implementation notes are available here: +// https://github.com/tensorflow/tensorflow/blob/e2db60c9d9598ebae0b7741587ce6f5d473584d9/tensorflow/c/experimental/saved_model/core/signature_def_function.h#L31-L48 +class TFSignatureDefFunction : public SignatureDefFunction { + public: + // Factory function for creating a TFSignatureDefFunction. + // + // Params: + // function_def - The function_def associated with the created + // TFSignatureDefFunction. TFSignatureDefFunction will + // register this function_def with `ctx` on creation, and + // de-register it on destruction. function_def must be + // non-null, but otherwise has no lifetime requirements. + // captures - The captured TensorHandles associated with this + // TFConcreteFunction. + // metadata - FunctionMetadata associated with this TFSignatureDefFunction. + // ctx - A handle to the Tensorflow runtime. This MUST be non-null and + // outlive TFSignatureDefFunction. + // out - The output TFSignatureDefFunction. + static Status Create(const FunctionDef* function_def, + std::vector captures, + SignatureDefFunctionMetadata metadata, + ImmediateExecutionContext* ctx, + std::unique_ptr* out); + + // This method creates a "Call" Op used to execute the function. + Status MakeCallOp(absl::Span inputs, + ImmediateOpPtr* out) const override; + + const SignatureDefFunctionMetadata& GetFunctionMetadata() const override; + + ~TFSignatureDefFunction() override = default; + + private: + TFSignatureDefFunction(std::unique_ptr func, + SignatureDefFunctionMetadata metadata); + + TFSignatureDefFunction(const TFSignatureDefFunction&) = delete; + TFSignatureDefFunction& operator=(const TFSignatureDefFunction&) = delete; + + std::unique_ptr func_; + SignatureDefFunctionMetadata metadata_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TF_SIGNATURE_DEF_FUNCTION_H_ diff --git a/tensorflow/c/experimental/saved_model/internal/concrete_function.cc b/tensorflow/c/experimental/saved_model/internal/concrete_function.cc index 65c6eca5623..2beed8f4119 100644 --- a/tensorflow/c/experimental/saved_model/internal/concrete_function.cc +++ b/tensorflow/c/experimental/saved_model/internal/concrete_function.cc @@ -34,15 +34,15 @@ TF_FunctionMetadata* TF_ConcreteFunctionGetMetadata(TF_ConcreteFunction* func) { &tensorflow::unwrap(func)->GetFunctionMetadata())); } -TFE_Op* TF_ConcreteFunctionGetCallOp(TF_ConcreteFunction* func, - TFE_TensorHandle** inputs, int num_inputs, - TF_Status* status) { +TFE_Op* TF_ConcreteFunctionMakeCallOp(TF_ConcreteFunction* func, + TFE_TensorHandle** inputs, int num_inputs, + TF_Status* status) { tensorflow::ImmediateOpPtr call_op; absl::Span input_span( reinterpret_cast( tensorflow::unwrap(inputs)), static_cast(num_inputs)); - status->status = tensorflow::unwrap(func)->GetCallOp(input_span, &call_op); + status->status = tensorflow::unwrap(func)->MakeCallOp(input_span, &call_op); if (!status->status.ok()) { return nullptr; } diff --git a/tensorflow/c/experimental/saved_model/internal/saved_model_api_test.cc b/tensorflow/c/experimental/saved_model/internal/saved_model_api_test.cc index e58b232f9c9..df998fcf6cd 100644 --- a/tensorflow/c/experimental/saved_model/internal/saved_model_api_test.cc +++ b/tensorflow/c/experimental/saved_model/internal/saved_model_api_test.cc @@ -107,7 +107,7 @@ TEST_P(CSavedModelAPITest, LoadsSavedModel) { compute_fn_inputs.push_back(input_a); compute_fn_inputs.push_back(input_b); - TFE_Op* compute_fn_op = TF_ConcreteFunctionGetCallOp( + TFE_Op* compute_fn_op = TF_ConcreteFunctionMakeCallOp( compute_fn, compute_fn_inputs.data(), compute_fn_inputs.size(), status); EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); diff --git a/tensorflow/c/experimental/saved_model/public/concrete_function.h b/tensorflow/c/experimental/saved_model/public/concrete_function.h index 0fd0f70cf16..ff8a245961a 100644 --- a/tensorflow/c/experimental/saved_model/public/concrete_function.h +++ b/tensorflow/c/experimental/saved_model/public/concrete_function.h @@ -47,7 +47,7 @@ TF_CAPI_EXPORT extern TF_FunctionMetadata* TF_ConcreteFunctionGetMetadata( // high-level API here. A strawman for what this interface could look like: // TF_Value* TF_ExecuteFunction(TFE_Context*, TF_ConcreteFunction*, TF_Value* // inputs, int num_inputs, TF_Status* status); -TF_CAPI_EXPORT extern TFE_Op* TF_ConcreteFunctionGetCallOp( +TF_CAPI_EXPORT extern TFE_Op* TF_ConcreteFunctionMakeCallOp( TF_ConcreteFunction* func, TFE_TensorHandle** inputs, int num_inputs, TF_Status* status); diff --git a/tensorflow/c/kernels/BUILD b/tensorflow/c/kernels/BUILD index 93b82b2396f..6bb2b347a30 100644 --- a/tensorflow/c/kernels/BUILD +++ b/tensorflow/c/kernels/BUILD @@ -132,6 +132,23 @@ tf_cc_test( ], ) +tf_cc_test( + name = "summary_op_benchmark_test", + size = "small", + srcs = ["summary_op_benchmark_test.cc"], + deps = [ + ":summary_op", + "//tensorflow/c:kernels", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ], +) + cc_library( name = "tensor_shape_utils", srcs = ["tensor_shape_utils.cc"], diff --git a/tensorflow/c/kernels/histogram_summary_op.cc b/tensorflow/c/kernels/histogram_summary_op.cc index 5de52703f5d..143a2675a05 100644 --- a/tensorflow/c/kernels/histogram_summary_op.cc +++ b/tensorflow/c/kernels/histogram_summary_op.cc @@ -93,11 +93,13 @@ void HistogramSummaryOp_Compute(void* kernel, TF_OpKernelContext* ctx) { std::ostringstream err; err << "Nan in summary histogram for: " << k->op_node_name; TF_SetStatus(status.get(), TF_INVALID_ARGUMENT, err.str().c_str()); + TF_OpKernelContext_Failure(ctx, status.get()); return; } else if (Eigen::numext::isinf(double_val)) { std::ostringstream err; err << "Infinity in Histogram for: " << k->op_node_name; TF_SetStatus(status.get(), TF_INVALID_ARGUMENT, err.str().c_str()); + TF_OpKernelContext_Failure(ctx, status.get()); return; } histo.Add(double_val); diff --git a/tensorflow/c/kernels/summary_op_benchmark_test.cc b/tensorflow/c/kernels/summary_op_benchmark_test.cc new file mode 100644 index 00000000000..887a86066d3 --- /dev/null +++ b/tensorflow/c/kernels/summary_op_benchmark_test.cc @@ -0,0 +1,71 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" + +namespace tensorflow { +namespace { + +Graph* BM_ScalarSummaryOp(TensorShape shape, std::string tag, float value) { + Graph* g = new Graph(OpRegistry::Global()); + Tensor tags(DT_STRING, shape); + Tensor values(DT_FLOAT, shape); + for (int i = 0; i < tags.NumElements(); ++i) { + tags.flat()(i) = tag; + values.flat()(i) = value; + } + Node* ret; + TF_CHECK_OK(NodeBuilder(g->NewName("dummy"), "ScalarSummary") + .Input(test::graph::Constant(g, tags)) + .Input(test::graph::Constant(g, values)) + .Attr("T", DT_FLOAT) + .Finalize(g, &ret)); + return g; +} + +// Macro used to parse initializer list for tensorshape +#define DIMARGS(...) \ + { __VA_ARGS__ } +// // Random parameters for testing +constexpr char longTagParam[] = "LONGTAG____________________________"; +constexpr float largeValueParam = 2352352.2623433; + +#define BM_ScalarSummaryDev(device, dims, name, tag, value) \ + void BM_ScalarSummary##name##device(int iters) { \ + testing::StopTiming(); \ + TensorShape tensorshape(DIMARGS dims); \ + auto g = BM_ScalarSummaryOp(tensorshape, #tag, value); \ + testing::StartTiming(); \ + test::Benchmark("cpu", g).Run(iters); \ + } \ + BENCHMARK(BM_ScalarSummary##name##device); + +BM_ScalarSummaryDev(Cpu, (5, 10, 100), Base, Tag, 5.2); +// Benchmark for large shapes +BM_ScalarSummaryDev(Cpu, (500, 100, 100), LargeShape, Tag, 5.2); +// Benchmark for large tag tstring +BM_ScalarSummaryDev(Cpu, (5, 10, 100), LongTag, longTagParam, 5.2); +// Benchmark for large values +BM_ScalarSummaryDev(Cpu, (500, 100, 100), LargeValue, Tag, largeValueParam); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index 35c6a8b0357..5992b45e209 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -329,6 +329,7 @@ cc_library( srcs = ["xla_compilation_cache.cc"], hdrs = ["xla_compilation_cache.h"], deps = [ + ":flags", ":xla_activity_listener", ":xla_activity_proto_cc", "//tensorflow/compiler/mlir/tensorflow:compile_mlir_util_no_tf_dialect_passes", @@ -361,8 +362,11 @@ tf_cc_test( "xla_compilation_cache_test.cc", ], deps = [ + ":flags", ":xla_compilation_cache", + ":xla_cpu_jit", "//tensorflow/compiler/tf2xla:common", + "//tensorflow/compiler/xla/client:client_library", "//tensorflow/core:test", "//tensorflow/core:test_main", ], @@ -918,6 +922,7 @@ tf_cc_test( ":xla_cpu_jit", "//tensorflow/cc:cc_ops", "//tensorflow/cc:function_ops", + "//tensorflow/cc:functional_ops", "//tensorflow/cc:ops", "//tensorflow/cc:scope", "//tensorflow/compiler/tf2xla:test_util", diff --git a/tensorflow/compiler/jit/compilability_check_util.cc b/tensorflow/compiler/jit/compilability_check_util.cc index 6d4bc51f1b2..5575320c1be 100644 --- a/tensorflow/compiler/jit/compilability_check_util.cc +++ b/tensorflow/compiler/jit/compilability_check_util.cc @@ -518,10 +518,15 @@ RecursiveCompilabilityChecker::OperationFilter CreateOperationFilter( } } +// Returns `true` iff node has a given `attr` set to `true`. Returns `false` +// both for the missing attr, and the attr set to `false`. +static bool HasBoolAttr(const NodeDef& node, const char* attr) { + const auto& it = node.attr().find(attr); + return it != node.attr().end() && it->second.b(); +} + bool CanCreateXlaKernel(const NodeDef& node_def) { - // If kXlaMustCompileAttr is set on the node_def, use its value. - const auto& it = node_def.attr().find(kXlaMustCompileAttr); - return it != node_def.attr().end() && it->second.b(); + return HasBoolAttr(node_def, kXlaMustCompileAttr); } Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr, @@ -564,4 +569,58 @@ Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr, return Status::OK(); } +static auto const ops_triggering_xla_compilation = + new absl::flat_hash_set{"XlaBroadcastHelper", + "XlaConv", + "XlaDequantize", + "XlaDot", + "XlaDynamicSlice", + "XlaDynamicUpdateSlice", + "XlaEinsum", + "XlaGather", + "XlaIf", + "XlaKeyValueSort", + "XlaPad", + "XlaRecv", + "XlaReduce", + "XlaReduceWindow", + "XlaReplicaId", + "XlaScatter", + "XlaSelectAndScatter", + "XlaSelfAdjointEig", + "XlaSend", + "XlaSharding", + "XlaSort", + "XlaSpmdFullToShardShape", + "XlaSpmdShardToFullShape", + "XlaSvd", + "XlaWhile"}; + +static bool NodeCanTriggerXlaCompilation(const NodeDef& node) { + return node.attr().find(kXlaClusterIdAttr) != node.attr().end() || + HasBoolAttr(node, kXlaMustCompileAttr) || + HasBoolAttr(node, kXlaCompileAttr) || + HasBoolAttr(node, kXlaScopeAttr) || + HasBoolAttr(node, kXlaInternalScopeAttr) || + ops_triggering_xla_compilation->count(node.op()); +} + +bool CanTriggerXlaCompilation(const GraphDef& graph) { + for (const FunctionDef& function : graph.library().function()) { + for (const NodeDef& node : function.node_def()) { + if (NodeCanTriggerXlaCompilation(node)) { + return true; + } + } + } + + for (const NodeDef& node : graph.node()) { + if (NodeCanTriggerXlaCompilation(node)) { + return true; + } + } + + return false; +} + } // namespace tensorflow diff --git a/tensorflow/compiler/jit/compilability_check_util.h b/tensorflow/compiler/jit/compilability_check_util.h index 3b20784cc29..384367c33c6 100644 --- a/tensorflow/compiler/jit/compilability_check_util.h +++ b/tensorflow/compiler/jit/compilability_check_util.h @@ -126,9 +126,10 @@ class RecursiveCompilabilityChecker { bool allow_inaccurate_ops = false; }; - RecursiveCompilabilityChecker(const OperationFilter* op_filter, - const DeviceType* jit_device_type) - : op_filter_(*op_filter), jit_device_type_(*jit_device_type) {} + RecursiveCompilabilityChecker(OperationFilter op_filter, + DeviceType jit_device_type) + : op_filter_(std::move(op_filter)), + jit_device_type_(std::move(jit_device_type)) {} using UncompilableNodesMap = std::map(&op_filter_, - &device_type_); + checker_ = absl::make_unique(op_filter_, + device_type_); } FunctionLibraryRuntime* GetFunctionLibraryRuntime() { @@ -354,5 +355,110 @@ TEST_F(CompilabilityCheckUtilTest, CheckFunctionalIfNode) { "unsupported op")); } +TEST_F(CompilabilityCheckUtilTest, TestCanNotTriggerXlaCompilation) { + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + Scope root = Scope::NewRootScope().ExitOnError(); + FunctionDefLibrary library; + + FunctionDef identity_func = FunctionDefHelper::Create( + "IdentityFunc", + /*in_def=*/{"x:float"}, + /*out_def=*/{"res:float"}, + /*attr_def=*/{}, + /*node_def=*/{{{"t0"}, "Identity", {"x"}, {{"T", DT_FLOAT}}}}, + /*ret_def*/ {{"res", "t0:output"}}); + + *library.add_function() = identity_func; + + Output in = ops::Placeholder(root, DT_FLOAT); + NameAttrList b_name_attr; + b_name_attr.set_name("IdentityFunc"); + ops::PartitionedCall call(root.WithOpName("call"), {in}, {DT_FLOAT}, + b_name_attr); + + GraphDef graph_def; + TF_ASSERT_OK(root.graph()->AddFunctionLibrary(library)); + TF_ASSERT_OK(root.ToGraphDef(&graph_def)); + + EXPECT_FALSE(CanTriggerXlaCompilation(graph_def)); +} + +TEST_F(CompilabilityCheckUtilTest, TestXlaOpsCanTriggerXlaCompilation) { + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + Scope root = Scope::NewRootScope().ExitOnError(); + FunctionDefLibrary library; + + FunctionDef sort_func = FunctionDefHelper::Create( + "SortFunc", + /*in_def=*/{"x:float"}, + /*out_def=*/{"res:float"}, + /*attr_def=*/{}, + /*node_def=*/{{{"t0"}, "XlaSort", {"x"}, {{"T", DT_FLOAT}}}}, + /*ret_def*/ {{"res", "t0:output"}}); + + *library.add_function() = sort_func; + + Output in = ops::Placeholder(root, DT_FLOAT); + NameAttrList b_name_attr; + b_name_attr.set_name("SortFunc"); + ops::PartitionedCall call(root.WithOpName("call"), {in}, {DT_FLOAT}, + b_name_attr); + + GraphDef graph_def; + TF_ASSERT_OK(root.graph()->AddFunctionLibrary(library)); + TF_ASSERT_OK(root.ToGraphDef(&graph_def)); + + EXPECT_TRUE(CanTriggerXlaCompilation(graph_def)); +} + +TEST_F(CompilabilityCheckUtilTest, TestCanTriggerXlaCompilation) { + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + Scope root = Scope::NewRootScope().ExitOnError(); + FunctionDefLibrary library; + + AttrValue true_attribute; + true_attribute.set_b(true); + + FunctionDef identity_func = FunctionDefHelper::Create( + "IdentityFunc", + /*in_def=*/{"x:float"}, + /*out_def=*/{"res:float"}, + /*attr_def=*/{}, + /*node_def=*/{{{"t0"}, "Identity", {"x"}, {{"T", DT_FLOAT}}}}, + /*ret_def*/ {{"res", "t0:output"}}); + + (*identity_func.mutable_attr())[kXlaMustCompileAttr] = true_attribute; + + FunctionDef call_identity = FunctionDefHelper::Create( + "CallIdentity", + /*in_def=*/{"x:float"}, + /*out_def=*/{"z:float"}, /*attr_def=*/{}, + /*node_def=*/ + {{{"func_call"}, + "PartitionedCall", + {"x"}, + {{"Tin", DataTypeSlice({DT_FLOAT})}, + {"Tout", DataTypeSlice({DT_FLOAT})}, + {"f", + FunctionDefHelper::FunctionRef("IdentityRef", {{"T", DT_FLOAT}})}, + {kXlaMustCompileAttr, true}}}}, + /*ret_def=*/{{"z", "func_call:output:0"}}); + + *library.add_function() = identity_func; + *library.add_function() = call_identity; + + Output in = ops::Placeholder(root, DT_FLOAT); + NameAttrList b_name_attr; + b_name_attr.set_name("CallIdentity"); + ops::PartitionedCall call(root.WithOpName("call"), {in}, {DT_FLOAT}, + b_name_attr); + + GraphDef graph_def; + TF_ASSERT_OK(root.graph()->AddFunctionLibrary(library)); + TF_ASSERT_OK(root.ToGraphDef(&graph_def)); + + EXPECT_TRUE(CanTriggerXlaCompilation(graph_def)); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/jit/defs.cc b/tensorflow/compiler/jit/defs.cc index 4bea71e8fc1..84e1e36bcf6 100644 --- a/tensorflow/compiler/jit/defs.cc +++ b/tensorflow/compiler/jit/defs.cc @@ -28,4 +28,6 @@ const char* const kXlaScopeAttr = "_XlaScope"; // only when auto_jit is ON. const char* const kXlaInternalScopeAttr = "_XlaInternalScope"; +const char* const kXlaClusterIdAttr = "_xla_compile_id"; + } // namespace tensorflow diff --git a/tensorflow/compiler/jit/defs.h b/tensorflow/compiler/jit/defs.h index 9eb4c2ca2e8..fa983db8df8 100644 --- a/tensorflow/compiler/jit/defs.h +++ b/tensorflow/compiler/jit/defs.h @@ -35,6 +35,9 @@ extern const char* const kXlaCompileAttr; // "_XlaCompile" extern const char* const kXlaScopeAttr; // "_XlaScope" extern const char* const kXlaInternalScopeAttr; // "_XlaInternalScope" +// The id of the compiled cluster. +extern const char* const kXlaClusterIdAttr; // "_xla_compile_id" + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_JIT_DEFS_H_ diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc index ed25baa62ff..4a5c79c02d9 100644 --- a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc @@ -20,6 +20,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/strings/ascii.h" #include "absl/strings/str_cat.h" +#include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/framework/node_def.pb.h" @@ -34,9 +35,6 @@ limitations under the License. namespace tensorflow { -const char* const EncapsulateXlaComputationsPass::kXlaClusterAttr = - "_xla_compile_id"; - namespace { const char* const kXlaClusterOutput = "XlaClusterOutput"; @@ -45,10 +43,7 @@ bool IsCpuGpuCompile(const Graph* graph) { for (Node* n : graph->nodes()) { string name; // Only consider nodes being compiled. - if (!GetNodeAttr(n->attrs(), - EncapsulateXlaComputationsPass::kXlaClusterAttr, &name) - .ok()) - continue; + if (!GetNodeAttr(n->attrs(), kXlaClusterIdAttr, &name).ok()) continue; // Early return for any node with a device that is not a CPU or GPU. DeviceNameUtils::ParsedName parsed; if (DeviceNameUtils::ParseFullName(n->requested_device(), &parsed)) { @@ -180,8 +175,7 @@ Status RewriteSubgraph(const std::vector& arg_source_tensors, retvals[i]->AddAttr("index", i); } - AddNodeAttr(EncapsulateXlaComputationsPass::kXlaClusterAttr, call_def->name(), - call_def); + AddNodeAttr(kXlaClusterIdAttr, call_def->name(), call_def); AddNodeAttr("_variable_start_index", variable_start_index, call_def); // Uniquify the function name. @@ -216,8 +210,8 @@ Status RewriteSubgraph(const std::vector& arg_source_tensors, // O(n) pass over the edges. for (const Edge* e : (*graph)->edges()) { if (!e->IsControlEdge() && - e->src()->attrs().Find(kXlaClusterAttr) != nullptr && - e->dst()->attrs().Find(kXlaClusterAttr) == nullptr && + e->src()->attrs().Find(kXlaClusterIdAttr) != nullptr && + e->dst()->attrs().Find(kXlaClusterIdAttr) == nullptr && e->dst()->type_string() != kXlaClusterOutput) { return errors::InvalidArgument( "Undeclared output of XLA computation. Some common causes of this " @@ -232,9 +226,9 @@ Status RewriteSubgraph(const std::vector& arg_source_tensors, auto output = absl::make_unique((*graph)->op_registry()); TF_RETURN_WITH_CONTEXT_IF_ERROR( - EncapsulateSubgraphsInFunctions(kXlaClusterAttr, **graph, RewriteSubgraph, - /*reuse_existing_functions=*/true, - &output, flib_def), + EncapsulateSubgraphsInFunctions( + kXlaClusterIdAttr, **graph, RewriteSubgraph, + /*reuse_existing_functions=*/true, &output, flib_def), "EncapsulateXlaComputationsPass failed"); graph->swap(output); return Status::OK(); @@ -246,7 +240,7 @@ Status RewriteSubgraph(const std::vector& arg_source_tensors, // while iterating. std::vector launch_nodes; for (Node* n : graph->nodes()) { - const string& name = GetNodeAttrString(n->attrs(), kXlaClusterAttr); + const string& name = GetNodeAttrString(n->attrs(), kXlaClusterIdAttr); if (!name.empty()) { launch_nodes.push_back(n); } diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.h b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.h index 3057e4c7469..9931b23fa41 100644 --- a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.h +++ b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.h @@ -34,8 +34,6 @@ namespace tensorflow { // XlaLaunch operators. class EncapsulateXlaComputationsPass : public GraphOptimizationPass { public: - static const char* const kXlaClusterAttr; // _xla_compile_id - Status Run(const GraphOptimizationPassOptions& options) override; // The following methods are public only for unit tests. diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc b/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc index cc177036591..61c9a3ff9c0 100644 --- a/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc +++ b/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/cc/ops/function_ops.h" #include "tensorflow/cc/ops/resource_variable_ops.h" #include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" #include "tensorflow/compiler/tf2xla/cc/ops/xla_jit_ops.h" #include "tensorflow/compiler/tf2xla/test_util.h" @@ -46,19 +47,18 @@ static std::unique_ptr MakeOuterGraph( auto w = ops::Placeholder(scope.WithOpName("W"), DT_RESOURCE); NodeDef def; - TF_CHECK_OK( - NodeDefBuilder("launch0", function, &flib_def) - .Input(a.node()->name(), 0, DT_INT32) - .Input(b.node()->name(), 0, DT_FLOAT) - .Input(c.node()->name(), 0, DT_INT32) - .Input(d.node()->name(), 0, DT_FLOAT) - .Input(u.node()->name(), 0, DT_RESOURCE) - .Input(v.node()->name(), 0, DT_RESOURCE) - .Input(w.node()->name(), 0, DT_RESOURCE) - .Device("/gpu:0") - .Attr(EncapsulateXlaComputationsPass::kXlaClusterAttr, "launch0") - .Attr("_variable_start_index", 4) - .Finalize(&def)); + TF_CHECK_OK(NodeDefBuilder("launch0", function, &flib_def) + .Input(a.node()->name(), 0, DT_INT32) + .Input(b.node()->name(), 0, DT_FLOAT) + .Input(c.node()->name(), 0, DT_INT32) + .Input(d.node()->name(), 0, DT_FLOAT) + .Input(u.node()->name(), 0, DT_RESOURCE) + .Input(v.node()->name(), 0, DT_RESOURCE) + .Input(w.node()->name(), 0, DT_RESOURCE) + .Device("/gpu:0") + .Attr(kXlaClusterIdAttr, "launch0") + .Attr("_variable_start_index", 4) + .Finalize(&def)); Status status; Node* launch = scope.graph()->AddNode(def, &status); @@ -107,7 +107,7 @@ static std::unique_ptr MakeBodyGraph() { auto arg6 = ops::_Arg(scope.WithOpName("w_0_arg"), DT_RESOURCE, 6); auto add_attrs = [](Node* node) { - node->AddAttr(EncapsulateXlaComputationsPass::kXlaClusterAttr, "launch0"); + node->AddAttr(kXlaClusterIdAttr, "launch0"); node->set_requested_device("/gpu:0"); }; @@ -155,8 +155,7 @@ TEST(EncapsulateXlaComputations, DeterministicEncapsulate) { : ops::Add(scope.WithOpName("E"), a1, a0); auto add_attrs = [](Node* node) { - node->AddAttr(EncapsulateXlaComputationsPass::kXlaClusterAttr, - "launch0"); + node->AddAttr(kXlaClusterIdAttr, "launch0"); }; add_attrs(e.node()); @@ -216,7 +215,7 @@ TEST(EncapsulateXlaComputations, Encapsulate) { auto w = ops::Placeholder(scope.WithOpName("W"), DT_RESOURCE); auto add_attrs = [](Node* node) { - node->AddAttr(EncapsulateXlaComputationsPass::kXlaClusterAttr, "launch0"); + node->AddAttr(kXlaClusterIdAttr, "launch0"); node->set_requested_device("/gpu:0"); }; diff --git a/tensorflow/compiler/jit/flags.cc b/tensorflow/compiler/jit/flags.cc index a4a750bae0d..ee7daf092da 100644 --- a/tensorflow/compiler/jit/flags.cc +++ b/tensorflow/compiler/jit/flags.cc @@ -268,4 +268,10 @@ void AppendMarkForCompilationPassFlags(std::vector* flag_list) { AppendMarkForCompilationPassFlagsInternal(flag_list); } +static std::atomic xla_compilation_disabled(false); + +void DisableXlaCompilation() { xla_compilation_disabled = true; } + +bool FailOnXlaCompilation() { return xla_compilation_disabled; } + } // namespace tensorflow diff --git a/tensorflow/compiler/jit/flags.h b/tensorflow/compiler/jit/flags.h index 6c54fc8825e..5612b3b5864 100644 --- a/tensorflow/compiler/jit/flags.h +++ b/tensorflow/compiler/jit/flags.h @@ -162,6 +162,13 @@ MlirCommonFlags* GetMlirCommonFlags(); void AppendMarkForCompilationPassFlags( std::vector* flag_list); +// Disables XLA compilation, forces it to return an error message instead. Can +// be used by a server to ensure that JIT compilation is opt-in. +void DisableXlaCompilation(); + +// Returns `false` unless `DisableXlaCompilation` was called. +bool FailOnXlaCompilation(); + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_JIT_FLAGS_H_ diff --git a/tensorflow/compiler/jit/kernels/xla_ops.cc b/tensorflow/compiler/jit/kernels/xla_ops.cc index de462928c46..79f1e47d98b 100644 --- a/tensorflow/compiler/jit/kernels/xla_ops.cc +++ b/tensorflow/compiler/jit/kernels/xla_ops.cc @@ -158,7 +158,7 @@ XlaLocalLaunchBase::XlaLocalLaunchBase(OpKernelConstruction* ctx, constants_(constants), resources_(resources), function_(function), - platform_info_(XlaPlatformInfoFromContext(ctx)), + platform_info_(XlaPlatformInfoFromDevice(ctx->device())), has_ref_vars_(has_ref_vars) {} static Status CompileToLocalExecutable( @@ -180,7 +180,7 @@ static Status CompileToLocalExecutable( TF_RETURN_IF_ERROR(rm->LookupOrCreate( rm->default_container(), "xla_cache", &cache, [&](XlaCompilationCache** cache) { - return BuildXlaCompilationCache(ctx, platform_info, cache); + return BuildXlaCompilationCache(ctx->device(), platform_info, cache); })); // Hold the reference to the JIT during evaluation. (We could probably // free it sooner because the ResourceMgr will retain a reference, but @@ -191,7 +191,9 @@ static Status CompileToLocalExecutable( absl::optional tf_allocator_adapter; XlaCompiler::Options options = GenerateCompilerOptions( - *cache, ctx, platform_info, has_ref_vars, &tf_allocator_adapter); + *cache, *ctx->function_library(), ctx->device(), + ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr, + platform_info, has_ref_vars, &tf_allocator_adapter); std::map constant_args; for (int i : constants) { @@ -248,8 +250,10 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { VLOG(1) << "Executing XLA Computation..."; absl::optional tf_allocator_adapter; - se::DeviceMemoryAllocator* allocator = - GetAllocator(&tf_allocator_adapter, ctx, platform_info_); + se::DeviceMemoryAllocator* allocator = GetAllocator( + &tf_allocator_adapter, ctx->device(), + ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr, + platform_info_); int device_ordinal = stream ? stream->parent()->device_ordinal() : client->default_device_ordinal(); XlaComputationLaunchContext launch_context( @@ -373,7 +377,7 @@ XlaCompileOp::XlaCompileOp(OpKernelConstruction* ctx) constants_(ConstantsVector(ctx)), resources_(ResourcesVector(ctx)), function_(FunctionAttr(ctx)), - platform_info_(XlaPlatformInfoFromContext(ctx)), + platform_info_(XlaPlatformInfoFromDevice(ctx->device())), must_compile_(MustCompileAttr(ctx)), has_ref_vars_(HasRefVars(ctx)) {} @@ -461,7 +465,7 @@ void XlaCompileOp::Compute(OpKernelContext* ctx) { } XlaRunOp::XlaRunOp(OpKernelConstruction* ctx) - : OpKernel(ctx), platform_info_(XlaPlatformInfoFromContext(ctx)) {} + : OpKernel(ctx), platform_info_(XlaPlatformInfoFromDevice(ctx->device())) {} void XlaRunOp::Compute(OpKernelContext* ctx) { VLOG(3) << "XlaRunOp " << def().name(); @@ -472,8 +476,10 @@ void XlaRunOp::Compute(OpKernelContext* ctx) { XlaExecutableClosureStore::Global()->Consume(key); absl::optional tf_allocator_adapter; - se::DeviceMemoryAllocator* allocator = - GetAllocator(&tf_allocator_adapter, ctx, platform_info_); + se::DeviceMemoryAllocator* allocator = GetAllocator( + &tf_allocator_adapter, ctx->device(), + ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr, + platform_info_); se::Stream* stream = ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr; int device_ordinal = stream ? stream->parent()->device_ordinal() diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 19eb61b6f72..03ac7b0a59a 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -1196,12 +1196,9 @@ Status MarkForCompilationPassImpl::FindCompilationCandidates() { continue; } - DeviceType jit_device_type(registration->compilation_device_name); - - RecursiveCompilabilityChecker::OperationFilter op_filter = - CreateOperationFilter(*registration); - - if (!RecursiveCompilabilityChecker{&op_filter, &jit_device_type} + if (!RecursiveCompilabilityChecker{ + CreateOperationFilter(*registration), + DeviceType{registration->compilation_device_name}} .IsCompilableNode(*node, lib_runtime)) { continue; } @@ -1718,7 +1715,6 @@ bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef, const XlaOpRegistry::DeviceRegistration* registration; CHECK(XlaOpRegistry::GetCompilationDevice(device->device_type(), ®istration)); - DeviceType jit_device_type(registration->compilation_device_name); // We can always *compile* resource operations, stateful RNGs and dummy ops, // even if we are sometimes unable to auto-cluster them. @@ -1733,7 +1729,8 @@ bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef, op_filter.allow_slow_ops = true; op_filter.allow_inaccurate_ops = true; - RecursiveCompilabilityChecker checker{&op_filter, &jit_device_type}; + RecursiveCompilabilityChecker checker{ + op_filter, DeviceType{registration->compilation_device_name}}; if (!uncompilable_node_info) { // We do not need uncompilable node info. Just return the result. return checker.IsCompilableCall(ndef, flr); diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc index 971a5383f6b..b5bb2fab0ed 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.cc +++ b/tensorflow/compiler/jit/xla_compilation_cache.cc @@ -20,6 +20,7 @@ limitations under the License. #include "absl/base/call_once.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" +#include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/jit/xla_activity.pb.h" #include "tensorflow/compiler/jit/xla_activity_listener.h" #include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h" @@ -323,6 +324,10 @@ Status XlaCompilationCache::CompileImpl( absl::optional compile_threshold, const XlaCompiler::CompilationResult** out_compilation_result, xla::LocalExecutable** out_executable) { + if (FailOnXlaCompilation()) { + return errors::Internal("XLA compilation disabled"); + } + DCHECK_NE(out_executable, nullptr); VLOG(2) << "XlaCompilationCache::Compile " << DebugString(); diff --git a/tensorflow/compiler/jit/xla_compilation_cache_test.cc b/tensorflow/compiler/jit/xla_compilation_cache_test.cc index 7227615d2bb..5578925b790 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache_test.cc +++ b/tensorflow/compiler/jit/xla_compilation_cache_test.cc @@ -15,7 +15,9 @@ limitations under the License. #include "tensorflow/compiler/jit/xla_compilation_cache.h" +#include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test_benchmark.h" @@ -52,6 +54,30 @@ TEST(XlaCompilationCacheTest, SignatureEquality) { } } +TEST(XlaCompilationCacheTest, TestDisabledXlaCompilation) { + NameAttrList fn; + fn.set_name("afunction"); + + DisableXlaCompilation(); + + xla::LocalClient* client = xla::ClientLibrary::LocalClientOrDie(); + DeviceType device_type = DeviceType(DEVICE_CPU_XLA_JIT); + + const XlaCompiler::CompilationResult* compilation_result; + xla::LocalExecutable* executable; + + auto cache = new XlaCompilationCache(client, device_type); + core::ScopedUnref cache_ref(cache); + + Status status = cache->Compile(XlaCompiler::Options{}, fn, {}, + XlaCompiler::CompileOptions{}, + XlaCompilationCache::CompileMode::kStrict, + &compilation_result, &executable); + EXPECT_FALSE(status.ok()); + EXPECT_TRUE( + absl::StrContains(status.error_message(), "XLA compilation disabled")); +} + static void BM_BuildSignature(int iters, int n_args) { NameAttrList fn; fn.set_name("afunction"); diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc index da251c2c8f3..ba20b532a11 100644 --- a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc +++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc @@ -49,8 +49,10 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx, xla::LocalClient* client = static_cast(cache->client()); absl::optional tf_allocator_adapter; - se::DeviceMemoryAllocator* allocator = - GetAllocator(&tf_allocator_adapter, ctx, platform_info_); + se::DeviceMemoryAllocator* allocator = GetAllocator( + &tf_allocator_adapter, ctx->device(), + ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr, + platform_info_); XlaComputationLaunchContext launch_context( client, allocator, client->default_device_ordinal(), /*allocate_xla_tensors=*/platform_info_.xla_device_metadata() != nullptr, @@ -157,13 +159,16 @@ Status XlaCompileOnDemandOp::Compile( TF_RETURN_IF_ERROR(rm->LookupOrCreate( rm->default_container(), "xla_cache", cache, [&](XlaCompilationCache** write_into_cache) { - return BuildXlaCompilationCache(ctx, platform_info_, write_into_cache); + return BuildXlaCompilationCache(ctx->device(), platform_info_, + write_into_cache); })); absl::optional tf_allocator_adapter; - XlaCompiler::Options options = - GenerateCompilerOptions(**cache, ctx, platform_info_, - /*has_ref_vars=*/true, &tf_allocator_adapter); + XlaCompiler::Options options = GenerateCompilerOptions( + **cache, *ctx->function_library(), ctx->device(), + ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr, + platform_info_, + /*has_ref_vars=*/true, &tf_allocator_adapter); XlaCompiler::CompileOptions compile_options; compile_options.is_entry_computation = true; diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.h b/tensorflow/compiler/jit/xla_compile_on_demand_op.h index 095d3427d41..bb8ab889ce9 100644 --- a/tensorflow/compiler/jit/xla_compile_on_demand_op.h +++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.h @@ -37,7 +37,8 @@ namespace tensorflow { class XlaCompileOnDemandOp : public OpKernel { public: explicit XlaCompileOnDemandOp(OpKernelConstruction* ctx) - : OpKernel(ctx), platform_info_(XlaPlatformInfoFromContext(ctx)) {} + : OpKernel(ctx), + platform_info_(XlaPlatformInfoFromDevice(ctx->device())) {} void Compute(OpKernelContext* ctx) override; private: diff --git a/tensorflow/compiler/jit/xla_cpu_device.cc b/tensorflow/compiler/jit/xla_cpu_device.cc index 446cd8944de..dd1ddb616f5 100644 --- a/tensorflow/compiler/jit/xla_cpu_device.cc +++ b/tensorflow/compiler/jit/xla_cpu_device.cc @@ -51,7 +51,7 @@ Status XlaCpuDeviceFactory::CreateDevices( std::vector>* devices) { XlaDeviceFlags* flags = GetXlaDeviceFlags(); if (!flags->tf_xla_enable_xla_devices) { - LOG(INFO) << "Not creating XLA devices, tf_xla_enable_xla_devices not set"; + VLOG(1) << "Not creating XLA devices, tf_xla_enable_xla_devices not set"; return Status::OK(); } bool compile_on_demand = flags->tf_xla_compile_on_demand; diff --git a/tensorflow/compiler/jit/xla_device.h b/tensorflow/compiler/jit/xla_device.h index f7e7ee9cf95..6d6086ce0fa 100644 --- a/tensorflow/compiler/jit/xla_device.h +++ b/tensorflow/compiler/jit/xla_device.h @@ -94,6 +94,11 @@ class XlaDevice : public LocalDevice { static Status GetMetadata(OpKernelConstruction* ctx, const Metadata** metadata); + // Sets `*metadata` to the XlaDevice Metadata in the XLA device used by + // `device`. + static Status GetMetadataFromDevice(DeviceBase* device, + const XlaDevice::Metadata** metadata); + struct Options { // The StreamExecutor platform. Not owned. Must be non-null. se::Platform* platform = nullptr; @@ -196,8 +201,6 @@ class XlaDevice : public LocalDevice { xla::StatusOr> GetDeviceContextLocked() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); - static Status GetMetadataFromDevice(DeviceBase* device, - const XlaDevice::Metadata** metadata); Status MakeTensorFromProto(XlaDeviceContext* device_context, const TensorProto& tensor_proto, diff --git a/tensorflow/compiler/jit/xla_gpu_device.cc b/tensorflow/compiler/jit/xla_gpu_device.cc index 16f496d51a3..99ba5658819 100644 --- a/tensorflow/compiler/jit/xla_gpu_device.cc +++ b/tensorflow/compiler/jit/xla_gpu_device.cc @@ -66,7 +66,7 @@ class XlaGpuDeviceFactory : public DeviceFactory { Status XlaGpuDeviceFactory::ListPhysicalDevices(std::vector* devices) { XlaDeviceFlags* flags = GetXlaDeviceFlags(); if (!flags->tf_xla_enable_xla_devices) { - LOG(INFO) << "Not creating XLA devices, tf_xla_enable_xla_devices not set"; + VLOG(1) << "Not creating XLA devices, tf_xla_enable_xla_devices not set"; return Status::OK(); } diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc index 19e2b5a2bb5..ed6e399aed4 100644 --- a/tensorflow/compiler/jit/xla_launch_util.cc +++ b/tensorflow/compiler/jit/xla_launch_util.cc @@ -44,12 +44,6 @@ namespace { using xla::ScopedShapedBuffer; using xla::ShapedBuffer; -const char kPossibleNonVariableResourceHintMessage[] = - "If the error is similar to `Trying to access resource using the wrong " - "type`, this is likely because XLA only accepts Resource Variables as " - "inputs by snapshotting their values. Other TensorFlow resource types like " - "TensorList/TensorArray/Stack are not supported. Try removing non-variable " - "resource inputs to XLA."; } // anonymous namespace VariableInfo::VariableInfo(int index, absl::string_view name, Var* var) diff --git a/tensorflow/compiler/jit/xla_platform_info.cc b/tensorflow/compiler/jit/xla_platform_info.cc index a5e12b37563..b38bf9282b1 100644 --- a/tensorflow/compiler/jit/xla_platform_info.cc +++ b/tensorflow/compiler/jit/xla_platform_info.cc @@ -19,7 +19,7 @@ limitations under the License. namespace tensorflow { -Status BuildXlaCompilationCache(OpKernelContext* ctx, +Status BuildXlaCompilationCache(DeviceBase* device, const XlaPlatformInfo& platform_info, XlaCompilationCache** cache) { if (platform_info.xla_device_metadata()) { @@ -59,7 +59,7 @@ Status BuildXlaCompilationCache(OpKernelContext* ctx, xla::LocalClientOptions client_options; client_options.set_platform(platform.ValueOrDie()); client_options.set_intra_op_parallelism_threads( - ctx->device()->tensorflow_cpu_worker_threads()->num_threads); + device->tensorflow_cpu_worker_threads()->num_threads); auto client = xla::ClientLibrary::GetOrCreateLocalClient(client_options); if (!client.ok()) { return client.status(); @@ -75,21 +75,21 @@ Status BuildXlaCompilationCache(OpKernelContext* ctx, return Status::OK(); } -XlaPlatformInfo XlaPlatformInfoFromContext(OpKernelConstruction* ctx) { - DeviceType device_type = ctx->device_type(); +XlaPlatformInfo XlaPlatformInfoFromDevice(DeviceBase* device_base) { + auto device = static_cast(device_base); se::Platform::Id platform_id = nullptr; const XlaDevice::Metadata* xla_device_metadata = nullptr; se::DeviceMemoryAllocator* custom_allocator = nullptr; - if (ctx->device_type() == DeviceType(DEVICE_CPU)) { + if (device->device_type() == DEVICE_CPU) { platform_id = se::host::kHostPlatformId; - } else if (ctx->device_type() == DeviceType(DEVICE_GPU)) { - platform_id = ctx->device() - ->tensorflow_gpu_device_info() + } else if (device->device_type() == DEVICE_GPU) { + platform_id = device->tensorflow_gpu_device_info() ->stream->parent() ->platform() ->id(); - } else if (XlaDevice::GetMetadata(ctx, &xla_device_metadata).ok()) { + } else if (XlaDevice::GetMetadataFromDevice(device, &xla_device_metadata) + .ok()) { // If we are on an XlaDevice, use the underlying XLA platform's allocator // directly. We could use the StreamExecutor's allocator which may // theoretically be more correct, but XLA returns a nice OOM message in a @@ -104,47 +104,46 @@ XlaPlatformInfo XlaPlatformInfoFromContext(OpKernelConstruction* ctx) { xla_device_metadata->client()->backend().memory_allocator(); } - return XlaPlatformInfo(device_type, platform_id, xla_device_metadata, - custom_allocator); + return XlaPlatformInfo(DeviceType(device->device_type()), platform_id, + xla_device_metadata, custom_allocator); } se::DeviceMemoryAllocator* GetAllocator( absl::optional* tf_allocator_adapter, - OpKernelContext* ctx, const XlaPlatformInfo& platform_info) { + DeviceBase* device, se::Stream* stream, + const XlaPlatformInfo& platform_info) { if (platform_info.custom_allocator()) { return platform_info.custom_allocator(); } - if (!ctx->op_device_context()) { + if (!stream) { // Stream is not set for the host platform. se::Platform* platform = se::MultiPlatformManager::PlatformWithId(platform_info.platform_id()) .ValueOrDie(); - tf_allocator_adapter->emplace(ctx->device()->GetAllocator({}), platform); + tf_allocator_adapter->emplace(device->GetAllocator({}), platform); return &tf_allocator_adapter->value(); } - tf_allocator_adapter->emplace(ctx->device()->GetAllocator({}), - ctx->op_device_context()->stream()); + tf_allocator_adapter->emplace(device->GetAllocator({}), stream); return &tf_allocator_adapter->value(); } XlaCompiler::Options GenerateCompilerOptions( - const XlaCompilationCache& cache, OpKernelContext* ctx, - const XlaPlatformInfo& platform_info, bool has_ref_vars, + const XlaCompilationCache& cache, + const FunctionLibraryRuntime& function_library, DeviceBase* device, + se::Stream* stream, const XlaPlatformInfo& platform_info, bool has_ref_vars, absl::optional* tf_allocator_adapter) { - CHECK(ctx->function_library()); XlaCompiler::Options options; options.client = static_cast(cache.client()); - if (ctx->op_device_context() != nullptr) { - options.device_ordinal = - ctx->op_device_context()->stream()->parent()->device_ordinal(); + if (stream != nullptr) { + options.device_ordinal = stream->parent()->device_ordinal(); } options.device_type = cache.device_type(); - options.flib_def = ctx->function_library()->GetFunctionLibraryDefinition(); - options.graph_def_version = ctx->function_library()->graph_def_version(); + options.flib_def = function_library.GetFunctionLibraryDefinition(); + options.graph_def_version = function_library.graph_def_version(); options.allow_cpu_custom_calls = (platform_info.platform_id() == se::host::kHostPlatformId); options.device_allocator = - GetAllocator(tf_allocator_adapter, ctx, platform_info); + GetAllocator(tf_allocator_adapter, device, stream, platform_info); if (platform_info.xla_device_metadata()) { options.shape_representation_fn = platform_info.xla_device_metadata()->shape_representation_fn(); diff --git a/tensorflow/compiler/jit/xla_platform_info.h b/tensorflow/compiler/jit/xla_platform_info.h index d58b32a996f..bfb438cc398 100644 --- a/tensorflow/compiler/jit/xla_platform_info.h +++ b/tensorflow/compiler/jit/xla_platform_info.h @@ -80,27 +80,31 @@ class XlaPlatformInfo { }; // Returns created XLA compilation cache. -Status BuildXlaCompilationCache(OpKernelContext* ctx, +Status BuildXlaCompilationCache(DeviceBase* dev, const XlaPlatformInfo& platform_info, XlaCompilationCache** cache); // Returns information about the platform from kernel context. -XlaPlatformInfo XlaPlatformInfoFromContext(OpKernelConstruction* ctx); +XlaPlatformInfo XlaPlatformInfoFromDevice(DeviceBase* device); // Returns allocator from platform info if non-null, or populate and return a // pointer to the allocator adapter with allocator from context. // // This is necessary because for XLA devices the underlying TF allocator returns // dummy tensors. +// +// `stream` parameter is nullable when running on host. se::DeviceMemoryAllocator* GetAllocator( absl::optional* tf_allocator_adapter, - OpKernelContext* ctx, const XlaPlatformInfo& platform_info); + DeviceBase* device, se::Stream* stream, + const XlaPlatformInfo& platform_info); // Returns created options for the XLA compiler, and writes the used allocator // into `tf_allocator_adapter`. XlaCompiler::Options GenerateCompilerOptions( - const XlaCompilationCache& cache, OpKernelContext* ctx, - const XlaPlatformInfo& platform_info, bool has_ref_vars, + const XlaCompilationCache& cache, + const FunctionLibraryRuntime& function_library, DeviceBase* device, + se::Stream* stream, const XlaPlatformInfo& platform_info, bool has_ref_vars, absl::optional* tf_allocator_adapter); } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/BUILD b/tensorflow/compiler/mlir/BUILD index 731394a89da..5b79f7835c0 100644 --- a/tensorflow/compiler/mlir/BUILD +++ b/tensorflow/compiler/mlir/BUILD @@ -24,11 +24,40 @@ filegroup( srcs = glob(["**/*.td"]), ) +cc_library( + name = "string_container_utils", + hdrs = ["utils/string_container_utils.h"], + deps = [ + "@com_google_absl//absl/strings", + "@llvm-project//llvm:Support", + ], +) + +cc_library( + name = "array_container_utils", + hdrs = ["utils/array_container_utils.h"], + deps = [ + "@com_google_absl//absl/types:span", + "@llvm-project//llvm:Support", + ], +) + +cc_library( + name = "name_utils", + srcs = ["utils/name_utils.cc"], + hdrs = ["utils/name_utils.h"], + deps = [ + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + ], +) + cc_library( name = "op_or_arg_name_mapper", srcs = ["op_or_arg_name_mapper.cc"], hdrs = ["op_or_arg_name_mapper.h"], deps = [ + ":name_utils", "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", diff --git a/tensorflow/compiler/mlir/hlo/BUILD b/tensorflow/compiler/mlir/hlo/BUILD index a6ac886e972..4014b885290 100644 --- a/tensorflow/compiler/mlir/hlo/BUILD +++ b/tensorflow/compiler/mlir/hlo/BUILD @@ -341,6 +341,7 @@ cc_library( "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:StandardOps", "@llvm-project//mlir:Support", "@llvm-project//mlir:Transforms", @@ -348,6 +349,22 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "mhlo_control_flow_to_scf", + srcs = ["lib/Dialect/mhlo/transforms/mhlo_control_flow_to_scf.cc"], + hdrs = ["include/mlir-hlo/Dialect/mhlo/transforms/passes.h"], + deps = [ + ":hlo", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:Transforms", + ], +) + cc_library( name = "map_lmhlo_to_scalar_op", hdrs = ["include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h"], @@ -800,6 +817,7 @@ cc_library( ":lhlo_legalize_to_affine", ":lhlo_legalize_to_gpu", ":lhlo_legalize_to_parallel_loops", + ":mhlo_control_flow_to_scf", ":mhlo_fusion", ":mhlo_to_mhlo_lowering_patterns", ":sink_constants_to_control_flow", diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.td index fa3bde24df1..aa0f4c317d4 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.td +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.td @@ -30,6 +30,11 @@ def LegalizeControlFlowPass : Pass<"mhlo-legalize-control-flow", "FuncOp"> { let constructor = "createLegalizeControlFlowPass()"; } +def LegalizeControlFlowToScfPass : Pass<"mhlo-control-flow-to-scf", "FuncOp"> { + let summary = "Legalize from MHLO control flow to SCF control flow."; + let constructor = "createControlFlowToScfPass()"; +} + def LegalizeGatherToTorchIndexSelectPass : Pass<"mhlo-legalize-gather-to-torch-index-select", "FuncOp"> { let summary = "Legalizes gathers to a torch index select."; let constructor = "createLegalizeGatherToTorchIndexSelectPass()"; diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h index efa116f3f0d..541d8e46ec5 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h @@ -35,6 +35,9 @@ namespace mhlo { /// Lowers HLO control flow ops to the Standard dialect. std::unique_ptr> createLegalizeControlFlowPass(); +/// Lowers MHLO control flow ops to the SCF dialect. +std::unique_ptr> createControlFlowToScfPass(); + /// Lowers from HLO dialect to Standard dialect. std::unique_ptr> createLegalizeToStdPass(); diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/CMakeLists.txt b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/CMakeLists.txt index bb9f98d32d3..945fa0ea991 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/CMakeLists.txt +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/CMakeLists.txt @@ -93,6 +93,7 @@ add_mlir_library(MhloToLhloConversion add_mlir_library(MhloToStandard legalize_control_flow.cc legalize_to_standard.cc + mhlo_control_flow_to_scf.cc DEPENDS MLIRhlo_opsIncGen diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/mhlo_control_flow_to_scf.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/mhlo_control_flow_to_scf.cc new file mode 100644 index 00000000000..aba7b078413 --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/mhlo_control_flow_to_scf.cc @@ -0,0 +1,199 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "llvm/Support/Casting.h" +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/transforms/passes.h" +#include "mlir/Dialect/SCF/SCF.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/Value.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" + +#define DEBUG_TYPE "mhlo-control-flow-to-scf" + +namespace mlir { +namespace mhlo { + +namespace { + +/// Convert MHLO While to SCF. +void MatchAndRewrite(WhileOp whileOp); + +/// Pass that converts MHLO control flow to SCF. +class ControlFlowToScfPass + : public mlir::PassWrapper { + void getDependentDialects(DialectRegistry& registry) const override { + registry.insert(); + } + void runOnFunction() override { + getFunction().walk([&](WhileOp whileOp) { MatchAndRewrite(whileOp); }); + } +}; + +// TODO(jpienaar): Look into reformulating as a pattern. +void MatchAndRewrite(WhileOp whileOp) { + // Handle pattern: + // x = start + // step = ... + // limit = ... + // while (x < limit) { ... x += step; } + + // Only handling multi value while loops at the moment. + auto tupleOp = whileOp.getOperand().getDefiningOp(); + if (!tupleOp) return; + auto bodyReturn = whileOp.body() + .front() + .getTerminator() + ->getOperand(0) + .getDefiningOp(); + // Note: due to the shape restrictions on While, if the operand to While is a + // tuple, then so is the return type of the body. But the verifier isn't + // checking that at the moment, so just bail out here if this doesn't hold. + if (!bodyReturn) return; + + Value result = whileOp.cond().front().getTerminator()->getOperand(0); + // TODO(jpienaar): Expand to handle more than simple case with LT compare and + // constant step. + auto cmp = result.getDefiningOp(); + if (!cmp || cmp.comparison_direction() != "LT") return; + + const int kConstant = -1; + auto getValueAndIndex = [&](Value val) -> std::pair { + if (matchPattern(val, m_Constant())) return {val, kConstant}; + // If it is defined by a tuple, then the tuple has to have been fed in and + // the external value is captured. + if (auto gte = val.getDefiningOp()) { + if (!gte.getOperand().isa()) return {nullptr, 0}; + int index = gte.index().getSExtValue(); + return {tupleOp.getOperand(index), index}; + } + return {nullptr, 0}; + }; + + using ValueIndex = std::pair; + ValueIndex loopIndVar = getValueAndIndex(cmp.lhs()); + ValueIndex max = getValueAndIndex(cmp.rhs()); + if (!loopIndVar.first || !max.first) return; + auto add = + bodyReturn.getOperand(loopIndVar.second).getDefiningOp(); + if (!add) return; + ValueIndex step = getValueAndIndex(add.rhs()); + if (step.second != kConstant || !step.first) return; + + // Only handle case where tuple isn't propagated as is for now. + // TODO(jpienaar): Remove this when a tuple is also created inside the loop + // to propagate. + for (auto* use : whileOp.body().front().getArgument(0).getUsers()) + if (!isa(use)) return; + + LLVM_DEBUG(llvm::dbgs() << "Found for (" << whileOp.getLoc() << "):\n"; + llvm::dbgs() << " loopIndVar = " << loopIndVar.second << " max = " + << max.second << " step = " << step.second << "\n"; + llvm::dbgs() << " loopIndVar = " << loopIndVar.first << " max = " + << max.first << " step = " << step.first << "\n";); + OpBuilder b(whileOp); + // Inputs to new for loop. + llvm::SmallVector input; + input.reserve(tupleOp.getNumOperands()); + for (auto r : tupleOp.getOperands().take_front(loopIndVar.second)) + input.push_back(r); + for (auto r : tupleOp.getOperands().drop_front(loopIndVar.second + 1)) + input.push_back(r); + + auto tensorIndexType = RankedTensorType::get({}, b.getIndexType()); + auto getAsIndex = [&](Value val) { + auto loc = whileOp.getLoc(); + return b.create( + loc, b.create(loc, tensorIndexType, val), ValueRange()); + }; + + // SCF for uses index type, so converted these. + auto forloopIndVar = getAsIndex(loopIndVar.first); + auto forMax = getAsIndex(max.first); + auto forStep = getAsIndex(step.first); + auto forOp = b.create(whileOp.getLoc(), forloopIndVar, + forMax, forStep, input); + // Transfer the body without the block arguments. + forOp.getLoopBody().front().getOperations().splice( + forOp.getLoopBody().front().getOperations().end(), + whileOp.body().front().getOperations()); + + b.setInsertionPointToStart(&forOp.getLoopBody().front()); + auto loopIndVarElType = + loopIndVar.first.getType().cast().getElementType(); + Value indVar = b.create( + whileOp.getLoc(), RankedTensorType::get({}, loopIndVarElType), + b.create(whileOp.getLoc(), loopIndVarElType, + forOp.getInductionVar())); + // Update all block argument users to the SCF For args. + for (auto* use : + llvm::make_early_inc_range(whileOp.body().getArgument(0).getUsers())) { + // TODO(jpienaar): Expand here too when we allow using the tuple in the + // loop. + auto gte = cast(use); + // If the loop induction var, then refer to the loop induction variable as + // this operand is not updated. + if (gte.index() == loopIndVar.second) { + use->getResult(0).replaceAllUsesWith(indVar); + use->erase(); + continue; + } + int index = gte.index().getSExtValue(); + // If after the loop induction variable, then decrement as we don't include + // the loop induction variable in the for iter operands. + if (index > loopIndVar.second) --index; + use->getResult(0).replaceAllUsesWith(forOp.getIterOperands()[index]); + use->erase(); + } + + // Create new yield op without induction var update. + SmallVector newYieldOps; + newYieldOps.reserve(bodyReturn.getNumOperands() - 1); + for (auto r : bodyReturn.getOperands().take_front(loopIndVar.second)) + newYieldOps.push_back(r); + for (auto r : bodyReturn.getOperands().drop_front(loopIndVar.second + 1)) + newYieldOps.push_back(r); + // Delete return & tuple op. + forOp.getLoopBody().front().back().erase(); + forOp.getLoopBody().front().back().erase(); + b.setInsertionPointToEnd(&forOp.getLoopBody().front()); + b.create(whileOp.getLoc(), newYieldOps); + + // Recombine output tuple with max value of induction variable. + llvm::SmallVector loopOut; + loopOut.reserve(forOp.getNumResults() + 1); + for (auto r : forOp.getResults().take_front(loopIndVar.second)) + loopOut.push_back(r); + loopOut.push_back(max.first); + for (auto r : forOp.getResults().drop_front(loopIndVar.second)) + loopOut.push_back(r); + b.setInsertionPoint(whileOp); + auto newRes = b.create(whileOp.getLoc(), loopOut); + whileOp.replaceAllUsesWith(newRes.getOperation()); + whileOp.erase(); +} + +} // anonymous namespace + +std::unique_ptr> createControlFlowToScfPass() { + return std::make_unique(); +} + +} // namespace mhlo +} // namespace mlir diff --git a/tensorflow/compiler/mlir/hlo/tests/legalize_to_scf.mlir b/tensorflow/compiler/mlir/hlo/tests/legalize_to_scf.mlir new file mode 100644 index 00000000000..9c887a73a0f --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/tests/legalize_to_scf.mlir @@ -0,0 +1,38 @@ +// RUN: mlir-hlo-opt --mhlo-control-flow-to-scf %s | FileCheck %s + +func @lt_loop(%arg0: tensor<4xf32>, %arg1: tensor, %arg2: tensor, %arg3: tensor<4xf32>, %arg4: tensor, %arg5: tensor, %arg6: tensor, %arg7: tensor, %arg8: tensor) -> (tuple, tensor, tensor>) { + %cst = constant dense<-1> : tensor + %cst_0 = constant dense<1> : tensor + %cst_1 = constant dense<0> : tensor + %cst_2 = constant dense<1000> : tensor + %0 = "mhlo.tuple"(%cst_1, %cst, %cst_2) : (tensor, tensor, tensor) -> tuple, tensor, tensor> + %1 = "mhlo.while"(%0) ( { + ^bb0(%arg9: tuple, tensor, tensor>): // no predecessors + %2 = "mhlo.get_tuple_element"(%arg9) {index = 0 : i32} : (tuple, tensor, tensor>) -> tensor + %3 = "mhlo.get_tuple_element"(%arg9) {index = 2 : i32} : (tuple, tensor, tensor>) -> tensor + %4 = "mhlo.compare"(%2, %3) {comparison_direction = "LT"} : (tensor, tensor) -> tensor + "mhlo.return"(%4) : (tensor) -> () + }, { + ^bb0(%arg9: tuple, tensor, tensor>): // no predecessors + %2 = "mhlo.get_tuple_element"(%arg9) {index = 0 : i32} : (tuple, tensor, tensor>) -> tensor + %3 = mhlo.add %2, %cst_0 : tensor + %4 = "mhlo.get_tuple_element"(%arg9) {index = 1 : i32} : (tuple, tensor, tensor>) -> tensor + %5 = "mhlo.get_tuple_element"(%arg9) {index = 2 : i32} : (tuple, tensor, tensor>) -> tensor + %6 = "mhlo.tuple"(%3, %4, %5) : (tensor, tensor, tensor) -> tuple, tensor, tensor> + "mhlo.return"(%6) : (tuple, tensor, tensor>) -> () + }) : (tuple, tensor, tensor>) -> tuple, tensor, tensor> + return %1 : tuple, tensor, tensor> +} + +// CHECK-LABEL: func @lt_loop( +// CHECK: %[[VAL_9:.*]] = constant dense<-1> : tensor +// CHECK: %[[VAL_10:.*]] = constant dense<1> : tensor +// CHECK: %[[VAL_11:.*]] = constant dense<0> : tensor +// CHECK: %[[VAL_12:.*]] = constant dense<1000> : tensor +// CHECK: %[[VAL_14:.*]] = index_cast %[[VAL_11]] : tensor to tensor +// CHECK: %[[VAL_15:.*]] = extract_element %[[VAL_14]][] : tensor +// CHECK: %[[VAL_16:.*]] = index_cast %[[VAL_12]] : tensor to tensor +// CHECK: %[[VAL_17:.*]] = extract_element %[[VAL_16]][] : tensor +// CHECK: %[[VAL_18:.*]] = index_cast %[[VAL_10]] : tensor to tensor +// CHECK: %[[VAL_19:.*]] = extract_element %[[VAL_18]][] : tensor +// CHECK: scf.for %[[VAL_21:.*]] = %[[VAL_15]] to %[[VAL_17]] step %[[VAL_19]] iter_args(%[[VAL_22:.*]] = %[[VAL_9]], %[[VAL_23:.*]] = %[[VAL_12]]) diff --git a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir index d02e4e705f4..ba98e760590 100644 --- a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir @@ -1029,14 +1029,49 @@ func @splitv(%arg0: tensor<1x4x3x3xf32>, %arg1: tensor<2xi32>, %arg2: tensor, tensor<2xi32>, tensor) -> (tensor<1x4x2x3xf32>, tensor<1x4x1x3xf32>) } -func @matmul_transposed(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> tensor<40x40xf32> { +func @matmul(%arg0: tensor<40x37xf32>, %arg1: tensor<37x40xf32>) -> tensor<40x40xf32> { + %0 = "tf.MatMul"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", device = "/device:CPU:0", name = "MatMul", transpose_a = false, transpose_b = false} : +(tensor<40x37xf32>, tensor<37x40xf32>) -> tensor<40x40xf32> + return %0 : tensor<40x40xf32> +// CHECK-LABEL: matmul +// CHECK: %[[CST:.*]] = constant dense<[1, 0]> : tensor<2xi32> +// CHECK: %[[ARG:.*]] = "tfl.transpose"(%arg1, %[[CST]]) : (tensor<37x40xf32>, tensor<2xi32>) -> tensor<40x37xf32> +// CHECK: %[[CST_0:.*]] = constant unit +// CHECK: "tfl.fully_connected"(%arg0, %[[ARG]], %[[CST_0]]) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<40x37xf32>, none) -> tensor<40x40xf32> +} + +func @matmul_transposed_a(%arg0: tensor<37x40xf32>, %arg1: tensor<37x40xf32>) -> tensor<40x40xf32> { + %0 = "tf.MatMul"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", device = "/device:CPU:0", name = "MatMul", transpose_a = true, transpose_b = false} : +(tensor<37x40xf32>, tensor<37x40xf32>) -> tensor<40x40xf32> + return %0 : tensor<40x40xf32> +// CHECK-LABEL: matmul_transposed_a +// CHECK: %[[CST_0:.*]] = constant dense<[1, 0]> : tensor<2xi32> +// CHECK: %[[ARG_0:.*]] = "tfl.transpose"(%arg0, %[[CST_0]]) : (tensor<37x40xf32>, tensor<2xi32>) -> tensor<40x37xf32> +// CHECK: %[[CST_1:.*]] = constant dense<[1, 0]> : tensor<2xi32> +// CHECK: %[[ARG_1:.*]] = "tfl.transpose"(%arg1, %[[CST_1]]) : (tensor<37x40xf32>, tensor<2xi32>) -> tensor<40x37xf32> +// CHECK: %[[CST_2:.*]] = constant unit +// CHECK: "tfl.fully_connected"(%[[ARG_0]], %[[ARG_1]], %[[CST_2]]) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<40x37xf32>, none) -> tensor<40x40xf32> +} + +func @matmul_transposed_b(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> tensor<40x40xf32> { %0 = "tf.MatMul"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", device = "/device:CPU:0", name = "MatMul", transpose_a = false, transpose_b = true} : (tensor<40x37xf32>, tensor<40x37xf32>) -> tensor<40x40xf32> return %0 : tensor<40x40xf32> -// CHECK-LABEL: matmul_transposed +// CHECK-LABEL: matmul_transposed_b // CHECK: "tfl.fully_connected"(%arg0, %arg1, %cst) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<40x37xf32>, none) -> tensor<40x40xf32> } +func @matmul_transposed_ab(%arg0: tensor<37x40xf32>, %arg1: tensor<40x37xf32>) -> tensor<40x40xf32> { + %0 = "tf.MatMul"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", device = "/device:CPU:0", name = "MatMul", transpose_a = true, transpose_b = true} : +(tensor<37x40xf32>, tensor<40x37xf32>) -> tensor<40x40xf32> + return %0 : tensor<40x40xf32> +// CHECK-LABEL: matmul_transposed_ab +// CHECK: %[[CST_0:.*]] = constant dense<[1, 0]> : tensor<2xi32> +// CHECK: %[[ARG_0:.*]] = "tfl.transpose"(%arg0, %[[CST_0]]) : (tensor<37x40xf32>, tensor<2xi32>) -> tensor<40x37xf32> +// CHECK: %[[CST_1:.*]] = constant unit +// CHECK: "tfl.fully_connected"(%[[ARG_0]], %arg1, %[[CST_1]]) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<40x37xf32>, none) -> tensor<40x40xf32> +} + func @concatv2With3Tensors(%arg0: tensor<2x1xi32>, %arg1: tensor<2x1xi32>, %arg2: tensor<2x1xi32>) -> tensor<2x3xi32> { %0 = "tf.Const"() { value = dense<-1> : tensor } : () -> tensor %1 = "tf.ConcatV2"(%arg0, %arg1, %arg2, %0) : (tensor<2x1xi32>, tensor<2x1xi32>, tensor<2x1xi32>, tensor) -> tensor<2x3xi32> diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc index b31da15c35f..bc460dbce2a 100644 --- a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc @@ -66,7 +66,6 @@ namespace TFL { // The actual LegalizeTF Pass. namespace { -using xla::Status; using xla::StatusOr; constexpr char kUnidirectionalSequenceLstm[] = "tf.UnidirectionalSequenceLstm"; @@ -232,26 +231,47 @@ LogicalResult ConvertTFConcatV2Op::matchAndRewrite( return success(); } -// The following is effectively: -// def : Pat< -// (TF_MatMulOp $a, $b, ConstBoolAttrFalse:$transpose_a, -// ConstBoolAttrTrue:$transpose_b), -// (TFL_FullyConnectedOp:$__0 $a, $b, -// NoInput.pattern, TFL_AF_None, TFL_FCWO_Default, ConstBoolAttrFalse)>; LogicalResult ConvertTFMatMulOp::matchAndRewrite( Operation* op, PatternRewriter& rewriter) const { auto tf_matmul_op = cast(op); - if (tf_matmul_op.transpose_a()) return failure(); - if (!tf_matmul_op.transpose_b()) return failure(); + auto lhs = op->getOperand(0); + auto rhs = op->getOperand(1); + auto transpose = [&](Value input) -> std::pair { + RankedTensorType type = + input.getType().dyn_cast_or_null(); + if (!type || type.getRank() != 2) return {failure(), nullptr}; + + auto permute_attr = DenseIntElementsAttr::get( + RankedTensorType::get({2}, rewriter.getI32Type()), {1, 0}); + auto permute = rewriter.create( + op->getLoc(), permute_attr.getType(), permute_attr); + llvm::SmallVector new_shape{type.getShape()[1], + type.getShape()[0]}; + auto output = rewriter.create( + op->getLoc(), RankedTensorType::get(new_shape, type.getElementType()), + input, permute); + return {success(), output}; + }; + + // TODO(jpienaar): Remove once handled via dailect conversion. + if (tf_matmul_op.transpose_a()) { + LogicalResult result = success(); + std::tie(result, lhs) = transpose(lhs); + if (failed(result)) return failure(); + } + if (!tf_matmul_op.transpose_b()) { + LogicalResult result = success(); + std::tie(result, rhs) = transpose(rhs); + if (failed(result)) return failure(); + } Type output_type = tf_matmul_op.getResult().getType(); - // TODO(jpienaar): Follow up post shuffle discussion. auto no_input = rewriter.create( op->getLoc(), rewriter.getNoneType(), rewriter.getUnitAttr()); auto fc_op = rewriter.create( - op->getLoc(), ArrayRef{output_type}, op->getOperand(0), - op->getOperand(1), no_input, rewriter.getStringAttr("NONE"), - rewriter.getStringAttr("DEFAULT"), rewriter.getBoolAttr(false)); + op->getLoc(), ArrayRef{output_type}, lhs, rhs, no_input, + rewriter.getStringAttr("NONE"), rewriter.getStringAttr("DEFAULT"), + rewriter.getBoolAttr(false)); rewriter.replaceOp(op, {fc_op.getResult(0)}); return success(); } diff --git a/tensorflow/compiler/mlir/op_or_arg_name_mapper.cc b/tensorflow/compiler/mlir/op_or_arg_name_mapper.cc index bce0ed4a33d..6b605741355 100644 --- a/tensorflow/compiler/mlir/op_or_arg_name_mapper.cc +++ b/tensorflow/compiler/mlir/op_or_arg_name_mapper.cc @@ -28,6 +28,7 @@ limitations under the License. #include "mlir/IR/Location.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project +#include "tensorflow/compiler/mlir/utils/name_utils.h" static inline absl::string_view StringRefToView(llvm::StringRef ref) { return absl::string_view(ref.data(), ref.size()); @@ -103,62 +104,16 @@ int OpOrArgNameMapper::InitOpName(OpOrVal op_or_val, llvm::StringRef name) { bool OpOrArgNameMapper::IsUnique(llvm::StringRef name) { return true; } -namespace { -// Derives name from location. -std::string GetNameFromLoc(mlir::Location loc) { - llvm::SmallVector loc_names; - llvm::SmallVector locs; - locs.push_back(loc); - bool names_is_nonempty = false; - - while (!locs.empty()) { - mlir::Location curr_loc = locs.pop_back_val(); - - if (auto name_loc = curr_loc.dyn_cast()) { - // Add name in NameLoc. For NameLoc we also account for names due to ops - // in functions where the op's name is first. - auto name = name_loc.getName().strref().split('@').first; - loc_names.push_back(name); - if (!name.empty()) names_is_nonempty = true; - continue; - } else if (auto call_loc = curr_loc.dyn_cast()) { - // Add name if CallSiteLoc's callee has a NameLoc (as should be the - // case if imported with DebugInfo). - if (auto name_loc = call_loc.getCallee().dyn_cast()) { - auto name = name_loc.getName().strref().split('@').first; - loc_names.push_back(name); - if (!name.empty()) names_is_nonempty = true; - continue; - } - } else if (auto fused_loc = curr_loc.dyn_cast()) { - // Push all locations in FusedLoc in reverse order, so locations are - // visited based on order in FusedLoc. - auto reversed_fused_locs = llvm::reverse(fused_loc.getLocations()); - locs.append(reversed_fused_locs.begin(), reversed_fused_locs.end()); - continue; - } - - // Location is not a supported, so an empty StringRef is added. - loc_names.push_back(llvm::StringRef()); - } - - if (names_is_nonempty) - return llvm::join(loc_names.begin(), loc_names.end(), ";"); - - return ""; -} -} // anonymous namespace - std::string OpOrArgLocNameMapper::GetName(OpOrVal op_or_val) { if (auto* op = op_or_val.dyn_cast()) { - auto name_from_loc = GetNameFromLoc(op->getLoc()); + auto name_from_loc = mlir::GetNameFromLoc(op->getLoc()); if (!name_from_loc.empty()) return name_from_loc; // If the location is none of the expected types, then simply use name // generated using the op type. return std::string(op->getName().getStringRef()); } auto val = op_or_val.dyn_cast(); - auto name_from_loc = GetNameFromLoc(val.getLoc()); + auto name_from_loc = mlir::GetNameFromLoc(val.getLoc()); if (!name_from_loc.empty()) return name_from_loc; // If the location is none of the expected types, then simply use name // generated using the op type. Follow TF convention and append the result diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index 115a5780e08..1344ded4804 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -794,6 +794,7 @@ cc_library( "transforms/tpu_identity_pruning.cc", "transforms/tpu_merge_variables_with_execute.cc", "transforms/tpu_outside_compilation_cluster.cc", + "transforms/tpu_resource_read_for_write.cc", "transforms/tpu_rewrite_pass.cc", "transforms/tpu_sharding_identification_pass.cc", "transforms/tpu_space_to_depth_pass.cc", @@ -960,6 +961,7 @@ cc_library( "//tensorflow/cc/saved_model:loader_lite", "//tensorflow/cc/saved_model:loader_util", "//tensorflow/compiler/jit:shape_inference_helpers", + "//tensorflow/compiler/mlir:name_utils", "//tensorflow/compiler/mlir:op_or_arg_name_mapper", "//tensorflow/compiler/tf2xla:functionalize_control_flow", "//tensorflow/compiler/xla:status_macros", diff --git a/tensorflow/compiler/mlir/tensorflow/c/BUILD b/tensorflow/compiler/mlir/tensorflow/c/BUILD index 243f4b5139f..5c6f39699bf 100644 --- a/tensorflow/compiler/mlir/tensorflow/c/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/c/BUILD @@ -2,7 +2,6 @@ load( "//tensorflow:tensorflow.bzl", "tf_copts", "tf_cuda_library", - "tfe_xla_copts", ) package( @@ -20,7 +19,7 @@ tf_cuda_library( srcs = [ "c_api_unified_experimental_mlir.cc", ], - copts = tf_copts() + tfe_xla_copts(), + copts = tf_copts(), deps = [ "//tensorflow/c:c_api", "//tensorflow/c:tensor_interface", diff --git a/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir.cc b/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir.cc index bd21ba015bf..6bfe4c302cd 100644 --- a/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir.cc +++ b/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir.cc @@ -452,7 +452,8 @@ Status MlirAbstractOp::SetAttrFloat(const char* attr_name, float value) { return Unimplemented("SetAttrFloat has not been implemented yet."); } Status MlirAbstractOp::SetAttrBool(const char* attr_name, bool value) { - return Unimplemented("SetAttrBool has not been implemented yet."); + attrs_[attr_name] = BoolAttr::get(value, context_); + return Status::OK(); } Status MlirAbstractOp::SetAttrShape(const char* attr_name, const int64_t* dims, const int num_dims) { diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc index ea9ae5d9477..eced738b0a5 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc @@ -250,33 +250,6 @@ ParseResult ParseGraphOp(OpAsmParser &parser, OperationState &result) { // tf_executor.fetch //===----------------------------------------------------------------------===// -namespace { - -void Print(FetchOp fetch, OpAsmPrinter &p) { - p << fetch.getOperationName(); - if (fetch.getNumOperands() > 0) { - p << ' '; - p.printOperands(fetch.operand_begin(), fetch.operand_end()); - p << " : "; - interleaveComma(fetch.getOperandTypes(), p); - } - p.printOptionalAttrDict(fetch.getAttrs()); -} - -ParseResult ParseFetchOp(OpAsmParser &parser, OperationState &result) { - SmallVector opInfo; - SmallVector types; - llvm::SMLoc loc = parser.getCurrentLocation(); - return failure(parser.parseOperandList(opInfo) || - (!opInfo.empty() && parser.parseColonTypeList(types)) || - parser.resolveOperands(opInfo, types, loc, result.operands) || - parser.parseOptionalAttrDict(result.attributes) - - ); -} - -} // anonymous namespace - //===----------------------------------------------------------------------===// // tf_executor.island //===----------------------------------------------------------------------===// @@ -411,31 +384,6 @@ ParseResult ParseIslandOp(OpAsmParser &parser, OperationState &result) { // tf_executor.yield //===----------------------------------------------------------------------===// -namespace { - -void Print(YieldOp yield, OpAsmPrinter &p) { - p << yield.getOperationName(); - if (yield.getNumOperands() > 0) { - p << ' '; - p.printOperands(yield.operand_begin(), yield.operand_end()); - p << " : "; - interleaveComma(yield.getOperandTypes(), p); - } - p.printOptionalAttrDict(yield.getAttrs()); -} - -ParseResult ParseYieldOp(OpAsmParser &parser, OperationState &result) { - SmallVector op_info; - SmallVector types; - llvm::SMLoc loc = parser.getCurrentLocation(); - return failure(parser.parseOperandList(op_info) || - (!op_info.empty() && parser.parseColonTypeList(types)) || - parser.resolveOperands(op_info, types, loc, result.operands) || - parser.parseOptionalAttrDict(result.attributes)); -} - -} // anonymous namespace - //===----------------------------------------------------------------------===// // tf_executor.Switch //===----------------------------------------------------------------------===// @@ -848,23 +796,6 @@ LogicalResult Verify(NextIterationSourceOp source) { return success(); } -void Print(NextIterationSourceOp next_iteration, OpAsmPrinter &p) { - p << next_iteration.getOperationName() << " : " << next_iteration.getType(0); - p.printOptionalAttrDict(next_iteration.getAttrs()); -} - -ParseResult ParseNextIterationSourceOp(OpAsmParser &parser, - OperationState &result) { - SmallVector types; - if (parser.parseColonTypeList(types)) return failure(); - - MLIRContext *context = parser.getBuilder().getContext(); - Type token_type = TokenType::get(context); - Type control_type = ControlType::get(context); - result.addTypes({types.front(), token_type, control_type}); - return parser.parseOptionalAttrDict(result.attributes); -} - } // anonymous namespace //===----------------------------------------------------------------------===// @@ -891,36 +822,6 @@ LogicalResult Verify(NextIterationSinkOp sink) { return success(); } -void Print(NextIterationSinkOp next_iteration, OpAsmPrinter &p) { - p << next_iteration.getOperationName() << " ["; - p.printOperand(next_iteration.getOperand(0)); - p << "] "; - p.printOperands(llvm::drop_begin(next_iteration.getOperands(), 1)); - p << " : " << next_iteration.getOperand(1).getType(); - p.printOptionalAttrDict(next_iteration.getAttrs()); -} - -ParseResult ParseNextIterationSinkOp(OpAsmParser &parser, - OperationState &result) { - SmallVector op_infos; - llvm::SMLoc loc = parser.getCurrentLocation(); - - // First type is always the token consumed from the NextIteration.source - Type token_type = TokenType::get(parser.getBuilder().getContext()); - SmallVector types = {token_type}; - - if (parser.parseOperandList(op_infos, 1, OpAsmParser::Delimiter::Square) || - parser.parseOperandList(op_infos) || parser.parseColonTypeList(types)) - return failure(); - - Type control_type = ControlType::get(parser.getBuilder().getContext()); - types.append(op_infos.size() - 2, control_type); - if (parser.resolveOperands(op_infos, types, loc, result.operands)) - return failure(); - - return parser.parseOptionalAttrDict(result.attributes); -} - } // anonymous namespace //===----------------------------------------------------------------------===// @@ -959,32 +860,6 @@ ParseResult ParseExitOp(OpAsmParser &parser, OperationState &result) { // tf_executor.ControlTrigger //===----------------------------------------------------------------------===// -namespace { - -void Print(ControlTriggerOp trigger, OpAsmPrinter &p) { - p << trigger.getOperationName() << ' '; - p.printOperands(trigger.getOperands()); - p.printOptionalAttrDict(trigger.getAttrs()); -} - -ParseResult ParseControlTriggerOp(OpAsmParser &parser, OperationState &result) { - SmallVector op_infos; - SmallVector types; - llvm::SMLoc loc = parser.getCurrentLocation(); - - if (parser.parseOperandList(op_infos)) return failure(); - Type control_type = ControlType::get(parser.getBuilder().getContext()); - types.append(op_infos.size(), control_type); - if (parser.resolveOperands(op_infos, types, loc, result.operands)) - return failure(); - - // Single control as the only output - result.types.push_back(control_type); - return parser.parseOptionalAttrDict(result.attributes); -} - -} // anonymous namespace - //===----------------------------------------------------------------------===// // tf_executor.LoopCond //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor_ops.td index 3081018b8da..de2d2485628 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor_ops.td @@ -47,10 +47,12 @@ def TfExecutor_Dialect : Dialect { } // Control type. -def TfeControlType : Type()">, "control">; +def TfeControlType : Type()">, "control">, + BuildableType<"$_builder.getType()">; // Token type. -def TfeTokenType : Type()">, "token">; +def TfeTokenType : Type()">, "token">, + BuildableType<"$_builder.getType()">; // TODO(hinsu): Define and use TensorType instead of AnyType for data operands // and results. For example, MergeOp output type. @@ -148,7 +150,11 @@ def TfExecutor_FetchOp : TfExecutor_Op<"fetch", }]> ]; + let assemblyFormat = "($fetches^ `:` type($fetches))? attr-dict"; + let verifier = ?; + let printer = ?; + let parser = ?; } def TfExecutor_IslandOp : TfExecutor_Op<"island", @@ -229,7 +235,11 @@ def TfExecutor_YieldOp : TfExecutor_Op<"yield", }]> ]; + let assemblyFormat = "($fetches^ `:` type($fetches))? attr-dict"; + let verifier = ?; + let printer = ?; + let parser = ?; } def TfExecutor_SwitchOp : TfExecutor_Op<"Switch", @@ -466,6 +476,10 @@ def TfExecutor_NextIterationSourceOp : TfExecutor_Op<"NextIteration.Source", } }]; + let assemblyFormat = "`:` type($output) attr-dict"; + + let printer = ?; + let parser = ?; } @@ -527,6 +541,11 @@ def TfExecutor_NextIterationSinkOp : TfExecutor_Op<"NextIteration.Sink", result.attributes.append(attributes.begin(), attributes.end()); }]> ]; + + let assemblyFormat = " `[` $token `]` $input (`,` $controlInputs^)? `:` type($input) attr-dict"; + + let printer = ?; + let parser = ?; } def TfExecutor_ExitOp : TfExecutor_Op<"Exit", @@ -552,7 +571,7 @@ def TfExecutor_ExitOp : TfExecutor_Op<"Exit", .Attr("T: type") For example: - %1:2 = tf_executor.Exit %0#0 {T: "tfdtype$DT_INT32"} : tensor<*xi32> + %1:2 = tf_executor.Exit %0#0 : tensor<*xi32> {T: "tfdtype$DT_INT32"} Note: Additional result corresponds to the control output. }]; @@ -607,6 +626,11 @@ def TfExecutor_ControlTriggerOp : TfExecutor_Op<"ControlTrigger", result.attributes.append(attributes.begin(), attributes.end()); }]> ]; + + let assemblyFormat = "$controlInputs attr-dict"; + + let printer = ?; + let parser = ?; } def TfExecutor_LoopCondOp : TfExecutor_Op<"LoopCond", diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td index 283e3326029..faf7d428aea 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td @@ -52,6 +52,12 @@ an output element, this operation computes \\(y = |x|\\). def TF_AcosOp : TF_Op<"Acos", [NoSideEffect, SameOperandsAndResultType]> { let summary = "Computes acos of x element-wise."; + let description = [{ +Provided an input tensor, the `tf.math.acos` operation returns the inverse cosine of each element of the tensor. If `y = tf.math.cos(x)` then, `x = tf.math.acos(y)`. + + Input range is `[-1, 1]` and the output has a range of `[0, pi]`. + }]; + let arguments = (ins TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64]>:$x ); @@ -94,6 +100,10 @@ def TF_AddOp : TF_Op<"Add", [NoSideEffect, ResultsBroadcastableShape, TF_LayoutA let description = [{ *NOTE*: `Add` supports broadcasting. `AddN` does not. More about broadcasting [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) + +Given two input tensors, the `tf.add` operation computes the sum for every element in the tensor. + +Both input and output have a range `(-inf, inf)`. }]; let arguments = (ins @@ -136,31 +146,6 @@ Inputs must be of same size and shape. let hasFolder = 1; } -def TF_AddV2Op : TF_Op<"AddV2", [Commutative, NoSideEffect, ResultsBroadcastableShape, TF_CwiseBinary, TF_LayoutAgnostic, TF_SameOperandsAndResultElementTypeResolveRef]>, - WithBroadcastableBinOpBuilder { - let summary = "Returns x + y element-wise."; - - let description = [{ -*NOTE*: `Add` supports broadcasting. `AddN` does not. More about broadcasting -[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) - }]; - - let arguments = (ins - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint8]>:$x, - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint8]>:$y - ); - - let results = (outs - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint8]>:$z - ); - - TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; - - let hasCanonicalizer = 1; - - let hasFolder = 1; -} - def TF_AdjustContrastv2Op : TF_Op<"AdjustContrastv2", [NoSideEffect]> { let summary = "Adjust the contrast of one or more images."; @@ -1740,6 +1725,24 @@ def TF_ConcatV2Op : TF_Op<"ConcatV2", [NoSideEffect]> { let hasCanonicalizer = 1; } +def TF_ConfigureDistributedTPUOp : TF_Op<"ConfigureDistributedTPU", []> { + let summary = [{ +Sets up the centralized structures for a distributed TPU system. + }]; + + let arguments = (ins + StrAttr:$embedding_config, + StrAttr:$tpu_embedding_config, + DefaultValuedAttr:$is_global_init, + DefaultValuedAttr:$enable_whole_mesh_compilations, + DefaultValuedAttr:$compilation_failure_closes_chips + ); + + let results = (outs + TF_StrTensor:$topology + ); +} + def TF_ConjOp : TF_Op<"Conj", [NoSideEffect, SameOperandsAndResultType]> { let summary = "Returns the complex conjugate of a complex number."; @@ -2786,27 +2789,6 @@ def TF_DivOp : TF_Op<"Div", [NoSideEffect, ResultsBroadcastableShape, TF_SameOpe let hasFolder = 1; } -def TF_DivNoNanOp : TF_Op<"DivNoNan", [NoSideEffect, ResultsBroadcastableShape, TF_SameOperandsAndResultElementTypeResolveRef]>, - WithBroadcastableBinOpBuilder { - let summary = "Returns 0 if the denominator is zero."; - - let description = [{ -*NOTE*: `DivNoNan` supports broadcasting. More about broadcasting -[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) - }]; - - let arguments = (ins - TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$x, - TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$y - ); - - let results = (outs - TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$z - ); - - TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; -} - def TF_DynamicStitchOp : TF_Op<"DynamicStitch", [NoSideEffect, SameVariadicOperandSize]> { let summary = [{ Interleave the values from the `data` tensors into a single tensor. @@ -3853,6 +3835,95 @@ The size of 1D Tensors matches the dimension C of the 4D Tensors. }]; } +def TF_FusedBatchNormV2Op : TF_Op<"FusedBatchNormV2", [NoSideEffect, TF_FoldOperandsTransposeInterface, TF_LayoutSensitiveInterface]> { + let summary = "Batch normalization."; + + let description = [{ +Note that the size of 4D Tensors are defined by either "NHWC" or "NCHW". +The size of 1D Tensors matches the dimension C of the 4D Tensors. + }]; + + let arguments = (ins + TensorOf<[BF16, F16, F32]>:$x, + F32Tensor:$scale, + F32Tensor:$offset, + F32Tensor:$mean, + F32Tensor:$variance, + + DefaultValuedAttr:$epsilon, + DefaultValuedAttr:$exponential_avg_factor, + DefaultValuedAttr:$data_format, + DefaultValuedAttr:$is_training + ); + + let results = (outs + TensorOf<[BF16, F16, F32]>:$y, + F32Tensor:$batch_mean, + F32Tensor:$batch_variance, + F32Tensor:$reserve_space_1, + F32Tensor:$reserve_space_2 + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; + TF_DerivedOperandTypeAttr U = TF_DerivedOperandTypeAttr<1>; + + let extraClassDeclaration = [{ + // TF_FoldOperandsTransposeInterface: + SmallVector GetLayoutDependentArgs() { return {0}; } + SmallVector GetLayoutDependentResults() { return {0}; } + LogicalResult FoldOperandsPermutation(ArrayRef permutation); + + // TF_LayoutSensitiveInterface: + StringRef GetOptimalLayout(const RuntimeDevices& devices); + LogicalResult UpdateDataFormat(StringRef data_format); + }]; +} + +def TF_FusedBatchNormV3Op : TF_Op<"FusedBatchNormV3", [NoSideEffect, TF_FoldOperandsTransposeInterface, TF_LayoutSensitiveInterface]> { + let summary = "Batch normalization."; + + let description = [{ +Note that the size of 4D Tensors are defined by either "NHWC" or "NCHW". +The size of 1D Tensors matches the dimension C of the 4D Tensors. + }]; + + let arguments = (ins + TensorOf<[BF16, F16, F32]>:$x, + F32Tensor:$scale, + F32Tensor:$offset, + F32Tensor:$mean, + F32Tensor:$variance, + + DefaultValuedAttr:$epsilon, + DefaultValuedAttr:$exponential_avg_factor, + DefaultValuedAttr:$data_format, + DefaultValuedAttr:$is_training + ); + + let results = (outs + TensorOf<[BF16, F16, F32]>:$y, + F32Tensor:$batch_mean, + F32Tensor:$batch_variance, + F32Tensor:$reserve_space_1, + F32Tensor:$reserve_space_2, + F32Tensor:$reserve_space_3 + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; + TF_DerivedOperandTypeAttr U = TF_DerivedOperandTypeAttr<1>; + + let extraClassDeclaration = [{ + // TF_FoldOperandsTransposeInterface: + SmallVector GetLayoutDependentArgs() { return {0}; } + SmallVector GetLayoutDependentResults() { return {0}; } + LogicalResult FoldOperandsPermutation(ArrayRef permutation); + + // TF_LayoutSensitiveInterface: + StringRef GetOptimalLayout(const RuntimeDevices& devices); + LogicalResult UpdateDataFormat(StringRef data_format); + }]; +} + def TF_GatherOp : TF_Op<"Gather", [NoSideEffect]> { let summary = "Gather slices from `params` according to `indices`."; @@ -6213,14 +6284,14 @@ retained with length 1. }]; let arguments = (ins - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$input, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Qint16, TF_Qint32, TF_Qint8, TF_Quint16, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$input, TF_I32OrI64Tensor:$reduction_indices, DefaultValuedAttr:$keep_dims ); let results = (outs - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Qint16, TF_Qint32, TF_Qint8, TF_Quint16, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -6324,27 +6395,6 @@ def TF_MaxPoolGradOp : TF_Op<"MaxPoolGrad", [NoSideEffect]> { }]; } -def TF_MaximumOp : TF_Op<"Maximum", [NoSideEffect, ResultsBroadcastableShape, TF_SameOperandsAndResultElementTypeResolveRef]>, - WithBroadcastableBinOpBuilder { - let summary = "Returns the max of x and y (i.e. x > y ? x : y) element-wise."; - - let description = [{ -*NOTE*: `Maximum` supports broadcasting. More about broadcasting -[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) - }]; - - let arguments = (ins - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, TF_Uint8]>:$x, - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, TF_Uint8]>:$y - ); - - let results = (outs - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, TF_Uint8]>:$z - ); - - TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; -} - def TF_MeanOp : TF_Op<"Mean", [NoSideEffect, TF_FoldOperandsTransposeInterface]> { let summary = "Computes the mean of elements across dimensions of a tensor."; @@ -6440,14 +6490,14 @@ retained with length 1. }]; let arguments = (ins - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$input, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Qint16, TF_Qint32, TF_Qint8, TF_Quint16, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$input, TF_I32OrI64Tensor:$reduction_indices, DefaultValuedAttr:$keep_dims ); let results = (outs - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Qint16, TF_Qint32, TF_Qint8, TF_Quint16, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -7878,33 +7928,6 @@ tf.real(input) ==> [-2.25, 3.25] TF_DerivedResultTypeAttr Tout = TF_DerivedResultTypeAttr<0>; } -def TF_RealDivOp : TF_Op<"RealDiv", [NoSideEffect, ResultsBroadcastableShape, TF_CwiseBinary]>, - WithBroadcastableBinOpBuilder { - let summary = "Returns x / y element-wise for real types."; - - let description = [{ -If `x` and `y` are reals, this will return the floating-point division. - -*NOTE*: `Div` supports broadcasting. More about broadcasting -[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) - }]; - - let arguments = (ins - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint16, TF_Uint8]>:$x, - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint16, TF_Uint8]>:$y - ); - - let results = (outs - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint16, TF_Uint8]>:$z - ); - - TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; - - let hasCanonicalizer = 1; - - let hasFolder = 1; -} - def TF_ReciprocalOp : TF_Op<"Reciprocal", [NoSideEffect, SameOperandsAndResultType]> { let summary = "Computes the reciprocal of x element-wise."; @@ -9314,6 +9337,18 @@ Generate a sharded filename. The filename is printf formatted as ); } +def TF_ShutdownDistributedTPUOp : TF_Op<"ShutdownDistributedTPU", []> { + let summary = "Shuts down a running distributed TPU system."; + + let description = [{ +The op returns an error if no system is running. + }]; + + let arguments = (ins); + + let results = (outs); +} + def TF_SigmoidOp : TF_Op<"Sigmoid", [NoSideEffect, SameOperandsAndResultType]> { let summary = "Computes sigmoid of `x` element-wise."; @@ -9832,6 +9867,41 @@ backpropagation, TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<1>; } +def TF_SparseMatMulOp : TF_Op<"SparseMatMul", [NoSideEffect]> { + let summary = [{ +Multiply matrix "a" by matrix "b". + }]; + + let description = [{ +The inputs must be two-dimensional matrices and the inner dimension of "a" must +match the outer dimension of "b". Both "a" and "b" must be `Tensor`s not +`SparseTensor`s. This op is optimized for the case where at least one of "a" or +"b" is sparse, in the sense that they have a large proportion of zero values. +The breakeven for using this versus a dense matrix multiply on one platform was +30% zero values in the sparse matrix. + +The gradient computation of this operation will only take advantage of sparsity +in the input gradient when that gradient comes from a Relu. + }]; + + let arguments = (ins + TensorOf<[BF16, F32]>:$a, + TensorOf<[BF16, F32]>:$b, + + DefaultValuedAttr:$transpose_a, + DefaultValuedAttr:$transpose_b, + DefaultValuedAttr:$a_is_sparse, + DefaultValuedAttr:$b_is_sparse + ); + + let results = (outs + F32Tensor:$product + ); + + TF_DerivedOperandTypeAttr Ta = TF_DerivedOperandTypeAttr<0>; + TF_DerivedOperandTypeAttr Tb = TF_DerivedOperandTypeAttr<1>; +} + def TF_SparseReshapeOp : TF_Op<"SparseReshape", [NoSideEffect]> { let summary = [{ Reshapes a SparseTensor to represent values in a new dense shape. @@ -11625,9 +11695,9 @@ array([[1, 2, 3, 1, 2, 3], TF_DerivedOperandTypeAttr Tmultiples = TF_DerivedOperandTypeAttr<1>; TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; - // TODO(parkers): Add folds for multiples = [1,...]. - // TODO(parkers): Add errors for negative multiples and multiples.size() != - // input.rank() + let verifier = [{ return Verify(*this); }]; + + let hasFolder = 1; } def TF_TopKV2Op : TF_Op<"TopKV2", [NoSideEffect]> { @@ -12893,7 +12963,8 @@ create these operators. DefaultValuedAttr:$dilations, DefaultValuedAttr:$use_cudnn_on_gpu, DefaultValuedAttr:$fused_ops, - DefaultValuedAttr:$epsilon + DefaultValuedAttr:$epsilon, + DefaultValuedAttr:$leakyrelu_alpha ); let results = (outs diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td index 1755c975c23..4624680506a 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td @@ -157,20 +157,10 @@ class TF_TensorFlowType : "TensorFlow " # description # " type">, BuildableType<"getType()">; -// Any tensor element type allowed in TensorFlow ops -def TF_ElementType : Type, - "tf.dtype">; - -// Any TensorFlow tensor type -def TF_Tensor : TensorOf<[TF_ElementType]>; - //===----------------------------------------------------------------------===// // Integer types +// TODO(mgester) shouldn't this be SignedIntOfWidths? def TF_I32Or64 : SignlessIntOfWidths<[32, 64]>; def TF_I32OrI64Tensor : TensorOf<[TF_I32Or64]>; @@ -191,10 +181,11 @@ def TF_Uint64Tensor : TensorOf<[TF_Uint64]>; def TF_UInt : UnsignedIntOfWidths<[8, 16, 32, 64]>; // Any signed integer type +// TODO(mgester) shouldn't this be SignedIntOfWidths? def TF_SInt : SignlessIntOfWidths<[8, 16, 32, 64]>; // Any integer type -def TF_Int : AnyTypeOf<[TF_SInt, TF_UInt]>; +def TF_Int : AnyTypeOf<[TF_SInt, TF_UInt], "integer">; // Any integer tensor types def TF_IntTensor : TensorOf<[TF_Int]>; @@ -208,8 +199,8 @@ def TF_Quint8 : TF_TensorFlowType<"Quint8", "quint8">; def TF_Quint16 : TF_TensorFlowType<"Quint16", "quint16">; // Any quantized type -def TF_AnyQuantized : AnyTypeOf<[TF_Qint8, TF_Qint16, TF_Qint32, TF_Quint8, - TF_Quint16]>; +def TF_Quantized : AnyTypeOf<[TF_Qint8, TF_Qint16, TF_Qint32, TF_Quint8, + TF_Quint16], "quantized">; //===----------------------------------------------------------------------===// // Floating-point types @@ -217,8 +208,10 @@ def TF_F32Or64 : FloatOfWidths<[32, 64]>; def TF_F32OrF64Tensor : TensorOf<[TF_F32Or64]>; +def TF_Float : AnyTypeOf<[F16, F32, F64, BF16], "floating-point">; + // Any floating-point tensor types -def TF_FpTensor : TensorOf<[AnyFloat]>; +def TF_FpTensor : TensorOf<[TF_Float]>; //===----------------------------------------------------------------------===// // Complex types @@ -231,10 +224,9 @@ def TF_Complex64Tensor : TensorOf<[TF_Complex64]>; def TF_Complex128 : Complex>; def TF_Complex128Tensor : TensorOf<[TF_Complex128]>; -def TF_AnyComplex : AnyTypeOf<[TF_Complex64, TF_Complex128], - "64/128-bit complex type">; +def TF_Complex : AnyTypeOf<[TF_Complex64, TF_Complex128], "complex">; -def TF_ComplexTensor : TensorOf<[TF_AnyComplex]>; +def TF_ComplexTensor : TensorOf<[TF_Complex]>; //===----------------------------------------------------------------------===// // String/variant/resource types @@ -248,28 +240,113 @@ def TF_VariantTensor : TensorOf<[TF_Variant]>; def TF_Resource : TF_TensorFlowType<"Resource", "resource">; def TF_ResourceTensor : TensorOf<[TF_Resource]>; +//===----------------------------------------------------------------------===// +// Reference types + +// Float reference types +def TF_F16Ref : TF_TensorFlowType<"HalfRef", "f16ref">; +def TF_F32Ref : TF_TensorFlowType<"FloatRef", "f32ref">; +def TF_F64Ref : TF_TensorFlowType<"DoubleRef", "f64ref">; +def TF_Bfloat16Ref : TF_TensorFlowType<"Bfloat16Ref", "bf16ref">; + +// Any float reference type +def TF_FloatRef : AnyTypeOf<[TF_F16Ref, TF_F32Ref, TF_F64Ref, TF_Bfloat16Ref], + "floating-point reference">; + +// Complex reference types +def TF_Complex64Ref : TF_TensorFlowType<"Complex64Ref", "complex64ref">; +def TF_Complex128Ref : TF_TensorFlowType<"Complex128Ref", "complex128ref">; + +// Any complex reference type +def TF_ComplexRef : AnyTypeOf<[TF_Complex64Ref, TF_Complex128Ref], "complex reference">; + +// Integer reference types +def TF_Int8Ref : TF_TensorFlowType<"Int8Ref", "i8ref">; +def TF_Int16Ref : TF_TensorFlowType<"Int16Ref", "i16ref">; +def TF_Int32Ref : TF_TensorFlowType<"Int32Ref", "i32ref">; +def TF_Int64Ref : TF_TensorFlowType<"Int64Ref", "i64ref">; + +def TF_Uint8Ref : TF_TensorFlowType<"Uint8Ref", "ui8ref">; +def TF_Uint16Ref : TF_TensorFlowType<"Uint16Ref", "ui16ref">; +def TF_Uint32Ref : TF_TensorFlowType<"Uint32Ref", "ui32ref">; +def TF_Uint64Ref : TF_TensorFlowType<"Uint64Ref", "ui64ref">; + +// Any signed integer reference type +def TF_SIntRef : AnyTypeOf<[TF_Int8Ref, TF_Int16Ref, TF_Int32Ref, TF_Int64Ref], + "signed integer reference">; + +// Any unsigned integer reference type +def TF_UIntRef : AnyTypeOf<[TF_Uint8Ref, TF_Uint16Ref, TF_Uint32Ref, + TF_Uint64Ref], "unsigned integer reference">; + +// Any integer reference type +def TF_IntRef : AnyTypeOf<[TF_SIntRef, TF_UIntRef], "integer reference">; + +// Quantized reference types +def TF_Qint8Ref : TF_TensorFlowType<"Qint8Ref", "qint8ref">; +def TF_Qint16Ref : TF_TensorFlowType<"Qint16Ref", "qint16ref">; +def TF_Qint32Ref : TF_TensorFlowType<"Qint32Ref", "qint32ref">; +def TF_Quint8Ref : TF_TensorFlowType<"Quint8Ref", "quint8ref">; +def TF_Quint16Ref : TF_TensorFlowType<"Quint16Ref", "quint16ref">; + +// Any quantized reference type +def TF_QuantizedRef : AnyTypeOf<[TF_Qint8Ref, TF_Qint16Ref, TF_Qint32Ref, + TF_Quint8Ref, TF_Quint16Ref], "quantized reference">; + +// Other reference types +def TF_BoolRef : TF_TensorFlowType<"BoolRef", "boolref">; +def TF_ResourceRef : TF_TensorFlowType<"ResourceRef", "resourceref">; +def TF_StringRef : TF_TensorFlowType<"StringRef", "stringref">; +def TF_VariantRef : TF_TensorFlowType<"VariantRef", "variantref">; + +// Reference tensor types +def TF_FpRefTensor : TensorOf<[TF_FloatRef]>; +def TF_I32OrI64RefTensor : TensorOf<[TF_Int32Ref, TF_Int64Ref]>; + //===----------------------------------------------------------------------===// // Multi-category type constraints def TF_IntOrF32OrF64Tensor: TensorOf<[TF_Int, TF_F32Or64]>; -def TF_FpOrI32OrI64Tensor : TensorOf<[AnyFloat, TF_I32Or64]>; +def TF_FpOrI32OrI64Tensor : TensorOf<[TF_Float, TF_I32Or64]>; // Any integer or floating-point tensor types -def TF_IntOrFpTensor : TensorOf<[TF_Int, AnyFloat]>; +def TF_IntOrFpTensor : TensorOf<[TF_Int, TF_Float]>; -def TF_SintOrFpTensor : TensorOf<[TF_SInt, AnyFloat]>; +def TF_SintOrFpTensor : TensorOf<[TF_SInt, TF_Float]>; -def TF_FpOrComplexTensor : TensorOf<[AnyFloat, TF_AnyComplex]>; +def TF_FpOrComplexTensor : TensorOf<[TF_Float, TF_Complex]>; -def TF_AnyNumber : AnyTypeOf<[TF_Int, AnyFloat, TF_AnyQuantized, TF_AnyComplex], - "number">; +def TF_Number : AnyTypeOf<[TF_Int, TF_Float, TF_Quantized, TF_Complex], + "number">; +def TF_NumberRef : AnyTypeOf<[TF_IntRef, TF_FloatRef, TF_QuantizedRef, + TF_ComplexRef], "number reference">; -def TF_NumberTensor : TensorOf<[TF_AnyNumber]>; +def TF_NumberTensor : TensorOf<[TF_Number]>; +def TF_NumberRefTensor : TensorOf<[TF_NumberRef]>; -def TF_NumberOrStr : AnyTypeOf<[AnyFloat, TF_SInt, TF_AnyComplex, TF_Uint8, TF_Str]>; +def TF_NumberOrStr : AnyTypeOf<[TF_Float, TF_SInt, TF_Complex, TF_Uint8, + TF_Str]>; def TF_NumberOrStrTensor : TensorOf<[TF_NumberOrStr]>; +//===----------------------------------------------------------------------===// +// Tensor and tensor element types + +// Bool type +def TF_Bool : I<1>; + +// Any tensor element type allowed in TensorFlow ops +// (see https://www.tensorflow.org/api_docs/python/tf/dtypes/DType) +def TF_ElementType : Type, + "tf.dtype">; + +// Any TensorFlow tensor type +def TF_Tensor : TensorOf<[TF_ElementType]>; + //===----------------------------------------------------------------------===// // TensorFlow attribute definitions //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td index db0a97d4b96..bc76cd3faf9 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td @@ -570,36 +570,6 @@ def TF_PlaceholderWithDefaultOp : TF_Op<"PlaceholderWithDefault", [NoSideEffect] DerivedAttr shape = TF_DerivedResultShapeAttr; } -def TF_SparseMatMulOp : TF_Op<"SparseMatMul", [NoSideEffect]> { - let summary = [{ -SparseMatMul is MatMul with hints on the sparseness of the matrices. - }]; - - let description = [{ -Similar to MatMul, with a_is_sparse and b_is_sparse indicating whether a and b -are sparse matrices. - }]; - - let arguments = (ins - TensorOf<[BF16, F32]>:$a, - TensorOf<[BF16, F32]>:$b, - - DefaultValuedAttr:$a_is_sparse, - DefaultValuedAttr:$b_is_sparse, - - DefaultValuedAttr:$transpose_a, - DefaultValuedAttr:$transpose_b - ); - - let results = (outs - TensorOf<[F32]>:$product - ); - - TF_DerivedOperandTypeAttr Ta = TF_DerivedOperandTypeAttr<0>; - TF_DerivedOperandTypeAttr Tb = TF_DerivedOperandTypeAttr<1>; -} - - def TF_StatefulPartitionedCallOp : TF_Op<"StatefulPartitionedCall", [CallOpInterface]> { let summary = @@ -1213,63 +1183,6 @@ def TF_TPUPartitionedCallOp : TF_Op<"TPUPartitionedCall", [CallOpInterface]> { let verifier = [{ return VerifyPartitionedCall(*this); }]; } -class TF_FusedBatchNormOpBase : TF_Op { - let summary = "Batch normalization."; - - let description = [{ -Note that the size of 4D Tensors are defined by either "NHWC" or "NCHW". -The size of 1D Tensors matches the dimension C of the 4D Tensors. - }]; - - let arguments = (ins - TensorOf<[BF16, F16, F32]>:$x, - F32Tensor:$scale, - F32Tensor:$offset, - F32Tensor:$mean, - F32Tensor:$variance, - - DefaultValuedAttr:$epsilon, - DefaultValuedAttr:$exponential_avg_factor, - DefaultValuedAttr:$data_format, - DefaultValuedAttr:$is_training - ); - - TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; - TF_DerivedOperandTypeAttr U = TF_DerivedOperandTypeAttr<1>; - - let extraClassDeclaration = [{ - // TF_FoldOperandsTransposeInterface: - SmallVector GetLayoutDependentArgs() { return {0}; } - SmallVector GetLayoutDependentResults() { return {0}; } - LogicalResult FoldOperandsPermutation(ArrayRef permutation); - - // TF_LayoutSensitiveInterface: - StringRef GetOptimalLayout(const RuntimeDevices& devices); - LogicalResult UpdateDataFormat(StringRef data_format); - }]; -} - -def TF_FusedBatchNormV2Op : TF_FusedBatchNormOpBase<"FusedBatchNormV2"> { - let results = (outs - TensorOf<[BF16, F16, F32]>:$y, - F32Tensor:$batch_mean, - F32Tensor:$batch_variance, - F32Tensor:$reserve_space_1, - F32Tensor:$reserve_space_2 - ); -} - -def TF_FusedBatchNormV3Op : TF_FusedBatchNormOpBase<"FusedBatchNormV3"> { - let results = (outs - TensorOf<[BF16, F16, F32]>:$y, - F32Tensor:$batch_mean, - F32Tensor:$batch_variance, - F32Tensor:$reserve_space_1, - F32Tensor:$reserve_space_2, - F32Tensor:$reserve_space_3 - ); -} - def TF_BatchFunctionOp : TF_Op<"BatchFunction", [AttrSizedOperandSegments]> { let summary = [{ Batches all the inputs tensors to the computation done by the function. @@ -1341,4 +1254,98 @@ must be a Tensor or a list/tuple of Tensors. TF_DerivedResultTypeListAttr Tout = TF_DerivedResultTypeListAttr<0>; } +def TF_AddV2Op : TF_Op<"AddV2", [Commutative, NoSideEffect, ResultsBroadcastableShape, TF_CwiseBinary, TF_LayoutAgnostic, TF_SameOperandsAndResultElementTypeResolveRef]>, + WithBroadcastableBinOpBuilder { + let summary = "Returns x + y element-wise."; + + let description = [{ +*NOTE*: `Add` supports broadcasting. `AddN` does not. More about broadcasting +[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) + }]; + + let arguments = (ins + TensorOf<[TF_Float, TF_SInt, TF_Complex, TF_Uint8, TF_FloatRef, TF_SIntRef, TF_ComplexRef, TF_Uint8Ref]>:$x, + TensorOf<[TF_Float, TF_SInt, TF_Complex, TF_Uint8, TF_FloatRef, TF_SIntRef, TF_ComplexRef, TF_Uint8Ref]>:$y + ); + + let results = (outs + TensorOf<[TF_Float, TF_SInt, TF_Complex, TF_Uint8, TF_FloatRef, TF_SIntRef, TF_ComplexRef, TF_Uint8Ref]>:$z + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; + + let hasCanonicalizer = 1; + + let hasFolder = 1; +} + +def TF_DivNoNanOp : TF_Op<"DivNoNan", [NoSideEffect, ResultsBroadcastableShape, TF_SameOperandsAndResultElementTypeResolveRef]>, + WithBroadcastableBinOpBuilder { + let summary = "Returns 0 if the denominator is zero."; + + let description = [{ +*NOTE*: `DivNoNan` supports broadcasting. More about broadcasting +[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) + }]; + + let arguments = (ins + TensorOf<[F16, F32, F64, TF_Complex, TF_F16Ref, TF_F32Ref, TF_F64Ref, TF_ComplexRef]>:$x, + TensorOf<[F16, F32, F64, TF_Complex, TF_F16Ref, TF_F32Ref, TF_F64Ref, TF_ComplexRef]>:$y + ); + + let results = (outs + TensorOf<[F16, F32, F64, TF_Complex, TF_F16Ref, TF_F32Ref, TF_F64Ref, TF_ComplexRef]>:$z + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + +def TF_MaximumOp : TF_Op<"Maximum", [NoSideEffect, ResultsBroadcastableShape, TF_SameOperandsAndResultElementTypeResolveRef]>, + WithBroadcastableBinOpBuilder { + let summary = "Returns the max of x and y (i.e. x > y ? x : y) element-wise."; + + let description = [{ +*NOTE*: `Maximum` supports broadcasting. More about broadcasting +[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) + }]; + + let arguments = (ins + TensorOf<[TF_Float, I16, I32, I64, TF_Uint8, TF_FloatRef, TF_Int16Ref, TF_Int32Ref, TF_Int64Ref, TF_Uint8Ref]>:$x, + TensorOf<[TF_Float, I16, I32, I64, TF_Uint8, TF_FloatRef, TF_Int16Ref, TF_Int32Ref, TF_Int64Ref, TF_Uint8Ref]>:$y + ); + + let results = (outs + TensorOf<[TF_Float, I16, I32, I64, TF_Uint8, TF_FloatRef, TF_Int16Ref, TF_Int32Ref, TF_Int64Ref, TF_Uint8Ref]>:$z + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + +def TF_RealDivOp : TF_Op<"RealDiv", [NoSideEffect, ResultsBroadcastableShape, TF_CwiseBinary]>, + WithBroadcastableBinOpBuilder { + let summary = "Returns x / y element-wise for real types."; + + let description = [{ +If `x` and `y` are reals, this will return the floating-point division. + +*NOTE*: `Div` supports broadcasting. More about broadcasting +[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) + }]; + + let arguments = (ins + TensorOf<[TF_Float, TF_SInt, TF_Complex, TF_Uint16, TF_Uint8, TF_FloatRef, TF_SIntRef, TF_ComplexRef, TF_Uint16Ref, TF_Uint8Ref]>:$x, + TensorOf<[TF_Float, TF_SInt, TF_Complex, TF_Uint16, TF_Uint8, TF_FloatRef, TF_SIntRef, TF_ComplexRef, TF_Uint16Ref, TF_Uint8Ref]>:$y + ); + + let results = (outs + TensorOf<[TF_Float, TF_SInt, TF_Complex, TF_Uint16, TF_Uint8, TF_FloatRef, TF_SIntRef, TF_ComplexRef, TF_Uint16Ref, TF_Uint8Ref]>:$z + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; + + let hasCanonicalizer = 1; + + let hasFolder = 1; +} + #endif // TF_OPS diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc index 26fbecf387e..52a2cee5ebd 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc @@ -1783,6 +1783,87 @@ static LogicalResult Verify(TensorScatterUpdateOp op) { return success(); } +//===----------------------------------------------------------------------===// +// TileOp +//===----------------------------------------------------------------------===// + +// Verifies that, +// +// - input has at least rank 1 +// - multiples is rank 1 +// - multiples.size() == input.rank() +// - input.rank() == output.rank() +// - Elements in multiples are non-negative +// - input.shape[i] * multiples[i] == output.shape[i] +// for i in [0, input.rank() - 1] + +static LogicalResult Verify(TileOp op) { + auto input_type = op.input().getType().dyn_cast(); + auto multiples_type = op.multiples().getType().dyn_cast(); + auto output_type = op.output().getType().dyn_cast(); + + if (multiples_type && multiples_type.getRank() != 1) { + return op.emitOpError() << "expected multiples to be rank 1, got rank = " + << multiples_type.getRank(); + } + + if (input_type && multiples_type && multiples_type.hasStaticShape() && + (input_type.getRank() != multiples_type.getNumElements() || + (input_type.getRank() == 0 && multiples_type.getNumElements() == 1))) { + return op.emitOpError() + << "expected size of multiples equal to rank of input" + << ", got multiples of size " << multiples_type.getNumElements() + << ", and input of rank " << input_type.getRank(); + } + + if (input_type && output_type) { + if (input_type.getRank() != output_type.getRank()) { + return op.emitOpError() + << "expected rank of input to equal to rank of output" + << ", got input of rank " << input_type.getRank() + << ", and output of rank " << output_type.getRank(); + } + + DenseIntElementsAttr multiples_attr; + if (matchPattern(op.multiples(), m_Constant(&multiples_attr))) { + for (int32_t i = 0, e = input_type.getRank(); i < e; ++i) { + const int64_t input_dim = input_type.getDimSize(i); + const int64_t output_dim = output_type.getDimSize(i); + const int64_t m = multiples_attr.getValue(i).getSExtValue(); + + if (m < 0) { + return op.emitOpError() + << "expected multiples to be non-negative, got " + << "multiples[" << i << "] = " << m; + } + + if (!ShapedType::isDynamic(input_dim) && + !ShapedType::isDynamic(output_dim) && output_dim != input_dim * m) { + return op.emitOpError() + << "requires input.shape[" << i << "] (" << input_dim << ")" + << " * " << m << " to be equal to " + << "output.shape[" << i << "] (" << output_dim << ")"; + } + } + } + } + + return success(); +} + +OpFoldResult TileOp::fold(ArrayRef operands) { + DenseIntElementsAttr multiples_attr; + if (matchPattern(multiples(), m_Constant(&multiples_attr))) { + // Return input directly when multiples are all ones, + // regardless what input is. + if (multiples_attr.isSplat() && + multiples_attr.getSplatValue().getSExtValue() == 1) { + return input(); + } + } + return {}; +} + //===----------------------------------------------------------------------===// // TopKV2Op //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/tensorflow/tests/breakup-islands.mlir b/tensorflow/compiler/mlir/tensorflow/tests/breakup-islands.mlir index 05d34eb0755..6654341ab42 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/breakup-islands.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/breakup-islands.mlir @@ -285,7 +285,7 @@ func @empty_island_multiple_data_results(%arg0: tensor<*xf32>, %arg1: tensor<*xi // and certain tf_executor ops are added correctly. // CHECK: %[[CONTROL:[^ ,]*]] = tf_executor.island wraps "tf.Print" -// CHECK: tf_executor.NextIteration.Sink [{{.*}}] {{.*}}, %[[CONTROL]] +// CHECK: tf_executor.NextIteration.Sink[{{.*}}] {{.*}}, %[[CONTROL]] func @next_iteration_sink_control_input() { tf_executor.graph { %source:3 = tf_executor.NextIteration.Source : tensor<*xi32> diff --git a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir index 50486909694..ff90c6f4c5b 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir @@ -568,6 +568,14 @@ func @testSelectElseUnranked(%arg0: tensor<3xi1>, %arg1: tensor<3x2xf16>, %arg2: return %0: tensor<*xf16> } +// CHECK-LABEL: testTileMultiplesAllOnes +func @testTileMultiplesAllOnes(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { + %cst = constant dense <[1, 1]> : tensor<2xi32> + // CHECK: return %arg0 + %0 = "tf.Tile"(%arg0, %cst) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<2x3xf32> + return %0: tensor<2x3xf32> +} + // CHECK-LABEL: testLogicalNotOfEqual func @testLogicalNotOfEqual(%arg0: tensor<8x16xf32>, %arg1: tensor<8x16xf32>) -> tensor<8x16xi1> { %0 = "tf.Equal"(%arg0, %arg1) : (tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xi1> diff --git a/tensorflow/compiler/mlir/tensorflow/tests/executor_island_coarsening.mlir b/tensorflow/compiler/mlir/tensorflow/tests/executor_island_coarsening.mlir index bec48181b3b..726495f1fbc 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/executor_island_coarsening.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/executor_island_coarsening.mlir @@ -220,7 +220,7 @@ func @merge_islands_only() { %11:2 = tf_executor.island(%10#1) wraps "tf.opF"() : () -> tensor %12:2 = tf_executor.island wraps "tf.opG"(%10#0, %11#0) : (tensor<*xi32>, tensor) -> tensor<*xi32> %13 = tf_executor.ControlTrigger %2, %12#1, %9#1 - tf_executor.NextIteration.Sink [%3#1] %12#0, %13 : tensor<*xi32> + tf_executor.NextIteration.Sink[%3#1] %12#0, %13 : tensor<*xi32> tf_executor.fetch } return @@ -244,7 +244,7 @@ func @merge_islands_only() { // CHECK-NEXT: %[[OP_G:[0-9]*]] = "tf.opG"(%[[OP_E]], %[[OP_F]]) // CHECK-NEXT: tf_executor.yield %[[OP_G]] : tensor<*xi32> // CHECK: %[[CT:.*]] = tf_executor.ControlTrigger %[[ISLAND_1]], %[[ISLAND_3_control]], %[[EXIT_control]] -// CHECK-NEXT: tf_executor.NextIteration.Sink [%[[NEXTIT_SRC_token]]] %[[ISLAND_3]], %[[CT]] +// CHECK-NEXT: tf_executor.NextIteration.Sink[%[[NEXTIT_SRC_token]]] %[[ISLAND_3]], %[[CT]] // Test no merging took place as cycle would be formed otherwise. diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-while-loop.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-while-loop.pbtxt index e21fd901a9e..a6b1979ee26 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-while-loop.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-while-loop.pbtxt @@ -7,7 +7,7 @@ # CHECK: %[[NEXTITERATION:[a-z0-9]+]], %[[NEXTITERATION_token:[a-z0-9]+]], {{.*}} = tf_executor.NextIteration.Source # CHECK: tf_executor.Merge {{.*}} %[[NEXTITERATION]] -# CHECK: tf_executor.NextIteration.Sink [%[[NEXTITERATION_token]]] +# CHECK: tf_executor.NextIteration.Sink[%[[NEXTITERATION_token]]] node { name: "Const" diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir index 9a8d97eddf1..30a763bb687 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir @@ -3468,3 +3468,85 @@ func @testCumprod(%arg: tensor<8x16xf32>) -> tensor<8x16xf32> { %0 = "tf.Cumprod"(%arg, %axis) : (tensor<8x16xf32>, tensor) -> tensor<8x16xf32> return %0 : tensor<8x16xf32> } + +// ----- + +func @testTile(%arg0: tensor<2x3x?xf32>) { + %cst = constant dense <[2, 3, 4]> : tensor<3xi32> + %0 = "tf.Tile"(%arg0, %cst) : (tensor<2x3x?xf32>, tensor<3xi32>) -> tensor<4x9x?xf32> + return +} + +// ----- + +func @testTileMultipleNotRank1(%arg0: tensor<2x3xf32>, %arg1: tensor<1x1xi32>) { + // expected-error @+1 {{expected multiples to be rank 1, got rank = 2}} + %0 = "tf.Tile"(%arg0, %arg1) : (tensor<2x3xf32>, tensor<1x1xi32>) -> tensor<2x3xf32> + return +} + +// ----- + +func @testTileInputRankNotEqualToMultiplesSize(%arg0: tensor<2x3xf32>, %arg1: tensor<3xi32>) { + // expected-error @+1 {{expected size of multiples equal to rank of input, got multiples of size 3, and input of rank 2}} + %0 = "tf.Tile"(%arg0, %arg1) : (tensor<2x3xf32>, tensor<3xi32>) -> tensor<2x3xf32> + return +} + +// ----- + +func @testTileInputRankNotEqualToOutputRank(%arg0: tensor<2x3xf32>, %arg1: tensor<2xi32>) { + // expected-error @+1 {{expected rank of input to equal to rank of output, got input of rank 2, and output of rank 3}} + %0 = "tf.Tile"(%arg0, %arg1) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<2x3x1xf32> + return +} + +// ----- + +func @testTileNegativeMultiples(%arg0: tensor<2x3xf32>) { + %cst = constant dense <[-1, 1]> : tensor<2xi32> + // expected-error @+1 {{expected multiples to be non-negative, got multiples[0] = -1}} + %0 = "tf.Tile"(%arg0, %cst) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<2x3xf32> + return +} + +// ----- + +func @testTileInvalidOutputShape(%arg0: tensor<2x3xf32>) { + %cst = constant dense <[2, 3]> : tensor<2xi32> + // expected-error @+1 {{requires input.shape[1] (3) * 3 to be equal to output.shape[1] (6)}} + %0 = "tf.Tile"(%arg0, %cst) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<4x6xf32> + return +} + +// ----- + +// Test reference variable support for some ops (no errors expected) + +// CHECK-LABEL: @testMaximumWithRef +func @testMaximumWithRef(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: tf.Maximum + %0 = "tf.Maximum"(%arg0, %arg1) : (tensor, tensor) -> tensor + return %0 : tensor +} + +// CHECK-LABEL: @testAddV2WithRef +func @testAddV2WithRef(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: tf.AddV2 + %0 = "tf.AddV2"(%arg0, %arg1) : (tensor, tensor) -> tensor + return %0 : tensor +} + +// CHECK-LABEL: @testRealDivWithRef +func @testRealDivWithRef(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: tf.RealDivOp + %0 = "tf.RealDivOp"(%arg0, %arg1) : (tensor, tensor) -> tensor + return %0 : tensor +} + +// CHECK-LABEL: @testDivNoNanWithRef +func @testDivNoNanWithRef(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: tf.DivNoNanOp + %0 = "tf.DivNoNanOp"(%arg0, %arg1) : (tensor, tensor) -> tensor + return %0 : tensor +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops.mlir index 1e537880620..23a8e904ad9 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops.mlir @@ -433,7 +433,7 @@ func @nextiteration(%arg0: tensor<*xf32>, %arg1: i1) -> tensor<*xf32> { %1:3 = tf_executor.NextIteration.Source : tensor<*xf32> tf_executor.NextIteration.Sink[%1#1] %1#0 : tensor<*xf32> // CHECK: tf_executor.NextIteration.Source : tensor<*xf32> -// CHECK: tf_executor.NextIteration.Sink [%{{.*}}] %{{.*}} : tensor<*xf32> +// CHECK: tf_executor.NextIteration.Sink[%{{.*}}] %{{.*}} : tensor<*xf32> tf_executor.fetch %1#0 : tensor<*xf32> } return %0 : tensor<*xf32> @@ -445,7 +445,7 @@ func @nextiteration_with_attributes(%arg0: tensor<*xf32>, %arg1: i1) -> tensor<* %1:3 = tf_executor.NextIteration.Source : tensor<*xf32> {attr3 = 32 : i64, tf_executor.attr_fetch = "some_value"} tf_executor.NextIteration.Sink[%1#1] %1#0 : tensor<*xf32> {attr4 = 42 : i64, tf_executor.attr_push = "other_value"} // CHECK: tf_executor.NextIteration.Source : tensor<*xf32> {attr3 = 32 : i64, tf_executor.attr_fetch = "some_value"} -// CHECK: tf_executor.NextIteration.Sink [%{{.*}}] %{{.*}} : tensor<*xf32> {attr4 = 42 : i64, tf_executor.attr_push = "other_value"} +// CHECK: tf_executor.NextIteration.Sink[%{{.*}}] %{{.*}} : tensor<*xf32> {attr4 = 42 : i64, tf_executor.attr_push = "other_value"} tf_executor.fetch %1#0 : tensor<*xf32> } return %0 : tensor<*xf32> @@ -457,9 +457,9 @@ func @nextiteration_control(%arg0: tensor<*xf32>, %arg1: tensor) -> tensor<* %1:3 = tf_executor.Switch %arg0, %arg1 : tensor<*xf32> %2:2 = tf_executor.Enter %arg0, %1#2, %1#2 frame "some/frame" : tensor<*xf32> %3:3 = tf_executor.NextIteration.Source : tensor<*xf32> - tf_executor.NextIteration.Sink [%3#1] %3#0, %1#2 : tensor<*xf32> + tf_executor.NextIteration.Sink[%3#1] %3#0, %1#2 : tensor<*xf32> // CHECK: tf_executor.NextIteration.Source : tensor<*xf32> -// CHECK: tf_executor.NextIteration.Sink [%{{.*}}] %{{.*}}, %{{.*}} : tensor<*xf32> +// CHECK: tf_executor.NextIteration.Sink[%{{.*}}] %{{.*}}, %{{.*}} : tensor<*xf32> tf_executor.fetch %3#0 : tensor<*xf32> } return %0 : tensor<*xf32> diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu-resource-read-for-write.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu-resource-read-for-write.mlir new file mode 100644 index 00000000000..a505a4e3269 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu-resource-read-for-write.mlir @@ -0,0 +1,64 @@ +// RUN: tf-opt -tf-tpu-resource-read-for-write %s | FileCheck %s --dump-input=always + +// CHECK-LABEL: func @write_only_resource +// CHECK-SAME: ([[ARG0:%.*]]: tensor, [[ARG1:%.*]]: tensor, [[ARG2:%.*]]: tensor<*x!tf.resource>>) +func @write_only_resource(%arg0: tensor, %arg1: tensor, %arg2: tensor<*x!tf.resource>>) { + // CHECK-NEXT: [[READ:%.*]] = "tf.ReadVariableOp"([[ARG2]]) + // CHECK-NEXT: [[CLUSTER:%.*]]:2 = "tf_device.cluster_func"([[ARG0]], [[ARG1]], [[READ]]) + // CHECK-SAME: _tpu_replicate = "write" + %0:2 = "tf_device.cluster_func"(%arg0, %arg1) {_tpu_replicate = "write", func = @write_func} : (tensor, tensor) -> (tensor, tensor) + // CHECK-NEXT: "tf.AssignVariableOp"([[ARG2]], [[CLUSTER]]#1) + "tf.AssignVariableOp"(%arg2, %0#1) : (tensor<*x!tf.resource>>, tensor) -> () + // CHECK-NEXT: return + return +} + +// CHECK-LABEL: func @write_func +// CHECK-SAME: ({{%.*}}: tensor, {{%.*}}: tensor, {{%.*}}: tensor) -> (tensor, tensor) +func @write_func(%arg0: tensor, %arg1: tensor) -> (tensor, tensor) { + return %arg1, %arg0 : tensor, tensor +} + +// CHECK-LABEL: func @read_write_resource +func @read_write_resource(%arg0: tensor, %arg1: tensor, %arg2: tensor<*x!tf.resource>>) { + // CHECK-COUNT-1: tf.ReadVariableOp + %0 = "tf.ReadVariableOp"(%arg2) : (tensor<*x!tf.resource>>) -> tensor + %1:2 = "tf_device.cluster_func"(%arg0, %arg1, %0) {_tpu_replicate = "read_write", func = @read_write_func} : (tensor, tensor, tensor) -> (tensor, tensor) + "tf.AssignVariableOp"(%arg2, %1#1) : (tensor<*x!tf.resource>>, tensor) -> () + return +} + +// CHECK-LABEL: func @read_write_func +// CHECK-SAME: ({{%.*}}: tensor, {{%.*}}: tensor) -> (tensor, tensor) +func @read_write_func(%arg0: tensor, %arg1: tensor) -> (tensor, tensor) { + return %arg1, %arg0 : tensor, tensor +} + +// CHECK-LABEL: func @multiple_write_resource +func @multiple_write_resource(%arg0: tensor, %arg1: tensor<*x!tf.resource>>) { + // CHECK-NOT: tf.ReadVariableOp + %0:2 = "tf_device.cluster_func"(%arg0) {_tpu_replicate = "multiple_write", func = @multiple_write_func} : (tensor) -> (tensor, tensor) + "tf.AssignVariableOp"(%arg1, %0#0) : (tensor<*x!tf.resource>>, tensor) -> () + "tf.AssignVariableOp"(%arg1, %0#1) : (tensor<*x!tf.resource>>, tensor) -> () + return +} + +// CHECK-LABEL: func @multiple_write_func +// CHECK-SAME: ({{%.*}}: tensor) -> (tensor, tensor) +func @multiple_write_func(%arg0: tensor) -> (tensor, tensor) { + return %arg0, %arg0 : tensor, tensor +} + +// CHECK-LABEL: func @multiple_result_user +func @multiple_result_user(%arg0: tensor, %arg1: tensor<*x!tf.resource>>) -> tensor { + // CHECK-NOT: tf.ReadVariableOp + %0 = "tf_device.cluster_func"(%arg0) {_tpu_replicate = "multiple_uses", func = @multiple_result_user_func} : (tensor) -> tensor + "tf.AssignVariableOp"(%arg1, %0) : (tensor<*x!tf.resource>>, tensor) -> () + return %0 : tensor +} + +// CHECK-LABEL: func @multiple_result_user_func +// CHECK-SAME: ({{%.*}}: tensor) -> tensor +func @multiple_result_user_func(%arg0: tensor) -> tensor { + return %arg0 : tensor +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_head_tail_outside_compilation.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_head_tail_outside_compilation.mlir index 32a8000ea82..d897c8cbd89 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_head_tail_outside_compilation.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_head_tail_outside_compilation.mlir @@ -173,7 +173,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor func @tail_single_outside_compiled_op() { // CHECK: %[[CLUSTER_OUT:.*]] = "tf_device.cluster" // CHECK-NEXT: %[[A_OUT:.*]] = "tf.A" - // CHECK-NEXT: "tf.C" + // CHECK-NEXT: "tf.NoOp" // CHECK-NEXT: tf_device.return %[[A_OUT]] // CHECK-NEXT: { // CHECK-DAG: num_cores_per_replica = 1 @@ -190,7 +190,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor "tf_device.cluster"() ( { %a = "tf.A"() : () -> tensor "tf.B"(%a) {_xla_outside_compilation = "cluster1"} : (tensor) -> () - "tf.C"() : () -> () + "tf.NoOp"() : () -> () tf_device.return }) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> () return @@ -200,7 +200,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor func @tail_single_outside_compiled_op_user() -> tensor { // CHECK: %[[CLUSTER_OUT:.*]] = "tf_device.cluster" // CHECK-NEXT: %[[A_OUT:.*]] = "tf.A" - // CHECK-NEXT: "tf.C" + // CHECK-NEXT: "tf.NoOp" // CHECK-NEXT: tf_device.return %[[A_OUT]] // CHECK-NEXT: { // CHECK-DAG: num_cores_per_replica = 1 @@ -217,7 +217,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor %cluster = "tf_device.cluster"() ( { %a = "tf.A"() : () -> tensor %b = "tf.B"(%a) {_xla_outside_compilation = "cluster1"} : (tensor) -> tensor - "tf.C"() : () -> () + "tf.NoOp"() : () -> () tf_device.return %b : tensor }) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> tensor // CHECK: return %[[LAUNCH_OUT]] @@ -262,7 +262,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor %b = "tf.B"() : () -> tensor // CHECK: %[[CLUSTER_OUT:.*]]:2 = "tf_device.cluster" // CHECK-NEXT: %[[C_OUT:.*]] = "tf.C" - // CHECK-NEXT: %[[E_OUT:.*]] = "tf.E" + // CHECK-NEXT: %[[E_OUT:.*]] = "tf.Const" // CHECK-NEXT: tf_device.return %[[C_OUT]], %[[E_OUT]] // CHECK-NEXT: { // CHECK-DAG: num_cores_per_replica = 1 @@ -279,7 +279,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor %cluster:5 = "tf_device.cluster"() ( { %c = "tf.C"() : () -> tensor %d = "tf.D"(%c, %a) {_xla_outside_compilation = "cluster1"} : (tensor, tensor) -> tensor - %e = "tf.E"() : () -> tensor + %e = "tf.Const"() {value = dense<0> : tensor} : () -> tensor tf_device.return %a, %b, %c, %d, %e : tensor, tensor, tensor, tensor, tensor }) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> (tensor, tensor, tensor, tensor, tensor) // CHECK: return %[[A_OUT]], %[[B_OUT]], %[[CLUSTER_OUT]]#0, %[[LAUNCH_OUT]], %[[CLUSTER_OUT]]#1 @@ -320,14 +320,14 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor func @head_tail_no_extraction_middle_outside_compiled_ops(%arg0: tensor) { // CHECK-NOT: "tf_device.launch" // CHECK: "tf_device.cluster" - // CHECK-NEXT: "tf.A" + // CHECK-NEXT: "tf.Identity" // CHECK-NEXT: "tf.B" - // CHECK-NEXT: "tf.C" + // CHECK-NEXT: "tf.Identity" // CHECK-NEXT: tf_device.return "tf_device.cluster"() ( { - %a = "tf.A"(%arg0) : (tensor) -> tensor + %a = "tf.Identity"(%arg0) : (tensor) -> tensor %b = "tf.B"(%a) {_xla_outside_compilation = "cluster1"} : (tensor) -> tensor - "tf.C"(%b) : (tensor) -> () + %c = "tf.Identity"(%b) : (tensor) -> tensor tf_device.return }) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> () return @@ -379,7 +379,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // CHECK: %[[CLUSTER_OUT:.*]] = "tf_device.cluster" // CHECK-NEXT: %[[B_OUT:.*]] = "tf.B" // CHECK-NEXT: %[[C_OUT:.*]] = "tf.C"(%[[RI]], %[[B_OUT]]) - // CHECK-NEXT: "tf.E"(%[[C_OUT]], %[[HEAD_LAUNCH_OUT]]) + // CHECK-NEXT: "tf.IdentityN"(%[[C_OUT]], %[[HEAD_LAUNCH_OUT]]) // CHECK-NEXT: tf_device.return %[[C_OUT]] // CHECK-NEXT: { // CHECK-DAG: num_cores_per_replica = 1 @@ -399,11 +399,72 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor %b = "tf.B"() : () -> tensor %c = "tf.C"(%ri, %b) {_xla_outside_compilation = "cluster1"} : (tensor, tensor) -> tensor %d = "tf.D"(%a, %c, %ri) {_xla_outside_compilation = "cluster1"} : (tensor, tensor, tensor) -> tensor - %e = "tf.E"(%c, %a) : (tensor, tensor) -> tensor + %e:2 = "tf.IdentityN"(%c, %a) : (tensor, tensor) -> (tensor, tensor) tf_device.return }) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> () tf_device.return } return } + + // CHECK-LABEL: func @side_effect_middle + func @side_effect_middle() { + // CHECK: "tf_device.cluster" + // CHECK-NEXT: "tf.A" + // CHECK-NEXT: "tf.B" + // CHECK-NEXT: "tf.C" + // CHECK-NEXT: tf_device.return + "tf_device.cluster"() ( { + "tf.A"() : () -> () + "tf.B"() {_xla_outside_compilation = "cluster1"} : () -> () + "tf.C"() : () -> () + tf_device.return + }) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> () + return + } + + // CHECK-LABEL: func @side_effect_head_no_operand + func @side_effect_head_no_operand() { + // CHECK: %[[HEAD_LAUNCH_OUT:.*]] = "tf_device.launch"() + // CHECK-NEXT: "tf.B" + // CHECK-NEXT: %[[C_OUT:.*]] = "tf.C" + // CHECK-NEXT: tf_device.return %[[C_OUT]] + // CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0" + + // CHECK: "tf_device.cluster" + // CHECK-NEXT: "tf.Const" + // CHECK-NEXT: "tf.D"(%[[HEAD_LAUNCH_OUT]]) + // CHECK-NEXT: tf_device.return + + "tf_device.cluster"() ( { + %cst = "tf.Const"() {value = dense<0> : tensor} : () -> tensor + "tf.B"() {_xla_outside_compilation = "cluster1"} : () -> () + %c = "tf.C"() {_xla_outside_compilation = "cluster1"} : () -> tensor + "tf.D"(%c) : (tensor) -> () + tf_device.return + }) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> () + return + } + + // CHECK-LABEL: func @side_effect_tail_no_operand + func @side_effect_tail_no_operand() { + // CHECK: %[[CLUSTER_OUT:.*]] = "tf_device.cluster" + // CHECK-NEXT: %[[A_OUT:.*]] = "tf.A" + // CHECK-NEXT: "tf.Const" + // CHECK-NEXT: tf_device.return %[[A_OUT]] + + // CHECK: "tf_device.launch"() + // CHECK-NEXT: "tf.B"(%[[CLUSTER_OUT]]) + // CHECK-NEXT: "tf.C" + // CHECK-NEXT: tf_device.return + // CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0" + "tf_device.cluster"() ( { + %a = "tf.A"() : () -> tensor + "tf.B"(%a) {_xla_outside_compilation = "cluster1"} : (tensor) -> () + "tf.C"() {_xla_outside_compilation = "cluster1"} : () -> () + %cst = "tf.Const"() {value = dense<0> : tensor} : () -> tensor + tf_device.return + }) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> () + return + } } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc b/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc index 0c21078b0ad..9960ca77693 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc @@ -110,6 +110,7 @@ void CreateTPUBridgePipeline(OpPassManager &pm) { pm.addPass(TF::CreateResourceDeviceInferencePass()); pm.addPass(TFDevice::CreateClusterOutliningPass()); pm.addPass(CreateTPUDynamicPaddingMapperPass()); + pm.addPass(CreateTPUResourceReadForWritePass()); pm.addPass(CreateTPUShardingIdentificationPass()); pm.addPass(TFDevice::CreateAnnotateParameterReplicationPass()); pm.addPass(CreateTPURewritePass()); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.cc b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.cc index f690882b0a9..4385a2d00b0 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.cc @@ -629,8 +629,7 @@ class Lower_UnaryOpsComposition LogicalResult matchAndRewrite(TF::_UnaryOpsCompositionOp op, PatternRewriter &rewriter) const override { Value result = op.x(); - for (StringRef op_name : - op.op_names().getAsRange()) { + for (StringRef op_name : op.op_names().getAsValueRange()) { std::string full_name = "tf." + op_name.str(); // All ops in the sequences have the same result type as the original // result type. diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h index d93d9ddccaf..7dcb1caf5d4 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h @@ -287,6 +287,10 @@ std::unique_ptr> CreateTPUDynamicLayoutPass(); // `tf_device.launch_func` `padding_map` attribute to its encapsulated function. std::unique_ptr> CreateTPUDynamicPaddingMapperPass(); +// Creates a pass that adds `tf.ReadVariableOp` to a TPU cluster for resources +// the cluster only writes to. +std::unique_ptr> CreateTPUResourceReadForWritePass(); + // Creates a pass that rewrites `tf_device.launch_func` on TPUs into TPU runtime // ops. std::unique_ptr> CreateTPURewritePass(); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_head_tail_outside_compilation.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_head_tail_outside_compilation.cc index fed4002bfcf..8a709062c67 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_head_tail_outside_compilation.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_head_tail_outside_compilation.cc @@ -27,6 +27,7 @@ limitations under the License. #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Block.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/IR/Visitors.h" // from @llvm-project @@ -34,6 +35,7 @@ limitations under the License. #include "mlir/Pass/PassRegistry.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/RegionUtils.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" @@ -118,7 +120,10 @@ tf_device::LaunchOp CreateLaunchForBlock(OpBuilder* builder, Operation* op, // computation or other ops that can be extracted, and have no operands from // other ops in the TPU computation that cannot be extracted. llvm::SmallVector FindOutsideCompiledOpsAtHead( + const TF::SideEffectAnalysis& side_effect_analysis, tf_device::ClusterOp cluster) { + const auto& analysis = side_effect_analysis.GetAnalysisForFunc( + cluster.getParentOfType()); Region* cluster_region = &cluster.body(); llvm::SmallSetVector head_outside_compiled_ops; @@ -127,6 +132,15 @@ llvm::SmallVector FindOutsideCompiledOpsAtHead( if (!HasOutsideCompilationAttribute(&cluster_op)) continue; // An outside compiled op can be extracted if its operands are not from // other ops in the cluster that cannot be extracted. + + // Check if the side effecting op right before this side effecting op, if + // it is side effecting, can be head extracted. Because of op ordering due + // to side effects, if this is not true, this op cannot be head extracted. + auto predecessors = analysis.DirectControlPredecessors(&cluster_op); + if (!predecessors.empty() && + !head_outside_compiled_ops.contains(predecessors.back())) + continue; + auto walk_result = cluster_op.walk([&](Operation* op) { for (Value operand : op->getOperands()) { Operation* operand_op = GetOpOfValue(operand); @@ -168,11 +182,11 @@ void CreateHeadComputation(OpBuilder* builder, tf_device::ClusterOp cluster, // Extracts and move outside compiled ops that have no dependencies in the // cluster to before the cluster. mlir::LogicalResult LiftHeadOutsideCompiledOps( - OpBuilder* builder, const mlir::TF::RuntimeDevices& devices, - tf_device::ClusterOp cluster, std::string* host_device, - bool* cluster_updated) { + OpBuilder* builder, const TF::SideEffectAnalysis& side_effect_analysis, + const mlir::TF::RuntimeDevices& devices, tf_device::ClusterOp cluster, + std::string* host_device, bool* cluster_updated) { llvm::SmallVector head_outside_compiled_ops = - FindOutsideCompiledOpsAtHead(cluster); + FindOutsideCompiledOpsAtHead(side_effect_analysis, cluster); if (head_outside_compiled_ops.empty()) return success(); if (failed(tensorflow::GetHostDeviceOutsideComputation(devices, cluster, host_device))) @@ -191,9 +205,12 @@ mlir::LogicalResult LiftHeadOutsideCompiledOps( // TPU computation or other ops that can be extracted, and have no results used // by other ops in the TPU computation that cannot be extracted. void FindOutsideCompiledOpsAtTailAndClusterResults( + const TF::SideEffectAnalysis& side_effect_analysis, tf_device::ClusterOp cluster, llvm::SmallVectorImpl* tail_outside_compiled_ops, llvm::SmallVectorImpl* cluster_results) { + const auto& analysis = side_effect_analysis.GetAnalysisForFunc( + cluster.getParentOfType()); Region* cluster_region = &cluster.body(); llvm::SmallSetVector tail_outside_compiled_ops_set; Operation* terminator = cluster.GetBody().getTerminator(); @@ -205,6 +222,15 @@ void FindOutsideCompiledOpsAtTailAndClusterResults( for (Operation& cluster_op : cluster_ops) { if (!HasOutsideCompilationAttribute(&cluster_op)) continue; + // Check if the side effecting op right after this side effecting op, if + // it is side effecting, can be tail extracted. Because of op ordering due + // to side effects, if this is not true, this op cannot be tail extracted. + auto successors = analysis.DirectControlSuccessors( + &cluster_op, [&terminator](Operation* op) { return op != terminator; }); + if (!successors.empty() && + !tail_outside_compiled_ops_set.contains(successors.front())) + continue; + llvm::SmallVector results_to_forward; bool can_be_extracted = llvm::all_of(cluster_op.getUsers(), [&](Operation* op) { @@ -293,13 +319,14 @@ tf_device::ClusterOp UpdateClusterResults( // Extracts and move outside compiled ops that do not create dependencies in the // cluster to after the cluster. mlir::LogicalResult LiftTailOutsideCompiledOps( - OpBuilder* builder, const mlir::TF::RuntimeDevices& devices, - std::string host_device, tf_device::ClusterOp* cluster, - bool* cluster_updated) { + OpBuilder* builder, const TF::SideEffectAnalysis& side_effect_analysis, + const mlir::TF::RuntimeDevices& devices, std::string host_device, + tf_device::ClusterOp* cluster, bool* cluster_updated) { llvm::SmallVector tail_outside_compiled_ops; llvm::SmallVector cluster_results; - FindOutsideCompiledOpsAtTailAndClusterResults( - *cluster, &tail_outside_compiled_ops, &cluster_results); + FindOutsideCompiledOpsAtTailAndClusterResults(side_effect_analysis, *cluster, + &tail_outside_compiled_ops, + &cluster_results); if (tail_outside_compiled_ops.empty()) return success(); if (host_device.empty()) @@ -365,6 +392,7 @@ struct TPUExtractHeadTailOutsideCompilation }; void TPUExtractHeadTailOutsideCompilation::runOnOperation() { + auto& side_effect_analysis = getAnalysis(); // Get runtime devices information from the closest parent module. auto module = getOperation(); mlir::TF::RuntimeDevices devices; @@ -379,10 +407,12 @@ void TPUExtractHeadTailOutsideCompilation::runOnOperation() { for (tf_device::ClusterOp cluster : clusters) { std::string host_device; bool cluster_updated = false; - if (failed(LiftHeadOutsideCompiledOps(&builder, devices, cluster, - &host_device, &cluster_updated)) || - failed(LiftTailOutsideCompiledOps(&builder, devices, host_device, - &cluster, &cluster_updated))) + if (failed(LiftHeadOutsideCompiledOps(&builder, side_effect_analysis, + devices, cluster, &host_device, + &cluster_updated)) || + failed(LiftTailOutsideCompiledOps(&builder, side_effect_analysis, + devices, host_device, &cluster, + &cluster_updated))) return signalPassFailure(); if (cluster_updated) RemoveClusterAliasedOutputs(&builder, cluster); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_resource_read_for_write.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_resource_read_for_write.cc new file mode 100644 index 00000000000..cccd528da1d --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_resource_read_for_write.cc @@ -0,0 +1,140 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" + +namespace mlir { +namespace TFTPU { + +// A pass that finds TPU clusters with write only resource access and adds an +// associated resource read, so the resource can later be fused into TPUExecute. +namespace { +struct TPUResourceReadForWrite + : public PassWrapper> { + void runOnOperation() override; +}; + +// Helper struct holding a resource value and its associated type. +struct ResourceValueAndSubtype { + Value resource; + Type subtype; +}; + +// Finds resource handle and type for result if result writes to a resource. +ResourceValueAndSubtype GetResourceWriteResult( + tf_device::ClusterFuncOp cluster_func, Value result) { + ResourceValueAndSubtype resource; + if (!result.hasOneUse()) return resource; + Operation* result_user = *result.getUsers().begin(); + auto assign_var = dyn_cast(result_user); + if (!assign_var) return resource; + + auto handle = assign_var.resource(); + // Skip result if cluster writes to the same variable via multiple results. + for (Operation* handle_user : handle.getUsers()) { + if (handle_user == assign_var) continue; + auto assign_var_user = dyn_cast(handle_user); + if (!assign_var_user) continue; + if (assign_var_user.value().getDefiningOp() == cluster_func) + return resource; + } + + resource.resource = assign_var.resource(); + resource.subtype = assign_var.value().getType(); + return resource; +} + +// Checks if resource is read by TPU cluster. +bool ClusterFuncHasResourceRead(tf_device::ClusterFuncOp cluster_func, + Value resource) { + for (Operation* resource_user : resource.getUsers()) + if (auto read = dyn_cast(resource_user)) + for (Operation* read_user : read.value().getUsers()) + if (read_user == cluster_func) return true; + + return false; +} + +void TPUResourceReadForWrite::runOnOperation() { + SmallVector cluster_funcs; + getOperation().walk([&](tf_device::ClusterFuncOp cluster_func) { + cluster_funcs.push_back(cluster_func); + }); + + OpBuilder builder(&getContext()); + // Add resource reads for resource writes from TPU cluster where for such + // resources the TPU cluster does not read from. + for (tf_device::ClusterFuncOp cluster_func : cluster_funcs) { + builder.setInsertionPoint(cluster_func); + + SmallVector read_operands; + for (Value result : cluster_func.getResults()) { + // TODO(lyandy): Update pass to use resource alias analysis. + auto resource_and_type = GetResourceWriteResult(cluster_func, result); + if (!resource_and_type.resource) continue; + if (ClusterFuncHasResourceRead(cluster_func, resource_and_type.resource)) + continue; + auto new_read = builder.create( + resource_and_type.resource.getLoc(), resource_and_type.subtype, + resource_and_type.resource); + read_operands.push_back(new_read.value()); + } + + if (read_operands.empty()) continue; + + // Update caller and function types with new read operands. + auto operands = llvm::to_vector<4>(cluster_func.getOperands()); + operands.append(read_operands.begin(), read_operands.end()); + + auto new_cluster_func = builder.create( + cluster_func.getLoc(), cluster_func.getResultTypes(), operands, + cluster_func.getAttrs()); + cluster_func.replaceAllUsesWith(new_cluster_func); + FuncOp func = cluster_func.getFunc(); + Block& block = func.front(); + for (Value read_operand : read_operands) + block.addArgument(read_operand.getType()); + + func.setType(FunctionType::get(block.getArgumentTypes(), + func.getCallableResults(), &getContext())); + cluster_func.erase(); + } +} + +} // namespace + +std::unique_ptr> CreateTPUResourceReadForWritePass() { + return std::make_unique(); +} + +static PassRegistration pass( + "tf-tpu-resource-read-for-write", + "Inserts tf.ReadVariableOp inputs to a TPU cluster for resource writes " + "with no reads"); + +} // namespace TFTPU +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc b/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc index 631553b381e..0445dbb698a 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc @@ -49,6 +49,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" #include "tensorflow/compiler/mlir/tensorflow/utils/export_utils.h" #include "tensorflow/compiler/mlir/tensorflow/utils/translate_utils.h" +#include "tensorflow/compiler/mlir/utils/name_utils.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/graph_to_functiondef.h" @@ -80,46 +81,14 @@ constexpr char kInvalidExecutorGraphMsg[] = constexpr char kDeviceAttr[] = "tf.device"; constexpr char kResourceArgUniqueIdAttr[] = "tf._resource_arg_unique_id"; -bool IsLegalChar(char c, bool first_char) { - if (isalpha(c)) return true; - if (isdigit(c)) return true; - if (c == '.') return true; - if (c == '_') return true; - - // First character of a node name can only be a letter, digit, dot or - // underscore. - if (first_char) return false; - - if (c == '/') return true; - if (c == '-') return true; - - return false; -} - -// Convert characters in name that are considered illegal in TensorFlow Node -// name to '.'. -std::string LegalizeNodeName(llvm::StringRef name) { - assert(!name.empty() && "expected non-empty name"); - - std::string legalized_name; - bool first = true; - for (auto c : name) { - if (IsLegalChar(c, first)) { - legalized_name += c; - } else { - legalized_name += '.'; - } - first = false; - } - - return legalized_name; -} - // OpOrArgLocNameMapper that legalizes the returned name. class LegalizedOpOrValLocNameMapper : public OpOrArgLocNameMapper { private: std::string GetName(OpOrVal op_or_val) override { - return LegalizeNodeName(OpOrArgLocNameMapper::GetName(op_or_val)); + std::string name = OpOrArgLocNameMapper::GetName(op_or_val); + assert(!name.empty() && "expected non-empty name"); + mlir::LegalizeNodeName(name); + return name; } }; @@ -523,13 +492,14 @@ StatusOr> Exporter::Convert( if (index >= num_data_results) break; // TODO(jpienaar): If there is a result index specified, ensure only one // and that it matches the result index of the op. - std::string orig_name(output_names[index]); - auto tensor_id = ParseTensorName(orig_name); - auto name = LegalizeNodeName( - llvm::StringRef(tensor_id.node().data(), tensor_id.node().size())); + std::string name(output_names[index]); + auto tensor_id = ParseTensorName(name); + std::string tensor_id_node(tensor_id.node()); + assert(!tensor_id_node.empty() && "expected non-empty name"); + mlir::LegalizeNodeName(tensor_id_node); // Ensure name does not get reused. - (void)exporter.op_to_name_.GetUniqueName(name); + (void)exporter.op_to_name_.GetUniqueName(tensor_id_node); } } @@ -537,8 +507,9 @@ StatusOr> Exporter::Convert( TF_RET_CHECK(input_names.size() == block.getNumArguments()); for (const auto& it : llvm::enumerate(function.getArguments())) { // TODO(lyandy): Update when changing feed/fetch import. - std::string orig_name(input_names[it.index()]); - std::string name = LegalizeNodeName(orig_name); + std::string name(input_names[it.index()]); + assert(!name.empty() && "expected non-empty name"); + mlir::LegalizeNodeName(name); auto tensor_id = ParseTensorName(name); TF_RET_CHECK(tensor_id.index() == 0) << "input port designation not supported"; diff --git a/tensorflow/compiler/mlir/utils/array_container_utils.h b/tensorflow/compiler/mlir/utils/array_container_utils.h new file mode 100644 index 00000000000..c1a898185d9 --- /dev/null +++ b/tensorflow/compiler/mlir/utils/array_container_utils.h @@ -0,0 +1,51 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_UTILS_ARRAY_CONTAINER_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_UTILS_ARRAY_CONTAINER_UTILS_H_ + +#include "absl/types/span.h" +#include "llvm/ADT/ArrayRef.h" + +namespace mlir { + +template +inline llvm::ArrayRef SpanToArrayRef(absl::Span span) { + return llvm::ArrayRef(span.data(), span.size()); +} + +template +inline llvm::ArrayRef SpanToArrayRef(absl::Span span) { + return llvm::ArrayRef(span.data(), span.size()); +} + +template +inline llvm::MutableArrayRef SpanToMutableArrayRef(absl::Span span) { + return llvm::MutableArrayRef(span.data(), span.size()); +} + +template +inline absl::Span ArrayRefToSpan(llvm::ArrayRef ref) { + return absl::Span(ref.data(), ref.size()); +} + +template +inline absl::Span MutableArrayRefToSpan(llvm::MutableArrayRef ref) { + return absl::Span(ref.data(), ref.size()); +} + +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_UTILS_ARRAY_CONTAINER_UTILS_H_ diff --git a/tensorflow/compiler/mlir/utils/name_utils.cc b/tensorflow/compiler/mlir/utils/name_utils.cc new file mode 100644 index 00000000000..bc4e80f5aa1 --- /dev/null +++ b/tensorflow/compiler/mlir/utils/name_utils.cc @@ -0,0 +1,99 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/utils/name_utils.h" + +#include + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringExtras.h" +#include "mlir/IR/Identifier.h" // from @llvm-project + +namespace mlir { + +namespace { +// Checks if a character is legal for a TensorFlow node name, with special +// handling if a character is at the beginning. +bool IsLegalChar(char c, bool first_char) { + if (isalpha(c)) return true; + if (isdigit(c)) return true; + if (c == '.') return true; + if (c == '_') return true; + + // First character of a node name can only be a letter, digit, dot or + // underscore. + if (first_char) return false; + + if (c == '/') return true; + if (c == '-') return true; + + return false; +} +} // anonymous namespace + +void LegalizeNodeName(std::string& name) { + if (name.empty()) return; + + if (!IsLegalChar(name[0], /*first_char=*/true)) name[0] = '.'; + + for (char& c : llvm::drop_begin(name, 1)) + if (!IsLegalChar(c, /*first_char=*/false)) c = '.'; +} + +std::string GetNameFromLoc(Location loc) { + llvm::SmallVector loc_names; + llvm::SmallVector locs; + locs.push_back(loc); + bool names_is_nonempty = false; + + while (!locs.empty()) { + Location curr_loc = locs.pop_back_val(); + + if (auto name_loc = curr_loc.dyn_cast()) { + // Add name in NameLoc. For NameLoc we also account for names due to ops + // in functions where the op's name is first. + auto name = name_loc.getName().strref().split('@').first; + loc_names.push_back(name); + if (!name.empty()) names_is_nonempty = true; + continue; + } else if (auto call_loc = curr_loc.dyn_cast()) { + // Add name if CallSiteLoc's callee has a NameLoc (as should be the + // case if imported with DebugInfo). + if (auto name_loc = call_loc.getCallee().dyn_cast()) { + auto name = name_loc.getName().strref().split('@').first; + loc_names.push_back(name); + if (!name.empty()) names_is_nonempty = true; + continue; + } + } else if (auto fused_loc = curr_loc.dyn_cast()) { + // Push all locations in FusedLoc in reverse order, so locations are + // visited based on order in FusedLoc. + auto reversed_fused_locs = llvm::reverse(fused_loc.getLocations()); + locs.append(reversed_fused_locs.begin(), reversed_fused_locs.end()); + continue; + } + + // Location is not a supported, so an empty StringRef is added. + loc_names.push_back(llvm::StringRef()); + } + + if (names_is_nonempty) + return llvm::join(loc_names.begin(), loc_names.end(), ";"); + + return ""; +} + +} // namespace mlir diff --git a/tensorflow/compiler/mlir/utils/name_utils.h b/tensorflow/compiler/mlir/utils/name_utils.h new file mode 100644 index 00000000000..4b08a41feec --- /dev/null +++ b/tensorflow/compiler/mlir/utils/name_utils.h @@ -0,0 +1,35 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_UTILS_NAME_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_UTILS_NAME_UTILS_H_ + +#include + +#include "llvm/ADT/StringRef.h" +#include "mlir/IR/Location.h" // from @llvm-project + +namespace mlir { + +// Converts characters in name that are considered illegal in TensorFlow Node +// name to '.'. +void LegalizeNodeName(std::string& name); + +// Creates a TensorFlow node name from a location. +std::string GetNameFromLoc(Location loc); + +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_UTILS_NAME_UTILS_H_ diff --git a/tensorflow/compiler/mlir/utils/string_container_utils.h b/tensorflow/compiler/mlir/utils/string_container_utils.h new file mode 100644 index 00000000000..fb2fa06ca4d --- /dev/null +++ b/tensorflow/compiler/mlir/utils/string_container_utils.h @@ -0,0 +1,34 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_UTILS_STRING_CONTAINER_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_UTILS_STRING_CONTAINER_UTILS_H_ + +#include "absl/strings/string_view.h" +#include "llvm/ADT/StringRef.h" + +namespace mlir { + +inline absl::string_view StringRefToView(llvm::StringRef ref) { + return absl::string_view(ref.data(), ref.size()); +} + +inline llvm::StringRef StringViewToRef(absl::string_view view) { + return llvm::StringRef(view.data(), view.size()); +} + +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_UTILS_STRING_CONTAINER_UTILS_H_ diff --git a/tensorflow/compiler/mlir/xla/BUILD b/tensorflow/compiler/mlir/xla/BUILD index 32a2ed1c272..ec98d9d29e5 100644 --- a/tensorflow/compiler/mlir/xla/BUILD +++ b/tensorflow/compiler/mlir/xla/BUILD @@ -238,7 +238,6 @@ cc_library( deps = [ ":type_to_shape", "//tensorflow/compiler/mlir/hlo", - "//tensorflow/compiler/mlir/hlo:hlo_dialect_force_registration", "//tensorflow/compiler/mlir/tensorflow:convert_type", "//tensorflow/compiler/mlir/tensorflow:error_util", "//tensorflow/compiler/tf2xla:common", @@ -389,7 +388,6 @@ cc_library( ":xla_legalize_tf_with_tf2xla", "//tensorflow/compiler/mlir/hlo", "//tensorflow/compiler/mlir/hlo:chlo_legalize_to_hlo", - "//tensorflow/compiler/mlir/hlo:hlo_dialect_force_registration", "//tensorflow/compiler/mlir/hlo:hlo_legalize_to_lhlo", "//tensorflow/compiler/mlir/hlo:legalize_control_flow", "//tensorflow/compiler/mlir/hlo:legalize_tanh_to_approximation", diff --git a/tensorflow/compiler/tests/cholesky_op_test.py b/tensorflow/compiler/tests/cholesky_op_test.py index 4bd2dfd9244..41877d39381 100644 --- a/tensorflow/compiler/tests/cholesky_op_test.py +++ b/tensorflow/compiler/tests/cholesky_op_test.py @@ -28,7 +28,6 @@ from tensorflow.python.framework import errors from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import linalg_ops -from tensorflow.python.ops import math_ops from tensorflow.python.platform import test @@ -61,7 +60,7 @@ class CholeskyOpTest(xla_test.XLATestCase): dtypes.as_dtype(x.dtype), shape=x.shape) with self.test_scope(): chol = linalg_ops.cholesky(placeholder) - verification = math_ops.matmul(chol, chol, adjoint_b=True) + verification = test_util.matmul_without_tf32(chol, chol, adjoint_b=True) self._verifyCholeskyBase(sess, placeholder, x, chol, verification, atol) def testBasic(self): diff --git a/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py b/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py index 9d278cfbb28..08aad66abe1 100644 --- a/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py +++ b/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py @@ -29,7 +29,6 @@ from tensorflow.python.framework import errors from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import linalg_ops -from tensorflow.python.ops import math_ops from tensorflow.python.platform import test @@ -65,7 +64,8 @@ class MatrixTriangularSolveOpTest(xla_test.XLATestCase): with self.test_scope(): x = linalg_ops.matrix_triangular_solve( placeholder_a, placeholder_b, lower=lower, adjoint=adjoint) - verification = math_ops.matmul(placeholder_ca, x, adjoint_a=adjoint) + verification = test_util.matmul_without_tf32( + placeholder_ca, x, adjoint_a=adjoint) self._VerifyTriangularSolveBase(sess, placeholder_a, placeholder_ca, placeholder_b, a, clean_a, b, verification, atol) diff --git a/tensorflow/compiler/tests/qr_op_test.py b/tensorflow/compiler/tests/qr_op_test.py index 5fcf254db82..b2d5db8a3a8 100644 --- a/tensorflow/compiler/tests/qr_op_test.py +++ b/tensorflow/compiler/tests/qr_op_test.py @@ -24,12 +24,17 @@ from absl.testing import parameterized import numpy as np from tensorflow.compiler.tests import xla_test +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops from tensorflow.python.platform import test +@test_util.run_all_without_tensor_float_32( + "XLA QR op calls matmul. Also, matmul used for verification. Also with " + 'TF32, mysterious "Unable to launch cuBLAS gemm" error occasionally occurs') +# TODO(b/165435566): Fix "Unable to launch cuBLAS gemm" error class QrOpTest(xla_test.XLATestCase, parameterized.TestCase): def AdjustedNorm(self, x): @@ -73,7 +78,7 @@ class QrOpTest(xla_test.XLATestCase, parameterized.TestCase): with self.session() as sess: x_tf = array_ops.placeholder(dtype) - with self.test_scope(): + with self.device_scope(): q_tf, r_tf = linalg_ops.qr(x_tf, full_matrices=full_matrices) q_tf_val, r_tf_val = sess.run([q_tf, r_tf], feed_dict={x_tf: x_np}) diff --git a/tensorflow/compiler/tests/xla_test.py b/tensorflow/compiler/tests/xla_test.py index 8c31629c234..de97c6ff210 100644 --- a/tensorflow/compiler/tests/xla_test.py +++ b/tensorflow/compiler/tests/xla_test.py @@ -237,8 +237,8 @@ class XLATestCase(test.TestCase): 'test_session not supported on XLATestCase, please use session') @contextlib.contextmanager - def test_scope(self): - """Test scope that runs tests on `self.device`. + def device_scope(self): + """Scope that runs tests on `self.device`. Yields: A scope to apply to the operators under test. @@ -246,6 +246,15 @@ class XLATestCase(test.TestCase): with ops.device('device:{}:0'.format(self.device)): yield + def test_scope(self): + """Deprecated alias of `device_scope`. + + This should be avoided as the name starts with `test`, so test runners + treat it as a test. This interferes with class decorators that operate on + each test method. + """ + return self.device_scope() + def Benchmark(tf_bench, builder_fn, diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc index 3977c5c517d..b127337e02a 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc @@ -1709,12 +1709,12 @@ class ParameterizedOpConverterTestBase std::tuple> { public: ParameterizedOpConverterTestBase() - : trt_mode(std::get<0>(GetParam())), - tf_type(std::get<1>(GetParam())), - converter_precision(std::get<2>(GetParam())) {} + : trt_mode_(std::get<0>(GetParam())), + tf_type_(std::get<1>(GetParam())), + converter_precision_(std::get<2>(GetParam())) {} void Reset() { - OpConverterTest::Reset(converter_precision, trt_mode); + OpConverterTest::Reset(converter_precision_, trt_mode_); input_data_.clear(); } @@ -1750,7 +1750,7 @@ class ParameterizedOpConverterTestBase if (!partial_input_shape_dims.empty()) { partial_shape = partial_input_shape_dims; } else { - if (trt_mode == TrtTestMode::kDynamicShape) { + if (trt_mode_ == TrtTestMode::kDynamicShape) { // In dynamic shape mode we make all dims unknown. partial_shape = std::vector(dims.size(), -1); } else { @@ -1776,7 +1776,7 @@ class ParameterizedOpConverterTestBase void AddTestTensor(const string& name, const std::vector& dims, const std::vector& values = {}, const std::vector& partial_input_shape_dims = {}) { - AddTestTensor(name, dims, tf_type, values, partial_input_shape_dims); + AddTestTensor(name, dims, tf_type_, values, partial_input_shape_dims); } // Builds and runs the converted network. Checks output tensor shape. Tests @@ -1796,7 +1796,7 @@ class ParameterizedOpConverterTestBase TensorShapeUtils::MakeShape(expected_output_dims[i], &shape)); string out_name = (n_output == 1) ? name : StrCat(name, ":", i); DataType out_tf_type = - out_tf_types.size() > i ? out_tf_types[i] : tf_type; + out_tf_types.size() > i ? out_tf_types[i] : tf_type_; InputOutputData data{ out_name, ConstructTensor(shape.num_elements(), 0, out_tf_type)}; output_data.push_back(data); @@ -1840,9 +1840,9 @@ class ParameterizedOpConverterTestBase } protected: - const TrtTestMode trt_mode; - const DataType tf_type; - const TrtPrecisionMode converter_precision; + const TrtTestMode trt_mode_; + const DataType tf_type_; + const TrtPrecisionMode converter_precision_; DataVec input_data_; }; @@ -2075,7 +2075,7 @@ TEST_P(OpConverterTest1, ConvertFusedBatchNorm) { 37.342354, 41.013527, 30.9738, 34.469433, 45.018955, 48.59309, 59.369415, 63.04059}; for (auto get_node_def : get_node_def_vec) { - NodeDef tmp_node_def = get_node_def(tf_type, "NCHW", true, 0); + NodeDef tmp_node_def = get_node_def(tf_type_, "NCHW", true, 0); std::string op_name = tmp_node_def.op(); std::vector test_param{ {"NHWC", 0, false, 0, @@ -2097,7 +2097,7 @@ TEST_P(OpConverterTest1, ConvertFusedBatchNorm) { errors::Unimplemented(StrCat("The input \"variance\" for ", op_name, " must be a constant, at my_batchnorm"))}, {"NCHW", 0, false, 0.01}}; // The last one is the only test that runs. - if (trt_mode == TrtTestMode::kDynamicShape) { + if (trt_mode_ == TrtTestMode::kDynamicShape) { test_param.push_back( {"NCHW", 0, false, 0.01, errors::InvalidArgument( @@ -2107,7 +2107,7 @@ TEST_P(OpConverterTest1, ConvertFusedBatchNorm) { for (auto p : test_param) { Reset(); NodeDef node_def = - get_node_def(tf_type, p.data_format, p.is_training, p.epsilon); + get_node_def(tf_type_, p.data_format, p.is_training, p.epsilon); for (int i = 0; i < node_input.size(); i++) { if (i == 0 || i == p.tensor_input_idx) { // The first input (x) is always added as a tensor, and it hase shape @@ -2126,7 +2126,7 @@ TEST_P(OpConverterTest1, ConvertFusedBatchNorm) { // the first arg is a tensor. TODO(tfeher) Check if one can relax this // restriction. Status expected_status = - (i != 0 && trt_mode == TrtTestMode::kImplicitBatch) + (i != 0 && trt_mode_ == TrtTestMode::kImplicitBatch) ? errors::InvalidArgument( StrCat("Batch size doesn't match for tensor ", node_input[i].name, @@ -2134,19 +2134,19 @@ TEST_P(OpConverterTest1, ConvertFusedBatchNorm) { "converter batch size: 3 vs 2")) : Status::OK(); std::vector partial_input_shape; - if (i == 0 && trt_mode == TrtTestMode::kDynamicShape && + if (i == 0 && trt_mode_ == TrtTestMode::kDynamicShape && !p.keep_channel_unknown) { // keep channel dim static (known) partial_input_shape.resize(4, -1); partial_input_shape[1] = node_input[i].dims[1]; } - AddTestTensor(node_input[i].name, node_input[i].dims, tf_type, + AddTestTensor(node_input[i].name, node_input[i].dims, tf_type_, node_input[i].val, partial_input_shape, expected_status); } else { AddTestWeights(node_input[i].name, node_input[i].dims, - node_input[i].val, tf_type); + node_input[i].val, tf_type_); } } TestOpConverter("my_batchnorm", node_def, node_input[0].dims, @@ -2154,12 +2154,12 @@ TEST_P(OpConverterTest1, ConvertFusedBatchNorm) { ArrayFloatNear(expected_output)); } } -} // namespace convert +} TEST_P(OpConverterTest1, ConvertTranspose) { // Get the NodeDef for Transpose. Scope s = Scope::NewRootScope(); - auto input = ops::Placeholder(s.WithOpName("input"), tf_type); + auto input = ops::Placeholder(s.WithOpName("input"), tf_type_); auto weights = ops::Placeholder(s.WithOpName("weights"), DT_INT32); auto transpose = ops::Transpose(s.WithOpName("my_transpose"), input, weights); const NodeDef& node_def = transpose.operation.node()->def(); @@ -2187,13 +2187,13 @@ TEST_P(OpConverterTest1, ConvertTranspose) { {}, {3, 2, 1, 1}, {3, 2, 1, 0}, - (trt_mode == TrtTestMode::kImplicitBatch) + (trt_mode_ == TrtTestMode::kImplicitBatch) ? Status(error::UNIMPLEMENTED, "Transpose at batch dimension is not supported") : Status::OK()}, TestParamBase{{1, 1, 2, 3}, {}, {1, 3, 1, 2}, {0, 3, 1, 2}}, }; - if (trt_mode == TrtTestMode::kDynamicShape) { + if (trt_mode_ == TrtTestMode::kDynamicShape) { // Dynamic shape tests where some shapes are known test_params.push_back(TestParamBase{ {1, 1, 2, 3}, {-1, 1, 2, -1}, {1, 3, 1, 2}, {0, 3, 1, 2}}); @@ -2317,12 +2317,12 @@ TEST_F(OpConverterTest, ConvertReshape) { TEST_P(OpConverterTest1, ConvertShape) { // Get the NodeDef for Shape op. Scope s = Scope::NewRootScope(); - auto input = ops::Placeholder(s.WithOpName("input"), tf_type); + auto input = ops::Placeholder(s.WithOpName("input"), tf_type_); auto shape = ops::Shape(s.WithOpName("my_shape"), input); const NodeDef& node_def = shape.operation.node()->def(); Status conversion_status = - (trt_mode == TrtTestMode::kImplicitBatch) + (trt_mode_ == TrtTestMode::kImplicitBatch) ? errors::Unimplemented( "Shape is only supported for explicit batch mode.") : Status::OK(); @@ -2346,7 +2346,7 @@ TEST_P(OpConverterTest1, ConvertShape) { // we use for the unit test have no actual input tensor when it is converted // to a TensorRT network. int n_elements = 0; - if (input_is_weight(p) || trt_mode != TrtTestMode::kExplicitBatch) { + if (input_is_weight(p) || trt_mode_ != TrtTestMode::kExplicitBatch) { // Calculate the number of elements for adding input data. n_elements = std::accumulate(p.input_dims.begin(), p.input_dims.end(), 1, std::multiplies()); @@ -2355,7 +2355,7 @@ TEST_P(OpConverterTest1, ConvertShape) { if (!input_is_weight(p)) { AddTestTensor("input", p.input_dims, input_val); } else { - AddTestWeights("input", p.input_dims, input_val, tf_type); + AddTestWeights("input", p.input_dims, input_val, tf_type_); } TestOpConverter("my_shape", node_def, p.expected_output_dims, p.status, p.runtime_status, ElementsAreArray(p.input_dims), @@ -2620,7 +2620,7 @@ TEST_P(OpConverterTest2, ConvertBiasAdd) { for (const string& data_format : {"NHWC", "NCHW"}) { for (const int trt_input_rank : {1, 2, 3, 4}) { Reset(); - NodeDef node_def = get_biasadd_nodedef(data_format, tf_type); + NodeDef node_def = get_biasadd_nodedef(data_format, tf_type_); // Add input, dims_array will be like {2, 1, ..., 1, 3} std::vector dims_array(trt_input_rank + 1, 1); @@ -2642,7 +2642,7 @@ TEST_P(OpConverterTest2, ConvertBiasAdd) { for (int i = 0; i < channel_size; ++i) { bias[i] = i + 1; // bias will be {1, 2, 3, ...} } - AddTestWeights("weights", {channel_size}, bias, tf_type); + AddTestWeights("weights", {channel_size}, bias, tf_type_); // Build and run the engine. std::vector output_data; @@ -2678,7 +2678,7 @@ NodeDef GetBinaryOpNodeDef(DataType dtype) { TEST_P(OpConverterTest2, ConvertBinary) { { AttrValue dtype; - dtype.set_type(tf_type); + dtype.set_type(tf_type_); // Both inputs are weights. Reset(); NodeDef node_def = @@ -2723,19 +2723,19 @@ TEST_P(OpConverterTest2, ConvertBinary) { if (!op_test_info.count(op_name)) { FAIL() << "Binary op test map does not contain op " << op_name; } - NodeDef node_def = op_test_info[op_name].first(tf_type); + NodeDef node_def = op_test_info[op_name].first(tf_type_); std::vector input_names; std::vector> input_dims; std::vector> input_values; if (operand_1_is_tensor) { AddTestTensor("input1", {2, 1, 2}, {3, 6, 3, 6}); } else { - AddTestWeights("input1", {1, 2}, std::vector{3, 6}, tf_type); + AddTestWeights("input1", {1, 2}, std::vector{3, 6}, tf_type_); } if (operand_2_is_tensor) { AddTestTensor("input2", {2, 2, 1}, {2, 3, 2, 3}); } else { - AddTestWeights("input2", {2, 1}, std::vector{2, 3}, tf_type); + AddTestWeights("input2", {2, 1}, std::vector{2, 3}, tf_type_); } TestOpConverter("my_binary", node_def, {2, 2, 2}, Status::OK(), Status::OK(), @@ -2942,10 +2942,10 @@ TEST_P(OpConverterTest2, ConvertSquare) { // Input is weights, should fail. Reset(); Scope s = Scope::NewRootScope(); - auto input = ops::Placeholder(s.WithOpName("input"), tf_type); + auto input = ops::Placeholder(s.WithOpName("input"), tf_type_); auto square = ops::Square(s.WithOpName("my_square"), input); NodeDef node_def = square.operation.node()->def(); - AddTestWeights("input", {1, 2, 3}, {1, 2, 3, 4, -5, 6}, tf_type); + AddTestWeights("input", {1, 2, 3}, {1, 2, 3, 4, -5, 6}, tf_type_); RunValidationAndConversion( node_def, error::UNIMPLEMENTED, "The input \"x\" for Square must be a tensor, at my_square"); @@ -2954,7 +2954,7 @@ TEST_P(OpConverterTest2, ConvertSquare) { Reset(); Scope s = Scope::NewRootScope(); - auto input = ops::Placeholder(s.WithOpName("input"), tf_type); + auto input = ops::Placeholder(s.WithOpName("input"), tf_type_); auto square = ops::Square(s.WithOpName("my_square"), input); NodeDef node_def = square.operation.node()->def(); @@ -2967,7 +2967,7 @@ TEST_P(OpConverterTest2, ConvertSquare) { inputs[i] = value; expected_outputs[i] = value * value; } - AddTestTensor("input", {1, 1, 20}, tf_type, inputs); + AddTestTensor("input", {1, 1, 20}, tf_type_, inputs); TestOpConverter("my_square", node_def, {1, 1, 20}, Status::OK(), Status::OK(), ArrayFloatNear(expected_outputs, 0)); @@ -3094,7 +3094,7 @@ TEST_P(OpConverterTest1, ConvertActivation) { { // Input is weights, should fail. Reset(); - const NodeDef& node_def = CreateUnaryOp(tf_type); + const NodeDef& node_def = CreateUnaryOp(tf_type_); AddTestWeights("input", {1, 2, 3}, {-3, -2, -1, 0, 1, 2}); RunValidationAndConversion( node_def, error::UNIMPLEMENTED, @@ -3151,7 +3151,7 @@ TEST_P(OpConverterTest1, ConvertActivation) { FAIL() << "Activation op test map does not contain op " << op_name; } Reset(); - NodeDef node_def = op_map[op_name].first(tf_type); + NodeDef node_def = op_map[op_name].first(tf_type_); const std::vector input = {-100, -2, -1, 0, 1, 88}; AddTestTensor("input", p.input_dims, input); @@ -3179,7 +3179,7 @@ TEST_P(OpConverterTest1, ConvertActivation) { TEST_P(OpConverterTest1, ConvertExpandDims) { // Get the NodeDef for ExpandDims. Scope s = Scope::NewRootScope(); - auto input = ops::Placeholder(s.WithOpName("input"), tf_type); + auto input = ops::Placeholder(s.WithOpName("input"), tf_type_); auto weights = ops::Placeholder(s.WithOpName("weights"), DT_INT32); auto expanddims = ops::ExpandDims(s.WithOpName("my_expanddims"), input, weights); @@ -3207,7 +3207,7 @@ TEST_P(OpConverterTest1, ConvertExpandDims) { {}, {1, 1, 1, 2, 3}, {0}, - trt_mode == TrtTestMode::kImplicitBatch + trt_mode_ == TrtTestMode::kImplicitBatch ? Status(error::UNIMPLEMENTED, "TensorRT does not allow manipulation of the " "batch dimension, at my_expanddims") @@ -3216,7 +3216,7 @@ TEST_P(OpConverterTest1, ConvertExpandDims) { {}, {1, 1, 1, 2, 3}, {-5}, - trt_mode == TrtTestMode::kImplicitBatch + trt_mode_ == TrtTestMode::kImplicitBatch ? Status(error::UNIMPLEMENTED, "TensorRT does not allow manipulation of the " "batch dimension, at my_expanddims") @@ -3254,7 +3254,7 @@ TEST_P(OpConverterTest1, ConvertExpandDims) { } TEST_P(OpConverterTest1, ConvertSqueeze) { - const bool use_implicit_batch = (trt_mode == TrtTestMode::kImplicitBatch); + const bool use_implicit_batch = (trt_mode_ == TrtTestMode::kImplicitBatch); // Get the NodeDef for Squeeze. auto get_squeeze_nodedef = [](std::vector axes, DataType tf_type) -> NodeDef { @@ -3277,7 +3277,7 @@ TEST_P(OpConverterTest1, ConvertSqueeze) { {}, // input partial dims {2, 3}, // expected output dims {}, // axis - trt_mode == TrtTestMode::kExplicitBatch + trt_mode_ == TrtTestMode::kExplicitBatch ? Status::OK() : Status{error::UNIMPLEMENTED, "Squeeze is not implemented for empty squeeze_dims, at " @@ -3336,7 +3336,7 @@ TEST_P(OpConverterTest1, ConvertSqueeze) { "Dimension 2 with size 2 cannot be squeezed because it must be " "size 1, at my_squeeze"}}; - if (trt_mode == TrtTestMode::kDynamicShape) { + if (trt_mode_ == TrtTestMode::kDynamicShape) { // In this test we try to squeeze axis=2 which has size > 1. In dynamic // shape mode the converter sees only -1, so it cannot catch this error. squeeze_non_singleton.status = Status::OK(); // conversion status @@ -3351,7 +3351,7 @@ TEST_P(OpConverterTest1, ConvertSqueeze) { for (TestParamBase p : test_params) { SCOPED_TRACE(p); Reset(); - NodeDef node_def = get_squeeze_nodedef(p.param, tf_type); + NodeDef node_def = get_squeeze_nodedef(p.param, tf_type_); AddTestTensor("input", p.input_dims, {1, 2, 3, 4, 5, 6}, p.partial_input_dims); TestOpConverter("my_squeeze", node_def, p.expected_output_dims, p.status, @@ -4106,14 +4106,14 @@ TEST_F(OpConverterTest, ConvertSlice) { TEST_P(OpConverterTest1, ConvertConv2D) { // Get nodedef for Conv2D layer. - DataType tf_type_loc = tf_type; + DataType tf_type = tf_type_; auto get_conv2d_nodedef = - [tf_type_loc](std::vector strides = {1, 1, 1, 1}, - string padding = "SAME", string data_format = "NCHW", - std::vector dilations = {1, 1, 1, 1}) -> NodeDef { + [tf_type](std::vector strides = {1, 1, 1, 1}, + string padding = "SAME", string data_format = "NCHW", + std::vector dilations = {1, 1, 1, 1}) -> NodeDef { Scope s = Scope::NewRootScope(); - auto input = ops::Placeholder(s.WithOpName("input"), tf_type_loc); - auto filter = ops::Placeholder(s.WithOpName("weights"), tf_type_loc); + auto input = ops::Placeholder(s.WithOpName("input"), tf_type); + auto filter = ops::Placeholder(s.WithOpName("weights"), tf_type); ops::Conv2D::Attrs attrs = ops::Conv2D::Attrs().DataFormat(data_format).Dilations(dilations); auto conv2d = ops::Conv2D(s.WithOpName("my_conv2d"), input, filter, strides, @@ -4206,12 +4206,12 @@ TEST_P(OpConverterTest1, ConvertConv2D) { node_def, error::UNIMPLEMENTED, "Stride must be 1 for batch and channel dimensions, at my_conv2d"); } - if (trt_mode == TrtTestMode::kDynamicShape) { + if (trt_mode_ == TrtTestMode::kDynamicShape) { Reset(); NodeDef node_def = get_conv2d_nodedef(); // Channel dim unknown, should fail. AddTestTensorWithTFDims("input", {-1, -1, -1, -1}, - TfDataTypeToTrt(tf_type)); + TfDataTypeToTrt(tf_type_)); AddTestWeights("weights", {1, 2, 1, 1}, {-1, 1}); RunValidationAndConversion( node_def, error::INVALID_ARGUMENT, @@ -4233,8 +4233,6 @@ TEST_P(OpConverterTest1, ConvertConv2D) { // Ok. std::vector ok_params = { -// TODO(b/162447069): Enable the test parameters for TRT 7.1.3.x. -#if !IS_TRT_VERSION_GE(7, 1, 3, 0) // Basic TestParams{/*input_dims=*/{1, 1, 2, 3}, /*input=*/{0, 1, 2, 3, 3, 4}, @@ -4246,9 +4244,6 @@ TEST_P(OpConverterTest1, ConvertConv2D) { /*dilations=*/{1, 1, 1, 1}, /*expected_output_dims=*/{1, 1, 2, 2}, /*expected_output=*/{1, 1, 0, 1}}, -#endif -// TODO(b/162448349): Enable the test parameters for TRT 7.1.3.x. -#if !IS_TRT_VERSION_GE(7, 1, 3, 0) // SAME padding (Asymmetric) TestParams{/*input_dims=*/{1, 1, 2, 3}, /*input=*/{0, 1, 2, 3, 3, 4}, @@ -4271,9 +4266,6 @@ TEST_P(OpConverterTest1, ConvertConv2D) { /*dilations=*/{1, 1, 1, 1}, /*expected_output_dims=*/{1, 1, 2, 3}, /*expected_output=*/{1, 2, -1, 3, 1, -3}}, -#endif -// TODO(b/162447069): Enable the test parameters for TRT 7.1.3.x. -#if !IS_TRT_VERSION_GE(7, 1, 3, 0) // NHWC TestParams{/*input_dims=*/{1, 2, 3, 1}, /*input=*/{0, 1, 2, 3, 3, 4}, @@ -4307,7 +4299,6 @@ TEST_P(OpConverterTest1, ConvertConv2D) { /*dilations=*/{1, 1, 1, 1}, /*expected_output_dims=*/{1, 1, 2, 2}, /*expected_output=*/{1, 0, 1, 3}}, -#endif }; for (int i = 0; i < ok_params.size(); i++) { @@ -4316,15 +4307,15 @@ TEST_P(OpConverterTest1, ConvertConv2D) { get_conv2d_nodedef(ok_params[i].strides, ok_params[i].padding, ok_params[i].data_format, ok_params[i].dilations); std::vector partial_input_shape; - if (trt_mode == TrtTestMode::kDynamicShape) { + if (trt_mode_ == TrtTestMode::kDynamicShape) { // The channel dim cannot have unknown size, fix that. partial_input_shape.resize(ok_params[i].input_dims.size(), -1); int channel_id = (ok_params[i].data_format == "NCHW") ? 1 : 3; partial_input_shape[channel_id] = ok_params[i].input_dims[channel_id]; } - AddTestTensor("input", ok_params[i].input_dims, tf_type, ok_params[i].input, - partial_input_shape); + AddTestTensor("input", ok_params[i].input_dims, tf_type_, + ok_params[i].input, partial_input_shape); AddTestWeights("weights", ok_params[i].filter_dims, ok_params[i].filter); @@ -4851,7 +4842,7 @@ TEST_P(OpConverterTest1, ConvertPool) { for (int nDim : test_nDims) { // Input is weights, should fail. Reset(); - NodeDef node_def = get_pool_nodedef(tf_type, nDim); + NodeDef node_def = get_pool_nodedef(tf_type_, nDim); AddTestWeights("input", {1, 1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}); RunValidationAndConversion(node_def, error::UNIMPLEMENTED, @@ -4960,7 +4951,7 @@ TEST_P(OpConverterTest1, ConvertPool) { for (bool is_max_pooling : {true, false}) { Reset(); NodeDef node_def = - get_pool_nodedef(tf_type, nDim, ksize, strides, p.padding, + get_pool_nodedef(tf_type_, nDim, ksize, strides, p.padding, data_format, is_max_pooling); AddTestTensor("input", input_dims, input); TestOpConverter("my_pool", node_def, expected_output_dims, Status::OK(), @@ -5022,7 +5013,7 @@ TEST_F(OpConverterTest, ConvertTopK) { TEST_P(OpConverterTest3, ConvertGather) { // Get the NodeDef for GatherV2. Scope s = Scope::NewRootScope(); - auto params = ops::Placeholder(s.WithOpName("params"), tf_type); + auto params = ops::Placeholder(s.WithOpName("params"), tf_type_); auto indices = ops::Placeholder(s.WithOpName("indices"), DT_INT32); auto axis = ops::Placeholder(s.WithOpName("axis"), DT_INT32); auto gather = ops::GatherV2(s.WithOpName("my_gather"), params, indices, axis); @@ -5030,7 +5021,7 @@ TEST_P(OpConverterTest3, ConvertGather) { { // Axis is a tensor, should fail. Reset(); - AddTestTensor("params", {1, 1, 2, 3}, tf_type, {}); + AddTestTensor("params", {1, 1, 2, 3}, tf_type_, {}); AddTestTensor("indices", {1, 2}, DT_INT32, {}); AddTestTensor("axis", {1}, DT_INT32, {}); RunValidationAndConversion( @@ -5075,7 +5066,7 @@ TEST_P(OpConverterTest3, ConvertGather) { /*expected_output_shape=*/{2, 1, 1, 3}, /*expected_output=*/{4, 5, 6, 1, 2, 3}, /*params_is_tensor=*/true, - trt_mode == TrtTestMode::kImplicitBatch + trt_mode_ == TrtTestMode::kImplicitBatch ? Status{error::UNIMPLEMENTED, "TensorRT does not allow manipulation of the" " batch dimension, at my_gather"} @@ -5088,7 +5079,7 @@ TEST_P(OpConverterTest3, ConvertGather) { /*expected_output_shape=*/{2, 1, 2, 1}, /*expected_output=*/{3, 1, 6, 4}, /*params_is_tensor=*/true, - trt_mode == TrtTestMode::kImplicitBatch + trt_mode_ == TrtTestMode::kImplicitBatch ? Status{error::UNIMPLEMENTED, "Indices must have a batch size of 1 when params" " is a tensor."} @@ -5102,7 +5093,7 @@ TEST_P(OpConverterTest3, ConvertGather) { /*expected_output_shape=*/{2, 1, 2}, /*expected_output=*/{2, 3, 5, 6}, /*params_is_tensor=*/false, - trt_mode == TrtTestMode::kImplicitBatch + trt_mode_ == TrtTestMode::kImplicitBatch ? Status{error::UNIMPLEMENTED, "The input axis must be zero when params is a" " weight."} @@ -5115,13 +5106,13 @@ TEST_P(OpConverterTest3, ConvertGather) { /*expected_output_shape=*/{2}, /*expected_output=*/{2, 4}, /*params_is_tensor=*/true, - trt_mode == TrtTestMode::kImplicitBatch // conversion_status + trt_mode_ == TrtTestMode::kImplicitBatch // conversion_status ? Status{error::UNIMPLEMENTED, "TensorRT does not allow manipulation of the " "batch dimension, at my_gather"} : Status::OK(), - Status::OK(), // runtime_status - trt_mode == TrtTestMode::kImplicitBatch // add_index_status + Status::OK(), // runtime_status + trt_mode_ == TrtTestMode::kImplicitBatch // add_index_status ? Status{error::INVALID_ARGUMENT, "Batch size doesn't match for tensor indices: " "Provided batch size does not match converter " @@ -5236,7 +5227,7 @@ TEST_P(OpConverterTest3, ConvertGather) { if (p.params_is_tensor) { AddTestTensor("params", p.params_shape, params_input); } else { - AddTestWeights("params", p.params_shape, params_input, tf_type); + AddTestWeights("params", p.params_shape, params_input, tf_type_); } AddTestTensor("indices", p.indices_shape, DT_INT32, p.indices, {}, p.add_index_status); @@ -5276,7 +5267,7 @@ TEST_P(OpConverterTest1, ConvertReduce) { { // Input is weights, should fail. Reset(); - const NodeDef node_def = CreateReduceOp(tf_type, false); + const NodeDef node_def = CreateReduceOp(tf_type_, false); AddTestWeights("input", {1, 2, 3}, {-3, -2, -1, 0, 1, 2}); AddTestWeights("axis", {1}, {1}); RunValidationAndConversion( @@ -5286,7 +5277,7 @@ TEST_P(OpConverterTest1, ConvertReduce) { { // Axis is weights, should fail. Reset(); - const NodeDef node_def = CreateReduceOp(tf_type, false); + const NodeDef node_def = CreateReduceOp(tf_type_, false); AddTestTensor("input", {1, 2, 3}, {-3, -2, -1, 0, 1, 2}); AddTestTensor("axis", {1}, DT_INT32, {1}); RunValidationAndConversion( @@ -5346,7 +5337,7 @@ TEST_P(OpConverterTest1, ConvertReduce) { for (auto p : params) { SCOPED_TRACE(StrCat(op.name, keep_dims ? "keep_dims" : "")); Reset(); - NodeDef node_def = op.get_node(tf_type, keep_dims); + NodeDef node_def = op.get_node(tf_type_, keep_dims); AddTestTensor("input", p.input_dims, p.input_values); AddTestWeights("axis", {static_cast(p.axis.size())}, @@ -5366,7 +5357,7 @@ TEST_P(OpConverterTest1, ConvertReduce) { int ax_positive = ax >= 0 ? ax : ax + rank; // Zero marks elements that we will remove later. expected_output_dims[ax_positive] = keep_dims ? 1 : 0; - if (trt_mode == TrtTestMode::kImplicitBatch && + if (trt_mode_ == TrtTestMode::kImplicitBatch && (ax == 0 || ax == -rank)) { p.conversion_status = errors::Unimplemented( "TensorRT does not allow manipulation of the batch " @@ -5402,7 +5393,7 @@ TEST_P(OpConverterTest1, ConvertUnary) { { // Input is weights, should fail. Reset(); - const NodeDef node_def = CreateUnaryOp(tf_type); + const NodeDef node_def = CreateUnaryOp(tf_type_); AddTestWeights("input", {1, 2, 3}, {-3, -2, -1, 0, 1, 2}); RunValidationAndConversion( node_def, error::UNIMPLEMENTED, @@ -5458,7 +5449,7 @@ TEST_P(OpConverterTest1, ConvertUnary) { if (!op_map.count(op_name)) { FAIL() << "Unary op test map does not contain op " << op_name; } - NodeDef node_def = op_map[op_name].first(tf_type); + NodeDef node_def = op_map[op_name].first(tf_type_); // TODO(bixia): we assume this test is only instantiated for DT_FLOAT for // now. Need to find a better way to express input and output types. @@ -5466,7 +5457,7 @@ TEST_P(OpConverterTest1, ConvertUnary) { // TODO(tfeher): improve tests by defining an expected output data type and // check that. Currently only the shape and values of the output are // checked. - DataType input_tf_type = op_name == "Cast" ? DT_HALF : tf_type; + DataType input_tf_type = op_name == "Cast" ? DT_HALF : tf_type_; std::vector input_values{-0.9f, 0.6f, 0.0f, -3.5f, 100.0f, 2.9f}; AddTestTensor("input", p.input_dims, input_tf_type, input_values); @@ -6033,7 +6024,7 @@ TEST_P(OpConverterTest2, ConvertPack) { /*axis=*/1, /*expected_output_dims=*/{1, 2, 2, 3}, /*expected_output=*/InitTestVector(12), - trt_mode == TrtTestMode::kImplicitBatch + trt_mode_ == TrtTestMode::kImplicitBatch ? Status{error::UNIMPLEMENTED, "The input \"values_1\" for Pack must be a tensor, at " "my_pack"} @@ -6059,7 +6050,7 @@ TEST_P(OpConverterTest2, ConvertPack) { /*axis=*/-4, /*expected_output_dims=*/{2, 1, 2, 3}, /*expected_output=*/InitTestVector(12), - trt_mode == TrtTestMode::kImplicitBatch + trt_mode_ == TrtTestMode::kImplicitBatch ? Status{error::UNIMPLEMENTED, "TensorRT does not allow manipulation of the batch " "dimension, at my_pack"} @@ -6119,7 +6110,7 @@ TEST_P(OpConverterTest2, ConvertPack) { }, }; // Inputs have inconsistent shapes, should fail. - if (trt_mode != TrtTestMode::kDynamicShape) { + if (trt_mode_ != TrtTestMode::kDynamicShape) { params.push_back(TestParams{ /*input_shapes=*/{{1, 2, 3}, {1, 3, 2}}, /*partial_input_shapes=*/{{}, {}}, @@ -6139,7 +6130,7 @@ TEST_P(OpConverterTest2, ConvertPack) { // TODO(tfeher) Add dynamic shapes test once TRT handles shape error // decently } - if (trt_mode == TrtTestMode::kDynamicShape) { + if (trt_mode_ == TrtTestMode::kDynamicShape) { // Test with mixed dynamic / static shape input tensors params.push_back( TestParams{/*input_shapes=*/{{1, 2, 3}, {1, 2, 3}}, @@ -6155,14 +6146,14 @@ TEST_P(OpConverterTest2, ConvertPack) { const int num_inputs = p.input_shapes.size(); EXPECT_EQ(num_inputs, p.input_values.size()); - NodeDef node_def = GetPackNodeDef(tf_type, num_inputs, p.axis); + NodeDef node_def = GetPackNodeDef(tf_type_, num_inputs, p.axis); // Create inputs. for (int j = 0; j < num_inputs; ++j) { if (j == 1 && p.input_1_is_weight) { AddTestWeights(StrCat("values_", j), p.input_shapes[j], - p.input_values[j], tf_type); + p.input_values[j], tf_type_); } else { - AddTestTensor(StrCat("values_", j), p.input_shapes[j], tf_type, + AddTestTensor(StrCat("values_", j), p.input_shapes[j], tf_type_, p.input_values[j], p.partial_input_shapes[j]); } } @@ -6690,7 +6681,7 @@ TEST_P(OpConverterTest2, ConvertSquaredDifference) { { // Input is a weight, should fail. Reset(); - NodeDef node_def = GetSquaredDifferenceNodeDef(tf_type); + NodeDef node_def = GetSquaredDifferenceNodeDef(tf_type_); AddTestWeights("x", {1, 2, 3}, {1, 2, 3, 4, 5, 6}); AddTestTensor("y", {1, 1, 2, 3}); RunValidationAndConversion(node_def, error::UNIMPLEMENTED, @@ -6717,7 +6708,7 @@ TEST_P(OpConverterTest2, ConvertSquaredDifference) { /*value_y=*/std::vector(7 * 5, 0), /*expected_output_dims=*/{1, 1, 2, 3}, /*expected_output=*/common_input, - trt_mode == TrtTestMode::kDynamicShape + trt_mode_ == TrtTestMode::kDynamicShape ? Status::OK() : errors::InvalidArgument("Infeasible broadcast scheme"), errors::Internal( @@ -6743,7 +6734,7 @@ TEST_P(OpConverterTest2, ConvertSquaredDifference) { for (auto p : params) { Reset(); - NodeDef node_def = GetSquaredDifferenceNodeDef(tf_type); + NodeDef node_def = GetSquaredDifferenceNodeDef(tf_type_); AddTestTensor("x", p.dims_x, p.value_x); AddTestTensor("y", p.dims_y, p.value_y); TestOpConverter("my_squared_diff", node_def, p.expected_output_dims, @@ -6779,7 +6770,7 @@ template void TestConvertResize(OpConverterTest* test) { typedef typename EnumToDataType::Type CType; - std::vector> params{ + std::vector> params { // TODO(b/162442839): Enable the test parameters for TRT 7.1.3.x. #if !IS_TRT_VERSION_GE(7, 1, 3, 0) { diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index f8319cd446a..4d8f6f96811 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -1143,7 +1143,11 @@ Status ValidateGraph(const Graph* graph, return errors::InvalidArgument(absl::StrCat( "Detected unsupported operations when trying to compile graph ", name, " on ", device_type.type_string(), ": ", node->def().op(), " (", - s.error_message(), ")", FormatNodeForError(*node))); + s.error_message(), ")", FormatNodeForError(*node), + "One approach is to outside compile the unsupported ops to run on " + "CPUs by enabling soft placement " + "`tf.config.set_soft_device_placement(True)`." + " This has a potential performance penalty.")); } return Status::OK(); }; diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h index b0d93cde846..762700eaea8 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.h +++ b/tensorflow/compiler/tf2xla/xla_compiler.h @@ -129,8 +129,6 @@ class XlaCompiler { // Resource updates are converted into input / output of xla. The two // buffers are aliased with other if this option is true. - // - // Currently only supports TPU. bool alias_resource_update = false; }; diff --git a/tensorflow/compiler/xla/client/lib/BUILD b/tensorflow/compiler/xla/client/lib/BUILD index d2f174eadb5..b2a18492c57 100644 --- a/tensorflow/compiler/xla/client/lib/BUILD +++ b/tensorflow/compiler/xla/client/lib/BUILD @@ -305,6 +305,7 @@ xla_test( "//tensorflow/compiler/xla/tests:test_macros_header", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", + "//tensorflow/core/platform:tf32_utils", ], ) @@ -345,6 +346,9 @@ cc_library( hdrs = ["sorting.h"], deps = [ ":comparators", + ":constants", + ":loops", + ":slicing", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", diff --git a/tensorflow/compiler/xla/client/lib/matrix.cc b/tensorflow/compiler/xla/client/lib/matrix.cc index ec1cc7e0487..dbb73602801 100644 --- a/tensorflow/compiler/xla/client/lib/matrix.cc +++ b/tensorflow/compiler/xla/client/lib/matrix.cc @@ -395,7 +395,6 @@ xla::XlaOp Einsum(xla::XlaOp x, absl::Span x_config, xla::XlaOp y, } DotDimensionNumbers dnums; - std::vector lhs_outer_dims; auto is_batch_dim = [&](int64 d) { return x_map.contains(d) && y_map.contains(d) && output_map.contains(d); }; @@ -408,11 +407,13 @@ xla::XlaOp Einsum(xla::XlaOp x, absl::Span x_config, xla::XlaOp y, }; absl::InlinedVector rhs_outer_dims; + absl::InlinedVector lhs_outer_dims; absl::InlinedVector rhs_delete_dims; absl::InlinedVector lhs_delete_dims; for (int64 i = 0; i < x_rank; ++i) { auto dim_name = x_config[i]; const int64 rhs_dim = rhs_dimension_number(dim_name); + if (is_batch_dim(dim_name)) { if (x_shape.dimensions(i) == y_shape.dimensions(rhs_dim)) { dnums.add_lhs_batch_dimensions(i); @@ -448,30 +449,34 @@ xla::XlaOp Einsum(xla::XlaOp x, absl::Span x_config, xla::XlaOp y, } absl::c_sort(rhs_outer_dims); - absl::InlinedVector output_transpose_dims; - absl::InlinedVector output_reduce_dims; - auto output_dimension_number = [&](int64 d) { + + auto output_dimension_number = [&](int64 d) -> absl::optional { auto pos = absl::c_find(output_config, d); if (pos == output_config.end()) { - const int64 dim = - output_transpose_dims.size() + output_reduce_dims.size(); - output_reduce_dims.push_back(dim); - } else { - output_transpose_dims.push_back(pos - output_config.begin()); + return absl::nullopt; } + return pos - output_config.begin(); }; for (auto d : dnums.lhs_batch_dimensions()) { - output_dimension_number(x_config[d]); + output_transpose_dims.push_back(*output_dimension_number(x_config[d])); } for (auto d : lhs_outer_dims) { - output_dimension_number(x_config[d]); + if (auto output_dim = output_dimension_number(x_config[d])) { + output_transpose_dims.push_back(*output_dim); + continue; + } + lhs_delete_dims.push_back(d); } for (auto d : rhs_outer_dims) { - output_dimension_number(y_config[d]); + if (auto output_dim = output_dimension_number(y_config[d])) { + output_transpose_dims.push_back(*output_dim); + continue; + } + rhs_delete_dims.push_back(d); } const int64 transpose_rank = output_transpose_dims.size(); @@ -482,29 +487,31 @@ xla::XlaOp Einsum(xla::XlaOp x, absl::Span x_config, xla::XlaOp y, // Remove ones that where broadcasted from the x and the y shape and adjust // the dimension numbers that are more minor than those dimensions. + absl::c_sort(lhs_delete_dims); DeleteDimsFromContainer(lhs_delete_dims, &x_shape, dnums.mutable_lhs_batch_dimensions(), dnums.mutable_lhs_contracting_dimensions()); + + absl::c_sort(rhs_delete_dims); DeleteDimsFromContainer(rhs_delete_dims, &y_shape, dnums.mutable_rhs_batch_dimensions(), dnums.mutable_rhs_contracting_dimensions()); if (!lhs_delete_dims.empty()) { - x = Reshape(x, x_shape.dimensions()); + x = Reduce(x, ScalarLike(x, 0), + CreateScalarAddComputation(x_shape.element_type(), builder), + lhs_delete_dims); } if (!rhs_delete_dims.empty()) { - y = Reshape(y, y_shape.dimensions()); + y = Reduce(y, ScalarLike(y, 0), + CreateScalarAddComputation(y_shape.element_type(), builder), + rhs_delete_dims); } PrecisionConfig precision_proto; precision_proto.add_operand_precision(precision); precision_proto.add_operand_precision(precision); auto dot = DotGeneral(x, y, dnums, &precision_proto); - if (!output_reduce_dims.empty()) { - dot = Reduce(dot, ScalarLike(dot, 0), - CreateScalarAddComputation(x_shape.element_type(), builder), - output_reduce_dims); - } dot = Transpose(dot, transpose_dims); if (transpose_rank == output_rank) { return dot; diff --git a/tensorflow/compiler/xla/client/lib/qr_test.cc b/tensorflow/compiler/xla/client/lib/qr_test.cc index a61f243e126..9752f844dfd 100644 --- a/tensorflow/compiler/xla/client/lib/qr_test.cc +++ b/tensorflow/compiler/xla/client/lib/qr_test.cc @@ -27,12 +27,14 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/tf32_utils.h" namespace { using QrTest = xla::ClientLibraryTestBase; XLA_TEST_F(QrTest, Simple) { + tensorflow::allow_tf32_execution(false); // Test fails with tf32 allowed xla::XlaBuilder builder(TestName()); xla::Array2D a_vals({ @@ -61,6 +63,7 @@ XLA_TEST_F(QrTest, Simple) { } XLA_TEST_F(QrTest, ZeroDiagonal) { + tensorflow::allow_tf32_execution(false); // Test fails with tf32 allowed xla::XlaBuilder builder(TestName()); xla::Array2D a_vals({ @@ -88,6 +91,7 @@ XLA_TEST_F(QrTest, ZeroDiagonal) { } XLA_TEST_F(QrTest, SimpleBatched) { + tensorflow::allow_tf32_execution(false); // Test fails with tf32 allowed xla::XlaBuilder builder(TestName()); xla::Array3D a_vals({ diff --git a/tensorflow/compiler/xla/client/lib/sorting.cc b/tensorflow/compiler/xla/client/lib/sorting.cc index 750237c2000..5a7a70192d1 100644 --- a/tensorflow/compiler/xla/client/lib/sorting.cc +++ b/tensorflow/compiler/xla/client/lib/sorting.cc @@ -16,6 +16,9 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/sorting.h" #include "tensorflow/compiler/xla/client/lib/comparators.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/loops.h" +#include "tensorflow/compiler/xla/client/lib/slicing.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/util.h" @@ -27,6 +30,19 @@ XlaOp TopK(XlaOp input, int64 k) { return builder->ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input)); int last_dim = input_shape.dimensions_size() - 1; + int64 last_dim_size = input_shape.dimensions(last_dim); + // TODO(b/165839365): tune these constants for better performance. + int64 kPerPartitionSize = 8192; // 2^13 + int64 kLastDimSizeThreshold = 524288; // 2^19 + int64 kMinNumPartitions = 8; + if ((k > 0) && (k < kPerPartitionSize) && (kPerPartitionSize / k > 2) && + last_dim_size >= kLastDimSizeThreshold) { + int64 num_partitions = + CeilOfRatio(last_dim_size - k, kPerPartitionSize - k); + if (num_partitions >= kMinNumPartitions) { + return TopKWithPartitions(input, k, num_partitions); + } + } Shape iota_shape = ShapeUtil::MakeShape(S32, AsInt64Slice(input_shape.dimensions())); @@ -80,30 +96,35 @@ XlaOp TopKWithPartitions(XlaOp input, int64 k, int64 num_partitions) { } } - XlaOp values, indices; - for (int64 partition = 0; partition < num_partitions; partition++) { - std::vector start_indices(input_shape.dimensions_size(), 0); - std::vector limit_indices(input_dims.begin(), input_dims.end()); - std::vector strides(input_shape.dimensions_size(), 1); - start_indices[last_dim] = partition * per_partition_size; - limit_indices[last_dim] = - std::min((partition + 1) * per_partition_size, last_dim_size); - // Slice value and indices for this partition.. - XlaOp sliced_input = Slice(input, start_indices, limit_indices, strides); + auto topk_body_fn = + [&](XlaOp partition, absl::Span values_and_indices, + XlaBuilder* builder) -> StatusOr> { + auto values = values_and_indices[0]; + auto indices = values_and_indices[1]; + auto input = values_and_indices[2]; + auto iota_s32 = values_and_indices[3]; + + // Slice value and indices for this partition. + XlaOp start = Mul(Add(partition, ConstantR0(builder, 1)), + ConstantR0(builder, per_partition_size)); + XlaOp sliced_input = + DynamicSliceInMinorDims(input, {start}, {per_partition_size}); XlaOp sliced_indices = - Slice(iota_s32, start_indices, limit_indices, strides); + DynamicSliceInMinorDims(iota_s32, {start}, {per_partition_size}); // Concat with previous results. - if (partition > 0) { - sliced_input = ConcatInDim(builder, {values, sliced_input}, last_dim); - sliced_indices = - ConcatInDim(builder, {indices, sliced_indices}, last_dim); - } + sliced_input = ConcatInDim(builder, {values, sliced_input}, last_dim); + sliced_indices = + ConcatInDim(builder, {indices, sliced_indices}, last_dim); // Sort this slice XlaOp sort_result = Sort({sliced_input, sliced_indices}, CreateScalarGtComputation({input_shape.element_type(), S32}, sliced_indices.builder()), - last_dim, /*is_stable=*/true); + last_dim, true); + + std::vector start_indices(input_shape.dimensions_size(), 0); + std::vector limit_indices(input_dims.begin(), input_dims.end()); + std::vector strides(input_shape.dimensions_size(), 1); // Slice topk. start_indices[last_dim] = 0; limit_indices[last_dim] = k; @@ -111,8 +132,42 @@ XlaOp TopKWithPartitions(XlaOp input, int64 k, int64 num_partitions) { limit_indices, strides); indices = Slice(GetTupleElement(sort_result, 1), start_indices, limit_indices, strides); - } - return Tuple(builder, {values, indices}); + return std::vector{values, indices, input, iota_s32}; + }; + + // Get the values and indices for the first topk so that they can + // be passed to the while loop. + std::vector start_indices(input_shape.dimensions_size(), 0); + std::vector limit_indices(input_dims.begin(), input_dims.end()); + std::vector strides(input_shape.dimensions_size(), 1); + start_indices[last_dim] = 0; + limit_indices[last_dim] = per_partition_size; + // Slice value and indices for the first partition. + XlaOp sliced_input = Slice(input, start_indices, limit_indices, strides); + XlaOp sliced_indices = + Slice(iota_s32, start_indices, limit_indices, strides); + // Sort this slice + XlaOp sort_result = + Sort({sliced_input, sliced_indices}, + CreateScalarGtComputation({input_shape.element_type(), S32}, + sliced_indices.builder()), + last_dim, /*is_stable=*/true); + + // Slice topk. + start_indices[last_dim] = 0; + limit_indices[last_dim] = k; + XlaOp values = Slice(GetTupleElement(sort_result, 0), start_indices, + limit_indices, strides); + XlaOp indices = Slice(GetTupleElement(sort_result, 1), start_indices, + limit_indices, strides); + + // Pass the result of the first TopK to the while loop and do + // num_partition - 1 iterations. + TF_ASSIGN_OR_RETURN(auto values_and_indices, + ForEachIndex(num_partitions - 1, S32, topk_body_fn, + {values, indices, input, iota_s32}, + "topk_with_partition", builder)); + return Tuple(builder, {values_and_indices[0], values_and_indices[1]}); }); } diff --git a/tensorflow/compiler/xla/client/lib/sorting_test.cc b/tensorflow/compiler/xla/client/lib/sorting_test.cc index e01f6faf59e..e820d5bfe6f 100644 --- a/tensorflow/compiler/xla/client/lib/sorting_test.cc +++ b/tensorflow/compiler/xla/client/lib/sorting_test.cc @@ -118,6 +118,19 @@ XLA_TEST_F(SortingTest, TopK3From8Values5Partitions) { ComputeAndCompareR1(&builder, {7.0, 6.0, 5.0}, {}); } +XLA_TEST_F(SortingTest, DISABLED_TopKLargeInput) { + XlaBuilder builder(TestName()); + Array input({2, 1000000}); + input.FillRandom(1.0f, 2.0f); + auto x = + CreateConstantFromLiteral(LiteralUtil::CreateFromArray(input), &builder); + Array2D expected_array(2, 1000); + expected_array.Fill(2.0f); + xla::GetTupleElement(xla::TopK(x, 1000), 0); + ErrorSpec error_spec(10.0f, 10.0f); + ComputeAndCompareR2(&builder, expected_array, {}, error_spec); +} + XLA_TEST_F(SortingTest, TopK3From8Indices5Partitions) { XlaBuilder builder(TestName()); auto x_rev = diff --git a/tensorflow/compiler/xla/pjrt/pjrt_client.h b/tensorflow/compiler/xla/pjrt/pjrt_client.h index 1bed959e3e6..39711534f79 100644 --- a/tensorflow/compiler/xla/pjrt/pjrt_client.h +++ b/tensorflow/compiler/xla/pjrt/pjrt_client.h @@ -695,7 +695,7 @@ struct ExecuteOptions { int32 launch_id = 0; // If non-null, an opaque context passed to an execution that may be used to // supply additional arguments to a derived class of PjRtExecutable. - ExecuteContext* context = nullptr; + const ExecuteContext* context = nullptr; }; // Represents a compiled computation that can be executed given handles to diff --git a/tensorflow/compiler/xla/python/BUILD b/tensorflow/compiler/xla/python/BUILD index 046fadb405b..e1eb93f8dba 100644 --- a/tensorflow/compiler/xla/python/BUILD +++ b/tensorflow/compiler/xla/python/BUILD @@ -264,6 +264,7 @@ cc_library( "//tensorflow/core/platform:status", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:optional", "@pybind11", ], diff --git a/tensorflow/compiler/xla/python/jax_jit.cc b/tensorflow/compiler/xla/python/jax_jit.cc index 9deb2d1c755..6f7115270bf 100644 --- a/tensorflow/compiler/xla/python/jax_jit.cc +++ b/tensorflow/compiler/xla/python/jax_jit.cc @@ -32,6 +32,7 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/container/inlined_vector.h" +#include "absl/synchronization/notification.h" #include "absl/types/optional.h" #include "pybind11/cast.h" #include "pybind11/numpy.h" @@ -229,14 +230,18 @@ struct CacheEntry { // We need py::object to maintain the objects alive. std::vector out_avals; std::vector out_lazy_exprs; + // Ensures a single thread performs the compilation for a given executable. + // + // The first thread (holding the GIL) will create the CacheEntry associated to + // a signature and if the object has been insterted already, other threads + // will wait for the notification. + absl::Notification compilation_complete; + absl::optional compilation_error = absl::nullopt; }; // A `CompiledFunction` is associated to a `jax.jit(f)` and takes care of the // bookkeeping of the different signatures used and the dispatch of calls to -// the correct underlying `PyExecutable`. -// TODO(jblespiau): This class is thread-unsafe. Note that using a mutex for the -// full `Call` will lead to a deadlock because it goes back to Python which will -// release the GIL. +// the correct underlying `PyExecutable`. This class is thread-safe. class CompiledFunction { public: CompiledFunction(py::function fun, py::function cache_miss_fun, @@ -293,6 +298,20 @@ class CompiledFunction { // to `Call`. std::shared_ptr pyclient_ = nullptr; xla::PjRtDevice* default_device_ = nullptr; + + // IMPORTANT: The GIL is not always held, because we call back to Python and + // Python will release the GIL. + // Thus, we protect the critical section modifying the `executables_` map + // and more generally the compilation with some `absl::Notification`. + // The first thread reaching such point will be responsible to create the + // notification for the executable and others will wait until notified. + // It's safe because the first thread will be holding the GIL while + // initializing the `Notification`. + // + // absl::optional is not supported + bool first_compilation_started_ = false; + absl::Notification first_compilation_complete_; + absl::optional first_compilation_error_ = absl::nullopt; }; CompiledFunction::CompiledFunction(py::function fun, @@ -617,6 +636,13 @@ CacheEntry& CompiledFunction::GetCacheEntry( absl::optional cache_miss_return) { auto found_iterator = executables_.find(signature); if (found_iterator != executables_.end()) { // Cache hit! + if (!found_iterator->second->compilation_complete.HasBeenNotified()) { + py::gil_scoped_release gil_release; + found_iterator->second->compilation_complete.WaitForNotification(); + if (found_iterator->second->compilation_error) { + throw found_iterator->second->compilation_error.value(); + } + } return *(found_iterator->second); } return SetAndReturnCacheEntry(args, kwargs, signature, cache_miss_return); @@ -628,7 +654,7 @@ CacheEntry& CompiledFunction::SetAndReturnCacheEntry( // We need to insert the element. auto result = executables_.emplace(signature, std::make_unique()); auto it = result.first; - + CacheEntry& cache_entry = *(it->second.get()); // CallSignatures in the cache own their keyword argument reference. result.first->first.IncRef(); @@ -637,34 +663,40 @@ CacheEntry& CompiledFunction::SetAndReturnCacheEntry( if (cache_miss_return) { executable_and_pytree = cache_miss_return.value(); } else { - executable_and_pytree = cache_miss_fun_(*args, **kwargs); + try { + executable_and_pytree = cache_miss_fun_(*args, **kwargs); + } catch (const std::exception& e) { + cache_entry.compilation_error = e; + cache_entry.compilation_complete.Notify(); + throw; + } } if (executable_and_pytree.size() != 4) { throw std::runtime_error( "AssertionError: The cache miss function should return 4 " "arguments."); } - it->second->executable = py::cast>( + cache_entry.executable = py::cast>( std::move(executable_and_pytree[0])); int num_devices = - it->second->executable->pjrt_executable().local_devices().size(); + cache_entry.executable->pjrt_executable().local_devices().size(); if (num_devices != 1) { throw std::runtime_error(absl::StrCat( "Running on more than a single device is not currently supported." "The underlying PjRtExecutable has ", num_devices)); } - it->second->device = - it->second->executable->pjrt_executable().local_devices()[0]; - it->second->out_pytree_def = py::cast(executable_and_pytree[1]); + cache_entry.device = + cache_entry.executable->pjrt_executable().local_devices()[0]; + cache_entry.out_pytree_def = py::cast(executable_and_pytree[1]); py::list shaped_arrays = py::reinterpret_borrow(executable_and_pytree[2]); py::list lazy_expressions = py::reinterpret_borrow(executable_and_pytree[3]); - it->second->out_avals.reserve(shaped_arrays.size()); - it->second->out_lazy_exprs.reserve(lazy_expressions.size()); + cache_entry.out_avals.reserve(shaped_arrays.size()); + cache_entry.out_lazy_exprs.reserve(lazy_expressions.size()); int num_outputs = shaped_arrays.size(); for (int i = 0; i < num_outputs; ++i) { @@ -673,11 +705,12 @@ CacheEntry& CompiledFunction::SetAndReturnCacheEntry( py::object lazy_expr = py::reinterpret_borrow(lazy_expressions[i]); - it->second->out_avals.push_back(shaped_array); - it->second->out_lazy_exprs.push_back(lazy_expr); + cache_entry.out_avals.push_back(shaped_array); + cache_entry.out_lazy_exprs.push_back(lazy_expr); } - return *(it->second); + cache_entry.compilation_complete.Notify(); + return cache_entry; } py::object CompiledFunction::Call(py::args args, py::kwargs kwargs) { @@ -687,14 +720,36 @@ py::object CompiledFunction::Call(py::args args, py::kwargs kwargs) { ParsedArgumentsAsBuffers arguments; FlattenArguments(args, kwargs, static_argnums_, arguments); + // TODO(jblespiau): It would be preferable to have a single location for + // locking code. absl::optional cache_miss_result = absl::nullopt; if (!default_device_) { - cache_miss_result = cache_miss_fun_(*args, **kwargs); - auto executable = py::cast>( - cache_miss_result.value()[0]); + // TODO(jblespiau): This code will deadlock if a jitted function + // recursively calls itself. + if (first_compilation_started_) { + if (!first_compilation_complete_.HasBeenNotified()) { + py::gil_scoped_release gil_release; + first_compilation_complete_.WaitForNotification(); + if (first_compilation_error_) { + throw first_compilation_error_.value(); + } + } + } else { + first_compilation_started_ = true; + try { + cache_miss_result = cache_miss_fun_(*args, **kwargs); + } catch (const std::exception& e) { + first_compilation_error_ = e; + first_compilation_complete_.Notify(); + throw; + } + auto executable = py::cast>( + cache_miss_result.value()[0]); - pyclient_ = executable->client(); - default_device_ = executable->LocalDevices()[0].contents; + pyclient_ = executable->client(); + default_device_ = executable->LocalDevices()[0].contents; + first_compilation_complete_.Notify(); + } } // The C++ jit do not support Tracers arguments yet. The Python-based jit diff --git a/tensorflow/compiler/xla/python/xla.cc b/tensorflow/compiler/xla/python/xla.cc index d5977f4f0cf..06605660b63 100644 --- a/tensorflow/compiler/xla/python/xla.cc +++ b/tensorflow/compiler/xla/python/xla.cc @@ -171,13 +171,13 @@ class TraceMeWrapper : public tensorflow::profiler::TraceMeWrapper { void BuildProfilerSubmodule(py::module* m) { py::module profiler = m->def_submodule("profiler", "TensorFlow profiler integration"); - py::class_> + py::class_> profiler_server_class(profiler, "ProfilerServer"); profiler.def( "start_server", - [](int port) -> std::unique_ptr { - auto server = absl::make_unique(); + [](int port) -> std::unique_ptr { + auto server = absl::make_unique(); server->StartProfilerServer(port); return server; }, diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index dd16bd32dd1..a1d6959ed37 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -2343,6 +2343,7 @@ cc_library( ":hlo_dce", ":hlo_pass", ":hlo_pass_pipeline", + ":hlo_verifier", ":tuple_simplifier", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 214cbfa93a7..9c4208da098 100755 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -3303,6 +3303,9 @@ Status AlgebraicSimplifierVisitor::HandlePad(HloInstruction* pad) { // padding with a pad with non-negative padding followed by a slice. bool all_zero = true; bool has_negative = false; + // Used to possibly split off the unchanged padding dimensions. + std::vector padding_dimensions; + int64 dimension_index = 0; for (auto& padding_dimension : pad->padding_config().dimensions()) { if (padding_dimension.edge_padding_low() < 0 || padding_dimension.edge_padding_high() < 0) { @@ -3311,12 +3314,93 @@ Status AlgebraicSimplifierVisitor::HandlePad(HloInstruction* pad) { if (padding_dimension.edge_padding_low() != 0 || padding_dimension.edge_padding_high() != 0) { all_zero = false; + padding_dimensions.push_back(dimension_index); + } else if (padding_dimension.interior_padding()) { + padding_dimensions.push_back(dimension_index); } + dimension_index++; } if (all_zero) { - ReplaceInstructionIfSameShape(pad, pad->mutable_operand(0)); - return Status::OK(); + if (ReplaceInstructionIfSameShape(pad, pad->mutable_operand(0))) { + return Status::OK(); + } + } + + // The context of this optimization can be found at b/163617402 + // It tries to capture the case of pad(broadcast(x)), where + // x->shape().dimensions(), or broadcast(x)->dimensions(), is + // a subset of the padded dimensions in pad->config(), + // and the padded dimensions in pad->config() is in turn a strict + // subset of broadcast->shape().dimensions(). The combined op can be + // rewritten to broadcast2(pad(broadcast1(x))), where broadcast1 extends + // x with dimensions that need to be padded, and broadcast2 extends + // the result of padding to full dimensions. + // TODO(qyi): for future extensions: The condition for broadcast(x) + // ->dimensions() to be a subset of padded dimensions in pad->config() + // does not have to be strictly required, but it makes the calculation + // for optimization easier, so it is required by the current implementation. + // Only the second condition between the padded dimensions and the + // dimensions of the final shape have to be enforced for the optimization + // to make sense. If needed to remove the first constraint, the shape + // calculations across the implementation need to be re-adjusted. + auto pad_dims = padding_dimensions.size(); + if (pad_dims < dimension_index && + pad->operand(0)->opcode() == HloOpcode::kBroadcast && + pad->operand(0)->user_count() == 1 && + pad->operand(0)->operand(0)->shape().rank() <= pad_dims) { + // Check broadcast operand dimensions is a subset of pading_dimensions. + // If not, skip the optimization. + bool opt_is_valid = true; + std::vector broadcast_dimensions; + HloBroadcastInstruction* broadcast = + static_cast(pad->mutable_operand(0)); + for (auto broadcast_index : broadcast->dimensions()) { + bool found = false; + for (int i = 0; i < pad_dims; ++i) { + if (broadcast_index == padding_dimensions[i]) { + broadcast_dimensions.push_back(i); + found = true; + break; + } + } + if (!found) { + opt_is_valid = false; + break; + } + } + if (opt_is_valid) { + auto pad_shape = pad->shape(); + auto broadcast_shape = broadcast->shape(); + auto pad_shape1 = pad_shape; + auto broadcast_shape1 = broadcast_shape; + PaddingConfig pad_config; + for (int i = padding_dimensions.size() - 1; i >= 0; --i) { + int64 j = padding_dimensions[i]; + while (--dimension_index > j) { + broadcast_shape1.DeleteDimension(dimension_index); + pad_shape1.DeleteDimension(dimension_index); + } + } + while (--dimension_index >= 0) { + broadcast_shape1.DeleteDimension(dimension_index); + pad_shape1.DeleteDimension(dimension_index); + } + for (auto dimension_to_pad : padding_dimensions) { + auto dimension = pad_config.add_dimensions(); + *dimension = pad->padding_config().dimensions(dimension_to_pad); + } + *broadcast->mutable_shape() = broadcast_shape1; + *broadcast->mutable_dimensions() = broadcast_dimensions; + simplifier_->UpdateLayout(broadcast->mutable_shape()); + auto pad2 = + computation_->AddInstruction(pad->CloneWithNewShape(pad_shape1)); + *pad2->mutable_padding_config() = pad_config; + simplifier_->UpdateLayout(pad2->mutable_shape()); + auto broadcast2 = computation_->AddInstruction( + HloInstruction::CreateBroadcast(pad_shape, pad2, padding_dimensions)); + return ReplaceInstruction(pad, broadcast2); + } } if (has_negative) { @@ -3351,7 +3435,8 @@ Status AlgebraicSimplifierVisitor::HandlePad(HloInstruction* pad) { pad->shape(), nonzero_pad->mutable_shape())); simplifier_->UpdateLayout(nonzero_pad->mutable_shape()); - // Second, construct the slice instruction to perform the negative padding. + // Second, construct the slice instruction to perform the negative + // padding. std::vector start_indices; std::vector end_indices; std::vector strides; diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index 70147f6ecad..761b5fa0a82 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -6955,5 +6955,57 @@ TEST_F(AlgebraicSimplifierTest, UnaryVariadicReduce) { GmockMatch(m::Add(m::Parameter(0), m::Parameter(1)))); } +TEST_F(AlgebraicSimplifierTest, BroadcastAndPadReorder) { + const char* kModuleStr = R"( + HloModule m + test { + c1 = pred[] constant(true) + b2 = pred[32,1,768]{2,1,0} broadcast(pred[] c1), dimensions={} + c3 = pred[] constant(false) + ROOT p4 = pred[4096,1,768]{2,1,0} pad(pred[32,1,768]{2,1,0} b2, pred[] c3), padding=0_4064x0_0x0_0 + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Broadcast( + m::Pad(m::Broadcast(m::Constant()), m::Constant())))); +} + +TEST_F(AlgebraicSimplifierTest, BroadcastAndPadReorderWithUse) { + const char* kModuleStr = R"( + HloModule m + test { + c1 = pred[] constant(true) + b2 = pred[1,768,32]{2,1,0} broadcast(pred[] c1), dimensions={} + c3 = pred[] constant(false) + p4 = pred[1,768,4096]{2,1,0} pad(pred[1,768,32]{2,1,0} b2, pred[] c3), padding=0_0x0_0x0_4064 + ROOT p5 = (pred[1,768,4096]{2,1,0}) tuple(pred[1,768,4096]{2,1,0} p4) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Tuple(m::Broadcast( + m::Pad(m::Broadcast(m::Constant()), m::Constant()))))); +} + +TEST_F(AlgebraicSimplifierTest, BroadcastAndPadReorderWithNonScalar) { + const char* kModuleStr = R"( + HloModule m + test { + c1 = pred[32] parameter(0) + b2 = pred[1,768,32]{2,1,0} broadcast(pred[32] c1), dimensions={2} + c3 = pred[] constant(false) + p4 = pred[1,768,4096]{2,1,0} pad(pred[1,768,32]{2,1,0} b2, pred[] c3), padding=0_0x0_0x0_4064 + ROOT p5 = (pred[1,768,4096]{2,1,0}) tuple(pred[1,768,4096]{2,1,0} p4) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Tuple(m::Broadcast( + m::Pad(m::Broadcast(m::Parameter()), m::Constant()))))); +} } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/conditional_code_motion.cc b/tensorflow/compiler/xla/service/conditional_code_motion.cc index ce80b4cfc15..42caf20ff80 100644 --- a/tensorflow/compiler/xla/service/conditional_code_motion.cc +++ b/tensorflow/compiler/xla/service/conditional_code_motion.cc @@ -34,6 +34,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h" +#include "tensorflow/compiler/xla/service/hlo_verifier.h" #include "tensorflow/compiler/xla/service/tuple_simplifier.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -114,6 +115,8 @@ int64 ReusesCarriedBy(HloInstruction* op, HloInstruction* user) { case HloOpcode::kConstant: case HloOpcode::kGetTupleElement: return 0; + case HloOpcode::kConditional: + return 10; default: // Assume fusion will not happen anyway if user count > 1) if (op->user_count() > 1) { @@ -582,6 +585,7 @@ StatusOr ConditionalCodeMotion::MoveInstructionIn( // to replace the conditional directly in the new computation. b_opd_use.mutable_operands().push_back(conditional); } + HloInstruction* new_root = computation->AddInstruction(HloInstruction::CreateTuple(operands)); VLOG(2) << "setting new root: " << new_root->ToString() << "\n"; @@ -592,6 +596,15 @@ StatusOr ConditionalCodeMotion::MoveInstructionIn( } VLOG(2) << "new branch computation: " << computation->ToString() << "\n"; } + // Update get tuple element index of the conditional. + if (use_index != -1) { + for (auto* user : conditional->users()) { + if (user->opcode() == HloOpcode::kGetTupleElement && + user->tuple_index() > use_index) { + user->set_tuple_index(user->tuple_index() - 1); + } + } + } hoisted_instructions[conditional] = b_old_root; int64 cp_start = 0; if (use_index >= 0) { @@ -677,7 +690,7 @@ class GroupConnectedBoundaries { : conditional_(conditional), conditional_parent_(conditional->parent()), is_layout_sensitive_(is_layout_sensitive) {} - // Returns true if `instruction` is worth hoisting out. + // Returns true if `instruction` is worth hoisting. bool WorthHoisting(HloInstruction* instruction) { // This is needed for the "moving-in" transformation, to prevent the root // of the parent computation (which contains the conditional) to be moved @@ -708,6 +721,7 @@ class GroupConnectedBoundaries { case HloOpcode::kAllReduce: case HloOpcode::kAdd: case HloOpcode::kPower: + case HloOpcode::kCopy: case HloOpcode::kConstant: case HloOpcode::kSubtract: case HloOpcode::kMultiply: @@ -1070,6 +1084,7 @@ StatusOr ConditionalCodeMotion::Run(HloModule* module) { subpipeline.AddPass(); subpipeline.AddPass(); subpipeline.AddPass(); + subpipeline.AddPass(false, true); TF_ASSIGN_OR_RETURN(bool cleanup_changed, subpipeline.Run(module)); changed |= cleanup_changed; } diff --git a/tensorflow/compiler/xla/service/conditional_code_motion_test.cc b/tensorflow/compiler/xla/service/conditional_code_motion_test.cc index b91f3813980..e5e3873cc66 100644 --- a/tensorflow/compiler/xla/service/conditional_code_motion_test.cc +++ b/tensorflow/compiler/xla/service/conditional_code_motion_test.cc @@ -728,6 +728,66 @@ ENTRY main { EXPECT_THAT(root, AllOf(op::GetTupleElement(op::Conditional()))); } +TEST_F(ConditionalCodeMotionTest, MoveCopyInBranch) { + absl::string_view hlo_string = + R"( +HloModule RemoveIdenticalInstruction + +branch1 { + arg_tuple.1 = (s32[], f32[10,3]{0,1}) parameter(0) + constant.1 = s32[] constant(4) + get-tuple-element.1 = s32[] get-tuple-element(arg_tuple.1), index=0 + add.1 = s32[] add(get-tuple-element.1, constant.1) + get-tuple-element.2 = f32[10,3]{0,1} get-tuple-element(arg_tuple.1), index=1 + slice.1 = f32[4,3]{0,1} slice(get-tuple-element.2), + slice={[0:4:1], [0:3:1]} + constant.2 = f32[] constant(0.0) + ROOT tuple.1 = (f32[4,3]{0,1}, s32[],f32[]) tuple(slice.1, add.1, constant.2) +} + +branch2 { + arg_tuple.2 = (s32[], f32[4,3]{1,0}) parameter(0) + get-tuple-element.3 = s32[] get-tuple-element(arg_tuple.2), index=0 + copy.1 = s32[] copy(get-tuple-element.3) + get-tuple-element.4 = f32[4,3]{1,0} get-tuple-element(arg_tuple.2), index=1 + copy.2 = f32[4,3]{0,1} copy(get-tuple-element.4) + constant.2 = f32[] constant(0.0) + ROOT tuple.2 = (f32[4,3]{0,1}, s32[], f32[]) tuple(copy.2, copy.1, constant.2) +} + +ENTRY main { + pred.1 = pred[] parameter(0) + tuple.3 = (s32[], f32[10,3]{0,1}) parameter(1) + tuple.4 = (s32[], f32[4,3]{1,0}) parameter(2) + conditional = (f32[4,3]{0,1}, s32[], f32[]) + conditional(pred.1, tuple.3, tuple.4), true_computation=branch1, + false_computation=branch2 + get-zero-index = f32[4,3]{0,1} get-tuple-element(conditional), index=0 + get-first-index = s32[] get-tuple-element(conditional), index=1 + get-second-index = f32[] get-tuple-element(conditional), index=2 + copy.3 = f32[4,3]{1,0} copy(get-zero-index) + ROOT tuple.5 = (f32[4,3]{0,1}, s32[], f32[]) tuple(copy.3, get-first-index, + get-second-index) +} +)"; + auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); + ConditionalCodeMotion pass(true, true); + ASSERT_TRUE(pass.Run(&*module).ValueOrDie()); + VLOG(1) << module->ToString(); + + const HloInstruction* conditional = + FindInstruction(module.get(), "conditional"); + const HloComputation* on_true = conditional->branch_computation(0); + ASSERT_EQ(on_true->instruction_count(), 9); + const HloComputation* on_false = conditional->branch_computation(1); + ASSERT_EQ(on_false->instruction_count(), 8); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, + AllOf(op::Tuple(op::GetTupleElement(op::Conditional(), 2), + op::GetTupleElement(op::Conditional(), 0), + op::GetTupleElement(op::Conditional(), 1)))); +} + } // namespace conditional_opt } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index d1d0827981e..ce761d8e0ae 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -254,6 +254,7 @@ cc_library( ":target_util", ":thunk", ":thunk_emitter", + "//tensorflow/compiler/mlir/hlo", "//tensorflow/compiler/mlir/hlo:lhlo", "//tensorflow/compiler/mlir/xla:hlo_utils", "//tensorflow/compiler/mlir/xla:mhlo_to_lhlo_with_xla", diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc index b994ead17ca..4680f072140 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc @@ -60,6 +60,7 @@ bool GpuInstructionFusion::ShouldFuseInexpensiveChecks(HloInstruction* consumer, // Output fusions are not currently supported on GPUs. if (producer->opcode() == HloOpcode::kFusion) { + VLOG(4) << "Producer " << producer->name() << " is a fusion op"; return false; } // Cost condition: not fuse (simple, expensive producers) and (consumers who @@ -67,11 +68,15 @@ bool GpuInstructionFusion::ShouldFuseInexpensiveChecks(HloInstruction* consumer, if (producer->opcode() != HloOpcode::kFusion && consumer->ReusesOperandElements(operand_index) && is_expensive(*producer)) { + VLOG(4) << "Do not fuse simple, expensive producer " << producer->name() + << " and consumer which reuses operand elements."; return false; } if (!IsProducerConsumerFusible(*producer, *consumer) || !InstructionFusion::ShouldFuse(consumer, operand_index)) { + VLOG(4) << "Producer " << producer->name() + << " is not fusible or should not be fused."; return false; } return true; @@ -107,8 +112,13 @@ bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer, fusion_node_evaluations_.emplace(consumer, FusionNodeIndexingEvaluation(consumer)); } - return !fusion_node_evaluations_.at(consumer).AverageCodeDuplicationTooHigh( - producer); + if (fusion_node_evaluations_.at(consumer).AverageCodeDuplicationTooHigh( + producer)) { + VLOG(5) << "Fusion of " << producer->name() << " into " << consumer->name() + << " would result in overly large code duplication."; + return false; + } + return true; } bool GpuInstructionFusion::ShouldFuseIntoMultiOutput(HloInstruction* consumer, diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_context.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_context.h index 7d5a8d032e6..2a493fe4494 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_context.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_context.h @@ -17,7 +17,10 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_CONTEXT_H_ #include "llvm/IR/Module.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/gpu/launch_dimensions.h" #include "tensorflow/compiler/xla/service/hlo_execution_profile.h" @@ -44,7 +47,11 @@ class IrEmitterContext { cuda_compute_capability_(cuda_compute_capability), profile_index_map_(profile_index_map), mlir_context_(mlir_context), - llvm_module_(llvm_module) {} + llvm_module_(llvm_module) { + mlir_context_ + ->loadDialect(); + } // Disallow copy and assign. IrEmitterContext(const IrEmitterContext&) = delete; IrEmitterContext& operator=(const IrEmitterContext&) = delete; diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h index b9146dd8fae..c2955689f98 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h @@ -404,11 +404,10 @@ class IrEmitterUnnested : public IrEmitter, // the process. `scatter` may be fused, scatter indices are taken from // `scatter_indices_gen`, updates from`updates_gen`. The output buffer is // expected to have the operand values in it already. If unique_indices - // is false, we will use an atomic update. Using false for unique_indices - // is safe only when it is guaranteed that there are no duplicate - // indices. - // When using unique_indices=true, it is the caller's responsibility to - // ensure there is no overlap. + // is false, we will use an atomic update. Using true for unique_indices + // behaves properly only when it is guaranteed that the indices to be + // updated do not overlap. The caller is responsible for ensuring this is + // the case. Status EmitScatter(Thunk* thunk, HloInstruction* scatter, const llvm_ir::ElementGenerator& scatter_indices_gen, const llvm_ir::ElementGenerator& updates_gen); diff --git a/tensorflow/compiler/xla/service/hlo_sharding_util.cc b/tensorflow/compiler/xla/service/hlo_sharding_util.cc index bb02319a261..da4e3d61a81 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding_util.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding_util.cc @@ -470,13 +470,19 @@ HloSharding ScatterIndexSharding(const HloSharding& data_sharding, if (index_tile_assignment_dims.size() < hlo->operand(1)->shape().rank()) { index_tile_assignment_dims.push_back(1); } + if (data_sharding.ReplicateOnLastTileDim()) { + index_tile_assignment_dims.push_back( + data_sharding.tile_assignment().dimensions().back()); + } Array new_tile_assignment = data_sharding.tile_assignment(); if (new_tile_assignment.num_elements() != Product(index_tile_assignment_dims)) { return HloSharding::Replicate(); } new_tile_assignment.Reshape(index_tile_assignment_dims); - return HloSharding::Tile(new_tile_assignment); + return data_sharding.ReplicateOnLastTileDim() + ? HloSharding::PartialTile(new_tile_assignment) + : HloSharding::Tile(new_tile_assignment); } HloSharding ScatterDataSharding(const HloSharding& index_sharding, @@ -496,13 +502,19 @@ HloSharding ScatterDataSharding(const HloSharding& index_sharding, index_dim++; } } + if (index_sharding.ReplicateOnLastTileDim()) { + data_tile_assignment_dims.push_back( + index_sharding.tile_assignment().dimensions().back()); + } Array new_tile_assignment = index_sharding.tile_assignment(); if (new_tile_assignment.num_elements() != Product(data_tile_assignment_dims)) { return HloSharding::Replicate(); } new_tile_assignment.Reshape(data_tile_assignment_dims); - return HloSharding::Tile(new_tile_assignment); + return index_sharding.ReplicateOnLastTileDim() + ? HloSharding::PartialTile(new_tile_assignment) + : HloSharding::Tile(new_tile_assignment); } HloSharding ScatterEffectiveIndexSharding(const HloSharding& index_sharding, diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index b290b1bd68b..2085b1ea4d0 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -516,11 +516,12 @@ StatusOr InstructionFusion::Run(HloModule* module) { continue; } - VLOG(5) << "Considering fusion of: " << instruction->ToString(); std::vector& sorted_operand_numbers = next_entry.second; for (int64 i : sorted_operand_numbers) { HloInstruction* operand = instruction->mutable_operand(i); + VLOG(5) << "Considering fusion of: " << instruction->ToString() + << " with operand " << operand->name(); if (!operand->IsFusible()) { VLOG(3) << "Operand (" << operand->ToString() << ") is not fusible"; diff --git a/tensorflow/compiler/xla/service/memory_space_assignment.cc b/tensorflow/compiler/xla/service/memory_space_assignment.cc index c53f2c19695..12cdb17a0a5 100644 --- a/tensorflow/compiler/xla/service/memory_space_assignment.cc +++ b/tensorflow/compiler/xla/service/memory_space_assignment.cc @@ -1273,10 +1273,13 @@ AlternateMemoryBestFitHeap::AllocateAllocationValues( } } - // Bitcasts don't define buffers and don't directly consume buffers. Skip - // allocating buffers for bitcast uses. The uses that feed from bitcasts - // will be handled specially. - if (hlo_use.instruction->opcode() != HloOpcode::kBitcast) { + // Bitcasts don't define buffers and don't directly consume buffers. Skip + // allocating buffers for bitcast uses (unless they are the root + // instruction). The uses that feed from bitcasts will be handled + // specially. + if (hlo_use.instruction->opcode() != HloOpcode::kBitcast || + hlo_use.instruction == + hlo_use.instruction->parent()->root_instruction()) { AllocationRequest request; // Rarely, (e.g., when conditional true and false parameters are the // same), definition time can be the time of the conditional and use diff --git a/tensorflow/compiler/xla/service/memory_space_assignment_test.cc b/tensorflow/compiler/xla/service/memory_space_assignment_test.cc index cc4f740bc25..f9ca0f8309b 100644 --- a/tensorflow/compiler/xla/service/memory_space_assignment_test.cc +++ b/tensorflow/compiler/xla/service/memory_space_assignment_test.cc @@ -4066,6 +4066,51 @@ TEST_P(MemorySpaceAssignmentTest, MoveCopyDoneEarlier) { find_schedule_index(cos->operand(0))); } +TEST_P(MemorySpaceAssignmentTest, BitcastRoot) { + // Tests against a bug where the root of entry computation is a bitcast + // instruction and it ends up getting an allocation in the alternate memory. + absl::string_view hlo_string = R"( +HloModule primitive_computation_gather.4, is_scheduled=true + +%while_body { + %param.1 = (s32[], f32[3,3,3]) parameter(0) + %get-tuple-element.32 = s32[] get-tuple-element(%param.1), index=0 + %copy.6 = s32[] copy(s32[] %get-tuple-element.32) + %constant.8 = s32[] constant(1) + %add = s32[] add(s32[] %copy.6, s32[] %constant.8) + %get-tuple-element.35 = f32[3,3,3] get-tuple-element(%param.1), index=1 + negate = f32[3,3,3] negate(get-tuple-element.35) + ROOT %tuple.10 = (s32[], f32[3,3,3]) tuple(s32[] %add, f32[3,3,3] negate) +} + +%while_cond { + %param.0 = (s32[], f32[3,3,3]) parameter(0) + %get-tuple-element = s32[] get-tuple-element(%param.0), index=0 + %constant.3 = s32[] constant(3) + ROOT %compare = pred[] compare(s32[] %get-tuple-element, s32[] %constant.3), direction=LT +} + +ENTRY %primitive_computation_gather.4 (parameter.1: f32[3,10,5], parameter.2: s32[3,1]) -> f32[3,3,3] { + %constant.1 = s32[] constant(0) + %copy.11 = s32[] copy(s32[] %constant.1) + %constant = f32[] constant(0) + %broadcast = f32[3,3,3] broadcast(f32[] %constant), dimensions={} + %tuple.8 = (s32[], f32[3,10,5], s32[3,1], f32[3,3,3]) tuple(s32[] %copy.11, f32[3,3,3] %broadcast) + %while = (s32[], f32[3,3,3]) while(%tuple.8), condition=%while_cond, body=%while_body + %get-tuple-element.7 = f32[3,3,3] get-tuple-element(%while), index=1 + ROOT %bitcast.1 = f32[3,3,3] bitcast(f32[3,3,3] %get-tuple-element.7) +} + )"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + AssignMemorySpace(module.get()); + + const HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_TRUE(!root->shape().has_layout() || + root->shape().layout().memory_space() == kDefaultMemorySpace); +} + // A mock MemorySpaceAssignmentRepacker class that accepst a map of // (start_time,offset) -> new_offset values. Using this map, the repacker // repacks the allocations to the new_offset. diff --git a/tensorflow/compiler/xla/service/mlir_gpu/BUILD b/tensorflow/compiler/xla/service/mlir_gpu/BUILD index 786f28c7705..af670eb059f 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/BUILD +++ b/tensorflow/compiler/xla/service/mlir_gpu/BUILD @@ -41,9 +41,12 @@ cc_library( srcs = ["emission_context.cc"], hdrs = ["emission_context.h"], deps = [ + "//tensorflow/compiler/mlir/hlo", + "//tensorflow/compiler/mlir/hlo:lhlo", "//tensorflow/compiler/xla/service:hlo", "@com_google_absl//absl/strings", "@llvm-project//mlir:IR", + "@llvm-project//mlir:StandardOps", ], ) diff --git a/tensorflow/compiler/xla/service/mlir_gpu/emission_context.cc b/tensorflow/compiler/xla/service/mlir_gpu/emission_context.cc index cb5ea946c1b..06c7ebd1099 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/emission_context.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/emission_context.cc @@ -16,8 +16,11 @@ limitations under the License. #include "tensorflow/compiler/xla/service/mlir_gpu/emission_context.h" #include "absl/strings/substitute.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project #include "mlir/IR/Location.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" namespace xla { @@ -25,7 +28,8 @@ namespace mlir_gpu { EmissionContext::EmissionContext(std::unique_ptr module) : module_(std::move(module)), context_() { - context_.loadAllGloballyRegisteredDialects(); + context_.loadDialect(); error_handler_ = [](const ErrorMap& instructions_with_error, HloModule* module) { std::set computations_with_error; diff --git a/tensorflow/compiler/xla/service/sharding_propagation_test.cc b/tensorflow/compiler/xla/service/sharding_propagation_test.cc index ef39e90c2e5..b89e84e8afe 100644 --- a/tensorflow/compiler/xla/service/sharding_propagation_test.cc +++ b/tensorflow/compiler/xla/service/sharding_propagation_test.cc @@ -2025,6 +2025,38 @@ ENTRY entry { op::Sharding("{devices=[1,2]0,1}")); } +TEST_F(ShardingPropagationTest, DataOperandToScatter_PartialReplicate) { + const char* const hlo_string = R"( +HloModule module + +add (lhs: f32[], rhs: f32[]) -> f32[] { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT sum = f32[] add(lhs, rhs) +} + +ENTRY entry { + %input = f32[2,9] parameter(0), + sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate} + %indices = s32[3] parameter(1), sharding={replicated} + %updates = f32[3,9] parameter(2), sharding={replicated} + %scatter = f32[2,9] scatter(%input, %indices, %updates), + to_apply=add, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1 + ROOT %copy = f32[2,9] copy(%scatter) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "scatter"), + op::Sharding("{devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}")); +} + TEST_F(ShardingPropagationTest, UpdateOperandToScatter) { const char* const hlo_string = R"( HloModule module @@ -2056,6 +2088,70 @@ ENTRY entry { op::Sharding("{devices=[1,2]0,1}")); } +TEST_F(ShardingPropagationTest, UpdateOperandToScatter_PartialReplicate) { + const char* const hlo_string = R"( +HloModule module + +add (lhs: f32[], rhs: f32[]) -> f32[] { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT sum = f32[] add(lhs, rhs) +} + +ENTRY entry { + %input = f32[2,9] parameter(0), sharding={replicated} + %indices = s32[3] parameter(1), sharding={replicated} + %updates = f32[3,9] parameter(2), + sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate} + %scatter = f32[2,9] scatter(%input, %indices, %updates), + to_apply=add, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1 + ROOT %copy = f32[2,9] copy(%scatter) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "scatter"), + op::Sharding("{devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}")); +} + +TEST_F(ShardingPropagationTest, ScatterToDataOperand_PartialReplicate) { + const char* const hlo_string = R"( +HloModule module + +add (lhs: f32[], rhs: f32[]) -> f32[] { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT sum = f32[] add(lhs, rhs) +} + +ENTRY entry { + %p0 = f32[2,9] parameter(0) + %input = f32[2,9] copy(%p0) + %indices = s32[3] parameter(1), sharding={replicated} + %updates = f32[3,9] parameter(2), sharding={replicated} + ROOT %scatter = f32[2,9] scatter(%input, %indices, %updates), + to_apply=add, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1, + sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "input"), + op::Sharding("{devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}")); +} + TEST_F(ShardingPropagationTest, ScatterToDataOperand) { const char* const hlo_string = R"( HloModule module @@ -2087,6 +2183,38 @@ ENTRY entry { op::Sharding("{devices=[1,2]0,1}")); } +TEST_F(ShardingPropagationTest, ScatterToUpdateOperand_PartialReplicate) { + const char* const hlo_string = R"( +HloModule module + +add (lhs: f32[], rhs: f32[]) -> f32[] { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT sum = f32[] add(lhs, rhs) +} + +ENTRY entry { + %input = f32[2,9] parameter(0) + %indices = s32[3] parameter(1), sharding={replicated} + %p2 = f32[3,9] parameter(2) + %updates = f32[3,9] copy(%p2) + ROOT %scatter = f32[2,9] scatter(%input, %indices, %updates), + to_apply=add, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1, + sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "updates"), + op::Sharding("{devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}")); +} + TEST_F(ShardingPropagationTest, ScatterToUpdateOperand) { const char* const hlo_string = R"( HloModule module @@ -2149,6 +2277,38 @@ ENTRY entry { op::Sharding("{devices=[2]0,1}")); } +TEST_F(ShardingPropagationTest, ScatterUpdateToIndex_PartialReplicate) { + const char* const hlo_string = R"( +HloModule module + +add (lhs: f32[], rhs: f32[]) -> f32[] { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT sum = f32[] add(lhs, rhs) +} + +ENTRY entry { + %input = f32[2,9] parameter(0), sharding={replicated} + %p1 = s32[3] parameter(1), sharding={replicated} + %indices = s32[3] copy(%p1) + %updates = f32[3,9] parameter(2), + sharding={devices=[2,1,2]0,1,2,3 last_tile_dim_replicate} + ROOT %scatter = f32[2,9] scatter(%input, %indices, %updates), + to_apply=add, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1, sharding={replicated} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "indices"), + op::Sharding("{devices=[2,2]0,1,2,3 last_tile_dim_replicate}")); +} + TEST_F(ShardingPropagationTest, ScatterIndexToUpdate) { const char* const hlo_string = R"( HloModule module @@ -2180,6 +2340,38 @@ ENTRY entry { op::Sharding("{devices=[2,1]0,1}")); } +TEST_F(ShardingPropagationTest, ScatterIndexToUpdate_PartialReplicate) { + const char* const hlo_string = R"( +HloModule module + +add (lhs: f32[], rhs: f32[]) -> f32[] { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT sum = f32[] add(lhs, rhs) +} + +ENTRY entry { + %input = f32[2,9] parameter(0), sharding={replicated} + %indices = s32[3] parameter(1), + sharding={devices=[2,2]0,1,2,3 last_tile_dim_replicate} + %p2 = f32[3,9] parameter(2), sharding={replicated} + %updates = f32[3,9] copy(%p2) + ROOT %scatter = f32[2,9] scatter(%input, %indices, %updates), + to_apply=add, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1, sharding={replicated} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "updates"), + op::Sharding("{devices=[2,1,2]0,1,2,3 last_tile_dim_replicate}")); +} + TEST_F(ShardingPropagationTest, PartialShardingOnElementwise) { const char* const hlo_string = R"( HloModule module diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc b/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc index d5d3730f3a7..323bea92a36 100644 --- a/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc +++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc @@ -1451,10 +1451,23 @@ Status SpmdPartitioningVisitor::HandleScatter(HloInstruction* hlo) { update_dim_to_index_dim); CHECK(new_updates_sharding.has_value()); updates = updates.Reshard(*new_updates_sharding); + // Update collective_ops_creator and partition_id for partial replicate. + auto collective_ops_creator = collective_ops_creator_; + auto partition_id = partition_id_; + if (indices.sharding().ReplicateOnLastTileDim()) { + auto sharding_grouped = GroupShardingOnDims( + indices.sharding(), + {indices.sharding().tile_assignment().num_dimensions() - 1}); + auto per_group_partitioner_state = CreatePerGroupPartitioningState( + indices.state(), sharding_grouped.device_groups, &b_); + collective_ops_creator = + per_group_partitioner_state.collective_ops_creator; + partition_id = per_group_partitioner_state.partition_id; + } // To avoid accumulating the initial operand multiple times during // all-reduce, we use identity operands for all non-zero partitions. auto not_partition_zero = b_.AddInstruction(HloInstruction::CreateConvert( - ShapeUtil::MakeScalarShape(PRED), partition_id_)); + ShapeUtil::MakeScalarShape(PRED), partition_id)); not_partition_zero = b_.AddInstruction(HloInstruction::CreateBroadcast( ShapeUtil::ChangeElementType(identity->shape(), PRED), not_partition_zero, {})); @@ -1465,7 +1478,7 @@ Status SpmdPartitioningVisitor::HandleScatter(HloInstruction* hlo) { auto pscatter = b_.AddInstruction(scatter->CloneWithNewOperands( scatter->shape(), {select_operand, indices.hlo(), updates.hlo()})); auto all_reduce = - collective_ops_creator_.create_cross_partition_all_reduce( + collective_ops_creator.create_cross_partition_all_reduce( &b_, pscatter, scatter->to_apply(), {}, NewChannel()); all_reduce->set_sharding(HloSharding::Replicate()); SetPartitionedHlo(hlo, [&]() { diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc index 8b951e3db31..8caaba88260 100644 --- a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc +++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc @@ -4070,6 +4070,39 @@ ENTRY entry { op::Shape("f32[2,5]"))); } +TEST_F(SpmdPartitioningTest, PassthroughScatter_PartialReplicate) { + const char* const hlo_string = R"( +HloModule module + +add (lhs: f32[], rhs: f32[]) -> f32[] { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT sum = f32[] add(lhs, rhs) +} + +ENTRY entry { + %input = f32[2,9] parameter(0), + sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate} + %indices = s32[3] parameter(1), sharding={replicated} + %updates = f32[3,9] parameter(2), + sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate} + ROOT %scatter = f32[2,9] scatter(%input, %indices, %updates), + to_apply=add, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1, + sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(op::Scatter(op::Parameter(0), op::Parameter(1), + op::Parameter(2)), + op::Shape("f32[2,5]"))); +} + TEST_F(SpmdPartitioningTest, IndexPassthroughScatter) { const char* const hlo_string = R"( HloModule module @@ -4104,6 +4137,42 @@ ENTRY entry { op::Shape("f32[2,9,8]"))); } +TEST_F(SpmdPartitioningTest, IndexPassthroughScatter_PartialReplicate) { + const char* const hlo_string = R"( +HloModule module + +add (lhs: f32[], rhs: f32[]) -> f32[] { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT sum = f32[] add(lhs, rhs) +} + +ENTRY entry { + %input = f32[2,9,8] parameter(0), sharding={replicated} + %indices = s32[4,2,4] parameter(1), + sharding={devices=[2,1,2,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate} + %updates = f32[4,4,8] parameter(2), + sharding={devices=[2,2,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate} + ROOT %scatter = f32[2,9,8] scatter(%input, %indices, %updates), + to_apply=add, + update_window_dims={2}, + inserted_window_dims={0,1}, + scatter_dims_to_operand_dims={0,1}, + index_vector_dim=1, sharding={replicated} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/8)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT( + root, + AllOf(op::AllReduce(op::Scatter( + op::Select(op::Broadcast(op::Convert(op::Reshape())), + op::Broadcast(op::Constant()), op::Parameter(0)), + op::Parameter(1), op::Parameter(2))), + op::Shape("f32[2,9,8]"))); +} + TEST_F(SpmdPartitioningTest, IndexPassthroughScatter_Min) { const char* const hlo_string = R"( HloModule module @@ -4172,6 +4241,43 @@ ENTRY entry { op::Shape("f32[9,9]"))); } +TEST_F(SpmdPartitioningTest, + ScatterPartitionedOnTrivialSliceDims_PartialReplicate) { + const char* const hlo_string = R"( +HloModule module + +add (lhs: f32[], rhs: f32[]) -> f32[] { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT sum = f32[] add(lhs, rhs) +} + +ENTRY entry { + %input = f32[17,9] parameter(0), + sharding={devices=[2,1,2]0,1,2,3 last_tile_dim_replicate} + %indices = s32[2,3] parameter(1), sharding={replicated} + %updates = f32[2,3,9] parameter(2), sharding={replicated} + ROOT %scatter = f32[17,9] scatter(%input, %indices, %updates), + to_apply=add, + update_window_dims={2}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=2, + sharding={devices=[2,1,2]0,1,2,3 last_tile_dim_replicate} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + auto offset = + op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId())); + auto indices = op::Subtract( + op::Parameter(1), AllOf(op::Broadcast(offset), op::Shape("s32[2,3]"))); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, + AllOf(op::Scatter(op::Parameter(0), indices, op::Parameter(2)), + op::Shape("f32[9,9]"))); +} + TEST_F(SpmdPartitioningTest, TiledReversePassthrough) { const char* const hlo_string = R"( HloModule module diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h index f6f15481b55..4fc193d9622 100644 --- a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h +++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h @@ -362,8 +362,8 @@ absl::optional PadFromPartialReplicateShape( // dimensions by dynamic slice. // For example, if partial_sharding is // {devices=[1,2,2]0,1,2,3 last_tile_dim_replicate} -// Target tile dims is {2, 2}, the returned compatible sharding will be -// sharding={devices=[1,2,2]0,2,1,3 last_tile_dim_replicate}. +// Target sharding is {devices=[2,2]0,1,2,3}, the returned compatible sharding +// will be sharding={devices=[2,2]0,2,1,3}. // If patial replicate sharding is not partial replicate or can't reshard to // target_tile_dims by dynamic slice, return absl::nullopt. // If target_sharding is already compatible, returns it. diff --git a/tensorflow/compiler/xla/service/triangular_solve_expander.cc b/tensorflow/compiler/xla/service/triangular_solve_expander.cc index d54eb9e78c3..4015c69e3e2 100644 --- a/tensorflow/compiler/xla/service/triangular_solve_expander.cc +++ b/tensorflow/compiler/xla/service/triangular_solve_expander.cc @@ -89,16 +89,23 @@ XlaOp DiagonalBlocks(XlaOp a, int64 block_size) { // The last block might be smaller than the block size, // so we will need to pad it if (n % block_size != 0) { - // Pad with zeros + // Pad with identity matrix. auto last_blocks = SliceInMinorDims(a, {n - n % block_size, n - n % block_size}, {n, n}); PaddingConfig config = MakeNoPaddingConfig(ndims); int64 padding = block_size - n % block_size; - config.mutable_dimensions(ndims - 1)->set_edge_padding_high(padding); config.mutable_dimensions(ndims - 2)->set_edge_padding_high(padding); last_blocks = Pad(last_blocks, Zero(builder, shape.element_type()), config); + auto eye = + IdentityMatrix(builder, shape.element_type(), padding, padding); + config = MakeNoPaddingConfig(ndims); + config.mutable_dimensions(ndims - 2)->set_edge_padding_low(n % + block_size); + eye = Pad(eye, Zero(builder, shape.element_type()), config); + last_blocks = ConcatInDim(builder, {last_blocks, eye}, ndims - 1); + // Add a singleton dimension // i.e. [..., block_size, block_size] -> [..., 1, block_size, block_size] TF_ASSIGN_OR_RETURN(Shape blocks_shape, builder->GetShape(last_blocks)); @@ -121,134 +128,6 @@ XlaOp DiagonalBlocks(XlaOp a, int64 block_size) { }); } -XlaOp InvertDiagonalBlocks(XlaOp diag_blocks, bool lower, bool transpose_a, - bool conjugate_a, - PrecisionConfig::Precision precision) { - XlaBuilder* builder = diag_blocks.builder(); - return builder->ReportErrorOrReturn([&]() -> StatusOr { - // Input is a batch of square lower triangular square matrices. Its shape is - // (..., size, size). We resize this to (num_blocks, size, size). - TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(diag_blocks)); - int64 block_size = ShapeUtil::GetDimension(shape, -1); - int64 num_blocks = ShapeUtil::ElementsIn(shape) / - tensorflow::MathUtil::IPow(block_size, 2); - diag_blocks = Reshape(diag_blocks, {num_blocks, block_size, block_size}); - - // The input must be triangular because we rely on that when doing - // multiplications later on - diag_blocks = Triangle(diag_blocks, /*lower=*/lower); - - // Rescale blocks to be unit triangular, but avoid dividing by - // zero (which can happen if the last block was padded) otherwise it will - // introduce nans which will propagate - auto diags = GetMatrixDiagonal(diag_blocks); - auto ones = FullLike(diags, 1); - diags = Select(Eq(diags, Zero(builder, shape.element_type())), ones, diags); - auto scaled_diag_blocks = Div(diag_blocks, diags, {0, 2}); - - // We can now use the fact that for an upper triangular matrix - // [[L11, 0], [L21, L22]], given the inverses L11' and L22', we have - // L22' = -L22' * L21 * L11'. In our case, L21 is a vector and our blocks - // have been rescaled to be unit triangular, so L22 = L22' = 1. - - // Initialize the output matrix with -1s on the diagonal. We use -1 instead - // of 1 because we cannot do matrix-vector multiplies with variable shapes - // inside of a loop, or do irregularly shaped in-place updates. Hence, - // L21 <- -L22 * L21 * L11 cannot be done naively. Instead, we update the - // entire row i.e. we calculate - // [L21 L22 0] <- -[L21 L22 0] @ diag_blocks([L11', -I, -I]) - // which means [L21 L22 0] <- [-L21 * L11', L22, 0]. - auto identity = - IdentityMatrix(builder, shape.element_type(), block_size, block_size); - auto neg_identity = -identity; - - // The first or last diagonal element should be set to 1 instead of -1 - // though, since we never update it - auto pos_one = Reshape(One(builder, shape.element_type()), {1, 1}); - auto start_index = ConstantR0(builder, (lower) ? 0 : block_size - 1); - auto output_block = - DynamicUpdateSlice(neg_identity, pos_one, - /*start_indices=*/{start_index, start_index}); - - // Broadcast diag([1, -1, -1, ...]) to every block - XlaOp output = Broadcast(output_block, - /*broadcast_sizes=*/{num_blocks}); - - // Now we construct a loop that performs matrix-vector multiplications - // inverting the blocks one row at a time - std::vector tuple_shapes = { - // The loop iteration counter is a scalar, incremented each iteration. - ShapeUtil::MakeShape(S32, {}), - // The output has the shape of A, with one row updated each iteration. - ShapeUtil::MakeShape(shape.element_type(), - {num_blocks, block_size, block_size}), - // The input is a loop invariant. - ShapeUtil::MakeShape(shape.element_type(), - {num_blocks, block_size, block_size})}; - Shape tuple_shape = ShapeUtil::MakeTupleShape(tuple_shapes); - - auto init_i = One(builder, S32); - auto init = Tuple(builder, {init_i, output, scaled_diag_blocks}); - - // Construct the loop condition function. - std::unique_ptr condb = - builder->CreateSubBuilder("InvertDiagCond"); - { - auto i = GetTupleElement( - Parameter(condb.get(), 0, tuple_shape, "InvertDiagCondTuple"), 0); - Lt(i, ConstantR0(condb.get(), block_size)); - } - TF_ASSIGN_OR_RETURN(auto cond, condb->Build()); - - // Construct the loop body function. - std::unique_ptr bodyb = - builder->CreateSubBuilder("InvertDiagBody"); - { - auto input_tuple = - Parameter(bodyb.get(), 0, tuple_shape, "InvertDiagBodyTuple"); - - auto i = GetTupleElement(input_tuple, 0); - auto body_out = GetTupleElement(input_tuple, 1); - auto body_input = GetTupleElement(input_tuple, 2); - - auto zero = ConstantR0(bodyb.get(), 0); - auto j = (lower) ? i : ScalarLike(i, block_size - 1) - i; - auto input_row = - DynamicSlice(body_input, {zero, j, zero}, - /*slice_sizes=*/{num_blocks, 1, block_size}); - - // We want -L21 L11^{-1} - DotDimensionNumbers dnums; - dnums.add_lhs_batch_dimensions(0); - dnums.add_rhs_batch_dimensions(0); - dnums.add_lhs_contracting_dimensions(2); - dnums.add_rhs_contracting_dimensions(1); - PrecisionConfig precision_proto; - precision_proto.add_operand_precision(precision); - precision_proto.add_operand_precision(precision); - auto update = -DotGeneral(input_row, body_out, dnums, &precision_proto); - - body_out = DynamicUpdateSlice(body_out, update, {zero, j, zero}); - - auto next_i = i + ScalarLike(i, 1); - Tuple(bodyb.get(), {next_i, body_out, body_input}); - } - TF_ASSIGN_OR_RETURN(auto body, bodyb->Build()); - - // Construct the While loop and return the result, - // return while_loop(cond_fun, body_fun, init)[1] - auto invert_while = While(cond, body, init); - auto inv_diag_blocks = GetTupleElement(invert_while, 1); - - // Undo the scaling - inv_diag_blocks = Div(inv_diag_blocks, diags, - /*broadcast_dimensions=*/{0, 1}); - - // Reshape back to original batch major dimensions - return Reshape(inv_diag_blocks, AsInt64Slice(shape.dimensions())); - }); -} - XlaOp SolveWithInvertedDiagonalBlocks(XlaOp a, XlaOp b, XlaOp inv_diag_blocks, bool left_side, bool lower, bool transpose_a, bool conjugate_a, @@ -357,10 +236,140 @@ XlaOp SolveWithInvertedDiagonalBlocks(XlaOp a, XlaOp b, XlaOp inv_diag_blocks, }); } -XlaOp BuildTriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower, - bool transpose_a, bool conjugate_a, - bool unit_diagonal, int64 block_size, - PrecisionConfig::Precision precision) { +} // namespace + +XlaOp TriangularSolveExpander::InvertDiagonalBlocks( + XlaOp diag_blocks, bool lower_triangular, + PrecisionConfig::Precision precision) { + XlaBuilder* builder = diag_blocks.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + // Input is a batch of square lower triangular square matrices. Its shape is + // (..., size, size). We resize this to (num_blocks, size, size). + TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(diag_blocks)); + int64 block_size = ShapeUtil::GetDimension(shape, -1); + int64 num_blocks = ShapeUtil::ElementsIn(shape) / + tensorflow::MathUtil::IPow(block_size, 2); + diag_blocks = Reshape(diag_blocks, {num_blocks, block_size, block_size}); + + // The input must be triangular because we rely on that when doing + // multiplications later on + diag_blocks = Triangle(diag_blocks, /*lower=*/lower_triangular); + + // Rescale blocks to be unit triangular, but avoid dividing by + // zero (which can happen if the last block was padded) otherwise it will + // introduce nans which will propagate + auto diags = GetMatrixDiagonal(diag_blocks); + auto ones = FullLike(diags, 1); + diags = Select(Eq(diags, Zero(builder, shape.element_type())), ones, diags); + auto scaled_diag_blocks = Div(diag_blocks, diags, {0, 2}); + + // We can now use the fact that for an upper triangular matrix + // [[L11, 0], [L21, L22]], given the inverses L11' and L22', we have + // L22' = -L22' * L21 * L11'. In our case, L21 is a vector and our blocks + // have been rescaled to be unit triangular, so L22 = L22' = 1. + + // Initialize the output matrix with -1s on the diagonal. We use -1 instead + // of 1 because we cannot do matrix-vector multiplies with variable shapes + // inside of a loop, or do irregularly shaped in-place updates. Hence, + // L21 <- -L22 * L21 * L11 cannot be done naively. Instead, we update the + // entire row i.e. we calculate + // [L21 L22 0] <- -[L21 L22 0] @ diag_blocks([L11', -I, -I]) + // which means [L21 L22 0] <- [-L21 * L11', L22, 0]. + auto identity = + IdentityMatrix(builder, shape.element_type(), block_size, block_size); + auto neg_identity = -identity; + + // The first or last diagonal element should be set to 1 instead of -1 + // though, since we never update it + auto pos_one = Reshape(One(builder, shape.element_type()), {1, 1}); + auto start_index = + ConstantR0(builder, lower_triangular ? 0 : block_size - 1); + auto output_block = + DynamicUpdateSlice(neg_identity, pos_one, + /*start_indices=*/{start_index, start_index}); + + // Broadcast diag([1, -1, -1, ...]) to every block + XlaOp output = Broadcast(output_block, + /*broadcast_sizes=*/{num_blocks}); + + // Now we construct a loop that performs matrix-vector multiplications + // inverting the blocks one row at a time + std::vector tuple_shapes = { + // The loop iteration counter is a scalar, incremented each iteration. + ShapeUtil::MakeShape(S32, {}), + // The output has the shape of A, with one row updated each iteration. + ShapeUtil::MakeShape(shape.element_type(), + {num_blocks, block_size, block_size}), + // The input is a loop invariant. + ShapeUtil::MakeShape(shape.element_type(), + {num_blocks, block_size, block_size})}; + Shape tuple_shape = ShapeUtil::MakeTupleShape(tuple_shapes); + + auto init_i = One(builder, S32); + auto init = Tuple(builder, {init_i, output, scaled_diag_blocks}); + + // Construct the loop condition function. + std::unique_ptr condb = + builder->CreateSubBuilder("InvertDiagCond"); + { + auto i = GetTupleElement( + Parameter(condb.get(), 0, tuple_shape, "InvertDiagCondTuple"), 0); + Lt(i, ConstantR0(condb.get(), block_size)); + } + TF_ASSIGN_OR_RETURN(auto cond, condb->Build()); + + // Construct the loop body function. + std::unique_ptr bodyb = + builder->CreateSubBuilder("InvertDiagBody"); + { + auto input_tuple = + Parameter(bodyb.get(), 0, tuple_shape, "InvertDiagBodyTuple"); + + auto i = GetTupleElement(input_tuple, 0); + auto body_out = GetTupleElement(input_tuple, 1); + auto body_input = GetTupleElement(input_tuple, 2); + + auto zero = ConstantR0(bodyb.get(), 0); + auto j = lower_triangular ? i : ScalarLike(i, block_size - 1) - i; + auto input_row = + DynamicSlice(body_input, {zero, j, zero}, + /*slice_sizes=*/{num_blocks, 1, block_size}); + + // We want -L21 L11^{-1} + DotDimensionNumbers dnums; + dnums.add_lhs_batch_dimensions(0); + dnums.add_rhs_batch_dimensions(0); + dnums.add_lhs_contracting_dimensions(2); + dnums.add_rhs_contracting_dimensions(1); + PrecisionConfig precision_proto; + precision_proto.add_operand_precision(precision); + precision_proto.add_operand_precision(precision); + auto update = -DotGeneral(input_row, body_out, dnums, &precision_proto); + + body_out = DynamicUpdateSlice(body_out, update, {zero, j, zero}); + + auto next_i = i + ScalarLike(i, 1); + Tuple(bodyb.get(), {next_i, body_out, body_input}); + } + TF_ASSIGN_OR_RETURN(auto body, bodyb->Build()); + + // Construct the While loop and return the result, + // return while_loop(cond_fun, body_fun, init)[1] + auto invert_while = While(cond, body, init); + auto inv_diag_blocks = GetTupleElement(invert_while, 1); + // Undo the scaling + inv_diag_blocks = Div(inv_diag_blocks, diags, + /*broadcast_dimensions=*/{0, 1}); + + // Reshape back to original batch major dimensions + return Reshape(inv_diag_blocks, AsInt64Slice(shape.dimensions())); + }); +} + +XlaOp TriangularSolveExpander::BuildTriangularSolve( + XlaOp a, XlaOp b, bool left_side, bool lower, bool transpose_a, + bool conjugate_a, bool unit_diagonal, int64 block_size, + PrecisionConfig::Precision precision) { XlaBuilder* builder = a.builder(); return builder->ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); @@ -422,6 +431,11 @@ XlaOp BuildTriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower, return b; } + // Degenerate case: 1x1 matrices. + if (ShapeUtil::GetDimension(a_shape, -1) == 1) { + return unit_diagonal ? b : Div(b, MaybeConjugate(a, conjugate_a)); + } + // TODO(phawkins): consider pushing triangle masking into // InvertDiagonalBlocks. if (unit_diagonal) { @@ -440,8 +454,7 @@ XlaOp BuildTriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower, auto diag_blocks = DiagonalBlocks(a, block_size); // We invert these blocks in parallel using batched matrix-vector products - auto inv_diag_blocks = InvertDiagonalBlocks(diag_blocks, lower, transpose_a, - conjugate_a, precision); + auto inv_diag_blocks = InvertDiagonalBlocks(diag_blocks, lower, precision); // We now find the solution using GEMMs auto x = @@ -452,8 +465,6 @@ XlaOp BuildTriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower, }); } -} // namespace - TriangularSolveExpander::TriangularSolveExpander(int64 block_size) : block_size_(block_size) {} diff --git a/tensorflow/compiler/xla/service/triangular_solve_expander.h b/tensorflow/compiler/xla/service/triangular_solve_expander.h index 362e8557229..3f9e58a3246 100644 --- a/tensorflow/compiler/xla/service/triangular_solve_expander.h +++ b/tensorflow/compiler/xla/service/triangular_solve_expander.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_TRIANGULAR_SOLVE_EXPANDER_H_ #include "absl/container/flat_hash_map.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/service/op_expander_pass.h" namespace xla { @@ -35,6 +36,14 @@ class TriangularSolveExpander : public OpExpanderPass { StatusOr ExpandInstruction( HloInstruction* instruction) override; + virtual XlaOp InvertDiagonalBlocks(XlaOp diag_blocks, bool lower_triangular, + PrecisionConfig::Precision precision); + + XlaOp BuildTriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower, + bool transpose_a, bool conjugate_a, + bool unit_diagonal, int64 block_size, + PrecisionConfig::Precision precision); + private: // Block size for BuildTriangularSolve const int64 block_size_; diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier.cc b/tensorflow/compiler/xla/service/while_loop_simplifier.cc index c80123bcd50..785fdecbfa0 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier.cc +++ b/tensorflow/compiler/xla/service/while_loop_simplifier.cc @@ -37,23 +37,15 @@ namespace m = match; using absl::optional; using hlo_query::ContainsInstrWithOpcode; -// Tries to remove elements in a while loop's tuple that aren't used within the -// loop. -// -// Specifically, if a loop is tuple-shaped, and there exists some element of -// that tuple that is not used by the loop condition and is not used by the loop -// body except to pass it to the next iteration of the loop, then we can remove -// that element from the loop's tuples. -static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { - CHECK_EQ(while_op->opcode(), HloOpcode::kWhile); - - // Don't try this transformation if the while loop isn't removable, since if - // it succeeds ultimately we're going to have to replace the old while loop - // with a new one. - if (!while_op->parent()->IsSafelyRemovable(while_op)) { - VLOG(2) << "Can't remove dead parameters from non-removable while op."; - return false; - } +// This is a utility function that removes the given tuple indices from the +// while loop init, body, and condition. The final shape returned is still the +// same as before. +static StatusOr RemoveDeadTupleIndices( + HloInstruction* while_op, absl::flat_hash_set& used_tuple_indices) { + // Build up maps from the old/new to the new/old tuple indices. + std::vector new_to_old_tuple_idx(used_tuple_indices.begin(), + used_tuple_indices.end()); + absl::c_sort(new_to_old_tuple_idx); HloModule* module = while_op->GetModule(); HloComputation* computation = while_op->parent(); @@ -62,107 +54,8 @@ static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { HloComputation* while_body = while_op->while_body(); HloInstruction* while_body_root = while_body->root_instruction(); - if (!while_init->shape().IsTuple()) { - VLOG(2) << "While op's carried value isn't tuple shaped."; - return false; - } - - if (while_body_root->opcode() != HloOpcode::kTuple) { - VLOG(2) << "While body's root is not a tuple(...) instruction."; - return false; - } - auto print_no_metadata = HloPrintOptions().set_print_metadata(false); - // Bail if param0 of while_cond or while_body has users which aren't of type - // get-tuple-element. - for (const HloInstruction* instr : {while_body->parameter_instruction(0), - while_cond->parameter_instruction(0)}) { - for (const HloInstruction* user : instr->users()) { - if (user->opcode() != HloOpcode::kGetTupleElement) { - VLOG(2) << "Cowardly refusing to analyze while loop with " - << instr->ToString(print_no_metadata) - << " used by non-GTE instruction " - << user->ToString(print_no_metadata) << " in computation " - << instr->parent()->name(); - return false; - } - } - } - - const int64 tuple_size = ShapeUtil::TupleElementCount(while_init->shape()); - if (tuple_size == 0) { - VLOG(2) << "Can't remove elements from while loop's tuple -- it's already " - "empty."; - return false; - } - - absl::flat_hash_set used_tuple_indices; - for (HloComputation* comp : {while_body, while_cond}) { - // The HLO verifier ensures that while_input's shape matches while_init's - // shape, which we verified above is a tuple. - HloInstruction* while_input = comp->parameter_instruction(0); - - for (const HloInstruction* user : while_input->users()) { - // This user doesn't count if it's only used by the while body's root, and - // the root places the tuple element into the same index of the tuple as - // it came from. That just amounts to us carrying the variable through - // the loop. - // - // Careful: HloInstruction::operand_index returns the first index the - // operand appears in, but it may appear more than once! - if (user->user_count() == 1 && user->users().front() == while_body_root && - while_body_root->operand_index(user) == user->tuple_index() && - absl::c_count(while_body_root->operands(), user) == 1) { - continue; - } - - used_tuple_indices.insert(user->tuple_index()); - if (used_tuple_indices.size() == tuple_size) { - VLOG(2) << "Loop " << while_op->ToString(print_no_metadata) - << " uses all of its inputs; no simplification possible."; - return false; - } - } - } - - // If a tuple element is not passed unmodified from the while body's param0 - // through to the while body's root, count that element as "used", since - // removing that element would be observable. - for (int64 i = 0; i < while_body_root->operand_count(); ++i) { - if (used_tuple_indices.contains(i)) { - continue; - } - - auto* operand = while_body_root->operand(i); - if (operand->opcode() != HloOpcode::kGetTupleElement || - operand->operand(0) != while_body->parameter_instruction(0) || - operand->tuple_index() != i) { - VLOG(2) << "Tuple index " << i - << " is not passed through loop body unmodified."; - used_tuple_indices.insert(i); - - if (used_tuple_indices.size() == tuple_size) { - VLOG(2) << "Loop " << while_op->ToString(print_no_metadata) - << " uses all of its inputs; no simplification possible."; - return false; - } - } - } - - // If we got here, used_tuple_indices.size() < tuple_size, meaning some - // elements of the loop's tuple aren't used by while_body or while_cond. - CHECK_LT(used_tuple_indices.size(), tuple_size); - - VLOG(1) << "Eliminating " << tuple_size - used_tuple_indices.size() - << " elements from tuple of " - << while_op->ToString(print_no_metadata); - - // Build up maps from the old/new to the new/old tuple indices. - std::vector new_to_old_tuple_idx(used_tuple_indices.begin(), - used_tuple_indices.end()); - absl::c_sort(new_to_old_tuple_idx); - absl::flat_hash_map old_to_new_tuple_idx; for (int64 new_idx = 0; new_idx < new_to_old_tuple_idx.size(); ++new_idx) { int64 old_idx = new_to_old_tuple_idx[new_idx]; @@ -288,6 +181,7 @@ static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { // The tuple simplifier will then simplify this if possible, removing // new_tuple and while_init. std::vector new_tuple_elems; + const int64 tuple_size = ShapeUtil::TupleElementCount(while_init->shape()); for (int64 old_idx = 0; old_idx < tuple_size; ++old_idx) { auto new_tuple_idx_it = old_to_new_tuple_idx.find(old_idx); if (new_tuple_idx_it != old_to_new_tuple_idx.end()) { @@ -305,9 +199,293 @@ static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { HloInstruction* new_tuple = computation->AddInstruction(HloInstruction::CreateTuple(new_tuple_elems)); TF_RETURN_IF_ERROR(computation->ReplaceInstruction(while_op, new_tuple)); + + return new_while_op; +} + +// Tries to remove elements in a while loop's tuple that aren't used within the +// loop. +// +// Specifically, if a loop is tuple-shaped, and there exists some element of +// that tuple that is not used by the loop condition and is not used by the loop +// body except to pass it to the next iteration of the loop, then we can remove +// that element from the loop's tuples. +static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { + CHECK_EQ(while_op->opcode(), HloOpcode::kWhile); + + // Don't try this transformation if the while loop isn't removable, since if + // it succeeds ultimately we're going to have to replace the old while loop + // with a new one. + if (!while_op->parent()->IsSafelyRemovable(while_op)) { + VLOG(2) << "Can't remove dead parameters from non-removable while op."; + return false; + } + + HloInstruction* while_init = while_op->mutable_operand(0); + HloComputation* while_cond = while_op->while_condition(); + HloComputation* while_body = while_op->while_body(); + HloInstruction* while_body_root = while_body->root_instruction(); + + if (!while_init->shape().IsTuple()) { + VLOG(2) << "While op's carried value isn't tuple shaped."; + return false; + } + + if (while_body_root->opcode() != HloOpcode::kTuple) { + VLOG(2) << "While body's root is not a tuple(...) instruction."; + return false; + } + + auto print_no_metadata = HloPrintOptions().set_print_metadata(false); + + // Bail if param0 of while_cond or while_body has users which aren't of type + // get-tuple-element. + for (const HloInstruction* instr : {while_body->parameter_instruction(0), + while_cond->parameter_instruction(0)}) { + for (const HloInstruction* user : instr->users()) { + if (user->opcode() != HloOpcode::kGetTupleElement) { + VLOG(2) << "Cowardly refusing to analyze while loop with " + << instr->ToString(print_no_metadata) + << " used by non-GTE instruction " + << user->ToString(print_no_metadata) << " in computation " + << instr->parent()->name(); + return false; + } + } + } + + const int64 tuple_size = ShapeUtil::TupleElementCount(while_init->shape()); + if (tuple_size == 0) { + VLOG(2) << "Can't remove elements from while loop's tuple -- it's already " + "empty."; + return false; + } + + absl::flat_hash_set used_tuple_indices; + for (HloComputation* comp : {while_body, while_cond}) { + // The HLO verifier ensures that while_input's shape matches while_init's + // shape, which we verified above is a tuple. + HloInstruction* while_input = comp->parameter_instruction(0); + + for (const HloInstruction* user : while_input->users()) { + // This user doesn't count if it's only used by the while body's root, and + // the root places the tuple element into the same index of the tuple as + // it came from. That just amounts to us carrying the variable through + // the loop. + // + // Careful: HloInstruction::operand_index returns the first index the + // operand appears in, but it may appear more than once! + if (user->user_count() == 1 && user->users().front() == while_body_root && + while_body_root->operand_index(user) == user->tuple_index() && + absl::c_count(while_body_root->operands(), user) == 1) { + continue; + } + + used_tuple_indices.insert(user->tuple_index()); + if (used_tuple_indices.size() == tuple_size) { + VLOG(2) << "Loop " << while_op->ToString(print_no_metadata) + << " uses all of its inputs; no simplification possible."; + return false; + } + } + } + + // If a tuple element is not passed unmodified from the while body's param0 + // through to the while body's root, count that element as "used", since + // removing that element would be observable. + for (int64 i = 0; i < while_body_root->operand_count(); ++i) { + if (used_tuple_indices.contains(i)) { + continue; + } + + auto* operand = while_body_root->operand(i); + if (operand->opcode() != HloOpcode::kGetTupleElement || + operand->operand(0) != while_body->parameter_instruction(0) || + operand->tuple_index() != i) { + VLOG(2) << "Tuple index " << i + << " is not passed through loop body unmodified."; + used_tuple_indices.insert(i); + + if (used_tuple_indices.size() == tuple_size) { + VLOG(2) << "Loop " << while_op->ToString(print_no_metadata) + << " uses all of its inputs; no simplification possible."; + return false; + } + } + } + + // If we got here, used_tuple_indices.size() < tuple_size, meaning some + // elements of the loop's tuple aren't used by while_body or while_cond. + CHECK_LT(used_tuple_indices.size(), tuple_size); + + VLOG(1) << "Eliminating " << tuple_size - used_tuple_indices.size() + << " elements from tuple of " + << while_op->ToString(print_no_metadata); + + TF_ASSIGN_OR_RETURN(while_op, + RemoveDeadTupleIndices(while_op, used_tuple_indices)); + return true; } +// This is a helper function for TryRemoveRepeatedWhileTupleIndices. It removes +// duplicates by replacing them with tuple_index, followed by a call to +// RemoveDeadTupleIndices. +static StatusOr TryRemoveRepeatedWhileTupleIndicesHelper( + HloInstruction* while_op, const int64 tuple_index, + absl::flat_hash_set& duplicates) { + HloComputation* while_cond = while_op->while_condition(); + HloComputation* while_body = while_op->while_body(); + HloInstruction* while_init = while_op->mutable_operand(0); + + VLOG(2) << "while_init " << while_init->ToString() << " operands " + << while_init->operand_count(); + VLOG(2) << "while_body_root " << while_body->root_instruction()->ToString() + << " operands " << while_body->root_instruction()->operand_count(); + + // Change the loop body and condition such that uses of the duplicates are + // replaced with the original tuple element. + for (HloComputation* comp : {while_body, while_cond}) { + auto new_get = comp->AddInstruction(HloInstruction::CreateGetTupleElement( + comp->parameter_instruction(0)->shape().tuple_shapes(tuple_index), + comp->parameter_instruction(0), tuple_index)); + + std::vector instrs_to_replace; + for (auto* instr : comp->instructions()) { + if (instr->opcode() == HloOpcode::kGetTupleElement && + duplicates.contains(instr->tuple_index()) && + instr->operand(0) == comp->parameter_instruction(0)) { + instrs_to_replace.push_back(instr); + } + } + + for (auto instr : instrs_to_replace) { + TF_RETURN_IF_ERROR(comp->ReplaceInstruction(instr, new_get)); + } + } + + // We know which tuple indices are useful; i.e, those which aren't duplicates. + absl::flat_hash_set used_tuple_indices; + for (int index = 0; index < while_init->shape().tuple_shapes_size(); + ++index) { + if (!duplicates.count(index)) { + used_tuple_indices.insert(index); + } + } + + // Remove the duplicate tuple elements. + TF_ASSIGN_OR_RETURN(while_op, + RemoveDeadTupleIndices(while_op, used_tuple_indices)); + + return while_op; +} + +// If the while loop init passes the same values to several tuple indices, and +// if the body keeps on passing them through, we can remove the duplicates. +static StatusOr TryRemoveRepeatedWhileTupleIndices( + HloInstruction* while_op) { + CHECK_EQ(while_op->opcode(), HloOpcode::kWhile); + + int index_to_investigate = 0; + // Don't try this transformation if the while loop isn't removable, since if + // it succeeds ultimately we're going to have to replace the old while loop + // with a new one. + if (!while_op->parent()->IsSafelyRemovable(while_op)) { + VLOG(2) << "Can't remove dead parameters from non-removable while op."; + return false; + } + + HloInstruction* while_init = while_op->mutable_operand(0); + HloComputation* while_cond = while_op->while_condition(); + HloComputation* while_body = while_op->while_body(); + HloInstruction* while_body_root = while_body->root_instruction(); + + if (!while_init->shape().IsTuple()) { + VLOG(2) << "While op's carried value isn't tuple shaped."; + return false; + } + + bool changed = false; + while (index_to_investigate < while_init->shape().tuple_shapes_size()) { + if (!while_init->shape().IsTuple() || + while_init->opcode() != HloOpcode::kTuple) { + VLOG(2) << "While op's carried value isn't tuple shaped."; + return false; + } + + if (while_body_root->opcode() != HloOpcode::kTuple) { + VLOG(2) << "While body's root is not a tuple(...) instruction."; + return false; + } + + auto& while_shape = while_init->shape(); + VLOG(2) << "Iterating " << index_to_investigate; + + absl::flat_hash_set duplicates; + auto* pivot_init_elem = while_init->operand(index_to_investigate); + auto* pivot_body_elem = while_body_root->operand(index_to_investigate); + if (pivot_body_elem->opcode() == HloOpcode::kGetTupleElement && + pivot_body_elem->operand(0) == while_body->parameter_instruction(0)) { + if (pivot_body_elem->tuple_index() != index_to_investigate) { + VLOG(2) << "Mismatch between pivot_body_elem->tuple_index() " + << pivot_body_elem->tuple_index() << " index_to_investigate " + << index_to_investigate; + index_to_investigate++; + continue; + } + } else { + index_to_investigate++; + continue; + } + + // Look from index_to_investigate onwards to see if it is repeated. + for (int64 i = index_to_investigate + 1; + i < while_shape.tuple_shapes_size(); ++i) { + auto* init_elem = while_init->operand(i); + auto* body_elem = while_body_root->operand(i); + if (body_elem->opcode() == HloOpcode::kGetTupleElement && + body_elem->operand(0) == while_body->parameter_instruction(0)) { + if (body_elem->tuple_index() != i) { + VLOG(2) << "Mismatch between body_elem->tuple_index() " + << body_elem->tuple_index() << " i " << i; + continue; + } + } else { + continue; + } + + if (pivot_init_elem == init_elem) { + VLOG(2) << "init_elem " << init_elem->ToString() << " pivot_init_elem " + << pivot_init_elem->ToString(); + VLOG(2) << "body_elem " << body_elem->ToString() << " pivot_body_elem " + << pivot_body_elem->ToString(); + duplicates.insert(i); + } + } + + // If duplicates are found, call the helper to remove them. + if (!duplicates.empty()) { + VLOG(2) << "Duplicate found " << duplicates.size() << " pivot_init " + << pivot_init_elem->ToString(); + TF_ASSIGN_OR_RETURN(while_op, + TryRemoveRepeatedWhileTupleIndicesHelper( + while_op, index_to_investigate, duplicates)); + changed = true; + VLOG(2) << "Changed while_op " << while_op->ToString() + << " while_op operand count " << while_op->operand_count(); + // Update the while loop variables so we can continue looking for + // duplicates of a different index. + while_init = while_op->mutable_operand(0); + while_cond = while_op->while_condition(); + while_body = while_op->while_body(); + while_body_root = while_body->root_instruction(); + } + index_to_investigate++; + } + + return changed; +} + // Removes each loop parameter (i.e. member of the while loop tuple) that is a // constant and is the same in the while loop body and the while loop init. static StatusOr TryRemoveConstantParams(HloInstruction* while_op) { @@ -1048,6 +1226,7 @@ StatusOr WhileLoopSimplifier::Run(HloModule* module) { TF_ASSIGN_OR_RETURN(result, TryRemoveWhileLoop(while_op)); changed |= result; + if (result) { // Don't continue simplifying after successfully removing the while loop // -- that would result in use-after-free nastiness. @@ -1067,6 +1246,12 @@ StatusOr WhileLoopSimplifier::Run(HloModule* module) { // successful, meaning that `while_op` is no longer valid after one of these // transformations returns true. + TF_ASSIGN_OR_RETURN(result, TryRemoveRepeatedWhileTupleIndices(while_op)); + changed |= result; + if (result) { + continue; + } + TF_ASSIGN_OR_RETURN(result, TryFlattenNestedTuples(while_op)); changed |= result; if (result) { @@ -1074,6 +1259,7 @@ StatusOr WhileLoopSimplifier::Run(HloModule* module) { } TF_ASSIGN_OR_RETURN(result, TryRemoveDeadWhileParams(while_op)); + changed |= result; if (result) { continue; diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc index d715fb3857a..c93cb5dc347 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc @@ -794,5 +794,51 @@ TEST_F(WhileLoopSimplifierTest, MergeInductionVariables_SkipS16) { .ValueOrDie()); } +TEST_F(WhileLoopSimplifierTest, RemoveRepeatedParams) { + const string hlo_string = R"( + HloModule SwappingTupleElements + + SwappingTupleElements.body { + loop_var = (s32[], s32[], s32[]) parameter(0) + get-tuple-element = s32[] get-tuple-element(loop_var), index=0 + get-tuple-element.1 = s32[] get-tuple-element(loop_var), index=1 + get-tuple-element.2 = s32[] get-tuple-element(loop_var), index=2 + y = s32[] add(get-tuple-element.1, get-tuple-element.2) + ROOT tuple = (s32[], s32[], s32[]) tuple(s32[] get-tuple-element, y, + s32[] get-tuple-element.2) + } + + SwappingTupleElements.always_true { + param = (s32[], s32[], s32[]) parameter(0) + get-tuple-element = s32[] get-tuple-element(param), index=0 + get-tuple-element.1 = s32[] get-tuple-element(param), index=1 + ROOT less-than = pred[] compare(get-tuple-element, get-tuple-element.1), direction=LT + } + + ENTRY SwappingTupleElements { + x = s32[] parameter(0) + y = s32[] parameter(1) + tuple.1 = (s32[], s32[], s32[]) tuple(s32[] x, s32[] y, s32[] x) + ROOT while = (s32[], s32[], s32[]) while(tuple.1), + condition=SwappingTupleElements.always_true, + body=SwappingTupleElements.body + } + )"; + + auto m = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); + EXPECT_TRUE(WhileLoopSimplifier().Run(m.get()).ValueOrDie()); + HloInstruction* new_while = FindFirstWhile(m.get()); + Shape new_while_shape = ParseShape("(s32[], s32[])").ValueOrDie(); + EXPECT_TRUE(ShapeUtil::Equal(new_while->shape(), new_while_shape)); + EXPECT_TRUE(ShapeUtil::Equal( + new_while->while_body()->root_instruction()->shape(), new_while_shape)); + EXPECT_TRUE(ShapeUtil::Equal( + new_while->while_body()->parameter_instruction(0)->shape(), + new_while_shape)); + EXPECT_TRUE(ShapeUtil::Equal( + new_while->while_condition()->parameter_instruction(0)->shape(), + new_while_shape)); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/shape_tree.h b/tensorflow/compiler/xla/shape_tree.h index 73bb3327784..bc48a9c94d1 100644 --- a/tensorflow/compiler/xla/shape_tree.h +++ b/tensorflow/compiler/xla/shape_tree.h @@ -648,7 +648,9 @@ void ShapeTree::CopySubtreeFrom(const ShapeTree& other, const ShapeIndex& target_base_index) { CHECK(ShapeUtil::Compatible( ShapeUtil::GetSubshape(shape(), target_base_index), - ShapeUtil::GetSubshape(other.shape(), source_base_index))); + ShapeUtil::GetSubshape(other.shape(), source_base_index))) + << ShapeUtil::GetSubshape(shape(), target_base_index) << " vs " + << ShapeUtil::GetSubshape(other.shape(), source_base_index); ForEachMutableElement([this, &other, &source_base_index, &target_base_index]( const ShapeIndex& index, T* data) { // Copy the data element only if index is in the diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index 17444c042e7..734d2ed443c 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -2699,5 +2699,6 @@ xla_test( "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", + "//tensorflow/core/platform:tf32_utils", ], ) diff --git a/tensorflow/compiler/xla/tests/cholesky_test.cc b/tensorflow/compiler/xla/tests/cholesky_test.cc index e7f5ca5ed8e..9a86852ce5c 100644 --- a/tensorflow/compiler/xla/tests/cholesky_test.cc +++ b/tensorflow/compiler/xla/tests/cholesky_test.cc @@ -30,6 +30,7 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/tf32_utils.h" namespace xla { namespace { @@ -181,6 +182,7 @@ class RandomCholeskyTest public ::testing::WithParamInterface {}; XLA_TEST_P(RandomCholeskyTest, Random) { + tensorflow::allow_tf32_execution(false); // Test fails with tf32 allowed XlaBuilder builder(TestName()); auto test_params = GetParam(); diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 935d8840831..3068a019470 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -3000,26 +3000,9 @@ filegroup( visibility = ["//visibility:public"], ) -filegroup( +alias( name = "lmdb_testdata", - testonly = 1, - srcs = [ - # A simple key-value store: - # 0 : 'b' - # 1 : 'b' - # ... - # 9 : 'b' - # Which is then overwritten with: - # 0 : 'a' - # 1 : 'b' - # ... - # 9 : 'j' - "lib/lmdb/testdata/data.mdb", - # LMDB, being a memory-mapped database, uses a different file format on - # big-endian systems. - "lib/lmdb/testdata/data_bigendian.mdb", - ], - visibility = ["//visibility:public"], + actual = "//tensorflow/core/lib/lmdb:lmdb_testdata", ) alias( diff --git a/tensorflow/core/api_def/base_api/api_def_Acos.pbtxt b/tensorflow/core/api_def/base_api/api_def_Acos.pbtxt index 2184b644b23..dc018aec4aa 100644 --- a/tensorflow/core/api_def/base_api/api_def_Acos.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_Acos.pbtxt @@ -1,4 +1,11 @@ op { graph_op_name: "Acos" summary: "Computes acos of x element-wise." + description: <ClearCache(); - } } void CollectiveRemoteAccessLocal::RecvFromPeer( diff --git a/tensorflow/core/common_runtime/device_resolver_local.h b/tensorflow/core/common_runtime/device_resolver_local.h index 12b7dce8ab1..1cd84afafb5 100644 --- a/tensorflow/core/common_runtime/device_resolver_local.h +++ b/tensorflow/core/common_runtime/device_resolver_local.h @@ -42,10 +42,6 @@ class DeviceResolverLocal : public DeviceResolverInterface { Status GetTaskCached(const string& task, std::vector* attributes) override; - void ClearTask(const string& task) override {} - - void ClearCache() override {} - protected: const DeviceMgr* dev_mgr_; }; diff --git a/tensorflow/core/common_runtime/eager/context.h b/tensorflow/core/common_runtime/eager/context.h index f76968f05af..3785df7a579 100644 --- a/tensorflow/core/common_runtime/eager/context.h +++ b/tensorflow/core/common_runtime/eager/context.h @@ -114,7 +114,7 @@ class CustomDevice { const string& target_device_name, TensorHandle** result) = 0; - virtual Status Execute(EagerOperation* op, TensorHandle** retvals, + virtual Status Execute(const EagerOperation* op, TensorHandle** retvals, int* num_retvals) = 0; }; diff --git a/tensorflow/core/common_runtime/eager/core.cc b/tensorflow/core/common_runtime/eager/core.cc index ff63c70a98f..d1e1218a370 100644 --- a/tensorflow/core/common_runtime/eager/core.cc +++ b/tensorflow/core/common_runtime/eager/core.cc @@ -208,12 +208,18 @@ Status EagerOperation::Execute(absl::Span retvals, device = ctx_.HostCPU(); } } + + tensorflow::TensorHandle** retval_array = + reinterpret_cast(retvals.data()); + if (VariantDeviceIsCustom(device)) { + return absl::get(device)->Execute(this, retval_array, + num_retvals); + } + if (device != kVariantDeviceNull) { SetDevice(device); } - return EagerExecute( - this, reinterpret_cast(retvals.data()), - num_retvals); + return EagerExecute(this, retval_array, num_retvals); } } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/eager/eager_operation.cc b/tensorflow/core/common_runtime/eager/eager_operation.cc index 947b67a4dab..6d1ecf64fcc 100644 --- a/tensorflow/core/common_runtime/eager/eager_operation.cc +++ b/tensorflow/core/common_runtime/eager/eager_operation.cc @@ -277,11 +277,6 @@ Status EagerOperation::AddInputList( return InferInputListAttrs(inputs.size()); } -Status EagerOperation::SetUseXla(bool enable) { - use_xla_ = enable; - return Status::OK(); -} - Status EagerOperation::Reset( const char* op, const char* device_name, bool remote, EagerExecutor* executor, @@ -313,7 +308,6 @@ Status EagerOperation::Reset( "registered in the binary running in this process."); } attrs_.Reset(op); - use_xla_ = false; stack_trace_.reset(); is_function_ = is_function; cancellation_manager_ = nullptr; diff --git a/tensorflow/core/common_runtime/eager/eager_operation.h b/tensorflow/core/common_runtime/eager/eager_operation.h index 327411e19c9..2e35dd43582 100644 --- a/tensorflow/core/common_runtime/eager/eager_operation.h +++ b/tensorflow/core/common_runtime/eager/eager_operation.h @@ -120,8 +120,6 @@ class EagerOperation : public ImmediateExecutionOperation { Status InputLength(const char* input_name, int* length) override; Status OutputLength(const char* output_name, int* length) override; - Status SetUseXla(bool enable) override; - void SetStackTrace(AbstractStackTrace stack_trace) override { stack_trace_ = stack_trace; } @@ -227,7 +225,6 @@ class EagerOperation : public ImmediateExecutionOperation { // updated accordingly. VariantDevice device_; - bool use_xla_ = false; absl::optional stack_trace_; bool is_function_; // Conceptually const, but can't be because of Reset bool colocation_exempt_; @@ -257,6 +254,11 @@ inline EagerOperation* OperationFromInterface( return down_cast(operation); } +inline const EagerOperation* OperationFromInterface( + const ImmediateExecutionOperation* operation) { + return down_cast(operation); +} + } // namespace tensorflow #endif // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_EAGER_OPERATION_H_ diff --git a/tensorflow/core/common_runtime/eager/execute.cc b/tensorflow/core/common_runtime/eager/execute.cc index 24582147479..cfb849c78f0 100644 --- a/tensorflow/core/common_runtime/eager/execute.cc +++ b/tensorflow/core/common_runtime/eager/execute.cc @@ -1070,11 +1070,6 @@ Status EagerExecute(EagerOperation* op, TensorHandle** retvals, [&] { return absl::StrCat("EagerExecute: ", op->Name()); }, profiler::TraceMeLevel::kInfo); - if (VariantDeviceIsCustom(op->Device())) { - return absl::get(op->Device()) - ->Execute(op, retvals, num_retvals); - } - if (!op->Executor().Async()) { // In sync mode, always clear error to maintain the same behavior as before. // TODO(b/141004939): Remove this. diff --git a/tensorflow/core/common_runtime/eager/placement_utils.cc b/tensorflow/core/common_runtime/eager/placement_utils.cc index 148c6c6ce03..619715f1cae 100644 --- a/tensorflow/core/common_runtime/eager/placement_utils.cc +++ b/tensorflow/core/common_runtime/eager/placement_utils.cc @@ -185,34 +185,35 @@ Status MaybePinToCustomDevice(VariantDevice* device, const EagerOperation& op) { if (VariantDeviceIsCustom(op.Device())) { *device = op.Device(); return Status::OK(); + } else if (!op.DeviceName().empty()) { + // Don't override explicit placements. + return Status::OK(); } + // Ops are placed on a custom device if there's no other explicit requested + // placement and there is only one custom device in the op inputs. if (!op.Inputs().empty()) { - // We keep track of what we've seen with devices instead of booleans to be - // able to provide a meaningful error message below. - VariantDevice first = op.Inputs()[0]->device(); - VariantDevice different = first; // A different input device, if any. - VariantDevice custom = first; // The first custom device seen, or an - // arbitrary non-custom device otherwise. - for (size_t i = 1; first == different && i < op.Inputs().size(); ++i) { - VariantDevice device = op.Inputs()[i]->device(); - if (device != first) { - different = device; - } - if (!VariantDeviceIsCustom(custom) && VariantDeviceIsCustom(device)) { - custom = device; - } - if (different != first && VariantDeviceIsCustom(custom)) { - return errors::InvalidArgument(absl::StrCat( - "If an operation has one of its inputs in a custom device, then " - "all inputs should be on that same device. Operation ", - op.Name(), " has one input in custom device ", - VariantDeviceName(custom), - " and at least one input in a different device ", - VariantDeviceName(custom == first ? different : first))); + CustomDevice* first = nullptr; + for (const TensorHandle* input : op.Inputs()) { + if (VariantDeviceIsCustom(input->device())) { + CustomDevice* current = absl::get(input->device()); + if (first == nullptr) { + first = current; + } else if (first != current) { + return errors::InvalidArgument(absl::StrCat( + "If an operation has one of its inputs in a custom device, then " + "all inputs should be on that same custom device or another " + "physical device. Operation ", + op.Name(), + " has one input in custom " + "device ", + VariantDeviceName(first), + " and at least one input in a different custom device ", + VariantDeviceName(current))); + } } } - if (different == first && VariantDeviceIsCustom(custom)) { + if (first != nullptr) { *device = first; return Status::OK(); } diff --git a/tensorflow/core/data/service/BUILD b/tensorflow/core/data/service/BUILD index 35971e39ea1..41a38152c69 100644 --- a/tensorflow/core/data/service/BUILD +++ b/tensorflow/core/data/service/BUILD @@ -49,6 +49,78 @@ tf_proto_library( ], ) +cc_library( + name = "credentials_factory", + srcs = ["credentials_factory.cc"], + hdrs = ["credentials_factory.h"], + deps = [ + "//tensorflow/core:lib", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + tf_grpc_cc_dependency(), + ], +) + +tf_cc_test( + name = "credentials_factory_test", + srcs = ["credentials_factory_test.cc"], + deps = [ + ":credentials_factory", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ], +) + +cc_library( + name = "data_service", + srcs = ["data_service.cc"], + hdrs = [ + "data_service.h", + ], + deps = [ + ":credentials_factory", + ":dispatcher_cc_grpc_proto", + ":dispatcher_proto_cc", + ":grpc_util", + ":worker_cc_grpc_proto", + ":worker_proto_cc", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + tf_grpc_cc_dependency(), + ], +) + +tf_cc_test( + name = "data_service_test", + srcs = ["data_service_test.cc"], + tags = ["no_windows"], + deps = [ + ":data_service", + ":dispatcher_cc_grpc_proto", + ":dispatcher_proto_cc", + ":grpc_dispatcher_impl", + ":grpc_util", + ":grpc_worker_impl", + ":local_credentials_factory", + ":server_lib", + ":test_cluster", + ":test_util", + ":worker_cc_grpc_proto", + ":worker_proto_cc", + "@com_google_absl//absl/strings", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core/data:compression_utils", + "//tensorflow/core/kernels/data:dataset_test_base", + tf_grpc_cc_dependency(), + ] + tf_protos_profiler_service(), +) + cc_library( name = "dataset_store", srcs = ["dataset_store.cc"], @@ -78,6 +150,14 @@ tf_cc_test( ], ) +cc_grpc_library( + name = "dispatcher_cc_grpc_proto", + srcs = [":dispatcher_proto"], + generate_mocks = True, + grpc_only = True, + deps = [":dispatcher_proto_cc"], +) + cc_library( name = "dispatcher_impl", srcs = ["dispatcher_impl.cc"], @@ -141,32 +221,14 @@ tf_cc_test( ) cc_library( - name = "worker_impl", - srcs = ["worker_impl.cc"], - hdrs = [ - "worker_impl.h", - ], + name = "grpc_dispatcher_impl", + srcs = ["grpc_dispatcher_impl.cc"], + hdrs = ["grpc_dispatcher_impl.h"], deps = [ - ":common_proto_cc", - ":credentials_factory", - ":data_service", ":dispatcher_cc_grpc_proto", - ":dispatcher_proto_cc", - ":grpc_util", - ":utils", - ":worker_proto_cc", - "//tensorflow/c:c_api_internal", - "//tensorflow/c:tf_status_helper", - "//tensorflow/core:core_cpu", - "//tensorflow/core:core_cpu_internal", - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", + ":dispatcher_impl", "//tensorflow/core:protos_all_cc", - "//tensorflow/core/data:dataset_proto_cc", - "//tensorflow/core/data:standalone", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/memory", + "//tensorflow/core/distributed_runtime/rpc:grpc_util", tf_grpc_cc_dependency(), ], ) @@ -196,6 +258,19 @@ tf_cc_test( ], ) +cc_library( + name = "grpc_worker_impl", + srcs = ["grpc_worker_impl.cc"], + hdrs = ["grpc_worker_impl.h"], + deps = [ + ":worker_cc_grpc_proto", + ":worker_impl", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/distributed_runtime/rpc:grpc_util", + tf_grpc_cc_dependency(), + ], +) + cc_library( name = "journal", srcs = ["journal.cc"], @@ -209,6 +284,15 @@ cc_library( ], ) +tf_proto_library( + name = "journal_proto", + srcs = ["journal.proto"], + cc_api_version = 2, + protodeps = [ + ":common_proto", + ], +) + tf_cc_test( name = "journal_test", srcs = ["journal_test.cc"], @@ -224,49 +308,50 @@ tf_cc_test( ], ) -tf_proto_library( - name = "journal_proto", - srcs = ["journal.proto"], - cc_api_version = 2, - protodeps = [ - ":common_proto", - ], -) - -cc_library( - name = "credentials_factory", - srcs = ["credentials_factory.cc"], - hdrs = ["credentials_factory.h"], - deps = [ - "//tensorflow/core:lib", - "@com_google_absl//absl/strings", - tf_grpc_cc_dependency(), - ], -) - -tf_cc_test( - name = "credentials_factory_test", - srcs = ["credentials_factory_test.cc"], - deps = [ - ":credentials_factory", - "//tensorflow/core:lib", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "//tensorflow/core:testlib", - ], -) - # Link this target to enable LOCAL credentials for the dataset service. cc_library( name = "local_credentials_factory", srcs = ["local_credentials_factory.cc"], deps = [ ":credentials_factory", + "@com_google_absl//absl/memory", tf_grpc_cc_dependency(), ], alwayslink = 1, ) +cc_library( + name = "server_lib", + srcs = ["server_lib.cc"], + hdrs = ["server_lib.h"], + linkstatic = True, + visibility = [ + "//visibility:public", + ], + deps = [ + ":credentials_factory", + ":grpc_dispatcher_impl", + ":grpc_util", + ":grpc_worker_impl", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:tensorflow", + "//tensorflow/core/profiler/rpc:profiler_service_impl", + tf_grpc_cc_dependency(), + ], + alwayslink = 1, +) + +# This needs to be cc_header_only_library - tf_pybind_cc_library_wrapper +# does not pull in the server_lib.h header. +cc_header_only_library( + name = "server_lib_headers_lib", + features = ["-parse_headers"], + deps = [ + ":server_lib", + ], +) + cc_library( name = "test_cluster", testonly = True, @@ -310,64 +395,6 @@ tf_cc_test( ], ) -cc_library( - name = "grpc_dispatcher_impl", - srcs = ["grpc_dispatcher_impl.cc"], - hdrs = ["grpc_dispatcher_impl.h"], - deps = [ - ":dispatcher_cc_grpc_proto", - ":dispatcher_impl", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core/distributed_runtime/rpc:grpc_util", - tf_grpc_cc_dependency(), - ], -) - -cc_library( - name = "grpc_worker_impl", - srcs = ["grpc_worker_impl.cc"], - hdrs = ["grpc_worker_impl.h"], - deps = [ - ":worker_cc_grpc_proto", - ":worker_impl", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core/distributed_runtime/rpc:grpc_util", - tf_grpc_cc_dependency(), - ], -) - -# This needs to be cc_header_only_library - tf_pybind_cc_library_wrapper -# does not pull in the server_lib.h header. -cc_header_only_library( - name = "server_lib_headers_lib", - features = ["-parse_headers"], - deps = [ - ":server_lib", - ], -) - -cc_library( - name = "server_lib", - srcs = ["server_lib.cc"], - hdrs = ["server_lib.h"], - linkstatic = True, - visibility = [ - "//visibility:public", - ], - deps = [ - ":credentials_factory", - ":grpc_dispatcher_impl", - ":grpc_util", - ":grpc_worker_impl", - "//tensorflow/core:lib", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core:tensorflow", - "//tensorflow/core/profiler/rpc:profiler_service_impl", - tf_grpc_cc_dependency(), - ], - alwayslink = 1, -) - cc_library( name = "utils", srcs = ["utils.cc"], @@ -394,62 +421,6 @@ tf_cc_test( ], ) -cc_library( - name = "data_service", - srcs = ["data_service.cc"], - hdrs = [ - "data_service.h", - ], - deps = [ - ":credentials_factory", - ":dispatcher_cc_grpc_proto", - ":dispatcher_proto_cc", - ":grpc_util", - ":worker_cc_grpc_proto", - ":worker_proto_cc", - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - "//tensorflow/core:protos_all_cc", - tf_grpc_cc_dependency(), - ], -) - -tf_cc_test( - name = "data_service_test", - srcs = ["data_service_test.cc"], - tags = ["no_windows"], - deps = [ - ":data_service", - ":dispatcher_cc_grpc_proto", - ":dispatcher_proto_cc", - ":grpc_dispatcher_impl", - ":grpc_util", - ":grpc_worker_impl", - ":local_credentials_factory", - ":server_lib", - ":test_cluster", - ":test_util", - ":worker_cc_grpc_proto", - ":worker_proto_cc", - "@com_google_absl//absl/strings", - "//tensorflow/core:lib", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "//tensorflow/core/data:compression_utils", - "//tensorflow/core/kernels/data:dataset_test_base", - tf_grpc_cc_dependency(), - ] + tf_protos_profiler_service(), -) - -cc_grpc_library( - name = "dispatcher_cc_grpc_proto", - srcs = [":dispatcher_proto"], - generate_mocks = True, - grpc_only = True, - deps = [":dispatcher_proto_cc"], -) - cc_grpc_library( name = "worker_cc_grpc_proto", srcs = [":worker_proto"], @@ -457,3 +428,34 @@ cc_grpc_library( grpc_only = True, deps = [":worker_proto_cc"], ) + +cc_library( + name = "worker_impl", + srcs = ["worker_impl.cc"], + hdrs = [ + "worker_impl.h", + ], + deps = [ + ":common_proto_cc", + ":credentials_factory", + ":data_service", + ":dispatcher_cc_grpc_proto", + ":dispatcher_proto_cc", + ":grpc_util", + ":utils", + ":worker_proto_cc", + "//tensorflow/c:c_api_internal", + "//tensorflow/c:tf_status_helper", + "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/data:dataset_proto_cc", + "//tensorflow/core/data:standalone", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/memory", + tf_grpc_cc_dependency(), + ], +) diff --git a/tensorflow/core/data/service/credentials_factory.cc b/tensorflow/core/data/service/credentials_factory.cc index 43b56d54d2e..e7a5177e7db 100644 --- a/tensorflow/core/data/service/credentials_factory.cc +++ b/tensorflow/core/data/service/credentials_factory.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/core/data/service/credentials_factory.h" +#include "absl/memory/memory.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/platform/mutex.h" @@ -35,9 +36,11 @@ CredentialsFactories& credentials_factories() { } } // namespace -void CredentialsFactory::Register(CredentialsFactory* factory) { +void CredentialsFactory::Register(std::unique_ptr factory) { mutex_lock l(*get_lock()); - if (!credentials_factories().insert({factory->Protocol(), factory}).second) { + if (!credentials_factories() + .insert({factory->Protocol(), factory.release()}) + .second) { LOG(ERROR) << "Two credentials factories are being registered with protocol " << factory->Protocol() << ". Which one gets used is undefined."; @@ -45,11 +48,11 @@ void CredentialsFactory::Register(CredentialsFactory* factory) { } Status CredentialsFactory::Get(absl::string_view protocol, - CredentialsFactory** out) { + CredentialsFactory*& out) { mutex_lock l(*get_lock()); auto it = credentials_factories().find(std::string(protocol)); if (it != credentials_factories().end()) { - *out = it->second; + out = it->second; return Status::OK(); } @@ -66,18 +69,18 @@ Status CredentialsFactory::Get(absl::string_view protocol, Status CredentialsFactory::CreateServerCredentials( absl::string_view protocol, - std::shared_ptr<::grpc::ServerCredentials>* out) { + std::shared_ptr<::grpc::ServerCredentials>& out) { CredentialsFactory* factory; - TF_RETURN_IF_ERROR(CredentialsFactory::Get(protocol, &factory)); + TF_RETURN_IF_ERROR(CredentialsFactory::Get(protocol, factory)); TF_RETURN_IF_ERROR(factory->CreateServerCredentials(out)); return Status::OK(); } Status CredentialsFactory::CreateClientCredentials( absl::string_view protocol, - std::shared_ptr<::grpc::ChannelCredentials>* out) { + std::shared_ptr<::grpc::ChannelCredentials>& out) { CredentialsFactory* factory; - TF_RETURN_IF_ERROR(CredentialsFactory::Get(protocol, &factory)); + TF_RETURN_IF_ERROR(CredentialsFactory::Get(protocol, factory)); TF_RETURN_IF_ERROR(factory->CreateClientCredentials(out)); return Status::OK(); } @@ -87,14 +90,14 @@ class InsecureCredentialsFactory : public CredentialsFactory { std::string Protocol() override { return "grpc"; } Status CreateServerCredentials( - std::shared_ptr<::grpc::ServerCredentials>* out) override { - *out = ::grpc::InsecureServerCredentials(); + std::shared_ptr<::grpc::ServerCredentials>& out) override { + out = ::grpc::InsecureServerCredentials(); return Status::OK(); } Status CreateClientCredentials( - std::shared_ptr<::grpc::ChannelCredentials>* out) override { - *out = ::grpc::InsecureChannelCredentials(); + std::shared_ptr<::grpc::ChannelCredentials>& out) override { + out = ::grpc::InsecureChannelCredentials(); return Status::OK(); } }; @@ -102,8 +105,8 @@ class InsecureCredentialsFactory : public CredentialsFactory { class InsecureCredentialsRegistrar { public: InsecureCredentialsRegistrar() { - auto factory = new InsecureCredentialsFactory(); - CredentialsFactory::Register(factory); + CredentialsFactory::Register( + absl::make_unique()); } }; static InsecureCredentialsRegistrar registrar; diff --git a/tensorflow/core/data/service/credentials_factory.h b/tensorflow/core/data/service/credentials_factory.h index 2407f64ee7f..754ed5473ac 100644 --- a/tensorflow/core/data/service/credentials_factory.h +++ b/tensorflow/core/data/service/credentials_factory.h @@ -34,33 +34,33 @@ class CredentialsFactory { // look up with `GetCredentials` to find the registered credentials factory. virtual std::string Protocol() = 0; - // Stores server credentials to `*out`. + // Stores server credentials to `out`. virtual Status CreateServerCredentials( - std::shared_ptr<::grpc::ServerCredentials>* out) = 0; + std::shared_ptr<::grpc::ServerCredentials>& out) = 0; - // Stores client credentials to `*out`. + // Stores client credentials to `out`. virtual Status CreateClientCredentials( - std::shared_ptr<::grpc::ChannelCredentials>* out) = 0; + std::shared_ptr<::grpc::ChannelCredentials>& out) = 0; // Registers a credentials factory. - static void Register(CredentialsFactory* factory); + static void Register(std::unique_ptr factory); // Creates server credentials using the credentials factory registered as - // `protocol`, and stores them to `*out`. + // `protocol`, and stores them to `out`. static Status CreateServerCredentials( absl::string_view protocol, - std::shared_ptr<::grpc::ServerCredentials>* out); + std::shared_ptr<::grpc::ServerCredentials>& out); // Creates client credentials using the credentials factory registered as - // `protocol`, and stores them to `*out`. + // `protocol`, and stores them to `out`. static Status CreateClientCredentials( absl::string_view protocol, - std::shared_ptr<::grpc::ChannelCredentials>* out); + std::shared_ptr<::grpc::ChannelCredentials>& out); private: - // Gets the credentials factory registered via `Register` for the specified - // protocol, and stores it to `*out`. - static Status Get(const absl::string_view protocol, CredentialsFactory** out); + // Borrows a pointer to the credentials factory registered via `Register` + // for the specified protocol, and stores it to `out`. + static Status Get(const absl::string_view protocol, CredentialsFactory*& out); }; } // namespace data diff --git a/tensorflow/core/data/service/credentials_factory_test.cc b/tensorflow/core/data/service/credentials_factory_test.cc index 507c553963a..fbaed581f77 100644 --- a/tensorflow/core/data/service/credentials_factory_test.cc +++ b/tensorflow/core/data/service/credentials_factory_test.cc @@ -32,43 +32,44 @@ class TestCredentialsFactory : public CredentialsFactory { std::string Protocol() override { return "test"; } Status CreateServerCredentials( - std::shared_ptr* out) override { + std::shared_ptr& out) override { return errors::Internal(kFailedToCreateServerCredentials); } Status CreateClientCredentials( - std::shared_ptr* out) override { + std::shared_ptr& out) override { return errors::Internal(kFailedToCreateClientCredentials); } }; } // namespace TEST(CredentialsFactory, Register) { - TestCredentialsFactory test_factory; - CredentialsFactory::Register(&test_factory); + auto test_factory = absl::make_unique(); + std::string protocol = test_factory->Protocol(); + CredentialsFactory::Register(std::move(test_factory)); std::shared_ptr server_credentials; ASSERT_EQ(errors::Internal(kFailedToCreateServerCredentials), - CredentialsFactory::CreateServerCredentials(test_factory.Protocol(), - &server_credentials)); + CredentialsFactory::CreateServerCredentials(protocol, + server_credentials)); std::shared_ptr client_credentials; ASSERT_EQ(errors::Internal(kFailedToCreateClientCredentials), - CredentialsFactory::CreateClientCredentials(test_factory.Protocol(), - &client_credentials)); + CredentialsFactory::CreateClientCredentials(protocol, + client_credentials)); } TEST(CredentialsFactory, DefaultGrpcProtocol) { std::shared_ptr server_credentials; TF_ASSERT_OK( - CredentialsFactory::CreateServerCredentials("grpc", &server_credentials)); + CredentialsFactory::CreateServerCredentials("grpc", server_credentials)); std::shared_ptr client_credentials; TF_ASSERT_OK( - CredentialsFactory::CreateClientCredentials("grpc", &client_credentials)); + CredentialsFactory::CreateClientCredentials("grpc", client_credentials)); } TEST(CredentialsFactory, MissingServerProtocol) { std::shared_ptr server_credentials; Status s = CredentialsFactory::CreateServerCredentials("unknown_protocol", - &server_credentials); + server_credentials); ASSERT_EQ(error::Code::NOT_FOUND, s.code()); ASSERT_TRUE( absl::StrContains(s.ToString(), @@ -79,7 +80,7 @@ TEST(CredentialsFactory, MissingServerProtocol) { TEST(CredentialsFactory, MissingClientProtocol) { std::shared_ptr client_credentials; Status s = CredentialsFactory::CreateClientCredentials("unknown_protocol", - &client_credentials); + client_credentials); ASSERT_EQ(error::Code::NOT_FOUND, s.code()); ASSERT_TRUE( absl::StrContains(s.ToString(), diff --git a/tensorflow/core/data/service/data_service.cc b/tensorflow/core/data/service/data_service.cc index 0f25805b653..c64daa69c7d 100644 --- a/tensorflow/core/data/service/data_service.cc +++ b/tensorflow/core/data/service/data_service.cc @@ -31,11 +31,11 @@ constexpr const char kParallelEpochs[] = "parallel_epochs"; constexpr const char kOneEpoch[] = "one_epoch"; } // namespace -Status ParseProcessingMode(const std::string& s, ProcessingMode* mode) { +Status ParseProcessingMode(const std::string& s, ProcessingMode& mode) { if (s == kParallelEpochs) { - *mode = ProcessingMode::PARALLEL_EPOCHS; + mode = ProcessingMode::PARALLEL_EPOCHS; } else if (s == kOneEpoch) { - *mode = ProcessingMode::ONE_EPOCH; + mode = ProcessingMode::ONE_EPOCH; } else { return errors::InvalidArgument("Unrecognized processing mode: ", s); } @@ -105,7 +105,7 @@ Status DataServiceDispatcherClient::GetDatasetDef(int64 dataset_id, } Status DataServiceDispatcherClient::RegisterDataset(GraphDef dataset, - int64* dataset_id) { + int64& dataset_id) { TF_RETURN_IF_ERROR(EnsureInitialized()); GetOrRegisterDatasetRequest req; *req.mutable_dataset()->mutable_graph() = dataset; @@ -115,13 +115,13 @@ Status DataServiceDispatcherClient::RegisterDataset(GraphDef dataset, if (!status.ok()) { return grpc_util::WrapError("Failed to register dataset", status); } - *dataset_id = resp.dataset_id(); + dataset_id = resp.dataset_id(); return Status::OK(); } Status DataServiceDispatcherClient::CreateJob(int64 dataset_id, ProcessingMode processing_mode, - int64* job_client_id) { + int64& job_client_id) { TF_RETURN_IF_ERROR(EnsureInitialized()); CreateJobRequest req; req.set_dataset_id(dataset_id); @@ -134,13 +134,13 @@ Status DataServiceDispatcherClient::CreateJob(int64 dataset_id, absl::StrCat("Failed to create job for dataset with id ", dataset_id), status); } - *job_client_id = resp.job_client_id(); + job_client_id = resp.job_client_id(); return Status::OK(); } Status DataServiceDispatcherClient::GetOrCreateJob( int64 dataset_id, ProcessingMode processing_mode, - const std::string& job_name, int job_name_index, int64* job_client_id) { + const std::string& job_name, int job_name_index, int64& job_client_id) { TF_RETURN_IF_ERROR(EnsureInitialized()); GetOrCreateJobRequest req; req.set_dataset_id(dataset_id); @@ -156,7 +156,7 @@ Status DataServiceDispatcherClient::GetOrCreateJob( dataset_id), status); } - *job_client_id = resp.job_client_id(); + job_client_id = resp.job_client_id(); return Status::OK(); } @@ -176,8 +176,8 @@ Status DataServiceDispatcherClient::ReleaseJobClient(int64 job_client_id) { } Status DataServiceDispatcherClient::GetTasks(int64 job_client_id, - std::vector* tasks, - bool* job_finished) { + std::vector& tasks, + bool& job_finished) { TF_RETURN_IF_ERROR(EnsureInitialized()); GetTasksRequest req; req.set_job_client_id(job_client_id); @@ -187,16 +187,16 @@ Status DataServiceDispatcherClient::GetTasks(int64 job_client_id, if (!s.ok()) { return grpc_util::WrapError("Failed to get tasks", s); } - tasks->clear(); + tasks.clear(); for (auto& task : resp.task_info()) { - tasks->push_back(task); + tasks.push_back(task); } - *job_finished = resp.job_finished(); + job_finished = resp.job_finished(); return Status::OK(); } Status DataServiceDispatcherClient::GetWorkers( - std::vector* workers) { + std::vector& workers) { TF_RETURN_IF_ERROR(EnsureInitialized()); GetWorkersRequest req; GetWorkersResponse resp; @@ -205,9 +205,9 @@ Status DataServiceDispatcherClient::GetWorkers( if (!s.ok()) { return grpc_util::WrapError("Failed to get workers", s); } - workers->clear(); + workers.clear(); for (auto& worker : resp.workers()) { - workers->push_back(worker); + workers.push_back(worker); } return Status::OK(); } @@ -219,15 +219,15 @@ Status DataServiceDispatcherClient::EnsureInitialized() { } std::shared_ptr credentials; TF_RETURN_IF_ERROR( - CredentialsFactory::CreateClientCredentials(protocol_, &credentials)); + CredentialsFactory::CreateClientCredentials(protocol_, credentials)); auto channel = grpc::CreateChannel(address_, credentials); stub_ = DispatcherService::NewStub(channel); return Status::OK(); } Status DataServiceWorkerClient::GetElement(int64 task_id, - CompressedElement* element, - bool* end_of_sequence) { + CompressedElement& element, + bool& end_of_sequence) { TF_RETURN_IF_ERROR(EnsureInitialized()); GetElementRequest req; req.set_task_id(task_id); @@ -237,9 +237,9 @@ Status DataServiceWorkerClient::GetElement(int64 task_id, if (!s.ok()) { return grpc_util::WrapError("Failed to get element", s); } - *end_of_sequence = resp.end_of_sequence(); - if (!*end_of_sequence) { - *element = std::move(*resp.mutable_compressed_element()); + end_of_sequence = resp.end_of_sequence(); + if (!end_of_sequence) { + element = std::move(*resp.mutable_compressed_element()); } return Status::OK(); } @@ -251,7 +251,7 @@ Status DataServiceWorkerClient::EnsureInitialized() { } std::shared_ptr credentials; TF_RETURN_IF_ERROR( - CredentialsFactory::CreateClientCredentials(protocol_, &credentials)); + CredentialsFactory::CreateClientCredentials(protocol_, credentials)); grpc::ChannelArguments args; args.SetMaxReceiveMessageSize(-1); auto channel = grpc::CreateCustomChannel(address_, credentials, args); @@ -261,20 +261,20 @@ Status DataServiceWorkerClient::EnsureInitialized() { Status CreateDataServiceDispatcherClient( const std::string& address, const std::string& protocol, - std::unique_ptr* out) { + std::unique_ptr& out) { auto client = absl::make_unique(address, protocol); TF_RETURN_IF_ERROR(client->Initialize()); - *out = std::move(client); + out = std::move(client); return Status::OK(); } Status CreateDataServiceWorkerClient( const std::string& address, const std::string& protocol, - std::unique_ptr* out) { + std::unique_ptr& out) { auto client = absl::make_unique(address, protocol); TF_RETURN_IF_ERROR(client->Initialize()); - *out = std::move(client); + out = std::move(client); return Status::OK(); } } // namespace data diff --git a/tensorflow/core/data/service/data_service.h b/tensorflow/core/data/service/data_service.h index 621e76da749..c5eb6a37269 100644 --- a/tensorflow/core/data/service/data_service.h +++ b/tensorflow/core/data/service/data_service.h @@ -34,8 +34,8 @@ enum class ProcessingMode : int64 { }; // Parses a string representing a processing mode and stores the result in -// *mode. Returns an InvalidArgument status if the string is not recognized. -Status ParseProcessingMode(const std::string& s, ProcessingMode* mode); +// `mode`. Returns an InvalidArgument status if the string is not recognized. +Status ParseProcessingMode(const std::string& s, ProcessingMode& mode); // Converts a processing mode to its corresponding string. std::string ProcessingModeToString(ProcessingMode mode); @@ -87,34 +87,34 @@ class DataServiceDispatcherClient : public DataServiceClientBase { Status GetDatasetDef(int64 dataset_id, DatasetDef& dataset_def); // Registers a dataset with the tf.data service, and stores the generated - // dataset id in `*dataset_id`. - Status RegisterDataset(GraphDef dataset, int64* dataset_id); + // dataset id in `dataset_id`. + Status RegisterDataset(GraphDef dataset, int64& dataset_id); // Creates a new tf.data service job for the specified dataset. The id for the - // created job will be stored in `*job_client_id`. + // created job will be stored in `job_client_id`. Status CreateJob(int64 dataset_id, ProcessingMode processing_mode, - int64* job_client_id); + int64& job_client_id); // Gets the job id for the job represented by the tuple - // (job_name, job_name_index), and stores the id in *job_client_id. If the + // (job_name, job_name_index), and stores the id in `job_client_id`. If the // job doesn't exist yet, it will be created. Status GetOrCreateJob(int64 dataset_id, ProcessingMode processing_mode, const std::string& job_name, int job_name_index, - int64* job_client_id); + int64& job_client_id); // Releases a job client id, indicating that the id will no longer be used to // read from the job. Status ReleaseJobClient(int64 job_client_id); // Queries the dispatcher for the tasks associated with the specified job. - // The tasks will be stored in *tasks, and whether the job is finished will - // be stored in `*job_finished`. - Status GetTasks(int64 job_client_id, std::vector* tasks, - bool* job_finished); + // The tasks will be stored in `tasks`, and whether the job is finished will + // be stored in `job_finished`. + Status GetTasks(int64 job_client_id, std::vector& tasks, + bool& job_finished); // Queries the dispatcher for its registered workers. The worker info will be - // stored in `*workers`. - Status GetWorkers(std::vector* workers); + // stored in `workers`. + Status GetWorkers(std::vector& workers); protected: Status EnsureInitialized() override; @@ -134,10 +134,10 @@ class DataServiceWorkerClient : public DataServiceClientBase { : DataServiceClientBase(address, protocol) {} // Fetches the next element for the specified task_id. The element's - // compressed tensors will be stored in *element. If no element is available, - // `*end_of_sequence` will be `true`, and `element` will be left unchanged. - Status GetElement(int64 task_id, CompressedElement* element, - bool* end_of_sequence); + // compressed tensors will be stored in `element`. If no element is available, + // `end_of_sequence` will be `true`, and `element` will be left unchanged. + Status GetElement(int64 task_id, CompressedElement& element, + bool& end_of_sequence); protected: Status EnsureInitialized() override; @@ -152,12 +152,12 @@ class DataServiceWorkerClient : public DataServiceClientBase { // Creates and initializes a new tf.data service dispatcher client. Status CreateDataServiceDispatcherClient( const std::string& address, const std::string& protocol, - std::unique_ptr* out); + std::unique_ptr& out); // Creates and initializes a new tf.data service worker client. Status CreateDataServiceWorkerClient( const std::string& address, const std::string& protocol, - std::unique_ptr* out); + std::unique_ptr& out); } // namespace data } // namespace tensorflow diff --git a/tensorflow/core/data/service/data_service_test.cc b/tensorflow/core/data/service/data_service_test.cc index 607570054b4..7b7e240b687 100644 --- a/tensorflow/core/data/service/data_service_test.cc +++ b/tensorflow/core/data/service/data_service_test.cc @@ -41,19 +41,19 @@ constexpr const char kProtocol[] = "grpc+local"; TEST(DataService, ParseParallelEpochsProcessingMode) { ProcessingMode mode; - TF_ASSERT_OK(ParseProcessingMode("parallel_epochs", &mode)); + TF_ASSERT_OK(ParseProcessingMode("parallel_epochs", mode)); EXPECT_EQ(mode, ProcessingMode::PARALLEL_EPOCHS); } TEST(DataService, ParseOneEpochProcessingMode) { ProcessingMode mode; - TF_ASSERT_OK(ParseProcessingMode("one_epoch", &mode)); + TF_ASSERT_OK(ParseProcessingMode("one_epoch", mode)); EXPECT_EQ(mode, ProcessingMode::ONE_EPOCH); } TEST(DataService, ParseInvalidProcessingMode) { ProcessingMode mode; - Status s = ParseProcessingMode("invalid", &mode); + Status s = ParseProcessingMode("invalid", mode); EXPECT_EQ(s.code(), error::Code::INVALID_ARGUMENT); } @@ -69,7 +69,7 @@ TEST(DataService, GetWorkers) { DataServiceDispatcherClient dispatcher(cluster.DispatcherAddress(), kProtocol); std::vector workers; - TF_EXPECT_OK(dispatcher.GetWorkers(&workers)); + TF_EXPECT_OK(dispatcher.GetWorkers(workers)); EXPECT_EQ(1, workers.size()); } diff --git a/tensorflow/core/data/service/dispatcher_impl.cc b/tensorflow/core/data/service/dispatcher_impl.cc index de5f63a01a0..dcd5cb5d80b 100644 --- a/tensorflow/core/data/service/dispatcher_impl.cc +++ b/tensorflow/core/data/service/dispatcher_impl.cc @@ -71,14 +71,14 @@ std::string DatasetKey(int64 id, uint64 fingerprint) { } Status CreateWorkerStub(const std::string& address, const std::string& protocol, - std::unique_ptr* stub) { + std::unique_ptr& stub) { ::grpc::ChannelArguments args; args.SetMaxReceiveMessageSize(-1); std::shared_ptr<::grpc::ChannelCredentials> credentials; TF_RETURN_IF_ERROR( - CredentialsFactory::CreateClientCredentials(protocol, &credentials)); + CredentialsFactory::CreateClientCredentials(protocol, credentials)); auto channel = ::grpc::CreateCustomChannel(address, credentials, args); - *stub = WorkerService::NewStub(channel); + stub = WorkerService::NewStub(channel); return Status::OK(); } } // namespace @@ -117,7 +117,7 @@ Status DataServiceDispatcherImpl::Start() { Update update; bool end_of_journal = false; FileJournalReader reader(Env::Default(), JournalDir(config_.work_dir())); - Status s = reader.Read(&update, &end_of_journal); + Status s = reader.Read(update, end_of_journal); if (errors::IsNotFound(s)) { LOG(INFO) << "No journal found. Starting dispatcher from new state."; } else if (!s.ok()) { @@ -125,7 +125,7 @@ Status DataServiceDispatcherImpl::Start() { } else { while (!end_of_journal) { TF_RETURN_IF_ERROR(ApplyWithoutJournaling(update)); - TF_RETURN_IF_ERROR(reader.Read(&update, &end_of_journal)); + TF_RETURN_IF_ERROR(reader.Read(update, end_of_journal)); } } // Initialize the journal writer in `Start` so that we fail fast in case it @@ -168,11 +168,11 @@ Status DataServiceDispatcherImpl::RegisterWorker( if (it != tasks_by_job.end()) { task = it->second; } else { - TF_RETURN_IF_ERROR(CreateTask(job, worker_address, &task)); + TF_RETURN_IF_ERROR(CreateTask(job, worker_address, task)); } TaskDef* task_def = response->add_tasks(); std::shared_ptr dataset; - TF_RETURN_IF_ERROR(state_.DatasetFromId(job->dataset_id, &dataset)); + TF_RETURN_IF_ERROR(state_.DatasetFromId(job->dataset_id, dataset)); std::string dataset_key = DatasetKey(dataset->dataset_id, dataset->fingerprint); if (config_.work_dir().empty()) { @@ -199,7 +199,7 @@ Status DataServiceDispatcherImpl::WorkerUpdate( for (auto& update : request->updates()) { int64 task_id = update.task_id(); std::shared_ptr task; - TF_RETURN_IF_ERROR(state_.TaskFromId(task_id, &task)); + TF_RETURN_IF_ERROR(state_.TaskFromId(task_id, task)); if (update.completed()) { if (task->finished) { VLOG(1) << "Received completion update for already-finished task " @@ -220,7 +220,7 @@ Status DataServiceDispatcherImpl::GetDatasetDef( const GetDatasetDefRequest* request, GetDatasetDefResponse* response) { mutex_lock l(mu_); std::shared_ptr dataset; - TF_RETURN_IF_ERROR(state_.DatasetFromId(request->dataset_id(), &dataset)); + TF_RETURN_IF_ERROR(state_.DatasetFromId(request->dataset_id(), dataset)); std::string key = DatasetKey(dataset->dataset_id, dataset->fingerprint); std::shared_ptr dataset_def; TF_RETURN_IF_ERROR(dataset_store_->Get(key, dataset_def)); @@ -242,7 +242,7 @@ Status DataServiceDispatcherImpl::GetOrRegisterDataset( VLOG(4) << "Registering dataset graph: " << graph.DebugString(); #endif std::shared_ptr dataset; - Status s = state_.DatasetFromFingerprint(fingerprint, &dataset); + Status s = state_.DatasetFromFingerprint(fingerprint, dataset); if (s.ok()) { int64 id = dataset->dataset_id; VLOG(3) << "Received duplicate RegisterDataset request with fingerprint " @@ -254,7 +254,7 @@ Status DataServiceDispatcherImpl::GetOrRegisterDataset( } int64 id; - TF_RETURN_IF_ERROR(RegisterDataset(fingerprint, request->dataset(), &id)); + TF_RETURN_IF_ERROR(RegisterDataset(fingerprint, request->dataset(), id)); response->set_dataset_id(id); VLOG(3) << "Registered new dataset with id " << id; return Status::OK(); @@ -262,15 +262,15 @@ Status DataServiceDispatcherImpl::GetOrRegisterDataset( Status DataServiceDispatcherImpl::RegisterDataset(uint64 fingerprint, const DatasetDef& dataset, - int64* dataset_id) + int64& dataset_id) EXCLUSIVE_LOCKS_REQUIRED(mu_) { - *dataset_id = state_.NextAvailableDatasetId(); + dataset_id = state_.NextAvailableDatasetId(); Update update; RegisterDatasetUpdate* register_dataset = update.mutable_register_dataset(); - register_dataset->set_dataset_id(*dataset_id); + register_dataset->set_dataset_id(dataset_id); register_dataset->set_fingerprint(fingerprint); TF_RETURN_IF_ERROR( - dataset_store_->Put(DatasetKey(*dataset_id, fingerprint), dataset)); + dataset_store_->Put(DatasetKey(dataset_id, fingerprint), dataset)); return Apply(update); } @@ -284,11 +284,11 @@ Status DataServiceDispatcherImpl::CreateJob(const CreateJobRequest* request, { mutex_lock l(mu_); TF_RETURN_IF_ERROR(CreateJob(request->dataset_id(), processing_mode, - absl::optional(), &job)); + absl::optional(), job)); int64 job_client_id; TF_RETURN_IF_ERROR(AcquireJobClientId(job, job_client_id)); response->set_job_client_id(job_client_id); - TF_RETURN_IF_ERROR(CreateTasksForJob(job, &tasks)); + TF_RETURN_IF_ERROR(CreateTasksForJob(job, tasks)); } TF_RETURN_IF_ERROR(AssignTasks(tasks)); @@ -309,7 +309,7 @@ Status DataServiceDispatcherImpl::GetOrCreateJob( std::vector> tasks; { mutex_lock l(mu_); - Status s = state_.NamedJobByKey(key, &job); + Status s = state_.NamedJobByKey(key, job); if (s.ok()) { TF_RETURN_IF_ERROR(ValidateMatchingJob(job, requested_processing_mode, request->dataset_id())); @@ -323,11 +323,11 @@ Status DataServiceDispatcherImpl::GetOrCreateJob( return s; } TF_RETURN_IF_ERROR( - CreateJob(request->dataset_id(), requested_processing_mode, key, &job)); + CreateJob(request->dataset_id(), requested_processing_mode, key, job)); int64 job_client_id; TF_RETURN_IF_ERROR(AcquireJobClientId(job, job_client_id)); response->set_job_client_id(job_client_id); - TF_RETURN_IF_ERROR(CreateTasksForJob(job, &tasks)); + TF_RETURN_IF_ERROR(CreateTasksForJob(job, tasks)); } TF_RETURN_IF_ERROR(AssignTasks(tasks)); VLOG(3) << "Created job " << job->job_id << " for dataset " @@ -376,7 +376,7 @@ Status DataServiceDispatcherImpl::ValidateMatchingJob( Status DataServiceDispatcherImpl::CreateJob( int64 dataset_id, ProcessingMode processing_mode, - absl::optional named_job_key, std::shared_ptr* job) + absl::optional named_job_key, std::shared_ptr& job) EXCLUSIVE_LOCKS_REQUIRED(mu_) { switch (processing_mode) { case ProcessingMode::PARALLEL_EPOCHS: @@ -421,22 +421,22 @@ Status DataServiceDispatcherImpl::AcquireJobClientId( Status DataServiceDispatcherImpl::CreateTasksForJob( std::shared_ptr job, - std::vector>* tasks) + std::vector>& tasks) EXCLUSIVE_LOCKS_REQUIRED(mu_) { std::vector> workers = state_.ListWorkers(); - tasks->clear(); - tasks->reserve(workers.size()); + tasks.clear(); + tasks.reserve(workers.size()); for (const auto& worker : workers) { std::shared_ptr task; - TF_RETURN_IF_ERROR(CreateTask(job, worker->address, &task)); - tasks->push_back(task); + TF_RETURN_IF_ERROR(CreateTask(job, worker->address, task)); + tasks.push_back(task); } return Status::OK(); } Status DataServiceDispatcherImpl::CreateTask(std::shared_ptr job, const std::string& worker_address, - std::shared_ptr* task) + std::shared_ptr& task) EXCLUSIVE_LOCKS_REQUIRED(mu_) { int64 task_id = state_.NextAvailableTaskId(); Update update; @@ -459,19 +459,19 @@ Status DataServiceDispatcherImpl::AssignTasks( } Status DataServiceDispatcherImpl::GetOrCreateWorkerStub( - const std::string& worker_address, WorkerService::Stub** out_stub) + const std::string& worker_address, WorkerService::Stub*& out_stub) LOCKS_EXCLUDED(mu_) { { mutex_lock l(mu_); auto it = worker_stubs_.find(worker_address); if (it != worker_stubs_.end()) { - *out_stub = it->second.get(); + out_stub = it->second.get(); return Status::OK(); } } std::unique_ptr stub; TF_RETURN_IF_ERROR( - CreateWorkerStub(worker_address, config_.protocol(), &stub)); + CreateWorkerStub(worker_address, config_.protocol(), stub)); { mutex_lock l(mu_); // A concurrent call could have already created the stub. @@ -479,7 +479,7 @@ Status DataServiceDispatcherImpl::GetOrCreateWorkerStub( if (worker == nullptr) { worker = std::move(stub); } - *out_stub = worker.get(); + out_stub = worker.get(); } return Status::OK(); } @@ -495,7 +495,7 @@ Status DataServiceDispatcherImpl::AssignTask(std::shared_ptr task) { mutex_lock l(mu_); std::shared_ptr dataset; - TF_RETURN_IF_ERROR(state_.DatasetFromId(task->dataset_id, &dataset)); + TF_RETURN_IF_ERROR(state_.DatasetFromId(task->dataset_id, dataset)); std::string dataset_key = DatasetKey(dataset->dataset_id, dataset->fingerprint); if (config_.work_dir().empty()) { @@ -511,7 +511,7 @@ Status DataServiceDispatcherImpl::AssignTask(std::shared_ptr task) task_def->set_task_id(task->task_id); ProcessTaskResponse resp; WorkerService::Stub* stub; - TF_RETURN_IF_ERROR(GetOrCreateWorkerStub(task->worker_address, &stub)); + TF_RETURN_IF_ERROR(GetOrCreateWorkerStub(task->worker_address, stub)); grpc::Status s = stub->ProcessTask(&client_ctx, req, &resp); if (!s.ok()) { return grpc_util::WrapError( @@ -530,7 +530,7 @@ Status DataServiceDispatcherImpl::GetTasks(const GetTasksRequest* request, std::shared_ptr job; TF_RETURN_IF_ERROR(state_.JobForJobClientId(request->job_client_id(), job)); std::vector> tasks; - TF_RETURN_IF_ERROR(state_.TasksForJob(job->job_id, &tasks)); + TF_RETURN_IF_ERROR(state_.TasksForJob(job->job_id, tasks)); for (const auto& task : tasks) { TaskInfo* task_info = response->mutable_task_info()->Add(); task_info->set_worker_address(task->worker_address); diff --git a/tensorflow/core/data/service/dispatcher_impl.h b/tensorflow/core/data/service/dispatcher_impl.h index 34cdc678183..2cf341a812e 100644 --- a/tensorflow/core/data/service/dispatcher_impl.h +++ b/tensorflow/core/data/service/dispatcher_impl.h @@ -77,19 +77,20 @@ class DataServiceDispatcherImpl { private: // Registers a dataset with the given fingerprint, storing the new dataset's - // id in `*dataset-id`. + // id in `dataset_id`. Status RegisterDataset(uint64 fingerprint, const DatasetDef& dataset, - int64* dataset_id) EXCLUSIVE_LOCKS_REQUIRED(mu_); + int64& dataset_id) EXCLUSIVE_LOCKS_REQUIRED(mu_); // Gets a worker's stub from `worker_stubs_`, or if none exists, creates a - // stub and stores it in `worker_stubs_`. + // stub and stores it in `worker_stubs_`. A borrowed pointer to the stub is + // stored in `out_stub`. Status GetOrCreateWorkerStub(const std::string& worker_address, - WorkerService::Stub** out_stub) + WorkerService::Stub*& out_stub) LOCKS_EXCLUDED(mu_); - // Creates a job and stores it in `*job`. This method updates the + // Creates a job and stores it in `job`. This method updates the // dispatcher state with the new job, but does not assign tasks to workers. Status CreateJob(int64 dataset_id, ProcessingMode processing_mode, absl::optional named_job_key, - std::shared_ptr* job) + std::shared_ptr& job) EXCLUSIVE_LOCKS_REQUIRED(mu_); // Acquires a job client id to read from the given job and sets // `job_client_id`. @@ -97,17 +98,17 @@ class DataServiceDispatcherImpl { const std::shared_ptr& job, int64& job_client_id) EXCLUSIVE_LOCKS_REQUIRED(mu_); // Creates one task for each worker, for the given job. The created tasks are - // stored in `*tasks`. This method only updates dispatcher metadata with the + // stored in `tasks`. This method only updates dispatcher metadata with the // new tasks, but doesn't assign the tasks to the workers. Status CreateTasksForJob( std::shared_ptr job, - std::vector>* tasks) + std::vector>& tasks) EXCLUSIVE_LOCKS_REQUIRED(mu_); - // Creates a new task for a job, storing the created task in `*task`. + // Creates a new task for a job, storing the created task in `task`. Status CreateTask(std::shared_ptr job, const std::string& worker_address, - std::shared_ptr* task); + std::shared_ptr& task); // Assigns the list of tasks to the workers indicated by their // `worker_address` fields. Status AssignTasks( diff --git a/tensorflow/core/data/service/dispatcher_state.cc b/tensorflow/core/data/service/dispatcher_state.cc index b302810f715..3afee88262e 100644 --- a/tensorflow/core/data/service/dispatcher_state.cc +++ b/tensorflow/core/data/service/dispatcher_state.cc @@ -25,7 +25,7 @@ namespace data { DispatcherState::DispatcherState() {} -Status DispatcherState::Apply(Update update) { +Status DispatcherState::Apply(const Update& update) { switch (update.update_type_case()) { case Update::kRegisterDataset: RegisterDataset(update.register_dataset()); @@ -151,32 +151,32 @@ int64 DispatcherState::NextAvailableDatasetId() const { } Status DispatcherState::DatasetFromId( - int64 id, std::shared_ptr* dataset) const { + int64 id, std::shared_ptr& dataset) const { auto it = datasets_by_id_.find(id); if (it == datasets_by_id_.end()) { return errors::NotFound("Dataset id ", id, " not found"); } - *dataset = it->second; + dataset = it->second; return Status::OK(); } Status DispatcherState::DatasetFromFingerprint( - uint64 fingerprint, std::shared_ptr* dataset) const { + uint64 fingerprint, std::shared_ptr& dataset) const { auto it = datasets_by_fingerprint_.find(fingerprint); if (it == datasets_by_fingerprint_.end()) { return errors::NotFound("Dataset fingerprint ", fingerprint, " not found"); } - *dataset = it->second; + dataset = it->second; return Status::OK(); } Status DispatcherState::WorkerFromAddress( - const std::string& address, std::shared_ptr* worker) const { + const std::string& address, std::shared_ptr& worker) const { auto it = workers_.find(address); if (it == workers_.end()) { return errors::NotFound("Worker with address ", address, " not found."); } - *worker = it->second; + worker = it->second; return Status::OK(); } @@ -201,23 +201,23 @@ DispatcherState::ListJobs() { } Status DispatcherState::JobFromId(int64 id, - std::shared_ptr* job) const { + std::shared_ptr& job) const { auto it = jobs_.find(id); if (it == jobs_.end()) { return errors::NotFound("Job id ", id, " not found"); } - *job = it->second; + job = it->second; return Status::OK(); } Status DispatcherState::NamedJobByKey(NamedJobKey named_job_key, - std::shared_ptr* job) const { + std::shared_ptr& job) const { auto it = named_jobs_.find(named_job_key); if (it == named_jobs_.end()) { return errors::NotFound("Named job key (", named_job_key.name, ", ", named_job_key.index, ") not found"); } - *job = it->second; + job = it->second; return Status::OK(); } @@ -239,25 +239,25 @@ int64 DispatcherState::NextAvailableJobClientId() const { } Status DispatcherState::TaskFromId(int64 id, - std::shared_ptr* task) const { + std::shared_ptr& task) const { auto it = tasks_.find(id); if (it == tasks_.end()) { return errors::NotFound("Task ", id, " not found"); } - *task = it->second; + task = it->second; return Status::OK(); } Status DispatcherState::TasksForJob( - int64 job_id, std::vector>* tasks) const { + int64 job_id, std::vector>& tasks) const { auto it = tasks_by_job_.find(job_id); if (it == tasks_by_job_.end()) { return errors::NotFound("Job ", job_id, " not found"); } - tasks->clear(); - tasks->reserve(it->second.size()); + tasks.clear(); + tasks.reserve(it->second.size()); for (const auto& task : it->second) { - tasks->push_back(task); + tasks.push_back(task); } return Status::OK(); } diff --git a/tensorflow/core/data/service/dispatcher_state.h b/tensorflow/core/data/service/dispatcher_state.h index d2080c8e10c..59d7f192fb1 100644 --- a/tensorflow/core/data/service/dispatcher_state.h +++ b/tensorflow/core/data/service/dispatcher_state.h @@ -56,7 +56,7 @@ class DispatcherState { DispatcherState& operator=(const DispatcherState&) = delete; // Applies the given update to the dispatcher's state. - Status Apply(Update update); + Status Apply(const Update& update); // A dataset registered with the dispatcher. struct Dataset { @@ -129,15 +129,15 @@ class DispatcherState { // Returns the next available dataset id. int64 NextAvailableDatasetId() const; // Gets a dataset by id. Returns NOT_FOUND if there is no such dataset. - Status DatasetFromId(int64 id, std::shared_ptr* dataset) const; + Status DatasetFromId(int64 id, std::shared_ptr& dataset) const; // Gets a dataset by fingerprint. Returns NOT_FOUND if there is no such // dataset. Status DatasetFromFingerprint(uint64 fingerprint, - std::shared_ptr* dataset) const; + std::shared_ptr& dataset) const; // Gets a worker by address. Returns NOT_FOUND if there is no such worker. Status WorkerFromAddress(const std::string& address, - std::shared_ptr* worker) const; + std::shared_ptr& worker) const; // Lists all workers registered with the dispatcher. std::vector> ListWorkers() const; @@ -146,9 +146,9 @@ class DispatcherState { // Returns a list of all jobs. std::vector> ListJobs(); // Gets a job by id. Returns NOT_FOUND if there is no such job. - Status JobFromId(int64 id, std::shared_ptr* job) const; + Status JobFromId(int64 id, std::shared_ptr& job) const; // Gets a named job by key. Returns NOT_FOUND if there is no such job. - Status NamedJobByKey(NamedJobKey key, std::shared_ptr* job) const; + Status NamedJobByKey(NamedJobKey key, std::shared_ptr& job) const; // Returns the job associated with the given job client id. Returns NOT_FOUND // if the job_client_id is unknown or has been released. @@ -160,12 +160,12 @@ class DispatcherState { // Returns the next available task id. int64 NextAvailableTaskId() const; // Gets a task by id. Returns NOT_FOUND if there is no such task. - Status TaskFromId(int64 id, std::shared_ptr* task) const; - // Stores a list of all tasks for the given job to `*tasks`. Returns NOT_FOUND + Status TaskFromId(int64 id, std::shared_ptr& task) const; + // Stores a list of all tasks for the given job to `tasks`. Returns NOT_FOUND // if there is no such job. Status TasksForJob(int64 job_id, - std::vector>* tasks) const; - // Stores a list of all tasks for the given worker to `*tasks`. Returns + std::vector>& tasks) const; + // Stores a list of all tasks for the given worker to `tasks`. Returns // NOT_FOUND if there is no such worker. Status TasksForWorker(const absl::string_view worker_address, std::vector>& tasks) const; diff --git a/tensorflow/core/data/service/dispatcher_state_test.cc b/tensorflow/core/data/service/dispatcher_state_test.cc index 1676fc704f4..43a47f8581f 100644 --- a/tensorflow/core/data/service/dispatcher_state_test.cc +++ b/tensorflow/core/data/service/dispatcher_state_test.cc @@ -36,39 +36,39 @@ using Task = DispatcherState::Task; using ::testing::IsEmpty; using ::testing::SizeIs; -Status RegisterDataset(int64 id, uint64 fingerprint, DispatcherState* state) { +Status RegisterDataset(int64 id, uint64 fingerprint, DispatcherState& state) { Update update; RegisterDatasetUpdate* register_dataset = update.mutable_register_dataset(); register_dataset->set_dataset_id(id); register_dataset->set_fingerprint(fingerprint); - TF_RETURN_IF_ERROR(state->Apply(update)); + TF_RETURN_IF_ERROR(state.Apply(update)); return Status::OK(); } -Status RegisterDataset(int64 id, DispatcherState* state) { +Status RegisterDataset(int64 id, DispatcherState& state) { return RegisterDataset(id, /*fingerprint=*/1, state); } -Status RegisterWorker(std::string worker_address, DispatcherState* state) { +Status RegisterWorker(std::string worker_address, DispatcherState& state) { Update update; update.mutable_register_worker()->set_worker_address(worker_address); - TF_RETURN_IF_ERROR(state->Apply(update)); + TF_RETURN_IF_ERROR(state.Apply(update)); return Status::OK(); } Status CreateAnonymousJob(int64 job_id, int64 dataset_id, - DispatcherState* state) { + DispatcherState& state) { Update update; CreateJobUpdate* create_job = update.mutable_create_job(); create_job->set_job_id(job_id); create_job->set_dataset_id(dataset_id); create_job->set_processing_mode(ProcessingModeDef::PARALLEL_EPOCHS); - TF_RETURN_IF_ERROR(state->Apply(update)); + TF_RETURN_IF_ERROR(state.Apply(update)); return Status::OK(); } Status CreateNamedJob(int64 job_id, int64 dataset_id, NamedJobKey named_job_key, - DispatcherState* state) { + DispatcherState& state) { Update update; CreateJobUpdate* create_job = update.mutable_create_job(); create_job->set_job_id(job_id); @@ -77,49 +77,49 @@ Status CreateNamedJob(int64 job_id, int64 dataset_id, NamedJobKey named_job_key, NamedJobKeyDef* key = create_job->mutable_named_job_key(); key->set_name(named_job_key.name); key->set_index(named_job_key.index); - TF_RETURN_IF_ERROR(state->Apply(update)); + TF_RETURN_IF_ERROR(state.Apply(update)); return Status::OK(); } Status AcquireJobClientId(int64 job_id, int64 job_client_id, - DispatcherState* state) { + DispatcherState& state) { Update update; AcquireJobClientUpdate* acquire_job_client = update.mutable_acquire_job_client(); acquire_job_client->set_job_id(job_id); acquire_job_client->set_job_client_id(job_client_id); - TF_RETURN_IF_ERROR(state->Apply(update)); + TF_RETURN_IF_ERROR(state.Apply(update)); return Status::OK(); } Status ReleaseJobClientId(int64 job_client_id, int64 release_time, - DispatcherState* state) { + DispatcherState& state) { Update update; ReleaseJobClientUpdate* release_job_client = update.mutable_release_job_client(); release_job_client->set_job_client_id(job_client_id); release_job_client->set_time_micros(release_time); - TF_RETURN_IF_ERROR(state->Apply(update)); + TF_RETURN_IF_ERROR(state.Apply(update)); return Status::OK(); } Status CreateTask(int64 task_id, int64 job_id, int64 dataset_id, - const std::string& worker_address, DispatcherState* state) { + const std::string& worker_address, DispatcherState& state) { Update update; CreateTaskUpdate* create_task = update.mutable_create_task(); create_task->set_task_id(task_id); create_task->set_job_id(job_id); create_task->set_dataset_id(dataset_id); create_task->set_worker_address(worker_address); - TF_RETURN_IF_ERROR(state->Apply(update)); + TF_RETURN_IF_ERROR(state.Apply(update)); return Status::OK(); } -Status FinishTask(int64 task_id, DispatcherState* state) { +Status FinishTask(int64 task_id, DispatcherState& state) { Update update; FinishTaskUpdate* finish_task = update.mutable_finish_task(); finish_task->set_task_id(task_id); - TF_RETURN_IF_ERROR(state->Apply(update)); + TF_RETURN_IF_ERROR(state.Apply(update)); return Status::OK(); } } // namespace @@ -128,17 +128,17 @@ TEST(DispatcherState, RegisterDataset) { int64 id = 10; uint64 fingerprint = 20; DispatcherState state; - TF_EXPECT_OK(RegisterDataset(id, fingerprint, &state)); + TF_EXPECT_OK(RegisterDataset(id, fingerprint, state)); EXPECT_EQ(state.NextAvailableDatasetId(), id + 1); { std::shared_ptr dataset; - TF_EXPECT_OK(state.DatasetFromFingerprint(fingerprint, &dataset)); + TF_EXPECT_OK(state.DatasetFromFingerprint(fingerprint, dataset)); EXPECT_EQ(dataset->dataset_id, id); } { std::shared_ptr dataset; - TF_EXPECT_OK(state.DatasetFromId(id, &dataset)); + TF_EXPECT_OK(state.DatasetFromId(id, dataset)); EXPECT_EQ(dataset->fingerprint, fingerprint); } } @@ -146,14 +146,14 @@ TEST(DispatcherState, RegisterDataset) { TEST(DispatcherState, MissingDatasetId) { DispatcherState state; std::shared_ptr dataset; - Status s = state.DatasetFromId(0, &dataset); + Status s = state.DatasetFromId(0, dataset); EXPECT_EQ(s.code(), error::NOT_FOUND); } TEST(DispatcherState, MissingDatasetFingerprint) { DispatcherState state; std::shared_ptr dataset; - Status s = state.DatasetFromFingerprint(0, &dataset); + Status s = state.DatasetFromFingerprint(0, dataset); EXPECT_EQ(s.code(), error::NOT_FOUND); } @@ -161,7 +161,7 @@ TEST(DispatcherState, NextAvailableDatasetId) { DispatcherState state; int64 id = state.NextAvailableDatasetId(); uint64 fingerprint = 20; - TF_EXPECT_OK(RegisterDataset(id, fingerprint, &state)); + TF_EXPECT_OK(RegisterDataset(id, fingerprint, state)); EXPECT_NE(state.NextAvailableDatasetId(), id); EXPECT_EQ(state.NextAvailableDatasetId(), state.NextAvailableDatasetId()); } @@ -169,9 +169,9 @@ TEST(DispatcherState, NextAvailableDatasetId) { TEST(DispatcherState, RegisterWorker) { DispatcherState state; std::string address = "test_worker_address"; - TF_EXPECT_OK(RegisterWorker(address, &state)); + TF_EXPECT_OK(RegisterWorker(address, state)); std::shared_ptr worker; - TF_EXPECT_OK(state.WorkerFromAddress(address, &worker)); + TF_EXPECT_OK(state.WorkerFromAddress(address, worker)); EXPECT_EQ(worker->address, address); } @@ -183,12 +183,12 @@ TEST(DispatcherState, ListWorkers) { std::vector> workers = state.ListWorkers(); EXPECT_THAT(workers, IsEmpty()); } - TF_EXPECT_OK(RegisterWorker(address_1, &state)); + TF_EXPECT_OK(RegisterWorker(address_1, state)); { std::vector> workers = state.ListWorkers(); EXPECT_THAT(workers, SizeIs(1)); } - TF_EXPECT_OK(RegisterWorker(address_2, &state)); + TF_EXPECT_OK(RegisterWorker(address_2, state)); { std::vector> workers = state.ListWorkers(); EXPECT_THAT(workers, SizeIs(2)); @@ -198,7 +198,7 @@ TEST(DispatcherState, ListWorkers) { TEST(DispatcherState, MissingWorker) { DispatcherState state; std::shared_ptr worker; - Status s = state.WorkerFromAddress("test_worker_address", &worker); + Status s = state.WorkerFromAddress("test_worker_address", worker); EXPECT_EQ(s.code(), error::NOT_FOUND); } @@ -213,15 +213,15 @@ TEST(DispatcherState, AnonymousJob) { int64 job_id = 3; int64 dataset_id = 10; DispatcherState state; - TF_EXPECT_OK(RegisterDataset(dataset_id, &state)); - TF_EXPECT_OK(CreateAnonymousJob(job_id, dataset_id, &state)); + TF_EXPECT_OK(RegisterDataset(dataset_id, state)); + TF_EXPECT_OK(CreateAnonymousJob(job_id, dataset_id, state)); std::shared_ptr job; - TF_EXPECT_OK(state.JobFromId(job_id, &job)); + TF_EXPECT_OK(state.JobFromId(job_id, job)); EXPECT_EQ(state.NextAvailableJobId(), job_id + 1); EXPECT_EQ(job->dataset_id, dataset_id); EXPECT_EQ(job->job_id, job_id); std::vector> tasks; - TF_EXPECT_OK(state.TasksForJob(job_id, &tasks)); + TF_EXPECT_OK(state.TasksForJob(job_id, tasks)); EXPECT_THAT(tasks, IsEmpty()); EXPECT_FALSE(job->finished); } @@ -230,11 +230,11 @@ TEST(DispatcherState, NamedJob) { int64 job_id = 3; int64 dataset_id = 10; DispatcherState state; - TF_EXPECT_OK(RegisterDataset(dataset_id, &state)); + TF_EXPECT_OK(RegisterDataset(dataset_id, state)); NamedJobKey named_job_key("test", 1); - TF_EXPECT_OK(CreateNamedJob(job_id, dataset_id, named_job_key, &state)); + TF_EXPECT_OK(CreateNamedJob(job_id, dataset_id, named_job_key, state)); std::shared_ptr job; - TF_EXPECT_OK(state.NamedJobByKey(named_job_key, &job)); + TF_EXPECT_OK(state.NamedJobByKey(named_job_key, job)); EXPECT_EQ(state.NextAvailableJobId(), job_id + 1); EXPECT_EQ(job->dataset_id, dataset_id); EXPECT_EQ(job->job_id, job_id); @@ -247,13 +247,13 @@ TEST(DispatcherState, CreateTask) { int64 task_id = 8; std::string worker_address = "test_worker_address"; DispatcherState state; - TF_EXPECT_OK(RegisterDataset(dataset_id, &state)); - TF_EXPECT_OK(CreateAnonymousJob(job_id, dataset_id, &state)); - TF_EXPECT_OK(CreateTask(task_id, job_id, dataset_id, worker_address, &state)); + TF_EXPECT_OK(RegisterDataset(dataset_id, state)); + TF_EXPECT_OK(CreateAnonymousJob(job_id, dataset_id, state)); + TF_EXPECT_OK(CreateTask(task_id, job_id, dataset_id, worker_address, state)); EXPECT_EQ(state.NextAvailableTaskId(), task_id + 1); { std::shared_ptr task; - TF_EXPECT_OK(state.TaskFromId(task_id, &task)); + TF_EXPECT_OK(state.TaskFromId(task_id, task)); EXPECT_EQ(task->task_id, task_id); EXPECT_EQ(task->job_id, job_id); EXPECT_EQ(task->dataset_id, dataset_id); @@ -261,7 +261,7 @@ TEST(DispatcherState, CreateTask) { } { std::vector> tasks; - TF_EXPECT_OK(state.TasksForJob(job_id, &tasks)); + TF_EXPECT_OK(state.TasksForJob(job_id, tasks)); EXPECT_THAT(tasks, SizeIs(1)); } { @@ -278,15 +278,15 @@ TEST(DispatcherState, CreateTasksForSameJob) { int64 task_id_2 = 9; std::string worker_address = "test_worker_address"; DispatcherState state; - TF_EXPECT_OK(RegisterDataset(dataset_id, &state)); - TF_EXPECT_OK(CreateAnonymousJob(job_id, dataset_id, &state)); + TF_EXPECT_OK(RegisterDataset(dataset_id, state)); + TF_EXPECT_OK(CreateAnonymousJob(job_id, dataset_id, state)); TF_EXPECT_OK( - CreateTask(task_id_1, job_id, dataset_id, worker_address, &state)); + CreateTask(task_id_1, job_id, dataset_id, worker_address, state)); TF_EXPECT_OK( - CreateTask(task_id_2, job_id, dataset_id, worker_address, &state)); + CreateTask(task_id_2, job_id, dataset_id, worker_address, state)); { std::vector> tasks; - TF_EXPECT_OK(state.TasksForJob(job_id, &tasks)); + TF_EXPECT_OK(state.TasksForJob(job_id, tasks)); EXPECT_THAT(tasks, SizeIs(2)); } } @@ -299,21 +299,21 @@ TEST(DispatcherState, CreateTasksForDifferentJobs) { int64 task_id_2 = 9; std::string worker_address = "test_worker_address"; DispatcherState state; - TF_EXPECT_OK(RegisterDataset(dataset_id, &state)); - TF_EXPECT_OK(CreateAnonymousJob(job_id_1, dataset_id, &state)); - TF_EXPECT_OK(CreateAnonymousJob(job_id_2, dataset_id, &state)); + TF_EXPECT_OK(RegisterDataset(dataset_id, state)); + TF_EXPECT_OK(CreateAnonymousJob(job_id_1, dataset_id, state)); + TF_EXPECT_OK(CreateAnonymousJob(job_id_2, dataset_id, state)); TF_EXPECT_OK( - CreateTask(task_id_1, job_id_1, dataset_id, worker_address, &state)); + CreateTask(task_id_1, job_id_1, dataset_id, worker_address, state)); TF_EXPECT_OK( - CreateTask(task_id_2, job_id_2, dataset_id, worker_address, &state)); + CreateTask(task_id_2, job_id_2, dataset_id, worker_address, state)); { std::vector> tasks; - TF_EXPECT_OK(state.TasksForJob(job_id_1, &tasks)); + TF_EXPECT_OK(state.TasksForJob(job_id_1, tasks)); EXPECT_THAT(tasks, SizeIs(1)); } { std::vector> tasks; - TF_EXPECT_OK(state.TasksForJob(job_id_2, &tasks)); + TF_EXPECT_OK(state.TasksForJob(job_id_2, tasks)); EXPECT_THAT(tasks, SizeIs(1)); } } @@ -325,12 +325,12 @@ TEST(DispatcherState, CreateTasksForSameWorker) { int64 task_id_2 = 9; std::string worker_address = "test_worker_address"; DispatcherState state; - TF_EXPECT_OK(RegisterDataset(dataset_id, &state)); - TF_EXPECT_OK(CreateAnonymousJob(job_id, dataset_id, &state)); + TF_EXPECT_OK(RegisterDataset(dataset_id, state)); + TF_EXPECT_OK(CreateAnonymousJob(job_id, dataset_id, state)); TF_EXPECT_OK( - CreateTask(task_id_1, job_id, dataset_id, worker_address, &state)); + CreateTask(task_id_1, job_id, dataset_id, worker_address, state)); TF_EXPECT_OK( - CreateTask(task_id_2, job_id, dataset_id, worker_address, &state)); + CreateTask(task_id_2, job_id, dataset_id, worker_address, state)); { std::vector> tasks; TF_EXPECT_OK(state.TasksForWorker(worker_address, tasks)); @@ -346,12 +346,12 @@ TEST(DispatcherState, CreateTasksForDifferentWorkers) { std::string worker_address_1 = "test_worker_address_1"; std::string worker_address_2 = "test_worker_address_2"; DispatcherState state; - TF_EXPECT_OK(RegisterDataset(dataset_id, &state)); - TF_EXPECT_OK(CreateAnonymousJob(job_id, dataset_id, &state)); + TF_EXPECT_OK(RegisterDataset(dataset_id, state)); + TF_EXPECT_OK(CreateAnonymousJob(job_id, dataset_id, state)); TF_EXPECT_OK( - CreateTask(task_id_1, job_id, dataset_id, worker_address_1, &state)); + CreateTask(task_id_1, job_id, dataset_id, worker_address_1, state)); TF_EXPECT_OK( - CreateTask(task_id_2, job_id, dataset_id, worker_address_2, &state)); + CreateTask(task_id_2, job_id, dataset_id, worker_address_2, state)); { std::vector> tasks; TF_EXPECT_OK(state.TasksForWorker(worker_address_1, tasks)); @@ -367,7 +367,7 @@ TEST(DispatcherState, CreateTasksForDifferentWorkers) { TEST(DispatcherState, GetTasksForWorkerEmpty) { std::string worker_address = "test_worker_address"; DispatcherState state; - TF_EXPECT_OK(RegisterWorker(worker_address, &state)); + TF_EXPECT_OK(RegisterWorker(worker_address, state)); { std::vector> tasks; TF_EXPECT_OK(state.TasksForWorker(worker_address, tasks)); @@ -381,15 +381,15 @@ TEST(DispatcherState, FinishTask) { int64 task_id = 4; std::string worker_address = "test_worker_address"; DispatcherState state; - TF_EXPECT_OK(RegisterDataset(dataset_id, &state)); - TF_EXPECT_OK(CreateAnonymousJob(job_id, dataset_id, &state)); - TF_EXPECT_OK(CreateTask(task_id, job_id, dataset_id, worker_address, &state)); - TF_EXPECT_OK(FinishTask(task_id, &state)); + TF_EXPECT_OK(RegisterDataset(dataset_id, state)); + TF_EXPECT_OK(CreateAnonymousJob(job_id, dataset_id, state)); + TF_EXPECT_OK(CreateTask(task_id, job_id, dataset_id, worker_address, state)); + TF_EXPECT_OK(FinishTask(task_id, state)); std::shared_ptr task; - TF_EXPECT_OK(state.TaskFromId(task_id, &task)); + TF_EXPECT_OK(state.TaskFromId(task_id, task)); EXPECT_TRUE(task->finished); std::shared_ptr job; - TF_EXPECT_OK(state.JobFromId(job_id, &job)); + TF_EXPECT_OK(state.JobFromId(job_id, job)); EXPECT_TRUE(job->finished); } @@ -400,24 +400,24 @@ TEST(DispatcherState, FinishMultiTaskJob) { int64 task_id_2 = 5; std::string worker_address = "test_worker_address"; DispatcherState state; - TF_EXPECT_OK(RegisterDataset(dataset_id, &state)); - TF_EXPECT_OK(CreateAnonymousJob(job_id, dataset_id, &state)); + TF_EXPECT_OK(RegisterDataset(dataset_id, state)); + TF_EXPECT_OK(CreateAnonymousJob(job_id, dataset_id, state)); TF_EXPECT_OK( - CreateTask(task_id_1, job_id, dataset_id, worker_address, &state)); + CreateTask(task_id_1, job_id, dataset_id, worker_address, state)); TF_EXPECT_OK( - CreateTask(task_id_2, job_id, dataset_id, worker_address, &state)); + CreateTask(task_id_2, job_id, dataset_id, worker_address, state)); - TF_EXPECT_OK(FinishTask(task_id_1, &state)); + TF_EXPECT_OK(FinishTask(task_id_1, state)); { std::shared_ptr job; - TF_EXPECT_OK(state.JobFromId(job_id, &job)); + TF_EXPECT_OK(state.JobFromId(job_id, job)); EXPECT_FALSE(job->finished); } - TF_EXPECT_OK(FinishTask(task_id_2, &state)); + TF_EXPECT_OK(FinishTask(task_id_2, state)); { std::shared_ptr job; - TF_EXPECT_OK(state.JobFromId(job_id, &job)); + TF_EXPECT_OK(state.JobFromId(job_id, job)); EXPECT_TRUE(job->finished); } } @@ -428,14 +428,14 @@ TEST(DispatcherState, AcquireJobClientId) { int64 job_client_id_2 = 2; int64 dataset_id = 10; DispatcherState state; - TF_EXPECT_OK(RegisterDataset(dataset_id, &state)); - TF_EXPECT_OK(CreateAnonymousJob(job_id, dataset_id, &state)); - TF_EXPECT_OK(AcquireJobClientId(job_id, job_client_id_1, &state)); + TF_EXPECT_OK(RegisterDataset(dataset_id, state)); + TF_EXPECT_OK(CreateAnonymousJob(job_id, dataset_id, state)); + TF_EXPECT_OK(AcquireJobClientId(job_id, job_client_id_1, state)); { std::shared_ptr job; - TF_EXPECT_OK(state.JobFromId(job_id, &job)); + TF_EXPECT_OK(state.JobFromId(job_id, job)); EXPECT_EQ(job->num_clients, 1); - TF_EXPECT_OK(AcquireJobClientId(job_id, job_client_id_2, &state)); + TF_EXPECT_OK(AcquireJobClientId(job_id, job_client_id_2, state)); EXPECT_EQ(job->num_clients, 2); } { @@ -456,12 +456,12 @@ TEST(DispatcherState, ReleaseJobClientId) { int64 job_client_id = 6; int64 release_time = 100; DispatcherState state; - TF_EXPECT_OK(RegisterDataset(dataset_id, &state)); - TF_EXPECT_OK(CreateAnonymousJob(job_id, dataset_id, &state)); - TF_EXPECT_OK(AcquireJobClientId(job_id, job_client_id, &state)); - TF_EXPECT_OK(ReleaseJobClientId(job_client_id, release_time, &state)); + TF_EXPECT_OK(RegisterDataset(dataset_id, state)); + TF_EXPECT_OK(CreateAnonymousJob(job_id, dataset_id, state)); + TF_EXPECT_OK(AcquireJobClientId(job_id, job_client_id, state)); + TF_EXPECT_OK(ReleaseJobClientId(job_client_id, release_time, state)); std::shared_ptr job; - TF_EXPECT_OK(state.JobFromId(job_id, &job)); + TF_EXPECT_OK(state.JobFromId(job_id, job)); EXPECT_EQ(job->num_clients, 0); Status s = state.JobForJobClientId(job_client_id, job); EXPECT_EQ(s.code(), error::NOT_FOUND); diff --git a/tensorflow/core/data/service/grpc_dispatcher_impl.cc b/tensorflow/core/data/service/grpc_dispatcher_impl.cc index a7a30798a93..fbfc5d20665 100644 --- a/tensorflow/core/data/service/grpc_dispatcher_impl.cc +++ b/tensorflow/core/data/service/grpc_dispatcher_impl.cc @@ -26,9 +26,9 @@ using ::grpc::ServerBuilder; using ::grpc::ServerContext; GrpcDispatcherImpl::GrpcDispatcherImpl( - ServerBuilder* server_builder, const experimental::DispatcherConfig& config) + const experimental::DispatcherConfig& config, ServerBuilder& server_builder) : impl_(config) { - server_builder->RegisterService(this); + server_builder.RegisterService(this); VLOG(1) << "Registered data service dispatcher"; } diff --git a/tensorflow/core/data/service/grpc_dispatcher_impl.h b/tensorflow/core/data/service/grpc_dispatcher_impl.h index 81f1cbf6f02..171deed4792 100644 --- a/tensorflow/core/data/service/grpc_dispatcher_impl.h +++ b/tensorflow/core/data/service/grpc_dispatcher_impl.h @@ -25,18 +25,12 @@ namespace tensorflow { namespace data { // This class is a wrapper that handles communication for gRPC. -// -// Example usage: -// -// ::grpc::ServerBuilder builder; -// // configure builder -// GrpcDispatcherImpl data_service(&builder); -// builder.BuildAndStart() -// class GrpcDispatcherImpl : public DispatcherService::Service { public: - explicit GrpcDispatcherImpl(::grpc::ServerBuilder* server_builder, - const experimental::DispatcherConfig& config); + // Constructs a GrpcDispatcherImpl with the given config, and registers it + // with `server_builder`. + explicit GrpcDispatcherImpl(const experimental::DispatcherConfig& config, + ::grpc::ServerBuilder& server_builder); ~GrpcDispatcherImpl() override {} Status Start(); diff --git a/tensorflow/core/data/service/grpc_worker_impl.cc b/tensorflow/core/data/service/grpc_worker_impl.cc index b3a37fe0eec..ef386be4640 100644 --- a/tensorflow/core/data/service/grpc_worker_impl.cc +++ b/tensorflow/core/data/service/grpc_worker_impl.cc @@ -24,10 +24,10 @@ namespace data { using ::grpc::ServerBuilder; using ::grpc::ServerContext; -GrpcWorkerImpl::GrpcWorkerImpl(ServerBuilder* server_builder, - const experimental::WorkerConfig& config) +GrpcWorkerImpl::GrpcWorkerImpl(const experimental::WorkerConfig& config, + ServerBuilder& server_builder) : impl_(config) { - server_builder->RegisterService(this); + server_builder.RegisterService(this); VLOG(1) << "Registered data service worker"; } diff --git a/tensorflow/core/data/service/grpc_worker_impl.h b/tensorflow/core/data/service/grpc_worker_impl.h index c42e5639385..3d30af9a806 100644 --- a/tensorflow/core/data/service/grpc_worker_impl.h +++ b/tensorflow/core/data/service/grpc_worker_impl.h @@ -25,18 +25,12 @@ namespace tensorflow { namespace data { // This class is a wrapper that handles communication for gRPC. -// -// Example usage: -// -// ::grpc::ServerBuilder builder; -// // configure builder -// GrpcWorkerImpl data_service(&builder); -// builder.BuildAndStart() -// class GrpcWorkerImpl : public WorkerService::Service { public: - explicit GrpcWorkerImpl(::grpc::ServerBuilder* server_builder, - const experimental::WorkerConfig& config); + // Constructs a GrpcWorkerImpl with the given config, and registers it with + // `server_builder`. + explicit GrpcWorkerImpl(const experimental::WorkerConfig& config, + ::grpc::ServerBuilder& server_builder); ~GrpcWorkerImpl() override {} Status Start(const std::string& worker_address); diff --git a/tensorflow/core/data/service/journal.cc b/tensorflow/core/data/service/journal.cc index b0ce0876c69..979fc78b7c0 100644 --- a/tensorflow/core/data/service/journal.cc +++ b/tensorflow/core/data/service/journal.cc @@ -96,7 +96,7 @@ Status FileJournalReader::EnsureInitialized() { return UpdateFile(DataServiceJournalFile(journal_dir_, 0)); } -Status FileJournalReader::Read(Update* update, bool* end_of_journal) { +Status FileJournalReader::Read(Update& update, bool& end_of_journal) { TF_RETURN_IF_ERROR(EnsureInitialized()); while (true) { tstring record; @@ -108,20 +108,20 @@ Status FileJournalReader::Read(Update* update, bool* end_of_journal) { if (errors::IsNotFound(env_->FileExists(next_journal_file))) { VLOG(3) << "Next journal file " << next_journal_file << " does not exist. End of journal reached."; - *end_of_journal = true; + end_of_journal = true; return Status::OK(); } TF_RETURN_IF_ERROR(UpdateFile(next_journal_file)); continue; } TF_RETURN_IF_ERROR(s); - if (!update->ParseFromString(record)) { + if (!update.ParseFromString(record)) { return errors::DataLoss("Failed to parse journal record."); } if (VLOG_IS_ON(4)) { - VLOG(4) << "Read journal entry: " << update->DebugString(); + VLOG(4) << "Read journal entry: " << update.DebugString(); } - *end_of_journal = false; + end_of_journal = false; return Status::OK(); } } diff --git a/tensorflow/core/data/service/journal.h b/tensorflow/core/data/service/journal.h index 3483497705e..e31830e8c35 100644 --- a/tensorflow/core/data/service/journal.h +++ b/tensorflow/core/data/service/journal.h @@ -77,9 +77,9 @@ class FileJournalWriter : public JournalWriter { class JournalReader { public: virtual ~JournalReader() = default; - // Reads the next update from the journal. Sets `*end_of_journal=true` if + // Reads the next update from the journal. Sets `end_of_journal=true` if // there are no more updates left in the journal. - virtual Status Read(Update* update, bool* end_of_journal) = 0; + virtual Status Read(Update& update, bool& end_of_journal) = 0; }; // JournalReader is not thread-safe, requiring external synchronization when @@ -93,7 +93,7 @@ class FileJournalReader : public JournalReader { FileJournalReader(const FileJournalReader&) = delete; FileJournalReader& operator=(const FileJournalReader&) = delete; - Status Read(Update* update, bool* end_of_journal) override; + Status Read(Update& update, bool& end_of_journal) override; private: // Initializes the reader if it is not yet initialized. diff --git a/tensorflow/core/data/service/journal_test.cc b/tensorflow/core/data/service/journal_test.cc index 313b216fe76..3f55447cc68 100644 --- a/tensorflow/core/data/service/journal_test.cc +++ b/tensorflow/core/data/service/journal_test.cc @@ -28,12 +28,12 @@ namespace data { namespace { using ::testing::HasSubstr; -bool NewJournalDir(std::string* journal_dir) { +bool NewJournalDir(std::string& journal_dir) { std::string filename = testing::TmpDir(); if (!Env::Default()->CreateUniqueFileName(&filename, "journal_dir")) { return false; } - *journal_dir = filename; + journal_dir = filename; return true; } @@ -67,7 +67,7 @@ Status CheckJournalContent(StringPiece journal_dir, for (const auto& update : expected) { Update result; bool end_of_journal = true; - TF_RETURN_IF_ERROR(reader.Read(&result, &end_of_journal)); + TF_RETURN_IF_ERROR(reader.Read(result, end_of_journal)); EXPECT_FALSE(end_of_journal); // We can't use the testing::EqualsProto matcher because it is not available // in OSS. @@ -75,7 +75,7 @@ Status CheckJournalContent(StringPiece journal_dir, } Update result; bool end_of_journal = false; - TF_RETURN_IF_ERROR(reader.Read(&result, &end_of_journal)); + TF_RETURN_IF_ERROR(reader.Read(result, end_of_journal)); EXPECT_TRUE(end_of_journal); return Status::OK(); } @@ -83,7 +83,7 @@ Status CheckJournalContent(StringPiece journal_dir, TEST(Journal, RoundTripMultiple) { std::string journal_dir; - EXPECT_TRUE(NewJournalDir(&journal_dir)); + EXPECT_TRUE(NewJournalDir(journal_dir)); std::vector updates = {MakeCreateJobUpdate(), MakeRegisterDatasetUpdate(), MakeFinishTaskUpdate()}; @@ -97,7 +97,7 @@ TEST(Journal, RoundTripMultiple) { TEST(Journal, AppendExistingJournal) { std::string journal_dir; - EXPECT_TRUE(NewJournalDir(&journal_dir)); + EXPECT_TRUE(NewJournalDir(journal_dir)); std::vector updates = {MakeCreateJobUpdate(), MakeRegisterDatasetUpdate(), MakeFinishTaskUpdate()}; @@ -111,17 +111,17 @@ TEST(Journal, AppendExistingJournal) { TEST(Journal, MissingFile) { std::string journal_dir; - EXPECT_TRUE(NewJournalDir(&journal_dir)); + EXPECT_TRUE(NewJournalDir(journal_dir)); FileJournalReader reader(Env::Default(), journal_dir); Update result; bool end_of_journal = true; - Status s = reader.Read(&result, &end_of_journal); + Status s = reader.Read(result, end_of_journal); EXPECT_TRUE(errors::IsNotFound(s)); } TEST(Journal, NonRecordData) { std::string journal_dir; - EXPECT_TRUE(NewJournalDir(&journal_dir)); + EXPECT_TRUE(NewJournalDir(journal_dir)); TF_ASSERT_OK(Env::Default()->RecursivelyCreateDir(journal_dir)); { @@ -134,14 +134,14 @@ TEST(Journal, NonRecordData) { FileJournalReader reader(Env::Default(), journal_dir); Update result; bool end_of_journal = true; - Status s = reader.Read(&result, &end_of_journal); + Status s = reader.Read(result, end_of_journal); EXPECT_THAT(s.error_message(), HasSubstr("corrupted record")); EXPECT_EQ(s.code(), error::DATA_LOSS); } TEST(Journal, InvalidRecordData) { std::string journal_dir; - EXPECT_TRUE(NewJournalDir(&journal_dir)); + EXPECT_TRUE(NewJournalDir(journal_dir)); TF_ASSERT_OK(Env::Default()->RecursivelyCreateDir(journal_dir)); { @@ -155,7 +155,7 @@ TEST(Journal, InvalidRecordData) { FileJournalReader reader(Env::Default(), journal_dir); Update result; bool end_of_journal = true; - Status s = reader.Read(&result, &end_of_journal); + Status s = reader.Read(result, end_of_journal); EXPECT_THAT(s.error_message(), HasSubstr("Failed to parse journal record")); EXPECT_EQ(s.code(), error::DATA_LOSS); } diff --git a/tensorflow/core/data/service/local_credentials_factory.cc b/tensorflow/core/data/service/local_credentials_factory.cc index 136bf49df9b..b9426e77a7d 100644 --- a/tensorflow/core/data/service/local_credentials_factory.cc +++ b/tensorflow/core/data/service/local_credentials_factory.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "absl/memory/memory.h" #include "tensorflow/core/data/service/credentials_factory.h" namespace tensorflow { @@ -23,14 +24,14 @@ class LocalCredentialsFactory : public CredentialsFactory { std::string Protocol() override { return "grpc+local"; } Status CreateServerCredentials( - std::shared_ptr<::grpc::ServerCredentials>* out) override { - *out = grpc::experimental::LocalServerCredentials(LOCAL_TCP); + std::shared_ptr<::grpc::ServerCredentials>& out) override { + out = grpc::experimental::LocalServerCredentials(LOCAL_TCP); return Status::OK(); } Status CreateClientCredentials( - std::shared_ptr<::grpc::ChannelCredentials>* out) override { - *out = grpc::experimental::LocalCredentials(LOCAL_TCP); + std::shared_ptr<::grpc::ChannelCredentials>& out) override { + out = grpc::experimental::LocalCredentials(LOCAL_TCP); return Status::OK(); } }; @@ -38,8 +39,7 @@ class LocalCredentialsFactory : public CredentialsFactory { class LocalCredentialsRegistrar { public: LocalCredentialsRegistrar() { - auto factory = new LocalCredentialsFactory(); - CredentialsFactory::Register(factory); + CredentialsFactory::Register(absl::make_unique()); } }; static LocalCredentialsRegistrar registrar; diff --git a/tensorflow/core/data/service/server_lib.cc b/tensorflow/core/data/service/server_lib.cc index 4ee186cd9ec..83a6e67c584 100644 --- a/tensorflow/core/data/service/server_lib.cc +++ b/tensorflow/core/data/service/server_lib.cc @@ -46,13 +46,13 @@ Status GrpcDataServerBase::Start() { ::grpc::ServerBuilder builder; std::shared_ptr<::grpc::ServerCredentials> credentials; TF_RETURN_IF_ERROR( - CredentialsFactory::CreateServerCredentials(protocol_, &credentials)); + CredentialsFactory::CreateServerCredentials(protocol_, credentials)); builder.AddListeningPort(strings::StrCat("0.0.0.0:", requested_port_), credentials, &bound_port_); builder.SetMaxReceiveMessageSize(-1); - AddDataServiceToBuilder(&builder); - AddProfilerServiceToBuilder(&builder); + AddDataServiceToBuilder(builder); + AddProfilerServiceToBuilder(builder); server_ = builder.BuildAndStart(); if (!server_) { return errors::Internal("Could not start gRPC server"); @@ -81,9 +81,9 @@ void GrpcDataServerBase::Join() { server_->Wait(); } int GrpcDataServerBase::BoundPort() { return bound_port(); } void GrpcDataServerBase::AddProfilerServiceToBuilder( - ::grpc::ServerBuilder* builder) { - profiler_service_ = CreateProfilerService(); - builder->RegisterService(profiler_service_.get()); + ::grpc::ServerBuilder& builder) { + profiler_service_ = profiler::CreateProfilerService(); + builder.RegisterService(profiler_service_.get()); } DispatchGrpcDataServer::DispatchGrpcDataServer( @@ -94,8 +94,8 @@ DispatchGrpcDataServer::DispatchGrpcDataServer( DispatchGrpcDataServer::~DispatchGrpcDataServer() { delete service_; } void DispatchGrpcDataServer::AddDataServiceToBuilder( - ::grpc::ServerBuilder* builder) { - service_ = absl::make_unique(builder, config_).release(); + ::grpc::ServerBuilder& builder) { + service_ = absl::make_unique(config_, builder).release(); } Status DispatchGrpcDataServer::StartServiceInternal() { @@ -122,8 +122,8 @@ WorkerGrpcDataServer::WorkerGrpcDataServer( WorkerGrpcDataServer::~WorkerGrpcDataServer() { delete service_; } void WorkerGrpcDataServer::AddDataServiceToBuilder( - ::grpc::ServerBuilder* builder) { - service_ = absl::make_unique(builder, config_).release(); + ::grpc::ServerBuilder& builder) { + service_ = absl::make_unique(config_, builder).release(); } Status WorkerGrpcDataServer::StartServiceInternal() { @@ -139,14 +139,14 @@ Status WorkerGrpcDataServer::StartServiceInternal() { } Status NewDispatchServer(const experimental::DispatcherConfig& config, - std::unique_ptr* out_server) { - *out_server = absl::make_unique(config); + std::unique_ptr& out_server) { + out_server = absl::make_unique(config); return Status::OK(); } Status NewWorkerServer(const experimental::WorkerConfig& config, - std::unique_ptr* out_server) { - *out_server = absl::make_unique(config); + std::unique_ptr& out_server) { + out_server = absl::make_unique(config); return Status::OK(); } diff --git a/tensorflow/core/data/service/server_lib.h b/tensorflow/core/data/service/server_lib.h index 0ddc80676c3..c45ec144652 100644 --- a/tensorflow/core/data/service/server_lib.h +++ b/tensorflow/core/data/service/server_lib.h @@ -53,8 +53,8 @@ class GrpcDataServerBase { int BoundPort(); protected: - virtual void AddDataServiceToBuilder(::grpc::ServerBuilder* builder) = 0; - void AddProfilerServiceToBuilder(::grpc::ServerBuilder* builder); + virtual void AddDataServiceToBuilder(::grpc::ServerBuilder& builder) = 0; + void AddProfilerServiceToBuilder(::grpc::ServerBuilder& builder); // Starts the service. This will be called after building the service, so // bound_port() will return the actual bound port. virtual Status StartServiceInternal() = 0; @@ -84,7 +84,7 @@ class DispatchGrpcDataServer : public GrpcDataServerBase { Status NumWorkers(int* num_workers); protected: - void AddDataServiceToBuilder(::grpc::ServerBuilder* builder) override; + void AddDataServiceToBuilder(::grpc::ServerBuilder& builder) override; Status StartServiceInternal() override; private: @@ -99,7 +99,7 @@ class WorkerGrpcDataServer : public GrpcDataServerBase { ~WorkerGrpcDataServer() override; protected: - void AddDataServiceToBuilder(::grpc::ServerBuilder* builder) override; + void AddDataServiceToBuilder(::grpc::ServerBuilder& builder) override; Status StartServiceInternal() override; private: @@ -108,13 +108,13 @@ class WorkerGrpcDataServer : public GrpcDataServerBase { GrpcWorkerImpl* service_; }; -// Creates a dispatch tf.data server and stores it in `*out_server`. +// Creates a dispatch tf.data server and stores it in `out_server`. Status NewDispatchServer(const experimental::DispatcherConfig& config, - std::unique_ptr* out_server); + std::unique_ptr& out_server); -// Creates a worker tf.data server and stores it in `*out_server`. +// Creates a worker tf.data server and stores it in `out_server`. Status NewWorkerServer(const experimental::WorkerConfig& config, - std::unique_ptr* out_server); + std::unique_ptr& out_server); } // namespace data } // namespace tensorflow diff --git a/tensorflow/core/data/service/test_cluster.cc b/tensorflow/core/data/service/test_cluster.cc index 8ae3f191407..49f7eaef30d 100644 --- a/tensorflow/core/data/service/test_cluster.cc +++ b/tensorflow/core/data/service/test_cluster.cc @@ -49,7 +49,7 @@ Status TestCluster::Initialize() { experimental::DispatcherConfig config; config.set_port(0); config.set_protocol(kProtocol); - TF_RETURN_IF_ERROR(NewDispatchServer(config, &dispatcher_)); + TF_RETURN_IF_ERROR(NewDispatchServer(config, dispatcher_)); TF_RETURN_IF_ERROR(dispatcher_->Start()); dispatcher_address_ = absl::StrCat("localhost:", dispatcher_->BoundPort()); workers_.reserve(num_workers_); @@ -67,7 +67,7 @@ Status TestCluster::AddWorker() { config.set_protocol(kProtocol); config.set_dispatcher_address(dispatcher_address_); config.set_worker_address("localhost:%port%"); - TF_RETURN_IF_ERROR(NewWorkerServer(config, &worker)); + TF_RETURN_IF_ERROR(NewWorkerServer(config, worker)); TF_RETURN_IF_ERROR(worker->Start()); worker_addresses_.push_back(absl::StrCat("localhost:", worker->BoundPort())); workers_.push_back(std::move(worker)); diff --git a/tensorflow/core/distributed_runtime/collective_rma_distributed.cc b/tensorflow/core/distributed_runtime/collective_rma_distributed.cc index 4215b163991..203c63cf2f9 100644 --- a/tensorflow/core/distributed_runtime/collective_rma_distributed.cc +++ b/tensorflow/core/distributed_runtime/collective_rma_distributed.cc @@ -155,9 +155,6 @@ void CollectiveRemoteAccessDistributed::RecvFromPeer( PopulateTensorFromExtra(extra, to_tensor); } } - if (!s.ok() && errors::IsFailedPrecondition(s)) { - dev_resolver_->ClearTask(peer_task); - } delete state; done(s); diff --git a/tensorflow/core/distributed_runtime/device_resolver_distributed.cc b/tensorflow/core/distributed_runtime/device_resolver_distributed.cc index ab0b3a60600..8b459c2613e 100644 --- a/tensorflow/core/distributed_runtime/device_resolver_distributed.cc +++ b/tensorflow/core/distributed_runtime/device_resolver_distributed.cc @@ -129,25 +129,4 @@ Status DeviceResolverDistributed::GetTaskCached( return Status::OK(); } -void DeviceResolverDistributed::ClearTask(const string& task) { - mutex_lock l(mu_); - // First find all the keys belonging to the task. - std::unordered_set task_keys; - for (const auto& it : attr_table_) { - const string& device_name = it.first; - if (DeviceNameUtils::IsSameAddressSpace(task, device_name)) { - task_keys.insert(device_name); - } - } - // Then delete them. - for (const string& key : task_keys) { - attr_table_.erase(key); - } -} - -void DeviceResolverDistributed::ClearCache() { - mutex_lock l(mu_); - attr_table_.clear(); -} - } // namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/device_resolver_distributed.h b/tensorflow/core/distributed_runtime/device_resolver_distributed.h index d400fb5750e..a041557acff 100644 --- a/tensorflow/core/distributed_runtime/device_resolver_distributed.h +++ b/tensorflow/core/distributed_runtime/device_resolver_distributed.h @@ -46,10 +46,6 @@ class DeviceResolverDistributed : public DeviceResolverInterface { Status GetTaskCached(const string& task, std::vector* attributes) override; - void ClearTask(const string& task) override; - - void ClearCache() override; - protected: // Loads attr_table_ with device attributes retrieved from remote task. void RefreshRemoteAttributes(const string& device, const string& task, diff --git a/tensorflow/core/distributed_runtime/device_resolver_distributed_test.cc b/tensorflow/core/distributed_runtime/device_resolver_distributed_test.cc index 3d7523f945c..25f3665a9c8 100644 --- a/tensorflow/core/distributed_runtime/device_resolver_distributed_test.cc +++ b/tensorflow/core/distributed_runtime/device_resolver_distributed_test.cc @@ -178,55 +178,6 @@ class DeviceResDistTest : public ::testing::Test { wc_.AddWorker(worker_name, fw); } - void RestartWorker(const string& worker_name, const string& device_type, - int num_devices, uint64 device_incarnation_base) { - for (auto it : resolvers_) { - it.second->ClearCache(); - } - // `DefineWorker` creates a device resolver and a worker and adds them to - // resolvers_ and workers_. Recreating the worker would overwrite these map - // entries. We destroy the old device resolver here; all other objects are - // cleaned up in the destructor. - delete resolvers_[worker_name]; - DefineWorker(worker_name, device_type, num_devices, - device_incarnation_base); - } - - void ResolveIncarnationsAndValidate( - const int num_workers, const int num_devices, const string& worker_prefix, - const string& device_type, - const std::vector>& expected_incarnations) { - for (int w = 0; w < num_workers; ++w) { - const string worker_name = absl::StrCat(worker_prefix, w); - auto* device_resolver = resolvers_[worker_name]; - const string device_prefix = - absl::StrCat(worker_name, "/device:", device_type, ":"); - for (int peer_w = 0; peer_w < num_workers; ++peer_w) { - const string peer_worker_name = absl::StrCat(worker_prefix, peer_w); - for (int d = 0; d < num_devices; ++d) { - const string device_name = - absl::StrCat(peer_worker_name, "/device:", device_type, ":", d); - DeviceNameUtils::ParsedName parsed; - ASSERT_TRUE(DeviceNameUtils::ParseFullName(device_name, &parsed)); - // NOLINT prevents linter from suggesting absl::Notification as a - // replacement, which is not available in OSS. - Notification note; // NOLINT - Status status; - DeviceAttributes attributes; - device_resolver->GetDeviceAttributesAsync( - device_name, peer_worker_name, &attributes, - [¬e, &status](const Status& s) { - status = s; - note.Notify(); - }); - note.WaitForNotification(); - TF_EXPECT_OK(status); - EXPECT_EQ(attributes.incarnation(), expected_incarnations[peer_w][d]); - } - } - } - } - FakeCache wc_; std::vector device_mgrs_; std::unordered_map resolvers_; @@ -259,52 +210,6 @@ TEST_F(DeviceResDistTest, Workers3Devices4) { } } } - // Clear just task 0 from all. - const string w0_name = "/job:worker/replica:0/task:0"; - for (auto it : resolvers_) { - if (it.first == w0_name) continue; - TestableDeviceResolverDistributed* dres = it.second; - EXPECT_EQ(8, it.second->attr_table().size()); - dres->ClearTask("/job:worker/replica:0/task:0"); - EXPECT_EQ(4, it.second->attr_table().size()); - } -} - -TEST_F(DeviceResDistTest, DeviceIncarnationChangesOnFailure) { - constexpr int num_workers = 3; - constexpr int num_devices = 4; - constexpr int failing_worker_index = 1; - const string device_type = "CPU"; - constexpr uint64 device_incarnation_base = 100; - DefineWorkers(num_workers, num_devices, device_type, device_incarnation_base); - const string worker_prefix = "/job:worker/replica:0/task:"; - const string failing_worker = - absl::StrCat(worker_prefix, failing_worker_index); - - // Check device incarnations match expected. - std::vector> expected_incarnations(num_workers); - for (int w = 0; w < num_workers; ++w) { - expected_incarnations[w].resize(num_devices); - for (int d = 0; d < num_devices; ++d) { - expected_incarnations[w][d] = - w * num_devices + d + device_incarnation_base; - } - } - ResolveIncarnationsAndValidate(num_workers, num_devices, worker_prefix, - device_type, expected_incarnations); - - // Restart worker `failing_worker`. - constexpr uint64 restart_incarnation_base = 200; - RestartWorker(failing_worker, device_type, num_devices, - restart_incarnation_base); - for (int d = 0; d < num_devices; ++d) { - expected_incarnations[failing_worker_index][d] = - d + restart_incarnation_base; - } - - // Check incarnations have changed for `failing worker`. - ResolveIncarnationsAndValidate(num_workers, num_devices, worker_prefix, - device_type, expected_incarnations); } } // namespace diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc b/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc index 6e706179863..d529abef36c 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc @@ -35,11 +35,10 @@ limitations under the License. #include "tensorflow/core/platform/tracing.h" #include "tensorflow/core/protobuf/transport_options.pb.h" #include "tensorflow/core/protobuf/worker.pb.h" +#include "tensorflow/core/util/env_var.h" namespace tensorflow { -const int kMaxWorkerRpcRetries = 10; - class GrpcRemoteWorker : public WorkerInterface { public: explicit GrpcRemoteWorker(SharedGrpcChannelPtr channel, @@ -274,7 +273,7 @@ class GrpcRemoteWorker : public WorkerInterface { bool fail_fast = true) { new RPCState( &stub_, cq_, method, *request, response, std::move(done), call_opts, - callback_threadpool_, /*max_retries=*/0, fail_fast, &target_); + callback_threadpool_, MaxRetries(), fail_fast, &target_); } void IssueRequest(const protobuf::Message* request, TensorResponse* response, @@ -282,7 +281,7 @@ class GrpcRemoteWorker : public WorkerInterface { CallOptions* call_opts = nullptr) { new RPCState(&stub_, cq_, method, *request, response, std::move(done), call_opts, - callback_threadpool_, /*max_retries=*/0, + callback_threadpool_, MaxRetries(), /*fail_fast=*/true, &target_); } @@ -299,6 +298,14 @@ class GrpcRemoteWorker : public WorkerInterface { // Helper function for initializing the RpcMethod objects below. const char* Method(GrpcWorkerMethod id) { return GrpcWorkerMethodName(id); } + // Helper function for configuring max GRPC retries. Defaults to 0 (no + // retries). + const int64 MaxRetries() { + int64 max_retries = -1; + TF_CHECK_OK(ReadInt64FromEnvVar("GRPC_MAX_RETRIES", 0, &max_retries)); + return max_retries; + } + SharedGrpcChannelPtr channel_; ::grpc::GenericStub stub_; ::grpc::CompletionQueue* cq_; diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc index 83e072559e9..71be10f69e5 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc @@ -249,7 +249,7 @@ Status GrpcServer::Init(const GrpcServerOptions& opts) { .release(); eager_service_ = new eager::GrpcEagerServiceImpl(&worker_env_, &builder); - profiler_service_ = CreateProfilerService(); + profiler_service_ = profiler::CreateProfilerService(); builder.RegisterService(profiler_service_.get()); // extra service: diff --git a/tensorflow/core/framework/collective.h b/tensorflow/core/framework/collective.h index d0c53231403..72e0b3d9224 100644 --- a/tensorflow/core/framework/collective.h +++ b/tensorflow/core/framework/collective.h @@ -170,12 +170,6 @@ class DeviceResolverInterface { // Returns the cached device attributes of a task. virtual Status GetTaskCached(const string& task, std::vector* attributes) = 0; - - // Clears the cache of device data belonging to the specified task. - virtual void ClearTask(const string& task) = 0; - - // Clears the cache of all device data. - virtual void ClearCache() = 0; }; // Interface that provides resolution of shared CollectiveParams fields. diff --git a/tensorflow/core/framework/function.cc b/tensorflow/core/framework/function.cc index ebf06c7d0cd..564290bcb21 100644 --- a/tensorflow/core/framework/function.cc +++ b/tensorflow/core/framework/function.cc @@ -1454,6 +1454,12 @@ Status FunctionLibraryDefinition::RemoveFunctionHelper(const string& func) { return Status::OK(); } +void FunctionLibraryDefinition::Clear() { + mutex_lock l(mu_); + function_defs_.clear(); + func_grad_.clear(); +} + Status FunctionLibraryDefinition::RemoveGradient(const string& func) { const auto& i = func_grad_.find(func); if (i == func_grad_.end()) { diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h index 3c7c09eee37..3c048161b7d 100644 --- a/tensorflow/core/framework/function.h +++ b/tensorflow/core/framework/function.h @@ -403,6 +403,9 @@ class FunctionLibraryDefinition : public OpRegistryInterface { // are no longer in use. Status RemoveFunction(const std::string& func) TF_LOCKS_EXCLUDED(mu_); + // Removes all the functions and gradient functions. + void Clear() TF_LOCKS_EXCLUDED(mu_); + // Adds the functions and gradients in 'other' to this function library. // Duplicate functions and gradients are ignored. // This operation is atomic. diff --git a/tensorflow/core/framework/function_test.cc b/tensorflow/core/framework/function_test.cc index a62acfe571e..38ab8be291d 100644 --- a/tensorflow/core/framework/function_test.cc +++ b/tensorflow/core/framework/function_test.cc @@ -1068,6 +1068,16 @@ TEST(FunctionLibraryDefinitionTest, RemoveFunction) { EXPECT_FALSE(lib_def.Contains("XTimesTwo")); } +TEST(FunctionLibraryDefinitionTest, Clear) { + FunctionLibraryDefinition lib_def(OpRegistry::Global(), {}); + TF_CHECK_OK(lib_def.AddFunctionDef(test::function::XTimesTwo())); + TF_CHECK_OK(lib_def.AddFunctionDef(test::function::XAddX())); + + lib_def.Clear(); + EXPECT_FALSE(lib_def.Contains("XTimesTwo")); + EXPECT_FALSE(lib_def.Contains("XAddX")); +} + TEST(FunctionLibraryDefinitionTest, AddLibrary) { // Create lib def with single function FunctionDefLibrary proto; diff --git a/tensorflow/core/framework/model.h b/tensorflow/core/framework/model.h index bfa6e31209a..826c2ed5cee 100644 --- a/tensorflow/core/framework/model.h +++ b/tensorflow/core/framework/model.h @@ -77,7 +77,14 @@ struct Parameter { Parameter(const string& name, std::shared_ptr state, double min, double max) : name(name), - value(state->value), + // Sometimes non-autotune nodes (with `autotune_=false`) may contain + // parameters (for example inputs of parallel interleave dataset which + // are not in the current cycle). To avoid unrealistic situation + // (say `buffer_size=-1` or `parallelism=-1`) in the optimization + // computation, if the state value is `kAutotune=-1` (just to indicate + // the `SharedState` is tunable), we initialize the parameter value to + // be the minimal value of the state. + value(state->value == kAutotune ? min : state->value), min(min), max(max), state(std::move(state)) {} diff --git a/tensorflow/core/framework/tensor_shape.h b/tensorflow/core/framework/tensor_shape.h index dbe103088c1..70253f4e7c8 100644 --- a/tensorflow/core/framework/tensor_shape.h +++ b/tensorflow/core/framework/tensor_shape.h @@ -332,6 +332,11 @@ class TensorShape : public TensorShapeBase { friend class Tensor; }; +/// Outputs `TensorShapeBase` to `std::ostream`. +inline std::ostream& operator<<(std::ostream& os, const TensorShape& ts) { + return os << ts.DebugString(); +} + /// Represents the value of one dimension in a TensorShape. struct TensorShapeDim { explicit TensorShapeDim(int64 s) : size(s) {} diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc index f85683a39c2..26965a87708 100644 --- a/tensorflow/core/grappler/op_types.cc +++ b/tensorflow/core/grappler/op_types.cc @@ -334,6 +334,8 @@ bool IsImmutableConst(const NodeDef& node) { bool IsInvGrad(const NodeDef& node) { return node.op() == "InvGrad"; } +bool IsLeakyRelu(const NodeDef& node) { return node.op() == "LeakyRelu"; } + bool IsLess(const NodeDef& node) { return node.op() == "Less"; } bool IsLessEqual(const NodeDef& node) { return node.op() == "LessEqual"; } diff --git a/tensorflow/core/grappler/op_types.h b/tensorflow/core/grappler/op_types.h index b3d94e8274b..ef4f6aebe27 100644 --- a/tensorflow/core/grappler/op_types.h +++ b/tensorflow/core/grappler/op_types.h @@ -99,6 +99,7 @@ bool IsIgammac(const NodeDef& node); bool IsImag(const NodeDef& node); bool IsImmutableConst(const NodeDef& node); bool IsInvGrad(const NodeDef& node); +bool IsLeakyRelu(const NodeDef& node); bool IsLess(const NodeDef& node); bool IsLessEqual(const NodeDef& node); bool IsLog(const NodeDef& node); diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD index 9d2925e8452..d1870468ecb 100644 --- a/tensorflow/core/grappler/optimizers/BUILD +++ b/tensorflow/core/grappler/optimizers/BUILD @@ -880,6 +880,7 @@ tf_cuda_cc_test( deps = [ ":remapper", "//tensorflow/cc:cc_ops", + "//tensorflow/cc:cc_ops_internal", "//tensorflow/core:framework", "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", diff --git a/tensorflow/core/grappler/optimizers/data/disable_intra_op_parallelism.cc b/tensorflow/core/grappler/optimizers/data/disable_intra_op_parallelism.cc index ee8f9e84765..52f1ba59b32 100644 --- a/tensorflow/core/grappler/optimizers/data/disable_intra_op_parallelism.cc +++ b/tensorflow/core/grappler/optimizers/data/disable_intra_op_parallelism.cc @@ -89,8 +89,15 @@ Status DisableIntraOpParallelism::OptimizeAndCollectStats( // `max_intra_op_parallelism` input *insert_node.mutable_input()->Add() = max_parallelism_value->name(); - for (const auto& attr_name : {"output_types", "output_shapes"}) { - graph_utils::CopyAttribute(attr_name, *last_node, &insert_node); + // Set `output_types` and `output_shapes` attributes by copying the relevant + // attrs from the input node. If we fail to set the attributes, we abort the + // rewrite. + for (auto attr : {"output_shapes", "output_types"}) { + if (last_node->attr().find(attr) != last_node->attr().end()) { + graph_utils::CopyAttribute(attr, *last_node, &insert_node); + } else { + return Status::OK(); + } } auto* added_node = graph.AddNode(std::move(insert_node)); diff --git a/tensorflow/core/grappler/optimizers/data/slack.cc b/tensorflow/core/grappler/optimizers/data/slack.cc index 27915e2d5d6..211b53ba083 100644 --- a/tensorflow/core/grappler/optimizers/data/slack.cc +++ b/tensorflow/core/grappler/optimizers/data/slack.cc @@ -101,10 +101,9 @@ Status Slack::RecursivelyHandleOp(const MutableGraphView& graph, return Status::OK(); } - return errors::InvalidArgument( - "Encountered unsupported op \"", dataset_node->op(), - "\" when rewriting the input pipeline graph to use slack in its " - "final prefetch transformation."); + LOG(WARNING) << "Could not find a final `prefetch` in the input pipeline to " + "which to introduce slack."; + return Status::OK(); } Status Slack::OptimizeAndCollectStats(Cluster* cluster, diff --git a/tensorflow/core/grappler/optimizers/remapper.cc b/tensorflow/core/grappler/optimizers/remapper.cc index c7f2de3d274..306ae2a9485 100644 --- a/tensorflow/core/grappler/optimizers/remapper.cc +++ b/tensorflow/core/grappler/optimizers/remapper.cc @@ -364,7 +364,8 @@ bool IsSupportedActivation(const NodeDef& node) { #ifdef INTEL_MKL return IsRelu(node) || IsRelu6(node) || IsElu(node) || IsTanh(node); #else - return IsRelu(node) || IsRelu6(node) || IsElu(node); + // Disable LeakyRelu temporarily before MKL PR is merged. + return IsRelu(node) || IsRelu6(node) || IsElu(node) || IsLeakyRelu(node); #endif } @@ -462,6 +463,9 @@ bool FindContractionWithBiasAndActivation( // Currently, only matmul + bias + tanh is enable if (!IsMatMul(*contraction_node_def) && IsTanh(*node_def)) return false; + // Currently, only conv + bias + leakyrelu is enabled + if (!IsConv2D(*contraction_node_def) && IsLeakyRelu(*node_def)) return false; + // Check that data type and data format are supported on assigned device. const ContractionWithBiasAddAndActivation pattern{base.contraction, base.bias_add, node_index}; @@ -734,6 +738,16 @@ bool FindContractionWithBiasAndAddActivation( return false; } + // Get the contraction node + const auto* bias_add_node_view = + add_node_view->GetRegularFanin(base.port_id).node_view(); + const auto* contraction_node_view = + bias_add_node_view->GetRegularFanin(0).node_view(); + const auto* contraction_node_def = contraction_node_view->node(); + + // Currently, only conv + bias + add + leakyrelu is enabled + if (!IsConv2D(*contraction_node_def) && IsLeakyRelu(*node_def)) return false; + // We successfully found a Conv2D+BiasAdd+AddN+activation pattern. const ContractionWithBiasAndAddActivation pattern{ base.contraction, base.bias_add, base.add, base.port_id, node_index}; @@ -934,7 +948,8 @@ bool FindFusedBatchNormEx(const RemapperContext& ctx, int node_index, return false; } -void CopyConv2DAttributes(const NodeDef& conv2d, NodeDef* fused_conv2d) { +void CopyConv2DAttributes(const NodeDef& conv2d, NodeDef* fused_conv2d, + const NodeDef* activation = nullptr) { DCHECK(IsConv2D(conv2d)) << "Input node must be a Conv2D"; auto* attr = fused_conv2d->mutable_attr(); @@ -947,10 +962,16 @@ void CopyConv2DAttributes(const NodeDef& conv2d, NodeDef* fused_conv2d) { (*attr)["dilations"] = src_attr.at("dilations"); (*attr)["data_format"] = src_attr.at("data_format"); (*attr)["use_cudnn_on_gpu"] = src_attr.at("use_cudnn_on_gpu"); + // Copy LeakyRelu's attr alpha to FusedConv2D's attr leakyrelu_alpha + if (activation != nullptr && IsLeakyRelu(*activation)) { + auto& activation_attr = activation->attr(); + (*attr)["leakyrelu_alpha"] = activation_attr.at("alpha"); + } } void CopyDepthwiseConv2dNativeAttributes(const NodeDef& dw_conv2d, - NodeDef* fused_dw_conv2d) { + NodeDef* fused_dw_conv2d, + const NodeDef* activation = nullptr) { DCHECK(IsDepthwiseConv2dNative(dw_conv2d)) << "Input node must be a DepthwiseConv2dNative"; @@ -962,6 +983,11 @@ void CopyDepthwiseConv2dNativeAttributes(const NodeDef& dw_conv2d, (*attr)["padding"] = src_attr.at("padding"); (*attr)["dilations"] = src_attr.at("dilations"); (*attr)["data_format"] = src_attr.at("data_format"); + // Copy LeakyRelu's attr alpha to FusedDepthwiseConv2d's attr leakyrelu_alpha + if (activation != nullptr && IsLeakyRelu(*activation)) { + auto& activation_attr = activation->attr(); + (*attr)["leakyrelu_alpha"] = activation_attr.at("alpha"); + } } void CopyFusedBatchNormAttributes(const NodeDef& fused_batch_norm, @@ -1064,6 +1090,7 @@ Status AddFusedContractionNode( const NodeDef& contraction = graph->node(matched.contraction); const NodeDef& bias_add = graph->node(matched.bias_add); const NodeDef& activation = graph->node(matched.activation); + VLOG(2) << "Fuse " << contraction.op() << " with BiasAdd and " << activation.op() << ":" << " activation=" << activation.name() @@ -1079,7 +1106,8 @@ Status AddFusedContractionNode( if (IsConv2D(contraction)) { fused_op.set_op(kFusedConv2D); - CopyConv2DAttributes(contraction, &fused_op); + // leaky relu has a special attribute alpha + CopyConv2DAttributes(contraction, &fused_op, &activation); } else if (IsDepthwiseConv2dNative(contraction)) { fused_op.set_op(kFusedDepthwiseConv2dNative); CopyDepthwiseConv2dNativeAttributes(contraction, &fused_op); @@ -1217,7 +1245,7 @@ Status AddFusedConv2DNode(RemapperContext* ctx, fused_conv2d.add_input(fused_batch_norm.input(3)); // 4: mean fused_conv2d.add_input(fused_batch_norm.input(4)); // 5: variance - CopyConv2DAttributes(contraction, &fused_conv2d); + CopyConv2DAttributes(contraction, &fused_conv2d, &activation); SetFusedOpAttributes(&fused_conv2d, {"FusedBatchNorm", activation.op()}, /*num_args=*/4, /*epsilon=*/matched.epsilon); @@ -1299,7 +1327,7 @@ Status AddFusedContractionNode( fused_conv2d.add_input(add.input(1 - matched.port_id)); CopyConv2DAttributes(contraction, &fused_conv2d); - SetFusedOpAttributes(&fused_conv2d, {"BiasAdd", "Add", "Relu"}, 2); + SetFusedOpAttributes(&fused_conv2d, {"BiasAdd", "Add", activation.op()}, 2); utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder(); Status status; diff --git a/tensorflow/core/grappler/optimizers/remapper_test.cc b/tensorflow/core/grappler/optimizers/remapper_test.cc index f4bc5e38526..4cc133286d7 100644 --- a/tensorflow/core/grappler/optimizers/remapper_test.cc +++ b/tensorflow/core/grappler/optimizers/remapper_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/core/grappler/optimizers/remapper.h" +#include "tensorflow/cc/ops/nn_ops_internal.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/framework/types.h" @@ -541,7 +542,7 @@ TEST_F(RemapperTest, DISABLED_FuseConv2DWithBiasAndActivationOnGPU) { TEST_F(RemapperTest, FuseConv2DWithBiasAndActivation) { using ::tensorflow::ops::Placeholder; - for (const string& activation : {"Relu", "Relu6", "Elu"}) { + for (const string& activation : {"Relu", "Relu6", "Elu", "LeakyRelu"}) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); auto input_shape = Placeholder::Shape({8, 32, 32, 3}); @@ -567,6 +568,13 @@ TEST_F(RemapperTest, FuseConv2DWithBiasAndActivation) { return ops::Identity(fetch, ops::Relu6(activate, bias_add)); } else if (activation == "Elu") { return ops::Identity(fetch, ops::Elu(activate, bias_add)); + // Disable LeakyRelu temporarily before MKL PR is merged. +#ifndef INTEL_MKL + } else if (activation == "LeakyRelu") { + auto attr = ops::internal::LeakyRelu::Alpha(0.5); + return ops::Identity( + fetch, ops::internal::LeakyRelu(activate, bias_add, attr)); +#endif // !INTEL_MKL } return ops::Identity(fetch, bias); @@ -605,6 +613,12 @@ TEST_F(RemapperTest, FuseConv2DWithBiasAndActivation) { ASSERT_EQ(fused_ops.size(), 2); EXPECT_EQ(fused_ops[0], "BiasAdd"); EXPECT_EQ(fused_ops[1], activation); + +#ifndef INTEL_MKL + if (activation == "LeakyRelu") { + EXPECT_EQ(node.attr().at("leakyrelu_alpha").f(), 0.5); + } +#endif // !INTEL_MKL found++; } } @@ -795,7 +809,7 @@ TEST_F(RemapperTest, FuseConv2DWithBatchNorm) { TEST_F(RemapperTest, FuseConv2DWithBatchNormAndActivation) { using ops::Placeholder; - for (const string& activation : {"Relu", "Relu6", "Elu"}) { + for (const string& activation : {"Relu", "Relu6", "Elu", "LeakyRelu"}) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); auto input_shape = ops::Placeholder::Shape({8, 32, 32, 3}); @@ -828,6 +842,13 @@ TEST_F(RemapperTest, FuseConv2DWithBatchNormAndActivation) { return ops::Identity(fetch, ops::Relu6(activate, batch_norm.y)); } else if (activation == "Elu") { return ops::Identity(fetch, ops::Elu(activate, batch_norm.y)); + // Disable LeakyRelu temporarily before MKL PR is merged. +#ifndef INTEL_MKL + } else if (activation == "LeakyRelu") { + auto attr = ops::internal::LeakyRelu::Alpha(0.5); + return ops::Identity( + fetch, ops::internal::LeakyRelu(activate, batch_norm.y, attr)); +#endif // !INTEL_MKL } return ops::Identity(fetch, batch_norm.y); @@ -874,6 +895,12 @@ TEST_F(RemapperTest, FuseConv2DWithBatchNormAndActivation) { ASSERT_EQ(fused_ops.size(), 2); EXPECT_EQ(fused_ops[0], "FusedBatchNorm"); EXPECT_EQ(fused_ops[1], activation); + +#ifndef INTEL_MKL + if (activation == "LeakyRelu") { + EXPECT_EQ(node.attr().at("leakyrelu_alpha").f(), 0.5); + } +#endif // !INTEL_MKL found++; } } diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 0981bf8d65b..d3f2d474f91 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -1661,6 +1661,7 @@ tf_cuda_cc_test( ":ops_testutil", ":ops_util", "//tensorflow/cc:cc_ops", + "//tensorflow/cc:cc_ops_internal", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", "//tensorflow/core:framework_internal", @@ -1671,6 +1672,7 @@ tf_cuda_cc_test( "//tensorflow/core:test_main", "//tensorflow/core:testlib", "//tensorflow/core/kernels/image", + "//tensorflow/core/platform:tf32_utils", "@com_google_absl//absl/algorithm:container", ], ) @@ -5912,6 +5914,7 @@ filegroup( "avgpooling_op.h", "batch_matmul_op_impl.h", "batch_norm_op.h", + "bincount_op.h", "broadcast_to_op.h", "bucketize_op.h", "control_flow_ops.h", @@ -6129,6 +6132,7 @@ filegroup( ":android_extended_ops_headers", "base64_ops.cc", "batchtospace_op.cc", + "bincount_op.cc", "broadcast_to_op.cc", "bucketize_op.cc", "ctc_decoder_ops.cc", diff --git a/tensorflow/core/kernels/bincount_op_gpu.cu.cc b/tensorflow/core/kernels/bincount_op_gpu.cu.cc index b137413d5e3..94ba4d86adb 100644 --- a/tensorflow/core/kernels/bincount_op_gpu.cu.cc +++ b/tensorflow/core/kernels/bincount_op_gpu.cu.cc @@ -126,7 +126,6 @@ struct BincountFunctor { return GpuLaunchKernel(BincountReduceKernel, config.block_count, config.thread_per_block, 0, d.stream(), arr.data(), output.data(), nthreads, num_bins); - return Status::OK(); } }; @@ -215,14 +214,11 @@ struct BincountReduceFunctor { config.block_count, config.thread_per_block, smem_usage, d.stream(), in.data(), weights.data(), weights.size(), out.data(), num_rows, num_cols, num_bins); - } else { - return GpuLaunchKernel( - BincountColReduceKernel, config.block_count, - config.thread_per_block, 0, d.stream(), in.data(), weights.data(), - weights.size(), out.data(), num_rows, num_cols, num_bins); } - - return Status::OK(); + return GpuLaunchKernel( + BincountColReduceKernel, config.block_count, + config.thread_per_block, 0, d.stream(), in.data(), weights.data(), + weights.size(), out.data(), num_rows, num_cols, num_bins); } }; diff --git a/tensorflow/core/kernels/conv_ops_fused_impl.h b/tensorflow/core/kernels/conv_ops_fused_impl.h index f838d05decf..43aa6c6e7fb 100644 --- a/tensorflow/core/kernels/conv_ops_fused_impl.h +++ b/tensorflow/core/kernels/conv_ops_fused_impl.h @@ -106,6 +106,21 @@ class LaunchFusedConv2DWithOutputKernel { template void operator()(const OutputKernel& output_kernel, OpKernelContext* ctx, const Tensor& input, const Tensor& filter, Tensor* output) { + // Wrap output_kernel into type erased function to reduce the number of + // unique template instantiations for Eigen Tensor contraction expressions. + using OutputKernelFn = + std::function&, + const Eigen::TensorContractionParams&, Eigen::Index, + Eigen::Index, Eigen::Index, Eigen::Index)>; + + OutputKernelFn output_kernel_fn = + [&output_kernel]( + const ContractionOutputMapper& output_mapper, + const Eigen::TensorContractionParams& params, Eigen::Index i, + Eigen::Index j, Eigen::Index num_rows, Eigen::Index num_cols) { + output_kernel(output_mapper, params, i, j, num_rows, num_cols); + }; + if (filter.dim_size(0) == 1 && filter.dim_size(1) == 1 && row_stride_ == 1 && col_stride_ == 1 && padding_ != EXPLICIT) { int conv_width = 1; // Width for the convolution step. @@ -115,12 +130,12 @@ class LaunchFusedConv2DWithOutputKernel { Eigen::array, 1> dim_pair; dim_pair[0] = Eigen::IndexPair(1, 0); - functor::MatMulConvFunctor()( + functor::MatMulConvFunctor()( ctx->eigen_device(), output->shaped({conv_width, filter.dim_size(3)}), input.shaped({conv_width, filter.dim_size(2)}), filter.shaped({filter.dim_size(2), filter.dim_size(3)}), - dim_pair, output_kernel); + dim_pair, std::move(output_kernel_fn)); } else if (filter.dim_size(0) == input.dim_size(1) && filter.dim_size(1) == input.dim_size(2) && row_dilation_ == 1 && @@ -132,29 +147,30 @@ class LaunchFusedConv2DWithOutputKernel { Eigen::array, 1> dim_pair; dim_pair[0] = Eigen::IndexPair(1, 0); - functor::MatMulConvFunctor()( + functor::MatMulConvFunctor()( ctx->eigen_device(), output->shaped({input.dim_size(0), filter.dim_size(3)}), input.shaped({input.dim_size(0), k}), filter.shaped({k, filter.dim_size(3)}), dim_pair, - output_kernel); + std::move(output_kernel_fn)); } else { if (padding_ == EXPLICIT) { - functor::SpatialConvolution()( + functor::SpatialConvolution()( ctx->eigen_device(), output->tensor(), input.tensor(), filter.tensor(), row_stride_, col_stride_, row_dilation_, col_dilation_, static_cast(explicit_paddings_[2]), static_cast(explicit_paddings_[3]), static_cast(explicit_paddings_[4]), - static_cast(explicit_paddings_[5]), output_kernel); + static_cast(explicit_paddings_[5]), + std::move(output_kernel_fn)); } else { - functor::SpatialConvolution()( + functor::SpatialConvolution()( ctx->eigen_device(), output->tensor(), input.tensor(), filter.tensor(), row_stride_, col_stride_, row_dilation_, col_dilation_, - BrainPadding2EigenPadding(padding_), output_kernel); + BrainPadding2EigenPadding(padding_), std::move(output_kernel_fn)); } } } @@ -185,14 +201,26 @@ struct LaunchFusedConv2DOp { BiasAddArgs bias_add_args; if (BiasAddArgs::IsSupported(fusion)) { - OP_REQUIRES_OK(context, InitBiasAddArgs(context, &bias_add_args)); + if (fusion == FusedComputationType::kBiasAddWithLeakyRelu) { + OP_REQUIRES_OK(context, InitBiasAddArgs(context, &bias_add_args, + &fusion_args.leakyrelu_alpha)); + } else { + OP_REQUIRES_OK(context, InitBiasAddArgs(context, &bias_add_args)); + } } FusedBatchNormArgs fused_batch_norm_args; if (FusedBatchNormArgs::IsSupported(fusion)) { - OP_REQUIRES_OK(context, - InitFusedBatchNormArgs(context, fusion_args.epsilon, - &fused_batch_norm_args)); + if (fusion == FusedComputationType::kFusedBatchNormWithLeakyRelu) { + OP_REQUIRES_OK(context, + InitFusedBatchNormArgs(context, fusion_args.epsilon, + &fused_batch_norm_args, + &fusion_args.leakyrelu_alpha)); + } else { + OP_REQUIRES_OK(context, + InitFusedBatchNormArgs(context, fusion_args.epsilon, + &fused_batch_norm_args)); + } } LaunchFusedConv2DWithOutputKernel conv2d( @@ -215,6 +243,10 @@ struct LaunchFusedConv2DOp { conv2d(WithBiasAddAndRelu6(bias_add_args), context, input, filter, output); break; + case FusedComputationType::kBiasAddWithLeakyRelu: + conv2d(WithBiasAddAndLeakyRelu(bias_add_args), context, input, + filter, output); + break; case FusedComputationType::kBiasAddWithElu: conv2d(WithBiasAddAndElu(bias_add_args), context, input, filter, output); @@ -234,6 +266,11 @@ struct LaunchFusedConv2DOp { fused_batch_norm_args), context, input, filter, output); break; + case FusedComputationType::kFusedBatchNormWithLeakyRelu: + conv2d(WithFusedBatchNormAndLeakyRelu(fusion_args.epsilon, + fused_batch_norm_args), + context, input, filter, output); + break; case FusedComputationType::kFusedBatchNormWithElu: conv2d(WithFusedBatchNormAndElu(fusion_args.epsilon, fused_batch_norm_args), @@ -681,10 +718,12 @@ class FusedConv2DOp : public OpKernel { {FCT::kBiasAddWithRelu, {"BiasAdd", "Relu"}}, {FCT::kBiasAddWithRelu6, {"BiasAdd", "Relu6"}}, {FCT::kBiasAddWithElu, {"BiasAdd", "Elu"}}, + {FCT::kBiasAddWithLeakyRelu, {"BiasAdd", "LeakyRelu"}}, {FCT::kFusedBatchNorm, {"FusedBatchNorm"}}, {FCT::kFusedBatchNormWithRelu, {"FusedBatchNorm", "Relu"}}, {FCT::kFusedBatchNormWithRelu6, {"FusedBatchNorm", "Relu6"}}, {FCT::kFusedBatchNormWithElu, {"FusedBatchNorm", "Elu"}}, + {FCT::kFusedBatchNormWithLeakyRelu, {"FusedBatchNorm", "LeakyRelu"}}, }; } diff --git a/tensorflow/core/kernels/conv_ops_test.cc b/tensorflow/core/kernels/conv_ops_test.cc index 3e192b83c57..e8e156b009c 100644 --- a/tensorflow/core/kernels/conv_ops_test.cc +++ b/tensorflow/core/kernels/conv_ops_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/cc/ops/const_op.h" #include "tensorflow/cc/ops/image_ops.h" #include "tensorflow/cc/ops/nn_ops.h" +#include "tensorflow/cc/ops/nn_ops_internal.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h" #include "tensorflow/core/framework/fake_input.h" @@ -32,6 +33,7 @@ limitations under the License. #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/platform/tf32_utils.h" #include "tensorflow/core/protobuf/rewriter_config.pb.h" #include "tensorflow/core/public/session.h" @@ -652,6 +654,8 @@ class FusedConv2DOpTest : public OpsTestBase { ops::Relu6(root.WithOpName("with_activation"), with_bias); } else if (activation_type == "Elu") { ops::Elu(root.WithOpName("with_activation"), with_bias); + } else if (activation_type == "LeakyRelu") { + ops::internal::LeakyRelu(root.WithOpName("with_activation"), with_bias); } else { ops::Identity(root.WithOpName("with_activation"), with_bias); } @@ -721,6 +725,9 @@ class FusedConv2DOpTest : public OpsTestBase { ops::Relu6(root.WithOpName("with_activation"), with_fused_batch_norm.y); } else if (activation_type == "Elu") { ops::Elu(root.WithOpName("with_activation"), with_fused_batch_norm.y); + } else if (activation_type == "LeakyRelu") { + ops::internal::LeakyRelu(root.WithOpName("with_activation"), + with_fused_batch_norm.y); } else { ops::Identity(root.WithOpName("with_activation"), with_fused_batch_norm.y); @@ -1038,9 +1045,10 @@ TYPED_TEST_P(FusedConv2DWithBiasOpTest, ExplicitPaddingConvolution) { #endif TYPED_TEST_P(FusedConv2DWithBiasOpTest, OneByOneConvolutionAndActivation) { + tensorflow::allow_tf32_execution(false); // Requires full precision Conv2D op const int filter_size = 1; const int filter_count = 12; - for (const string& activation : {"Relu", "Relu6", "Elu"}) { + for (const string& activation : {"Relu", "Relu6", "Elu", "LeakyRelu"}) { this->VerifyConv2DWithBiasAndActivation(activation, filter_size, filter_count); } @@ -1049,7 +1057,7 @@ TYPED_TEST_P(FusedConv2DWithBiasOpTest, OneByOneConvolutionAndActivation) { TYPED_TEST_P(FusedConv2DWithBiasOpTest, ImageSizeConvolutionAndActivation) { const int filter_size = TestFixture::kImageWidth; const int filter_count = 12; - for (const string& activation : {"Relu", "Relu6", "Elu"}) { + for (const string& activation : {"Relu", "Relu6", "Elu", "LeakyRelu"}) { this->VerifyConv2DWithBiasAndActivation(activation, filter_size, filter_count); } @@ -1058,7 +1066,7 @@ TYPED_TEST_P(FusedConv2DWithBiasOpTest, ImageSizeConvolutionAndActivation) { TYPED_TEST_P(FusedConv2DWithBiasOpTest, SpatialConvolutionAndActivation) { const int filter_size = 3; const int filter_count = 12; - for (const string& activation : {"Relu", "Relu6", "Elu"}) { + for (const string& activation : {"Relu", "Relu6", "Elu", "LeakyRelu"}) { this->VerifyConv2DWithBiasAndActivation(activation, filter_size, filter_count); } @@ -1069,7 +1077,7 @@ TYPED_TEST_P(FusedConv2DWithBiasOpTest, ExplicitPaddingConvolutionAndActivation) { const int filter_size = 3; const int filter_count = 12; - for (const string& activation : {"Relu", "Relu6", "Elu"}) { + for (const string& activation : {"Relu", "Relu6", "Elu", "LeakyRelu"}) { this->VerifyConv2DWithBiasAndActivation( activation, filter_size, filter_count, /*explicit_paddings=*/{0, 0, 1, 2, 3, 4, 0, 0}); @@ -1112,7 +1120,7 @@ TYPED_TEST_P(FusedConv2DWithBatchNormOpTest, ExplicitPaddingConvolution) { TYPED_TEST_P(FusedConv2DWithBatchNormOpTest, OneByOneConvolutionAndActivation) { const int filter_size = 1; const int filter_count = 12; - for (const string& activation : {"Relu", "Relu6", "Elu"}) { + for (const string& activation : {"Relu", "Relu6", "Elu", "LeakyRelu"}) { this->VerifyConv2DWithBatchNormAndActivation(activation, filter_size, filter_count); } @@ -1122,7 +1130,7 @@ TYPED_TEST_P(FusedConv2DWithBatchNormOpTest, ImageSizeConvolutionAndActivation) { const int filter_size = TestFixture::kImageWidth; const int filter_count = 12; - for (const string& activation : {"Relu", "Relu6", "Elu"}) { + for (const string& activation : {"Relu", "Relu6", "Elu", "LeakyRelu"}) { this->VerifyConv2DWithBatchNormAndActivation(activation, filter_size, filter_count); } @@ -1131,7 +1139,7 @@ TYPED_TEST_P(FusedConv2DWithBatchNormOpTest, TYPED_TEST_P(FusedConv2DWithBatchNormOpTest, SpatialConvolutionAndActivation) { const int filter_size = 3; const int filter_count = 12; - for (const string& activation : {"Relu", "Relu6", "Elu"}) { + for (const string& activation : {"Relu", "Relu6", "Elu", "LeakyRelu"}) { this->VerifyConv2DWithBatchNormAndActivation(activation, filter_size, filter_count); } @@ -1142,7 +1150,7 @@ TYPED_TEST_P(FusedConv2DWithBatchNormOpTest, ExplicitPaddingConvolutionAndActivation) { const int filter_size = 3; const int filter_count = 12; - for (const string& activation : {"Relu", "Relu6", "Elu"}) { + for (const string& activation : {"Relu", "Relu6", "Elu", "LeakyRelu"}) { this->VerifyConv2DWithBatchNormAndActivation( activation, filter_size, filter_count, /*explicit_paddings=*/{0, 0, 1, 2, 3, 4, 0, 0}); diff --git a/tensorflow/core/kernels/cwise_op_sigmoid.cc b/tensorflow/core/kernels/cwise_op_sigmoid.cc index 926284571ed..175cba3f63c 100644 --- a/tensorflow/core/kernels/cwise_op_sigmoid.cc +++ b/tensorflow/core/kernels/cwise_op_sigmoid.cc @@ -17,8 +17,8 @@ limitations under the License. #include "tensorflow/core/kernels/cwise_ops_gradients.h" namespace tensorflow { -REGISTER5(UnaryOp, CPU, "Sigmoid", functor::sigmoid, float, Eigen::half, double, - complex64, complex128); +REGISTER6(UnaryOp, CPU, "Sigmoid", functor::sigmoid, bfloat16, float, + Eigen::half, double, complex64, complex128); #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM REGISTER3(UnaryOp, GPU, "Sigmoid", functor::sigmoid, float, Eigen::half, double); @@ -27,8 +27,8 @@ REGISTER3(UnaryOp, GPU, "Sigmoid", functor::sigmoid, float, Eigen::half, REGISTER(UnaryOp, SYCL, "Sigmoid", functor::sigmoid, float); #endif // TENSORFLOW_USE_SYCL -REGISTER5(SimpleBinaryOp, CPU, "SigmoidGrad", functor::sigmoid_grad, float, - Eigen::half, double, complex64, complex128); +REGISTER6(SimpleBinaryOp, CPU, "SigmoidGrad", functor::sigmoid_grad, bfloat16, + float, Eigen::half, double, complex64, complex128); #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM REGISTER3(SimpleBinaryOp, GPU, "SigmoidGrad", functor::sigmoid_grad, float, Eigen::half, double); diff --git a/tensorflow/core/kernels/data/captured_function.cc b/tensorflow/core/kernels/data/captured_function.cc index 0066764baa0..681bc1f7c35 100644 --- a/tensorflow/core/kernels/data/captured_function.cc +++ b/tensorflow/core/kernels/data/captured_function.cc @@ -519,13 +519,6 @@ Status FunctionMetadata::Create( return Status::OK(); } } - for (const auto& node : fdef->node_def()) { - if (node.op() == kDataServiceDataset) { - return errors::InvalidArgument( - "The `.distribute(...)` dataset transformation is not supported " - "within tf.data functions."); - } - } return Status::OK(); } diff --git a/tensorflow/core/kernels/data/experimental/data_service_dataset_op.cc b/tensorflow/core/kernels/data/experimental/data_service_dataset_op.cc index 1c354153ec2..d89392598d5 100644 --- a/tensorflow/core/kernels/data/experimental/data_service_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/data_service_dataset_op.cc @@ -225,7 +225,7 @@ class DataServiceDatasetOp::Dataset : public DatasetBase { [&]() { return dispatcher_->CreateJob(dataset()->dataset_id_, dataset()->processing_mode_, - &job_client_id_); + job_client_id_); }, "create job", deadline_micros)); } else { @@ -233,7 +233,7 @@ class DataServiceDatasetOp::Dataset : public DatasetBase { [&]() { return dispatcher_->GetOrCreateJob( dataset()->dataset_id_, dataset()->processing_mode_, - dataset()->job_name_, iterator_index_, &job_client_id_); + dataset()->job_name_, iterator_index_, job_client_id_); }, "get or create job", deadline_micros)); } @@ -347,7 +347,7 @@ class DataServiceDatasetOp::Dataset : public DatasetBase { VLOG(3) << "Updating tasks"; std::vector tasks; bool job_finished; - Status s = dispatcher_->GetTasks(job_client_id_, &tasks, &job_finished); + Status s = dispatcher_->GetTasks(job_client_id_, tasks, job_finished); if (!s.ok()) { LOG(WARNING) << "Failed to get task info for job client id " << job_client_id_ << ": " << s; @@ -382,7 +382,7 @@ class DataServiceDatasetOp::Dataset : public DatasetBase { TaskInfo& task_info = new_task_entry.second; std::unique_ptr worker; Status s = CreateDataServiceWorkerClient(task_info.worker_address(), - dataset()->protocol_, &worker); + dataset()->protocol_, worker); if (!s.ok()) { status_ = s; get_next_cv_.notify_all(); @@ -489,8 +489,8 @@ class DataServiceDatasetOp::Dataset : public DatasetBase { CompressedElement compressed; bool end_of_sequence; for (int num_retries = 0;; ++num_retries) { - Status s = task->worker->GetElement(task->task_id, &compressed, - &end_of_sequence); + Status s = task->worker->GetElement(task->task_id, compressed, + end_of_sequence); if (s.ok()) { break; } @@ -629,7 +629,7 @@ void DataServiceDatasetOp::MakeDataset(OpKernelContext* ctx, ctx, ParseScalarArgument(ctx, kProcessingMode, &processing_mode_str)); ProcessingMode processing_mode; OP_REQUIRES_OK(ctx, - ParseProcessingMode(processing_mode_str, &processing_mode)); + ParseProcessingMode(processing_mode_str, processing_mode)); tstring address; OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kAddress, &address)); diff --git a/tensorflow/core/kernels/data/experimental/data_service_ops.cc b/tensorflow/core/kernels/data/experimental/data_service_ops.cc index ba175815c73..91bc7420eeb 100644 --- a/tensorflow/core/kernels/data/experimental/data_service_ops.cc +++ b/tensorflow/core/kernels/data/experimental/data_service_ops.cc @@ -63,7 +63,7 @@ void RegisterDatasetOp::Compute(OpKernelContext* ctx) { int64 deadline_micros = EnvTime::NowMicros() + kRetryTimeoutMicros; OP_REQUIRES_OK( ctx, grpc_util::Retry( - [&]() { return client.RegisterDataset(graph_def, &dataset_id); }, + [&]() { return client.RegisterDataset(graph_def, dataset_id); }, /*description=*/"register dataset", deadline_micros)); Tensor* output; diff --git a/tensorflow/core/kernels/dynamic_partition_op_gpu.cu.cc b/tensorflow/core/kernels/dynamic_partition_op_gpu.cu.cc index 7b64d9e8484..2bcc2b6ec65 100644 --- a/tensorflow/core/kernels/dynamic_partition_op_gpu.cu.cc +++ b/tensorflow/core/kernels/dynamic_partition_op_gpu.cu.cc @@ -48,6 +48,14 @@ limitations under the License. #include "tensorflow/core/util/gpu_kernel_helper.h" #include "tensorflow/core/util/transform_output_iterator.h" +#if GOOGLE_CUDA +#include "tensorflow/stream_executor/cuda/cuda_activation.h" +using stream_executor::cuda::ScopedActivateExecutorContext; +#elif TENSORFLOW_USE_ROCM +#include "tensorflow/core/platform/rocm.h" +using stream_executor::rocm::ScopedActivateExecutorContext; +#endif // GOOGLE_CUDA + namespace tensorflow { typedef Eigen::GpuDevice GPUDevice; @@ -302,6 +310,9 @@ class DynamicPartitionOpGPU : public AsyncOpKernel { TensorReference partition_ref(partition_count); auto wrapped_callback = [this, c, &data, &partitions, indices_out, partition_ref, cpu_tensor, done]() { + auto stream = c->op_device_context()->stream(); + ScopedActivateExecutorContext scoped_activation{stream->parent()}; + OpOutputList outputs; this->AllocateOutputs(c, &data, &partitions, &cpu_tensor, &outputs, done); if (!c->status().ok()) { diff --git a/tensorflow/core/kernels/fused_eigen_output_kernels.cc b/tensorflow/core/kernels/fused_eigen_output_kernels.cc index 94e621ae05b..e8e9fd6407e 100644 --- a/tensorflow/core/kernels/fused_eigen_output_kernels.cc +++ b/tensorflow/core/kernels/fused_eigen_output_kernels.cc @@ -60,18 +60,25 @@ Status InitializeFusedComputation( if (*fused_computation == FusedComputationType::kBiasAdd || *fused_computation == FusedComputationType::kBiasAddWithRelu || *fused_computation == FusedComputationType::kBiasAddWithRelu6 || - *fused_computation == FusedComputationType::kBiasAddWithElu) { + *fused_computation == FusedComputationType::kBiasAddWithElu || + *fused_computation == FusedComputationType::kBiasAddWithLeakyRelu) { if (num_args != 1) { return errors::InvalidArgument( "Fused ", kernel_name, " with BiasAdd must have one extra argument: bias."); } + if (*fused_computation == FusedComputationType::kBiasAddWithLeakyRelu) { + TF_RETURN_IF_ERROR(context->GetAttr( + "leakyrelu_alpha", &fused_computation_args->leakyrelu_alpha)); + } } if (*fused_computation == FusedComputationType::kFusedBatchNorm || *fused_computation == FusedComputationType::kFusedBatchNormWithRelu || *fused_computation == FusedComputationType::kFusedBatchNormWithRelu6 || - *fused_computation == FusedComputationType::kFusedBatchNormWithElu) { + *fused_computation == FusedComputationType::kFusedBatchNormWithElu || + *fused_computation == + FusedComputationType::kFusedBatchNormWithLeakyRelu) { if (num_args != 4) { return errors::InvalidArgument( "Fused ", kernel_name, @@ -80,6 +87,11 @@ Status InitializeFusedComputation( } TF_RETURN_IF_ERROR( context->GetAttr("epsilon", &fused_computation_args->epsilon)); + if (*fused_computation == + FusedComputationType::kFusedBatchNormWithLeakyRelu) { + TF_RETURN_IF_ERROR(context->GetAttr( + "leakyrelu_alpha", &fused_computation_args->leakyrelu_alpha)); + } } return Status::OK(); diff --git a/tensorflow/core/kernels/fused_eigen_output_kernels.h b/tensorflow/core/kernels/fused_eigen_output_kernels.h index 2588da10f58..546cf39e094 100644 --- a/tensorflow/core/kernels/fused_eigen_output_kernels.h +++ b/tensorflow/core/kernels/fused_eigen_output_kernels.h @@ -39,15 +39,18 @@ enum class FusedComputationType { kBiasAddWithRelu, kBiasAddWithRelu6, kBiasAddWithElu, + kBiasAddWithLeakyRelu, kFusedBatchNorm, kFusedBatchNormWithRelu, kFusedBatchNormWithRelu6, - kFusedBatchNormWithElu + kFusedBatchNormWithElu, + kFusedBatchNormWithLeakyRelu }; // We have to pass around additional arguments for all possible fusion types. struct FusedComputationArgs { - float epsilon = 0.0; // Used by `FusedBatchNorm` fusion only + float epsilon = 0.0; // Used by `FusedBatchNorm` fusion only + float leakyrelu_alpha = 0.0; // Used by `LeakyRelu` fusion only }; struct FusedComputationPattern { @@ -111,15 +114,32 @@ struct Elu { }; }; +// Applies `LeakyRelu` to the passed input expression. +struct LeakyRelu { + template + static auto apply(XprType expr, const float leakyrelu_alpha) -> decltype( + (expr < std::declval()) + .select(expr * + expr.constant(std::declval()), + expr)) { + return (expr < static_cast(0)) + .select(expr * expr.constant(static_cast( + leakyrelu_alpha)), + expr); + }; +}; + template struct BiasAddArgs { const T* bias_add_data = nullptr; + float leakyrelu_alpha; static bool IsSupported(FusedComputationType fusion) { return fusion == FusedComputationType::kBiasAdd || fusion == FusedComputationType::kBiasAddWithRelu || fusion == FusedComputationType::kBiasAddWithRelu6 || - fusion == FusedComputationType::kBiasAddWithElu; + fusion == FusedComputationType::kBiasAddWithElu || + fusion == FusedComputationType::kBiasAddWithLeakyRelu; } }; @@ -134,11 +154,14 @@ struct FusedBatchNormArgs { // scaling_factor = (estimated_variance + epsilon).rsqrt() * scale Eigen::Tensor scaling_factor; + float leakyrelu_alpha; + static bool IsSupported(FusedComputationType fusion) { return fusion == FusedComputationType::kFusedBatchNorm || fusion == FusedComputationType::kFusedBatchNormWithRelu || fusion == FusedComputationType::kFusedBatchNormWithRelu6 || - fusion == FusedComputationType::kFusedBatchNormWithElu; + fusion == FusedComputationType::kFusedBatchNormWithElu || + fusion == FusedComputationType::kFusedBatchNormWithLeakyRelu; } }; @@ -203,6 +226,34 @@ struct BiasAddOutputKernel { const T* bias_data; }; +template +struct BiasAddOutputKernel { + explicit BiasAddOutputKernel(const BiasAddArgs& args) + : bias_data(args.bias_add_data), leakyrelu_alpha(args.leakyrelu_alpha) {} + + template + EIGEN_ALWAYS_INLINE void operator()( + const ContractionOutputMapper& output_mapper, + const Eigen::TensorContractionParams& params, StorageIndex i, + StorageIndex j, StorageIndex num_rows, StorageIndex num_cols) const { + DCHECK(params.swapped_arguments); + + const T* bias_base = bias_data + i; + typename TTypes::UnalignedConstTensor bias(bias_base, num_rows); + + for (int col = 0; col < num_cols; ++col) { + T* output_base = &output_mapper(0, col); + typename TTypes::UnalignedTensor output(output_base, num_rows); + const auto expr = output + bias; + output = LeakyRelu::template apply(expr, leakyrelu_alpha); + } + } + + private: + const T* bias_data; + float leakyrelu_alpha; +}; + // Output kernel that fuses FusedBatchNorm operation into the output of tensor // contraction + activation function defined by Activation. template @@ -247,6 +298,51 @@ struct FusedBatchNormOutputKernel { const T* estimated_mean_data; }; +template +struct FusedBatchNormOutputKernel { + FusedBatchNormOutputKernel(T epsilon, const FusedBatchNormArgs& args) + : epsilon(epsilon), + scaling_factor_data(args.scaling_factor.data()), + offset_data(args.offset_data), + estimated_mean_data(args.estimated_mean_data), + leakyrelu_alpha(args.leakyrelu_alpha) {} + + template + EIGEN_ALWAYS_INLINE void operator()( + const ContractionOutputMapper& output_mapper, + const Eigen::TensorContractionParams& params, StorageIndex i, + StorageIndex j, StorageIndex num_rows, StorageIndex num_cols) const { + DCHECK(params.swapped_arguments); + + const T* scaling_factor_base = scaling_factor_data + i; + const T* offset_base = offset_data + i; + const T* mean_base = estimated_mean_data + i; + + typename TTypes::UnalignedConstTensor scaling_factor(scaling_factor_base, + num_rows); + typename TTypes::UnalignedConstTensor offset(offset_base, num_rows); + typename TTypes::UnalignedConstTensor mean(mean_base, num_rows); + + for (int col = 0; col < num_cols; ++col) { + T* output_base = &output_mapper(0, col); + typename TTypes::UnalignedTensor output(output_base, num_rows); + + auto scaled = (output - mean) * scaling_factor; + auto shifted = scaled + offset; + + output = LeakyRelu::template apply(shifted, + leakyrelu_alpha); + } + } + + private: + T epsilon; + const T* scaling_factor_data; + const T* offset_data; + const T* estimated_mean_data; + float leakyrelu_alpha; +}; + // Type aliases for the output kernels, purely for the sake of better launch // dispatching code readability. template @@ -258,6 +354,8 @@ using WithBiasAddAndRelu6 = BiasAddOutputKernel; template using WithBiasAddAndElu = BiasAddOutputKernel; template +using WithBiasAddAndLeakyRelu = BiasAddOutputKernel; +template using WithFusedBatchNorm = FusedBatchNormOutputKernel; template using WithFusedBatchNormAndRelu = FusedBatchNormOutputKernel; @@ -265,9 +363,12 @@ template using WithFusedBatchNormAndRelu6 = FusedBatchNormOutputKernel; template using WithFusedBatchNormAndElu = FusedBatchNormOutputKernel; +template +using WithFusedBatchNormAndLeakyRelu = FusedBatchNormOutputKernel; template -Status InitBiasAddArgs(OpKernelContext* context, BiasAddArgs* args) { +Status InitBiasAddArgs(OpKernelContext* context, BiasAddArgs* args, + const float* leakyrelu_alpha = nullptr) { // Bias of the following dimensions: [ output_depth ] const Tensor& bias = context->input(2); @@ -281,12 +382,17 @@ Status InitBiasAddArgs(OpKernelContext* context, BiasAddArgs* args) { args->bias_add_data = data_ptr(bias); + if (leakyrelu_alpha) { + args->leakyrelu_alpha = *leakyrelu_alpha; + } + return Status::OK(); } template Status InitFusedBatchNormArgs(OpKernelContext* context, float epsilon, - FusedBatchNormArgs* args) { + FusedBatchNormArgs* args, + const float* leakyrelu_alpha = nullptr) { const Tensor& scale = context->input(2); const Tensor& offset = context->input(3); const Tensor& estimated_mean = context->input(4); @@ -319,6 +425,10 @@ Status InitFusedBatchNormArgs(OpKernelContext* context, float epsilon, (estimated_variance.flat() + static_cast(epsilon)).rsqrt() * scale.flat(); + if (leakyrelu_alpha) { + args->leakyrelu_alpha = *leakyrelu_alpha; + } + return Status::OK(); } diff --git a/tensorflow/core/lib/lmdb/BUILD b/tensorflow/core/lib/lmdb/BUILD new file mode 100644 index 00000000000..c863d4c4ab5 --- /dev/null +++ b/tensorflow/core/lib/lmdb/BUILD @@ -0,0 +1,28 @@ +# Description: +# lmdb test data packages. + +package( + licenses = ["notice"], # Apache 2.0 +) + +filegroup( + name = "lmdb_testdata", + testonly = 1, + srcs = [ + # A simple key-value store: + # 0 : 'b' + # 1 : 'b' + # ... + # 9 : 'b' + # Which is then overwritten with: + # 0 : 'a' + # 1 : 'b' + # ... + # 9 : 'j' + "testdata/data.mdb", + # LMDB, being a memory-mapped database, uses a different file format on + # big-endian systems. + "testdata/data_bigendian.mdb", + ], + visibility = ["//visibility:public"], +) diff --git a/tensorflow/core/ops/compat/ops_history_v2/Max.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/Max.pbtxt index 4c931ccac4d..bf147acf0f4 100644 --- a/tensorflow/core/ops/compat/ops_history_v2/Max.pbtxt +++ b/tensorflow/core/ops/compat/ops_history_v2/Max.pbtxt @@ -32,8 +32,6 @@ op { type: DT_UINT16 type: DT_INT16 type: DT_INT8 - type: DT_COMPLEX64 - type: DT_COMPLEX128 type: DT_QINT8 type: DT_QUINT8 type: DT_QINT32 @@ -89,8 +87,6 @@ op { type: DT_UINT16 type: DT_INT16 type: DT_INT8 - type: DT_COMPLEX64 - type: DT_COMPLEX128 type: DT_QINT8 type: DT_QUINT8 type: DT_QINT32 @@ -148,8 +144,6 @@ op { type: DT_UINT16 type: DT_INT16 type: DT_INT8 - type: DT_COMPLEX64 - type: DT_COMPLEX128 type: DT_QINT8 type: DT_QUINT8 type: DT_QINT32 @@ -206,14 +200,12 @@ op { type: DT_UINT8 type: DT_INT16 type: DT_INT8 - type: DT_COMPLEX64 type: DT_INT64 type: DT_QINT8 type: DT_QUINT8 type: DT_QINT32 type: DT_BFLOAT16 type: DT_UINT16 - type: DT_COMPLEX128 type: DT_HALF type: DT_UINT32 type: DT_UINT64 @@ -234,3 +226,63 @@ op { } } } +op { + name: "Max" + input_arg { + name: "input" + type_attr: "T" + } + input_arg { + name: "reduction_indices" + type_attr: "Tidx" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "keep_dims" + type: "bool" + default_value { + b: false + } + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT32 + type: DT_UINT8 + type: DT_INT16 + type: DT_INT8 + type: DT_INT64 + type: DT_BFLOAT16 + type: DT_UINT16 + type: DT_HALF + type: DT_UINT32 + type: DT_UINT64 + type: DT_QINT8 + type: DT_QUINT8 + type: DT_QINT32 + type: DT_QINT16 + type: DT_QUINT16 + } + } + } + attr { + name: "Tidx" + type: "type" + default_value { + type: DT_INT32 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } +} diff --git a/tensorflow/core/ops/compat/ops_history_v2/Min.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/Min.pbtxt index f0ebdb0e41f..4959b5e8d58 100644 --- a/tensorflow/core/ops/compat/ops_history_v2/Min.pbtxt +++ b/tensorflow/core/ops/compat/ops_history_v2/Min.pbtxt @@ -32,8 +32,6 @@ op { type: DT_UINT16 type: DT_INT16 type: DT_INT8 - type: DT_COMPLEX64 - type: DT_COMPLEX128 type: DT_QINT8 type: DT_QUINT8 type: DT_QINT32 @@ -89,8 +87,6 @@ op { type: DT_UINT16 type: DT_INT16 type: DT_INT8 - type: DT_COMPLEX64 - type: DT_COMPLEX128 type: DT_QINT8 type: DT_QUINT8 type: DT_QINT32 @@ -148,8 +144,6 @@ op { type: DT_UINT16 type: DT_INT16 type: DT_INT8 - type: DT_COMPLEX64 - type: DT_COMPLEX128 type: DT_QINT8 type: DT_QUINT8 type: DT_QINT32 @@ -206,14 +200,12 @@ op { type: DT_UINT8 type: DT_INT16 type: DT_INT8 - type: DT_COMPLEX64 type: DT_INT64 type: DT_QINT8 type: DT_QUINT8 type: DT_QINT32 type: DT_BFLOAT16 type: DT_UINT16 - type: DT_COMPLEX128 type: DT_HALF type: DT_UINT32 type: DT_UINT64 @@ -234,3 +226,63 @@ op { } } } +op { + name: "Min" + input_arg { + name: "input" + type_attr: "T" + } + input_arg { + name: "reduction_indices" + type_attr: "Tidx" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "keep_dims" + type: "bool" + default_value { + b: false + } + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT32 + type: DT_UINT8 + type: DT_INT16 + type: DT_INT8 + type: DT_INT64 + type: DT_BFLOAT16 + type: DT_UINT16 + type: DT_HALF + type: DT_UINT32 + type: DT_UINT64 + type: DT_QINT8 + type: DT_QUINT8 + type: DT_QINT32 + type: DT_QINT16 + type: DT_QUINT16 + } + } + } + attr { + name: "Tidx" + type: "type" + default_value { + type: DT_INT32 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } +} diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc index cbf1ef53dde..3afdb679497 100644 --- a/tensorflow/core/ops/math_ops.cc +++ b/tensorflow/core/ops/math_ops.cc @@ -1026,7 +1026,7 @@ REGISTER_OP("Min") .Input("reduction_indices: Tidx") .Output("output: T") .Attr("keep_dims: bool = false") - .Attr("T: numbertype") + .Attr("T: {realnumbertype, quantizedtype}") .Attr("Tidx: {int32, int64} = DT_INT32") .SetShapeFn(shape_inference::ReductionShape); @@ -1035,7 +1035,7 @@ REGISTER_OP("Max") .Input("reduction_indices: Tidx") .Output("output: T") .Attr("keep_dims: bool = false") - .Attr("T: numbertype") + .Attr("T: {realnumbertype, quantizedtype}") .Attr("Tidx: {int32, int64} = DT_INT32") .SetShapeFn(shape_inference::ReductionShape); diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc index 1ef0e82bf4a..b34bb4131d9 100644 --- a/tensorflow/core/ops/nn_ops.cc +++ b/tensorflow/core/ops/nn_ops.cc @@ -404,6 +404,8 @@ REGISTER_OP("_FusedConv2D") .Attr("fused_ops: list(string) = []") // Attributes for the FusedBatchNorm ------------------------------------ // .Attr("epsilon: float = 0.0001") + // Attributes for the LeakyRelu ----------------------------------------- // + .Attr("leakyrelu_alpha: float = 0.2") // ---------------------------------------------------------------------- // .SetShapeFn(shape_inference::Conv2DShapeWithExplicitPadding) .Doc(R"doc( @@ -633,7 +635,10 @@ REGISTER_OP("_FusedDepthwiseConv2dNative") .Attr("fused_ops: list(string) = []") // Attributes for the FusedBatchNorm ------------------------------------ // .Attr("epsilon: float = 0.0001") + // Attributes for the LeakyRelu ----------------------------------------- // + .Attr("leakyrelu_alpha: float = 0.2") // ---------------------------------------------------------------------- // + .SetShapeFn(shape_inference::DepthwiseConv2DNativeShape); // -------------------------------------------------------------------------- diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt index 5b4de7eb980..cd9a234ada1 100644 --- a/tensorflow/core/ops/ops.pbtxt +++ b/tensorflow/core/ops/ops.pbtxt @@ -23634,17 +23634,17 @@ op { type: DT_UINT8 type: DT_INT16 type: DT_INT8 - type: DT_COMPLEX64 type: DT_INT64 - type: DT_QINT8 - type: DT_QUINT8 - type: DT_QINT32 type: DT_BFLOAT16 type: DT_UINT16 - type: DT_COMPLEX128 type: DT_HALF type: DT_UINT32 type: DT_UINT64 + type: DT_QINT8 + type: DT_QUINT8 + type: DT_QINT32 + type: DT_QINT16 + type: DT_QUINT16 } } } @@ -24792,17 +24792,17 @@ op { type: DT_UINT8 type: DT_INT16 type: DT_INT8 - type: DT_COMPLEX64 type: DT_INT64 - type: DT_QINT8 - type: DT_QUINT8 - type: DT_QINT32 type: DT_BFLOAT16 type: DT_UINT16 - type: DT_COMPLEX128 type: DT_HALF type: DT_UINT32 type: DT_UINT64 + type: DT_QINT8 + type: DT_QUINT8 + type: DT_QINT32 + type: DT_QINT16 + type: DT_QUINT16 } } } diff --git a/tensorflow/core/platform/env.h b/tensorflow/core/platform/env.h index 7b716798c28..308d8a09fa7 100644 --- a/tensorflow/core/platform/env.h +++ b/tensorflow/core/platform/env.h @@ -111,6 +111,13 @@ class Env { Status NewRandomAccessFile(const std::string& fname, std::unique_ptr* result); + Status NewRandomAccessFile(const std::string& fname, TransactionToken* token, + std::unique_ptr* result) { + // We duplicate these methods due to Google internal coding style prevents + // virtual functions with default arguments. See PR #41615. + return Status::OK(); + } + /// \brief Creates an object that writes to a new file with the specified /// name. /// @@ -127,6 +134,11 @@ class Env { Status NewWritableFile(const std::string& fname, std::unique_ptr* result); + Status NewWritableFile(const std::string& fname, TransactionToken* token, + std::unique_ptr* result) { + return Status::OK(); + } + /// \brief Creates an object that either appends to an existing file, or /// writes to a new file (if the file does not exist to begin with). /// @@ -142,6 +154,10 @@ class Env { Status NewAppendableFile(const std::string& fname, std::unique_ptr* result); + Status NewAppendableFile(const std::string& fname, TransactionToken* token, + std::unique_ptr* result) { + return Status::OK(); + } /// \brief Creates a readonly region of memory with the file context. /// /// On success, it returns a pointer to read-only memory region @@ -156,21 +172,41 @@ class Env { Status NewReadOnlyMemoryRegionFromFile( const std::string& fname, std::unique_ptr* result); + Status NewReadOnlyMemoryRegionFromFile( + const std::string& fname, TransactionToken* token, + std::unique_ptr* result) { + return Status::OK(); + } + /// Returns OK if the named path exists and NOT_FOUND otherwise. Status FileExists(const std::string& fname); + Status FileExists(const std::string& fname, TransactionToken* token) { + return Status::OK(); + } + /// Returns true if all the listed files exist, false otherwise. /// if status is not null, populate the vector with a detailed status /// for each file. bool FilesExist(const std::vector& files, std::vector* status); + bool FilesExist(const std::vector& files, TransactionToken* token, + std::vector* status) { + return true; + } + /// \brief Stores in *result the names of the children of the specified /// directory. The names are relative to "dir". /// /// Original contents of *results are dropped. Status GetChildren(const std::string& dir, std::vector* result); + Status GetChildren(const std::string& dir, TransactionToken* token, + std::vector* result) { + return Status::OK(); + } + /// \brief Returns true if the path matches the given pattern. The wildcards /// allowed in pattern are described in FileSystem::GetMatchingPaths. virtual bool MatchPath(const std::string& path, @@ -183,9 +219,18 @@ class Env { virtual Status GetMatchingPaths(const std::string& pattern, std::vector* results); + Status GetMatchingPaths(const std::string& pattern, TransactionToken* token, + std::vector* results) { + return Status::OK(); + } + /// Deletes the named file. Status DeleteFile(const std::string& fname); + Status DeleteFile(const std::string& fname, TransactionToken* token) { + return Status::OK(); + } + /// \brief Deletes the specified directory and all subdirectories and files /// underneath it. This is accomplished by traversing the directory tree /// rooted at dirname and deleting entries as they are encountered. @@ -213,6 +258,11 @@ class Env { Status DeleteRecursively(const std::string& dirname, int64* undeleted_files, int64* undeleted_dirs); + Status DeleteRecursively(const std::string& dirname, TransactionToken* token, + int64* undeleted_files, int64* undeleted_dirs) { + return Status::OK(); + } + /// \brief Creates the specified directory and all the necessary /// subdirectories. Typical return codes. /// * OK - successfully created the directory and sub directories, even if @@ -220,18 +270,35 @@ class Env { /// * PERMISSION_DENIED - dirname or some subdirectory is not writable. Status RecursivelyCreateDir(const std::string& dirname); + Status RecursivelyCreateDir(const std::string& dirname, + TransactionToken* token) { + return Status::OK(); + } /// \brief Creates the specified directory. Typical return codes /// * OK - successfully created the directory. /// * ALREADY_EXISTS - directory already exists. /// * PERMISSION_DENIED - dirname is not writable. Status CreateDir(const std::string& dirname); + Status CreateDir(const std::string& dirname, TransactionToken* token) { + return Status::OK(); + } + /// Deletes the specified directory. Status DeleteDir(const std::string& dirname); + Status DeleteDir(const std::string& dirname, TransactionToken* token) { + return Status::OK(); + } + /// Obtains statistics for the given path. Status Stat(const std::string& fname, FileStatistics* stat); + Status Stat(const std::string& fname, TransactionToken* token, + FileStatistics* stat) { + return Status::OK(); + } + /// \brief Returns whether the given path is a directory or not. /// Typical return codes (not guaranteed exhaustive): /// * OK - The path exists and is a directory. @@ -256,13 +323,59 @@ class Env { /// Stores the size of `fname` in `*file_size`. Status GetFileSize(const std::string& fname, uint64* file_size); + Status GetFileSize(const std::string& fname, TransactionToken* token, + uint64* file_size) { + return Status::OK(); + } + /// \brief Renames file src to target. If target already exists, it will be /// replaced. Status RenameFile(const std::string& src, const std::string& target); + Status RenameFile(const std::string& src, const std::string& target, + TransactionToken* token) { + return Status::OK(); + } + /// \brief Copy the src to target. Status CopyFile(const std::string& src, const std::string& target); + Status CopyFile(const std::string& src, const std::string& target, + TransactionToken* token) { + return Status::OK(); + } + + /// \brief starts a new transaction on the filesystem that handles filename + Status StartTransaction(const std::string& filename, + TransactionToken** token) { + token = nullptr; + return Status::OK(); + } + + /// \brief Adds `path` to transaction in `token` if token belongs to + /// filesystem that handles the path. + Status AddToTransaction(const std::string& path, TransactionToken* token) { + return Status::OK(); + } + + /// \brief Get token for `path` or start a new transaction and add `path` to + /// it. + Status GetTokenOrStartTransaction(const std::string& path, + TransactionToken** token) { + *token = nullptr; + return Status::OK(); + } + + /// \brief Returns the transaction for `path` or nullptr in `token` + Status GetTransactionForPath(const std::string& path, + TransactionToken** token) { + token = nullptr; + return Status::OK(); + } + + /// \brief Finalizes the transaction + Status EndTransaction(TransactionToken* token) { return Status::OK(); } + /// \brief Returns the absolute path of the current executable. It resolves /// symlinks if there is any. std::string GetExecutablePath(); diff --git a/tensorflow/core/platform/ram_file_system.h b/tensorflow/core/platform/ram_file_system.h index 407bcb3ba0f..ce6d05486e5 100644 --- a/tensorflow/core/platform/ram_file_system.h +++ b/tensorflow/core/platform/ram_file_system.h @@ -177,7 +177,7 @@ class RamFileSystem : public FileSystem { FileStatistics* stat) override { mutex_lock m(mu_); auto it = fs_.lower_bound(fname); - if (it == fs_.end()) { + if (it == fs_.end() || !absl::StartsWith(it->first, fname)) { return errors::NotFound(""); } diff --git a/tensorflow/core/platform/ram_file_system_test.py b/tensorflow/core/platform/ram_file_system_test.py index 0f4f47ec44e..960765d68a2 100644 --- a/tensorflow/core/platform/ram_file_system_test.py +++ b/tensorflow/core/platform/ram_file_system_test.py @@ -21,6 +21,7 @@ from __future__ import print_function import numpy as np +from tensorflow.python.eager import def_function from tensorflow.python.estimator.estimator import Estimator from tensorflow.python.estimator.model_fn import EstimatorSpec from tensorflow.python.estimator.run_config import RunConfig @@ -28,9 +29,11 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import test_util from tensorflow.python.layers import core as core_layers +from tensorflow.python.module import module from tensorflow.python.ops.losses import losses from tensorflow.python.platform import gfile from tensorflow.python.platform import test +from tensorflow.python.saved_model import saved_model from tensorflow.python.training import adam from tensorflow.python.training import training_util @@ -82,6 +85,17 @@ class RamFilesystemTest(test_util.TensorFlowTestCase): matches = ['ram://c/b/%d.txt' % i for i in range(10)] self.assertEqual(gfile.Glob('ram://c/b/*'), matches) + def test_file_exists(self): + with gfile.GFile('ram://exists/a/b/c.txt', 'w') as f: + f.write('') + self.assertTrue(gfile.Exists('ram://exists/a')) + self.assertTrue(gfile.Exists('ram://exists/a/b')) + self.assertTrue(gfile.Exists('ram://exists/a/b/c.txt')) + + self.assertFalse(gfile.Exists('ram://exists/b')) + self.assertFalse(gfile.Exists('ram://exists/a/c')) + self.assertFalse(gfile.Exists('ram://exists/a/b/k')) + def test_estimator(self): def model_fn(features, labels, mode, params): @@ -114,6 +128,18 @@ class RamFilesystemTest(test_util.TensorFlowTestCase): estimator.train(input_fn=input_fn, steps=10) estimator.train(input_fn=input_fn, steps=10) + def test_savedmodel(self): + class MyModule(module.Module): + + @def_function.function(input_signature=[]) + def foo(self): + return constant_op.constant([1]) + + saved_model.save(MyModule(), 'ram://my_module') + + loaded = saved_model.load('ram://my_module') + self.assertAllEqual(loaded.foo(), [1]) + if __name__ == '__main__': test.main() diff --git a/tensorflow/core/platform/tf32_utils.cc b/tensorflow/core/platform/tf32_utils.cc index d2f40ea161a..21059b98e11 100644 --- a/tensorflow/core/platform/tf32_utils.cc +++ b/tensorflow/core/platform/tf32_utils.cc @@ -20,8 +20,8 @@ limitations under the License. namespace tensorflow { // Whether TensorFloat-32 should be used where supported. -// TODO(nluehr): Maybe enable by default after TF32 Ampere testing. -static std::atomic tf32_allowed{false}; +// TODO(reedwm): Change word "allow" to "enable" in all TensorFloat-32 functions +static std::atomic tf32_allowed{true}; void allow_tf32_execution(bool allowed) { tf32_allowed = allowed; } diff --git a/tensorflow/core/profiler/convert/op_metrics_db_combiner.cc b/tensorflow/core/profiler/convert/op_metrics_db_combiner.cc index 425bf0077c3..4cf81f422af 100644 --- a/tensorflow/core/profiler/convert/op_metrics_db_combiner.cc +++ b/tensorflow/core/profiler/convert/op_metrics_db_combiner.cc @@ -36,6 +36,9 @@ void CopyOpMetricsMetadata(const OpMetrics& src, OpMetrics* dst) { DCHECK(dst != nullptr); DCHECK_EQ(src.hlo_module_id(), dst->hlo_module_id()); DCHECK_EQ(src.name(), dst->name()); + if (dst->long_name().empty()) { + dst->set_long_name(src.long_name()); + } if (dst->category().empty()) { dst->set_category(src.category()); } diff --git a/tensorflow/core/profiler/convert/op_stats_to_overview_page.cc b/tensorflow/core/profiler/convert/op_stats_to_overview_page.cc index 276181dd7bb..8f58b7bf3ae 100644 --- a/tensorflow/core/profiler/convert/op_stats_to_overview_page.cc +++ b/tensorflow/core/profiler/convert/op_stats_to_overview_page.cc @@ -234,7 +234,8 @@ OverviewPageAnalysis ComputeAnalysisResult(const OpStats& op_stats) { uint64 outside_compilation_device_op_time_ps = 0; for (const OpMetrics& metrics : op_stats.device_op_metrics_db().metrics_db()) { - if (!IsOutsideCompilationOp(metrics.provenance(), metrics.name())) continue; + if (!IsOutsideCompilationOp(metrics.provenance(), metrics.long_name())) + continue; outside_compilation_device_op_time_ps += metrics.self_time_ps(); } uint64 num_total_tf_ops = num_host_tf_ops + num_device_tf_ops; diff --git a/tensorflow/core/profiler/convert/op_stats_to_tf_stats.cc b/tensorflow/core/profiler/convert/op_stats_to_tf_stats.cc index a5e127a45d0..48354874509 100644 --- a/tensorflow/core/profiler/convert/op_stats_to_tf_stats.cc +++ b/tensorflow/core/profiler/convert/op_stats_to_tf_stats.cc @@ -29,6 +29,10 @@ namespace tensorflow { namespace profiler { namespace { +// The maximum number of Tensorflow Ops displayed on Tensorflow Stats page. +// 500 device side ops and 500 host side ops. +const int kMaxNumOfOps = 500; + TfStatsRecord ConvertOpMetricsToTfStatsRecord( bool on_device, const OpMetrics& metrics, double ridge_point_operational_intensity) { @@ -60,7 +64,8 @@ TfStatsTable GenerateTfStatsTable( total_device_time_ps -= IdleTimePs(device_tf_metrics_db); } double total_device_time_us = PicosToMicros(total_device_time_ps); - for (const OpMetrics* metrics : SortedOpMetricsDb(device_tf_metrics_db)) { + for (const OpMetrics* metrics : + SortedOpMetricsDb(device_tf_metrics_db, kMaxNumOfOps)) { if (exclude_idle && IsIdleOp(*metrics)) continue; TfStatsRecord* record = tf_stats_table.add_tf_stats_record(); *record = ConvertOpMetricsToTfStatsRecord( @@ -84,8 +89,8 @@ TfStatsTable GenerateTfStatsTable( total_host_time_ps -= IdleTimePs(host_tf_metrics_db); } double total_host_time_us = PicosToMicros(total_host_time_ps); - for (const OpMetrics* metrics : - tensorflow::profiler::SortedOpMetricsDb(host_tf_metrics_db)) { + for (const OpMetrics* metrics : tensorflow::profiler::SortedOpMetricsDb( + host_tf_metrics_db, kMaxNumOfOps)) { if (exclude_idle && IsIdleOp(*metrics)) continue; TfStatsRecord* record = tf_stats_table.add_tf_stats_record(); *record = ConvertOpMetricsToTfStatsRecord( diff --git a/tensorflow/core/profiler/internal/cpu/python_tracer.cc b/tensorflow/core/profiler/internal/cpu/python_tracer.cc index 4233c5fdd72..a6bc2a546b2 100644 --- a/tensorflow/core/profiler/internal/cpu/python_tracer.cc +++ b/tensorflow/core/profiler/internal/cpu/python_tracer.cc @@ -58,7 +58,7 @@ class PythonTracer : public ProfilerInterface { PythonTracer::~PythonTracer() { Stop().IgnoreError(); - PythonHooks::GetSingleton()->Finalize(); + PythonHooks::GetSingleton()->Finalize(nullptr); } Status PythonTracer::Start() { @@ -76,7 +76,7 @@ Status PythonTracer::Stop() { return errors::Internal("TraceMeRecorder not started"); } VLOG(1) << __FUNCTION__; - PythonHooks::GetSingleton()->Stop(options_); + PythonHooks::GetSingleton()->Stop(); recording_ = false; return Status::OK(); } @@ -87,17 +87,12 @@ Status PythonTracer::CollectData(RunMetadata* run_metadata) { // in the wrong threads. // We had assumed HostTracer::Stop is called when ProfilerSession try to // serialize PythonTracer. - PythonHooks::GetSingleton()->Finalize(); + PythonHooks::GetSingleton()->Finalize(nullptr); return Status::OK(); } Status PythonTracer::CollectData(XSpace* space) { - // This ProfilerInterface rely on HostTracer to serialize its trace. - // Make sure unpaired traceme don't get recorded, because it will end up - // in the wrong threads. - // We had assumed HostTracer::Stop is called when ProfilerSession try to - // serialize PythonTracer. - PythonHooks::GetSingleton()->Finalize(); + PythonHooks::GetSingleton()->Finalize(space); return Status::OK(); } @@ -107,8 +102,7 @@ Status PythonTracer::CollectData(XSpace* space) { std::unique_ptr CreatePythonTracer( const ProfileOptions& options) { PythonHooksOptions pyhooks_options; - pyhooks_options.enable_trace_python_function = - options.python_tracer_level() && options.host_tracer_level(); + pyhooks_options.enable_trace_python_function = options.python_tracer_level(); pyhooks_options.enable_python_traceme = options.host_tracer_level() != 0; return absl::make_unique(pyhooks_options); } diff --git a/tensorflow/core/profiler/protobuf/op_metrics.proto b/tensorflow/core/profiler/protobuf/op_metrics.proto index af38795b7b2..670ebd5ed67 100644 --- a/tensorflow/core/profiler/protobuf/op_metrics.proto +++ b/tensorflow/core/profiler/protobuf/op_metrics.proto @@ -26,12 +26,14 @@ message LayoutAnalysis { } // Metrics for an operation (accumulated over all occurrences). -// Next ID: 20 +// Next ID: 21 message OpMetrics { // HLO module id. 0 for TF ops. uint64 hlo_module_id = 13; // Name of this op. string name = 6; + // Long name of this op (e.g., HLO expression). + string long_name = 20; // Category of this op. string category = 11; // Provenance of this op (e.g., if HLO op, original TF op). diff --git a/tensorflow/core/profiler/rpc/client/save_profile.cc b/tensorflow/core/profiler/rpc/client/save_profile.cc index 81f9490ff76..20ff496d057 100644 --- a/tensorflow/core/profiler/rpc/client/save_profile.cc +++ b/tensorflow/core/profiler/rpc/client/save_profile.cc @@ -51,13 +51,14 @@ const absl::string_view kPathSep = "\\"; const absl::string_view kPathSep = "/"; #endif -string ProfilerJoinPathImpl(std::initializer_list paths) { - string result; +std::string ProfilerJoinPathImpl( + std::initializer_list paths) { + std::string result; for (absl::string_view path : paths) { if (path.empty()) continue; if (result.empty()) { - result = string(path); + result = std::string(path); continue; } @@ -75,7 +76,7 @@ string ProfilerJoinPathImpl(std::initializer_list paths) { // A local duplication of ::tensorflow::io::JoinPath that supports windows. // TODO(b/150699701): revert to use ::tensorflow::io::JoinPath when fixed. template -string ProfilerJoinPath(const T&... args) { +std::string ProfilerJoinPath(const T&... args) { return ProfilerJoinPathImpl({args...}); } @@ -86,8 +87,8 @@ Status DumpToolData(absl::string_view run_dir, absl::string_view host, const ProfileToolData& tool, std::ostream* os) { // Don't save the intermediate results for combining the per host tool data. if (absl::EndsWith(tool.name(), kTfStatsHelperSuffix)) return Status::OK(); - string host_prefix = host.empty() ? "" : absl::StrCat(host, "."); - string path = + std::string host_prefix = host.empty() ? "" : absl::StrCat(host, "."); + std::string path = ProfilerJoinPath(run_dir, absl::StrCat(host_prefix, tool.name())); TF_RETURN_IF_ERROR(WriteStringToFile(Env::Default(), path, tool.data())); if (os) { @@ -97,7 +98,8 @@ Status DumpToolData(absl::string_view run_dir, absl::string_view host, return Status::OK(); } -Status WriteGzippedDataToFile(const string& filepath, const string& data) { +Status WriteGzippedDataToFile(const std::string& filepath, + const std::string& data) { std::unique_ptr file; TF_RETURN_IF_ERROR(Env::Default()->NewWritableFile(filepath, &file)); io::ZlibCompressionOptions options = io::ZlibCompressionOptions::GZIP(); @@ -110,8 +112,9 @@ Status WriteGzippedDataToFile(const string& filepath, const string& data) { return Status::OK(); } -Status GetOrCreateRunDir(const string& repository_root, const string& run, - string* run_dir, std::ostream* os) { +Status GetOrCreateRunDir(const std::string& repository_root, + const std::string& run, std::string* run_dir, + std::ostream* os) { // Dumps profile data to //. *run_dir = ProfilerJoinPath(repository_root, run); *os << "Creating directory: " << *run_dir; @@ -120,21 +123,21 @@ Status GetOrCreateRunDir(const string& repository_root, const string& run, } } // namespace -string GetTensorBoardProfilePluginDir(const string& logdir) { +std::string GetTensorBoardProfilePluginDir(const std::string& logdir) { constexpr char kPluginName[] = "plugins"; constexpr char kProfileName[] = "profile"; return ProfilerJoinPath(logdir, kPluginName, kProfileName); } -Status MaybeCreateEmptyEventFile(const string& logdir) { +Status MaybeCreateEmptyEventFile(const std::string& logdir) { // Suffix for an empty event file. it should be kept in sync with // _EVENT_FILE_SUFFIX in tensorflow/python/eager/profiler.py. constexpr char kProfileEmptySuffix[] = ".profile-empty"; TF_RETURN_IF_ERROR(Env::Default()->RecursivelyCreateDir(logdir)); - std::vector children; + std::vector children; TF_RETURN_IF_ERROR(Env::Default()->GetChildren(logdir, &children)); - for (const string& child : children) { + for (const std::string& child : children) { if (absl::EndsWith(child, kProfileEmptySuffix)) { return Status::OK(); } @@ -143,10 +146,10 @@ Status MaybeCreateEmptyEventFile(const string& logdir) { return event_writer.InitWithSuffix(kProfileEmptySuffix); } -Status SaveProfile(const string& repository_root, const string& run, - const string& host, const ProfileResponse& response, +Status SaveProfile(const std::string& repository_root, const std::string& run, + const std::string& host, const ProfileResponse& response, std::ostream* os) { - string run_dir; + std::string run_dir; TF_RETURN_IF_ERROR(GetOrCreateRunDir(repository_root, run, &run_dir, os)); for (const auto& tool_data : response.tool_data()) { TF_RETURN_IF_ERROR(DumpToolData(run_dir, host, tool_data, os)); @@ -154,22 +157,24 @@ Status SaveProfile(const string& repository_root, const string& run, return Status::OK(); } -Status SaveGzippedToolData(const string& repository_root, const string& run, - const string& host, const string& tool_name, - const string& data) { - string run_dir; +Status SaveGzippedToolData(const std::string& repository_root, + const std::string& run, const std::string& host, + const std::string& tool_name, + const std::string& data) { + std::string run_dir; std::stringstream ss; Status status = GetOrCreateRunDir(repository_root, run, &run_dir, &ss); LOG(INFO) << ss.str(); TF_RETURN_IF_ERROR(status); - string host_prefix = host.empty() ? "" : absl::StrCat(host, "."); - string path = ProfilerJoinPath(run_dir, absl::StrCat(host_prefix, tool_name)); + std::string host_prefix = host.empty() ? "" : absl::StrCat(host, "."); + std::string path = + ProfilerJoinPath(run_dir, absl::StrCat(host_prefix, tool_name)); TF_RETURN_IF_ERROR(WriteGzippedDataToFile(path, data)); LOG(INFO) << "Dumped gzipped tool data for " << tool_name << " to " << path; return Status::OK(); } -string GetCurrentTimeStampAsString() { +std::string GetCurrentTimeStampAsString() { return absl::FormatTime("%E4Y_%m_%d_%H_%M_%S", absl::Now(), absl::LocalTimeZone()); } diff --git a/tensorflow/core/profiler/rpc/client/save_profile.h b/tensorflow/core/profiler/rpc/client/save_profile.h index c155502fb60..9c15ef26080 100644 --- a/tensorflow/core/profiler/rpc/client/save_profile.h +++ b/tensorflow/core/profiler/rpc/client/save_profile.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_CORE_PROFILER_RPC_CLIENT_SAVE_PROFILE_H_ #include +#include #include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/types.h" @@ -25,27 +26,28 @@ limitations under the License. namespace tensorflow { namespace profiler { -string GetCurrentTimeStampAsString(); +std::string GetCurrentTimeStampAsString(); // Returns the profile plugin directory given a logdir to TensorBoard. -string GetTensorBoardProfilePluginDir(const string& logdir); +std::string GetTensorBoardProfilePluginDir(const std::string& logdir); // Creates an empty event file if not already exists, which indicates that we // have a plugins/profile/ directory in the current logdir. -Status MaybeCreateEmptyEventFile(const string& logdir); +Status MaybeCreateEmptyEventFile(const std::string& logdir); // Saves all profiling tool data in a profile to //. // This writes user-facing log messages to `os`. // Note: this function creates a directory even when all fields in // ProfileResponse are unset/empty. -Status SaveProfile(const string& repository_root, const string& run, - const string& host, const ProfileResponse& response, +Status SaveProfile(const std::string& repository_root, const std::string& run, + const std::string& host, const ProfileResponse& response, std::ostream* os); // Gzip the data and save to //. -Status SaveGzippedToolData(const string& repository_root, const string& run, - const string& host, const string& tool_name, - const string& data); +Status SaveGzippedToolData(const std::string& repository_root, + const std::string& run, const std::string& host, + const std::string& tool_name, + const std::string& data); } // namespace profiler } // namespace tensorflow diff --git a/tensorflow/core/profiler/rpc/profiler_server.cc b/tensorflow/core/profiler/rpc/profiler_server.cc index 966a94a1116..cfff3fc05de 100644 --- a/tensorflow/core/profiler/rpc/profiler_server.cc +++ b/tensorflow/core/profiler/rpc/profiler_server.cc @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/core/profiler/rpc/profiler_service_impl.h" namespace tensorflow { +namespace profiler { void ProfilerServer::StartProfilerServer(int32 port) { std::string server_address = absl::StrCat("[::]:", port); @@ -54,4 +55,5 @@ ProfilerServer::~ProfilerServer() { } } +} // namespace profiler } // namespace tensorflow diff --git a/tensorflow/core/profiler/rpc/profiler_server.h b/tensorflow/core/profiler/rpc/profiler_server.h index b7148e7e686..45680e83b6c 100644 --- a/tensorflow/core/profiler/rpc/profiler_server.h +++ b/tensorflow/core/profiler/rpc/profiler_server.h @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/core/profiler/profiler_service.grpc.pb.h" namespace tensorflow { +namespace profiler { class ProfilerServer { public: @@ -34,6 +35,7 @@ class ProfilerServer { std::unique_ptr<::grpc::Server> server_; }; +} // namespace profiler } // namespace tensorflow #endif // TENSORFLOW_CORE_PROFILER_RPC_PROFILER_SERVER_H_ diff --git a/tensorflow/core/profiler/rpc/profiler_service_impl.cc b/tensorflow/core/profiler/rpc/profiler_service_impl.cc index ba463813fc0..8eadd87bf77 100644 --- a/tensorflow/core/profiler/rpc/profiler_service_impl.cc +++ b/tensorflow/core/profiler/rpc/profiler_service_impl.cc @@ -34,6 +34,7 @@ limitations under the License. #include "tensorflow/core/profiler/protobuf/xplane.pb.h" namespace tensorflow { +namespace profiler { namespace { const absl::string_view kXPlanePb = "xplane.pb"; @@ -115,4 +116,10 @@ std::unique_ptr CreateProfilerService() { return absl::make_unique(); } +} // namespace profiler + +std::unique_ptr CreateProfilerService() { + return absl::make_unique(); +} + } // namespace tensorflow diff --git a/tensorflow/core/profiler/rpc/profiler_service_impl.h b/tensorflow/core/profiler/rpc/profiler_service_impl.h index 00a850acbf2..3960b33f5be 100644 --- a/tensorflow/core/profiler/rpc/profiler_service_impl.h +++ b/tensorflow/core/profiler/rpc/profiler_service_impl.h @@ -23,6 +23,11 @@ namespace tensorflow { std::unique_ptr CreateProfilerService(); +namespace profiler { + +std::unique_ptr CreateProfilerService(); + +} // namespace profiler } // namespace tensorflow #endif // TENSORFLOW_CORE_PROFILER_RPC_PROFILER_SERVICE_IMPL_H_ diff --git a/tensorflow/core/profiler/utils/BUILD b/tensorflow/core/profiler/utils/BUILD index 2d3ec1d004d..d4957caaad1 100644 --- a/tensorflow/core/profiler/utils/BUILD +++ b/tensorflow/core/profiler/utils/BUILD @@ -212,6 +212,7 @@ cc_library( visibility = [":friends"], deps = [ ":timespan", + ":trace_utils", ":xplane_builder", ":xplane_visitor", "//tensorflow/core:platform_base", diff --git a/tensorflow/core/profiler/utils/xplane_utils.cc b/tensorflow/core/profiler/utils/xplane_utils.cc index 867d1315053..825469f9eab 100644 --- a/tensorflow/core/profiler/utils/xplane_utils.cc +++ b/tensorflow/core/profiler/utils/xplane_utils.cc @@ -40,13 +40,6 @@ Timespan XEventTimespan(const XEvent& event) { return Timespan(event.offset_ps(), event.duration_ps()); } -// Functor that compares XEvents of the same XLine for sorting by timespan. -struct XEventsComparator { - bool operator()(const XEvent* a, const XEvent* b) const { - return XEventTimespan(*a) < XEventTimespan(*b); - } -}; - } // namespace const XPlane* FindPlaneWithName(const XSpace& space, absl::string_view name) { @@ -144,6 +137,10 @@ void RemoveEmptyLines(XPlane* plane) { lines->end()); } +bool XEventsComparator::operator()(const XEvent* a, const XEvent* b) const { + return XEventTimespan(*a) < XEventTimespan(*b); +} + void SortXPlane(XPlane* plane) { for (XLine& line : *plane->mutable_lines()) { auto& events = *line.mutable_events(); diff --git a/tensorflow/core/profiler/utils/xplane_utils.h b/tensorflow/core/profiler/utils/xplane_utils.h index ff65f5af3ef..5cd5275e85e 100644 --- a/tensorflow/core/profiler/utils/xplane_utils.h +++ b/tensorflow/core/profiler/utils/xplane_utils.h @@ -21,6 +21,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/profiler/protobuf/xplane.pb.h" +#include "tensorflow/core/profiler/utils/trace_utils.h" namespace tensorflow { namespace profiler { @@ -75,6 +76,26 @@ void SortXPlane(XPlane* plane); // Sorts each plane of the XSpace. void SortXSpace(XSpace* space); +// Functor that compares XEvents for sorting by timespan. +struct XEventsComparator { + bool operator()(const XEvent* a, const XEvent* b) const; +}; + +// Returns a sorted vector of all XEvents in the given XPlane. +template +std::vector GetSortedEvents(XPlane* plane, Compare comp, + bool include_derived_events = false) { + std::vector events; + for (XLine& line : *plane->mutable_lines()) { + if (!include_derived_events && IsDerivedThreadId(line.id())) continue; + for (XEvent& event : *line.mutable_events()) { + events.push_back(&event); + } + } + absl::c_sort(events, XEventsComparator()); + return events; +} + // Normalize timestamps by time-shifting to start_time_ns_ as origin. void NormalizeTimestamps(XPlane* plane, uint64 start_time_ns); void NormalizeTimestamps(XSpace* space, uint64 start_time_ns); diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index 2cdfcb33c86..c04d74bbbc6 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -108,7 +108,7 @@ limitations under the License. #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 -#define TF_GRAPH_DEF_VERSION 505 // Updated: 2020/8/26 +#define TF_GRAPH_DEF_VERSION 509 // Updated: 2020/8/30 // Checkpoint compatibility versions (the versions field in SavedSliceMeta). // diff --git a/tensorflow/core/tpu/BUILD b/tensorflow/core/tpu/BUILD index d8abbd042b9..d20b5abd376 100644 --- a/tensorflow/core/tpu/BUILD +++ b/tensorflow/core/tpu/BUILD @@ -157,6 +157,7 @@ cc_library( ":tpu_api", ":tpu_compilation_device", ":tpu_config_c_api", + ":tpu_executor_init_fns", ":tpu_library_init_fns", ":tpu_node_device", ":tpu_system_device", diff --git a/tensorflow/core/tpu/kernels/BUILD b/tensorflow/core/tpu/kernels/BUILD index 819fd34305f..8f97b2e45fe 100644 --- a/tensorflow/core/tpu/kernels/BUILD +++ b/tensorflow/core/tpu/kernels/BUILD @@ -494,15 +494,15 @@ cc_library( hdrs = ["tpu_util.h"], deps = [ ":tpu_compilation_cache_key", - ":tpu_program_group_interface", + ":tpu_util_c_api_hdrs", "//tensorflow/cc:ops", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla/client:compile_only_client", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "//tensorflow/core/tpu:tpu_api", "@com_google_absl//absl/strings", - "@com_google_absl//absl/time", ], alwayslink = 1, ) @@ -925,3 +925,19 @@ cc_library( ], alwayslink = True, ) + +cc_library( + name = "tpu_pod_state", + srcs = ["tpu_pod_state.cc"], + hdrs = ["tpu_pod_state.h"], + copts = select({ + WITH_TPU_SUPPORT: ["-DLIBTFTPU"], + DEFAULT: [], + }), + deps = [ + ":tpu_compilation_cache_service", + ":tpu_util", + "//tensorflow/core:framework", + tf_grpc_cc_dependency(), + ], +) diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache_rpc_support.cc b/tensorflow/core/tpu/kernels/tpu_compilation_cache_rpc_support.cc index b880b7ac1a2..0e77edf4ecf 100644 --- a/tensorflow/core/tpu/kernels/tpu_compilation_cache_rpc_support.cc +++ b/tensorflow/core/tpu/kernels/tpu_compilation_cache_rpc_support.cc @@ -21,7 +21,7 @@ limitations under the License. namespace tensorflow { namespace tpu { std::shared_ptr<::grpc::ChannelCredentials> CreateChannelCredentials() { - return ::grpc::InsecureChannelCredentials(); + return ::grpc::InsecureChannelCredentials(); // NOLINT } #if defined(LIBTFTPU) diff --git a/tensorflow/core/tpu/kernels/tpu_pod_state.cc b/tensorflow/core/tpu/kernels/tpu_pod_state.cc new file mode 100644 index 00000000000..a45a4d63708 --- /dev/null +++ b/tensorflow/core/tpu/kernels/tpu_pod_state.cc @@ -0,0 +1,70 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/tpu/kernels/tpu_pod_state.h" + +#include "tensorflow/core/tpu/kernels/tpu_util.h" + +namespace tensorflow { + +const char kTpuPodStateResourceName[] = "tpu_pod_state"; + +TpuPodState::TpuPodState( + int service_port, std::unique_ptr cache_service) + : cache_service_(std::move(cache_service)), service_port_(service_port) {} + +TpuPodState::~TpuPodState() { + if (cache_service_) { + VLOG(1) << "Shutting down Compilation Cache Service."; + if (cache_service_->Shutdown(20)) { + if (service_port_ >= 0) { + tpu::RecycleUnusedPort(service_port_); + } + } else { + LOG(ERROR) + << "Failed to shutdown Compilation Cache Service within timeout."; + } + } + VLOG(1) << "Shutting down Compilation Cache Service done."; +} + +string TpuPodState::DebugString() const { + return "Wrapper for distributed TPU state"; +} + +Status GetTPUPodState(const ResourceMgr* rmgr, TpuPodState** pod_state) { + if (!rmgr) { + return errors::Internal("No resource manager."); + } + if (!rmgr->Lookup(rmgr->default_container(), kTpuPodStateResourceName, + pod_state) + .ok()) { + return errors::FailedPrecondition( + "The TPU system has not been initialized."); + } + return Status::OK(); +} + +bool HasTPUPodState(const ResourceMgr* rmgr) { + TpuPodState* pod_state; + if (!rmgr->Lookup(rmgr->default_container(), kTpuPodStateResourceName, + &pod_state) + .ok()) { + return false; + } + pod_state->Unref(); + return true; +} + +} // namespace tensorflow diff --git a/tensorflow/core/tpu/kernels/tpu_pod_state.h b/tensorflow/core/tpu/kernels/tpu_pod_state.h new file mode 100644 index 00000000000..9f37e28f60f --- /dev/null +++ b/tensorflow/core/tpu/kernels/tpu_pod_state.h @@ -0,0 +1,54 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_POD_STATE_H_ +#define TENSORFLOW_CORE_TPU_KERNELS_TPU_POD_STATE_H_ + +#include "grpcpp/server_builder.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_service.h" + +namespace tensorflow { + +// Name of tpu pod state. +ABSL_CONST_INIT extern const char kTpuPodStateResourceName[]; + +// Wrapper to hold centralized state for the distributed TPU in the TPU_SYSTEM +// device's resource manager. +class TpuPodState : public ResourceBase { + public: + // The port number given by isa_cache_port will be freed with + // RecycleUnusedPort in the destructor if it is non-negative. + TpuPodState(int service_port, + std::unique_ptr cache_service); + + ~TpuPodState() override; + + string DebugString() const override; + + private: + std::unique_ptr cache_service_; + int service_port_; +}; + +// Returns the TPU pod state or an error. +Status GetTPUPodState(const ResourceMgr* rmgr, TpuPodState** pod_state); + +// Checks whether the TPU POD state configuration is present within the resource +// manager. +bool HasTPUPodState(const ResourceMgr* rmgr); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_POD_STATE_H_ diff --git a/tensorflow/core/tpu/kernels/tpu_util.cc b/tensorflow/core/tpu/kernels/tpu_util.cc index 60f8fe0198b..837c23c6cf5 100644 --- a/tensorflow/core/tpu/kernels/tpu_util.cc +++ b/tensorflow/core/tpu/kernels/tpu_util.cc @@ -16,6 +16,7 @@ limitations under the License. #include "absl/strings/str_split.h" #include "tensorflow/core/platform/random.h" +#include "tensorflow/core/tpu/tpu_api.h" namespace tensorflow { namespace tpu { @@ -95,5 +96,9 @@ Status DynamicShapesToTensorShapes(const InputList& dynamic_shapes, } return Status::OK(); } + +void RecycleUnusedPort(int port) { + UtilApiFn()->TpuNetUtil_RecycleUnusedPortFn(port); +} } // namespace tpu } // namespace tensorflow diff --git a/tensorflow/core/tpu/kernels/tpu_util.h b/tensorflow/core/tpu/kernels/tpu_util.h index 579fbdf5e85..834db31c3d8 100644 --- a/tensorflow/core/tpu/kernels/tpu_util.h +++ b/tensorflow/core/tpu/kernels/tpu_util.h @@ -54,6 +54,11 @@ Status DynamicShapesToTensorShapes(const OpInputList& dynamic_shapes, std::vector* shapes); Status DynamicShapesToTensorShapes(const InputList& dynamic_shapes, std::vector* shapes); + +// We only recycle ports which were given to us by the portserver. For ports +// we obtained through local trial-and-error, there is no reason to expect the +// port to remain available after it is unbound. +void RecycleUnusedPort(int port); } // namespace tpu } // namespace tensorflow diff --git a/tensorflow/core/tpu/kernels/tpu_util_c_api.h b/tensorflow/core/tpu/kernels/tpu_util_c_api.h index ddc7a842f49..04b65e24e54 100644 --- a/tensorflow/core/tpu/kernels/tpu_util_c_api.h +++ b/tensorflow/core/tpu/kernels/tpu_util_c_api.h @@ -56,6 +56,9 @@ TFTPU_CAPI_EXPORT bool TpuCompile_ShouldTpuCompileOpIgnoreCancellation(); TFTPU_CAPI_EXPORT int TpuTopology_AvailableCoreCount( const XLA_TpuMeshState* mesh_state, TpuCoreTypeEnum tpu_core_type); +// Recycle unused service port. +TFTPU_CAPI_EXPORT void TpuNetUtil_RecycleUnusedPort(int port); + // Creates a unique compilation cache `key` used for `put` and `get` operations. // Returned buffers are heap-allocated and must be owned. TFTPU_CAPI_EXPORT CompilationCacheKeyResult @@ -79,6 +82,7 @@ struct TfTpu_UtilApiFn { TFTPU_ADD_FN_IN_STRUCT(TpuCompile_IsTpuCompilationEnabled); TFTPU_ADD_FN_IN_STRUCT(TpuCompile_ShouldTpuCompileOpIgnoreCancellation); TFTPU_ADD_FN_IN_STRUCT(TpuTopology_AvailableCoreCount); + TFTPU_ADD_FN_IN_STRUCT(TpuNetUtil_RecycleUnusedPort); TFTPU_ADD_FN_IN_STRUCT(TpuCompile_CreateCompilationCacheKey); TFTPU_ADD_FN_IN_STRUCT(TpuCompile_DestroyCompilationCacheKey); TFTPU_ADD_FN_IN_STRUCT(TpuCompile_CreateGuaranteedConstFingerprint); diff --git a/tensorflow/core/tpu/tpu_library_init_fns.inc b/tensorflow/core/tpu/tpu_library_init_fns.inc index 16494d0aa86..6a1432e27fa 100644 --- a/tensorflow/core/tpu/tpu_library_init_fns.inc +++ b/tensorflow/core/tpu/tpu_library_init_fns.inc @@ -1,4 +1,8 @@ +#if defined(PLATFORM_GOOGLE) #include "third_party/tensorflow/core/tpu/tpu_executor_init_fns.inc" +#else +#include "tensorflow/core/tpu/tpu_executor_init_fns.inc" +#endif namespace { @@ -88,6 +92,7 @@ tensorflow::Status SetTpuUtilStructFns(void* library_handle) { auto* util_fn = tensorflow::tpu::UtilApiFn(); TFTPU_SET_FN(util_fn, TpuTopology_AvailableCoreCount); + TFTPU_SET_FN(util_fn, TpuNetUtil_RecycleUnusedPort); TFTPU_SET_FN(util_fn, TpuCompile_IsTpuCompilationEnabled); TFTPU_SET_FN(util_fn, TpuCompile_ShouldTpuCompileOpIgnoreCancellation); TFTPU_SET_FN(util_fn, TpuCompile_CreateCompilationCacheKey); diff --git a/tensorflow/core/util/example_proto_fast_parsing.cc b/tensorflow/core/util/example_proto_fast_parsing.cc index b148ffab042..03acb98989b 100644 --- a/tensorflow/core/util/example_proto_fast_parsing.cc +++ b/tensorflow/core/util/example_proto_fast_parsing.cc @@ -167,11 +167,14 @@ class Feature { } // Helper methods - tstring& construct_at_end(LimitedArraySlice* bytes_list) { - return bytes_list->construct_at_end(); + tstring* construct_at_end(LimitedArraySlice* bytes_list) { + if (bytes_list->EndDistance() <= 0) { + return nullptr; + } + return &bytes_list->construct_at_end(); } - tstring& construct_at_end(SmallVector* bytes_list) { - return bytes_list->emplace_back(); + tstring* construct_at_end(SmallVector* bytes_list) { + return &bytes_list->emplace_back(); } template @@ -192,9 +195,10 @@ class Feature { // parse string uint32 bytes_length; if (!stream.ReadVarint32(&bytes_length)) return false; - tstring& bytes = construct_at_end(bytes_list); - bytes.resize_uninitialized(bytes_length); - if (!stream.ReadRaw(bytes.data(), bytes_length)) return false; + tstring* bytes = construct_at_end(bytes_list); + if (bytes == nullptr) return false; + bytes->resize_uninitialized(bytes_length); + if (!stream.ReadRaw(bytes->data(), bytes_length)) return false; } stream.PopLimit(limit); return true; diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go index 97a976f1145..6b7f011dbda 100644 --- a/tensorflow/go/op/wrappers.go +++ b/tensorflow/go/op/wrappers.go @@ -21532,6 +21532,11 @@ func ResourceGather(scope *Scope, resource tf.Output, indices tf.Output, dtype t // // *NOTE*: `Add` supports broadcasting. `AddN` does not. More about broadcasting // [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) +// +// Given two input tensors, the `tf.add` operation computes the sum for every element in the tensor. +// +// Both input and output have a range `(-inf, inf)`. +// func Add(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { if scope.Err() != nil { return @@ -21609,6 +21614,12 @@ func Atan(scope *Scope, x tf.Output) (y tf.Output) { } // Computes acos of x element-wise. +// +// +// Provided an input tensor, the `tf.math.acos` operation returns the inverse cosine of each element of the tensor. If `y = tf.math.cos(x)` then, `x = tf.math.acos(y)`. +// +// Input range is `[-1, 1]` and the output has a range of `[0, pi]`. +// func Acos(scope *Scope, x tf.Output) (y tf.Output) { if scope.Err() != nil { return @@ -42490,7 +42501,7 @@ func ResourceApplyCenteredRMSPropUseLocking(value bool) ResourceApplyCenteredRMS // mom: Should be from a Variable(). // lr: Scaling factor. Must be a scalar. // rho: Decay rate. Must be a scalar. -// +// momentum: Momentum Scale. Must be a scalar. // epsilon: Ridge term. Must be a scalar. // grad: The gradient. // diff --git a/tensorflow/lite/BUILD b/tensorflow/lite/BUILD index faf91b03800..058d2e75daa 100644 --- a/tensorflow/lite/BUILD +++ b/tensorflow/lite/BUILD @@ -16,13 +16,6 @@ exports_files(glob([ "models/testdata/*", ])) -config_setting( - name = "enable_default_profiler", - values = { - "copt": "-DTFLITE_ENABLE_DEFAULT_PROFILER", - }, -) - config_setting( name = "gemmlowp_profiling", values = { @@ -275,13 +268,9 @@ cc_library( "//tensorflow/lite/experimental/resource", "//tensorflow/lite/kernels/internal:compatibility", "//tensorflow/lite/nnapi:nnapi_implementation", + "//tensorflow/lite/profiling:platform_profiler", "//tensorflow/lite/schema:schema_fbs", - ] + select({ - ":enable_default_profiler": [ - "//tensorflow/lite/profiling:platform_profiler", - ], - "//conditions:default": [], - }), + ], alwayslink = 1, ) diff --git a/tensorflow/lite/c/c_api.h b/tensorflow/lite/c/c_api.h index 880b80e69b4..152bcf986fe 100644 --- a/tensorflow/lite/c/c_api.h +++ b/tensorflow/lite/c/c_api.h @@ -188,7 +188,7 @@ TFL_CAPI_EXPORT extern int32_t TfLiteInterpreterGetOutputTensorCount( const TfLiteInterpreter* interpreter); // Returns the tensor associated with the output index. -// REQUIRES: 0 <= input_index < TfLiteInterpreterGetOutputTensorCount(tensor) +// REQUIRES: 0 <= output_index < TfLiteInterpreterGetOutputTensorCount(tensor) // // NOTE: The shape and underlying data buffer for output tensors may be not // be available until after the output tensor has been both sized and allocated. diff --git a/tensorflow/lite/c/common.h b/tensorflow/lite/c/common.h index d320a90d005..31405dfb998 100644 --- a/tensorflow/lite/c/common.h +++ b/tensorflow/lite/c/common.h @@ -226,6 +226,17 @@ void TfLiteFloatArrayFree(TfLiteFloatArray* a); } \ } while (0) +#define TF_LITE_ENSURE_NEAR(context, a, b, epsilon) \ + do { \ + auto delta = ((a) > (b)) ? ((a) - (b)) : ((b) - (a)); \ + if (delta > epsilon) { \ + TF_LITE_KERNEL_LOG((context), "%s:%d %s not near %s (%f != %f)", \ + __FILE__, __LINE__, #a, #b, static_cast(a), \ + static_cast(b)); \ + return kTfLiteError; \ + } \ + } while (0) + #define TF_LITE_ENSURE_OK(context, status) \ do { \ const TfLiteStatus s = (status); \ diff --git a/tensorflow/lite/core/subgraph.cc b/tensorflow/lite/core/subgraph.cc index bd9ad12fc81..f3e8b8a4e4c 100644 --- a/tensorflow/lite/core/subgraph.cc +++ b/tensorflow/lite/core/subgraph.cc @@ -674,13 +674,17 @@ TfLiteStatus Subgraph::ResetVariableTensors() { continue; } - // Variable tensors have to be `kTfLiteArenaRwPersistent`, and must be - // allocated after the initial `PrepareOpsAndTensors()` is called. - TF_LITE_ENSURE_EQ(&context_, tensor.allocation_type, - kTfLiteArenaRwPersistent); - TF_LITE_ENSURE(&context_, tensor.data.raw != nullptr); - - tflite::ResetVariableTensor(&tensor); + if (tensor.allocation_type == kTfLiteArenaRwPersistent) { + // If variable tensors allocation type is `kTfLiteArenaRwPersistent`, then + // they must be allocated after the initial `PrepareOpsAndTensors()` is + // called. + TF_LITE_ENSURE(&context_, tensor.data.raw != nullptr); + tflite::ResetVariableTensor(&tensor); + } else { + // If variable tensors allocation type is not `kTfLiteArenaRwPersistent`, + // then it can only be `kTfLiteCustom` in which case, we do not reset it. + TF_LITE_ENSURE_EQ(&context_, tensor.allocation_type, kTfLiteCustom); + } } return kTfLiteOk; } diff --git a/tensorflow/lite/delegates/flex/allowlisted_flex_ops.cc b/tensorflow/lite/delegates/flex/allowlisted_flex_ops.cc index eefbeb72b15..1168196476a 100644 --- a/tensorflow/lite/delegates/flex/allowlisted_flex_ops.cc +++ b/tensorflow/lite/delegates/flex/allowlisted_flex_ops.cc @@ -75,6 +75,7 @@ const std::set& GetFlexAllowlist() { "BiasAdd", "BiasAddGrad", "BiasAddV1", + "Bincount", "BoostedTreesBucketize", "BroadcastArgs", "BroadcastGradientArgs", @@ -116,6 +117,7 @@ const std::set& GetFlexAllowlist() { "DecodeWav", "DeepCopy", "DeleteSessionTensor", + "DenseBincount", "DepthToSpace", "DepthwiseConv2dNative", "Dequantize", @@ -302,6 +304,7 @@ const std::set& GetFlexAllowlist() { "RFFT", "RFFT2D", "RFFT3D", + "RaggedBincount", "RaggedRange", "RaggedTensorToSparse", "RaggedTensorToTensor", @@ -416,6 +419,7 @@ const std::set& GetFlexAllowlist() { "SparseApplyProximalAdagrad", "SparseApplyProximalGradientDescent", "SparseApplyRMSProp", + "SparseBincount", "SparseCross", "SparseCrossHashed", "SparseCrossV2", diff --git a/tensorflow/lite/delegates/gpu/cl/BUILD b/tensorflow/lite/delegates/gpu/cl/BUILD index 9ae3836d6c4..a94805ea551 100644 --- a/tensorflow/lite/delegates/gpu/cl/BUILD +++ b/tensorflow/lite/delegates/gpu/cl/BUILD @@ -76,7 +76,10 @@ cc_test( ], deps = [ ":arguments", + ":buffer", + ":device_info", ":gpu_object", + ":tensor", ":tensor_type", "//tensorflow/lite/delegates/gpu/common:data_type", "@com_google_absl//absl/strings", diff --git a/tensorflow/lite/delegates/gpu/cl/arguments.cc b/tensorflow/lite/delegates/gpu/cl/arguments.cc index 5623de2419c..b7e6b08616e 100644 --- a/tensorflow/lite/delegates/gpu/cl/arguments.cc +++ b/tensorflow/lite/delegates/gpu/cl/arguments.cc @@ -256,13 +256,6 @@ void Arguments::AddObjectRef(const std::string& name, AccessType access_type, object_refs_[name] = {std::move(descriptor_ptr)}; } -void Arguments::AddObject(const std::string& name, AccessType access_type, - GPUObjectPtr&& object, - GPUObjectDescriptorPtr&& descriptor_ptr) { - descriptor_ptr->SetAccess(access_type); - objects_[name] = {std::move(object), std::move(descriptor_ptr)}; -} - void Arguments::AddObject(const std::string& name, GPUObjectDescriptorPtr&& descriptor_ptr) { descriptor_ptr->SetAccess(AccessType::READ); diff --git a/tensorflow/lite/delegates/gpu/cl/arguments.h b/tensorflow/lite/delegates/gpu/cl/arguments.h index 643e1b7655d..4636a06db6f 100644 --- a/tensorflow/lite/delegates/gpu/cl/arguments.h +++ b/tensorflow/lite/delegates/gpu/cl/arguments.h @@ -39,37 +39,16 @@ class Arguments { void AddFloat(const std::string& name, float value = 0.0f); void AddHalf(const std::string& name, half value = half(0.0f)); void AddInt(const std::string& name, int value = 0); - void AddBuffer(const std::string& name, const GPUBufferDescriptor& desc); - void AddImage2D(const std::string& name, const GPUImage2DDescriptor& desc); - void AddImage2DArray(const std::string& name, - const GPUImage2DArrayDescriptor& desc); - void AddImage3D(const std::string& name, const GPUImage3DDescriptor& desc); - void AddImageBuffer(const std::string& name, - const GPUImageBufferDescriptor& desc); - void AddCustomMemory(const std::string& name, - const GPUCustomMemoryDescriptor& desc); - void AddObjectRef(const std::string& name, AccessType access_type, GPUObjectDescriptorPtr&& descriptor_ptr); - void AddObject(const std::string& name, AccessType access_type, - GPUObjectPtr&& object, - GPUObjectDescriptorPtr&& descriptor_ptr); void AddObject(const std::string& name, GPUObjectDescriptorPtr&& descriptor_ptr); absl::Status SetInt(const std::string& name, int value); absl::Status SetFloat(const std::string& name, float value); absl::Status SetHalf(const std::string& name, half value); - absl::Status SetImage2D(const std::string& name, cl_mem memory); - absl::Status SetBuffer(const std::string& name, cl_mem memory); - absl::Status SetImage2DArray(const std::string& name, cl_mem memory); - absl::Status SetImage3D(const std::string& name, cl_mem memory); - absl::Status SetImageBuffer(const std::string& name, cl_mem memory); - absl::Status SetCustomMemory(const std::string& name, cl_mem memory); absl::Status SetObjectRef(const std::string& name, const GPUObject* object); - std::string GetListOfArgs(); - absl::Status Bind(cl_kernel kernel, int offset = 0); void RenameArgs(const std::string& postfix, std::string* code) const; @@ -87,6 +66,25 @@ class Arguments { Arguments& operator=(const Arguments&) = delete; private: + void AddBuffer(const std::string& name, const GPUBufferDescriptor& desc); + void AddImage2D(const std::string& name, const GPUImage2DDescriptor& desc); + void AddImage2DArray(const std::string& name, + const GPUImage2DArrayDescriptor& desc); + void AddImage3D(const std::string& name, const GPUImage3DDescriptor& desc); + void AddImageBuffer(const std::string& name, + const GPUImageBufferDescriptor& desc); + void AddCustomMemory(const std::string& name, + const GPUCustomMemoryDescriptor& desc); + + absl::Status SetImage2D(const std::string& name, cl_mem memory); + absl::Status SetBuffer(const std::string& name, cl_mem memory); + absl::Status SetImage2DArray(const std::string& name, cl_mem memory); + absl::Status SetImage3D(const std::string& name, cl_mem memory); + absl::Status SetImageBuffer(const std::string& name, cl_mem memory); + absl::Status SetCustomMemory(const std::string& name, cl_mem memory); + + std::string GetListOfArgs(); + std::string AddActiveArgument(const std::string& arg_name, bool use_f32_for_halfs); void AddGPUResources(const std::string& name, const GPUResources& resources); diff --git a/tensorflow/lite/delegates/gpu/cl/arguments_test.cc b/tensorflow/lite/delegates/gpu/cl/arguments_test.cc index 29a15e16a57..722ca5b1827 100644 --- a/tensorflow/lite/delegates/gpu/cl/arguments_test.cc +++ b/tensorflow/lite/delegates/gpu/cl/arguments_test.cc @@ -14,85 +14,58 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/lite/delegates/gpu/cl/arguments.h" +#include #include #include #include +#include "absl/strings/match.h" +#include "tensorflow/lite/delegates/gpu/cl/buffer.h" +#include "tensorflow/lite/delegates/gpu/cl/device_info.h" #include "tensorflow/lite/delegates/gpu/cl/gpu_object.h" -#include "tensorflow/lite/delegates/gpu/cl/tensor_type.h" namespace tflite { namespace gpu { namespace cl { -namespace { -struct TestDescriptor : public GPUObjectDescriptor { - absl::Status PerformSelector(const std::string& selector, - const std::vector& args, - const std::vector& template_args, - std::string* result) const override { - if (selector == "Length") { - *result = "length"; - return absl::OkStatus(); - } else if (selector == "Read") { - if (args.size() != 1) { - return absl::NotFoundError( - absl::StrCat("TestDescriptor Read require one argument, but ", - args.size(), " was passed")); - } - *result = absl::StrCat("buffer[", args[0], "]"); - return absl::OkStatus(); - } else { - return absl::NotFoundError(absl::StrCat( - "TestDescriptor don't have selector with name - ", selector)); - } - } - - GPUResources GetGPUResources(AccessType access_type) const override { - GPUResources resources; - resources.ints.push_back("length"); - GPUBufferDescriptor desc; - desc.data_type = DataType::FLOAT32; - desc.element_size = 4; - resources.buffers.push_back({"buffer", desc}); - return resources; - } -}; -} // namespace - TEST(ArgumentsTest, TestSelectorResolve) { - TestDescriptor descriptor; - Arguments args; - args.AddObjectRef("object", AccessType::WRITE, - absl::make_unique(descriptor)); - std::string sample_code = R"( - if (a < 3) { - value = args.object.Read(id); - } -)"; - const std::string expected_result = R"( - if (a < 3) { - value = object_buffer[id]; - } -)"; - ASSERT_OK(args.TransformToCLCode({}, &sample_code)); - EXPECT_EQ(sample_code, expected_result); + BufferDescriptor desc; + desc.element_type = DataType::FLOAT32; + desc.element_size = 4; + desc.memory_type = MemoryType::GLOBAL; - std::string cl_arguments = args.GetListOfArgs(); - EXPECT_TRUE(cl_arguments.find("__global float4* object_buffer") != - std::string::npos); + Arguments args; + args.AddObjectRef("weights", AccessType::READ, + absl::make_unique(std::move(desc))); + std::string sample_code = R"( +__kernel void main_function($0) { + if (a < 3) { + value = args.weights.Read(id); + } +})"; + + DeviceInfo device_info; + ASSERT_OK(args.TransformToCLCode(device_info, {}, &sample_code)); + EXPECT_TRUE(absl::StrContains(sample_code, "value = weights_buffer[id];")); + EXPECT_TRUE( + absl::StrContains(sample_code, "__global float4* weights_buffer")); } TEST(ArgumentsTest, TestNoSelector) { - TestDescriptor descriptor; + BufferDescriptor desc; + desc.element_type = DataType::FLOAT32; + desc.element_size = 4; + desc.memory_type = MemoryType::GLOBAL; + Arguments args; - args.AddObjectRef("object", AccessType::WRITE, - absl::make_unique(descriptor)); + args.AddObjectRef("weights", AccessType::READ, + absl::make_unique(std::move(desc))); std::string sample_code = R"( if (a < 3) { - value = args.object.Write(id); + value = args.weights.UnknownSelector(id); } )"; - EXPECT_FALSE(args.TransformToCLCode({}, &sample_code).ok()); + DeviceInfo device_info; + EXPECT_FALSE(args.TransformToCLCode(device_info, {}, &sample_code).ok()); } TEST(ArgumentsTest, TestRenameArgs) { diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/BUILD b/tensorflow/lite/delegates/gpu/cl/kernels/BUILD index 02f5f9c4a4a..15681be5e2b 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/BUILD +++ b/tensorflow/lite/delegates/gpu/cl/kernels/BUILD @@ -1005,6 +1005,37 @@ cc_test( ], ) +cc_library( + name = "reduce", + srcs = ["reduce.cc"], + hdrs = ["reduce.h"], + deps = [ + ":gpu_operation", + ":util", + "//tensorflow/lite/delegates/gpu/cl:precision", + "//tensorflow/lite/delegates/gpu/common:operations", + "//tensorflow/lite/delegates/gpu/common:status", + "//tensorflow/lite/delegates/gpu/common:types", + ], +) + +cc_test( + name = "reduce_test", + srcs = ["reduce_test.cc"], + linkstatic = True, + tags = tf_gpu_tests_tags() + [ + "linux", + "local", + ], + deps = [ + ":cl_test", + ":reduce", + "//tensorflow/lite/delegates/gpu/common:operations", + "//tensorflow/lite/delegates/gpu/common:status", + "@com_google_googletest//:gtest_main", + ], +) + cc_library( name = "relu", srcs = ["relu.cc"], @@ -1397,6 +1428,7 @@ test_suite( "padding_test", "pooling_test", "prelu_test", + "reduce_test", "relu_test", "reshape_test", "reshapex4_test", diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/elementwise.cc b/tensorflow/lite/delegates/gpu/cl/kernels/elementwise.cc index afec0ab8a56..3203ec3e7fc 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/elementwise.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/elementwise.cc @@ -42,10 +42,10 @@ std::string GetOneInputCode(const OperationType& op_type, result = "\n"; break; case OperationType::ELU: - result = "$0.x = $0.x < (FLT)(0.0f) ? exp($0.x) - (FLT)(1.0f) : $0.x;\n"; - result += "$0.y = $0.y < (FLT)(0.0f) ? exp($0.y) - (FLT)(1.0f) : $0.y;\n"; - result += "$0.z = $0.z < (FLT)(0.0f) ? exp($0.z) - (FLT)(1.0f) : $0.z;\n"; - result += "$0.w = $0.w < (FLT)(0.0f) ? exp($0.w) - (FLT)(1.0f) : $0.w;\n"; + result = "$0.x = $0.x < (FLT)(0.0f) ? expm1($0.x) : $0.x;\n"; + result += "$0.y = $0.y < (FLT)(0.0f) ? expm1($0.y) : $0.y;\n"; + result += "$0.z = $0.z < (FLT)(0.0f) ? expm1($0.z) : $0.z;\n"; + result += "$0.w = $0.w < (FLT)(0.0f) ? expm1($0.w) : $0.w;\n"; break; case OperationType::EXP: result = "$0 = exp($0);\n"; diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.cc b/tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.cc index f9d6ec762ec..b34b8e38b41 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.cc @@ -213,7 +213,6 @@ absl::Status GPUOperation::Compile(const CreationContext& creation_context) { RETURN_IF_ERROR(args_.TransformToCLCode( creation_context.device->info_, {{dst_tensors_names_[0], elementwise_code_}}, &code)); - code = absl::Substitute(code, args_.GetListOfArgs()); RETURN_IF_ERROR(creation_context.cache->GetOrCreateCLKernel( code, "main_function", *creation_context.context, *creation_context.device, &kernel_)); diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/reduce.cc b/tensorflow/lite/delegates/gpu/cl/kernels/reduce.cc new file mode 100644 index 00000000000..4f889d4ff0e --- /dev/null +++ b/tensorflow/lite/delegates/gpu/cl/kernels/reduce.cc @@ -0,0 +1,102 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/cl/kernels/reduce.h" + +#include + +#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h" +#include "tensorflow/lite/delegates/gpu/cl/precision.h" +#include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/common/status.h" + +namespace tflite { +namespace gpu { +namespace cl { +namespace { +std::string GetReduceChannelsKernelCode(const OperationDef& op_def, + const OperationType& op_type) { + std::string c = GetCommonDefines(op_def.precision); + if (op_type == OperationType::ADD) { + c += "#define OP(a, b) ((a) + (b))\n"; + } else if (op_type == OperationType::MUL) { + c += "#define OP(a, b) ((a) * (b))\n"; + } else if (op_type == OperationType::MAXIMUM) { + c += "#define OP(a, b) max(a, b)\n"; + } else if (op_type == OperationType::MINIMUM) { + c += "#define OP(a, b) min(a, b)\n"; + } + c += "__kernel void main_function($0) {\n"; + c += " int X = get_global_id(0);\n"; + c += " int Y = get_global_id(1);\n"; + c += " if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height()) " + "return;\n"; + if (op_type == OperationType::ADD) { + c += " FLT4 reduced = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f);\n"; + } else if (op_type == OperationType::MUL) { + c += " FLT4 reduced = (FLT4)(1.0f, 1.0f, 1.0f, 1.0f);\n"; + } else { + c += " FLT4 V0 = args.src_tensor.Read(X, Y, 0);\n"; + c += " FLT4 reduced = (FLT4)(V0.x, V0.x, V0.x, V0.x);\n"; + } + c += " int s = 0;\n"; + c += " for (; s < args.src_tensor.Slices() - 1; ++s) {\n"; + c += " FLT4 V = args.src_tensor.Read(X, Y, s);\n"; + c += " reduced = OP(reduced, V);\n"; + c += " }\n"; + c += " FLT reduced_final = OP(OP(reduced.x, reduced.y), OP(reduced.z, " + "reduced.w));\n"; + c += " FLT last_reduce;\n"; + c += " FLT4 last_val = args.src_tensor.Read(X, Y, s);\n"; + c += " int ch_rem = args.src_tensor.Channels() % 4;\n"; + c += " if (ch_rem == 0) {\n"; + c += " last_reduce = OP(OP(last_val.x, last_val.y), OP(last_val.z, " + "last_val.w));\n"; + c += " } else if (ch_rem == 1) {\n"; + c += " last_reduce = OP(OP(last_val.x, last_val.y), last_val.z);\n"; + c += " } else if (ch_rem == 2) {\n"; + c += " last_reduce = OP(last_val.x, last_val.y);\n"; + c += " } else {\n"; + c += " last_reduce = last_val.x;\n"; + c += " }\n"; + c += " reduced_final = OP(reduced_final, last_reduce);\n"; + c += " FLT4 result = (FLT4)(reduced_final, 0.0f, 0.0f, 0.0f);\n"; + c += " args.dst_tensor.Write(result, X, Y, 0);\n"; + c += "}\n"; + return c; +} +} // namespace + +GPUOperation CreateReduce(const OperationDef& definition, + const OperationType& op_type) { + GPUOperation op(definition); + auto src_desc = definition.src_tensors[0]; + if (definition.IsBatchSupported()) { + src_desc.SetStateVar("BatchedWidth", "true"); + } + op.AddSrcTensor("src_tensor", src_desc); + auto dst_desc = definition.dst_tensors[0]; + if (definition.IsBatchSupported()) { + dst_desc.SetStateVar("BatchedWidth", "true"); + } + op.AddDstTensor("dst_tensor", dst_desc); + op.code_ = GetReduceChannelsKernelCode(definition, op_type); + op.tensor_to_grid_ = TensorToGrid::kWBToX_HDToY_ZIs1; + return op; +} + +} // namespace cl +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/reduce.h b/tensorflow/lite/delegates/gpu/cl/kernels/reduce.h new file mode 100644 index 00000000000..ec5329aaf8a --- /dev/null +++ b/tensorflow/lite/delegates/gpu/cl/kernels/reduce.h @@ -0,0 +1,33 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_REDUCE_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_REDUCE_H_ + +#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h" +#include "tensorflow/lite/delegates/gpu/common/operations.h" + +namespace tflite { +namespace gpu { +namespace cl { + +GPUOperation CreateReduce(const OperationDef& definition, + const OperationType& op_type); + +} // namespace cl +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_REDUCE_H_ diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/reduce_test.cc b/tensorflow/lite/delegates/gpu/cl/kernels/reduce_test.cc new file mode 100644 index 00000000000..9275c451d34 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/cl/kernels/reduce_test.cc @@ -0,0 +1,129 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/cl/kernels/reduce.h" + +#include +#include +#include + +#include +#include +#include "tensorflow/lite/delegates/gpu/cl/kernels/cl_test.h" +#include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/common/status.h" + +using ::testing::FloatNear; +using ::testing::Pointwise; + +namespace tflite { +namespace gpu { +namespace cl { +namespace { + +TEST_F(OpenCLOperationTest, ReduceSumChannels) { + TensorFloat32 src_tensor; + src_tensor.shape = BHWC(1, 2, 1, 5); + src_tensor.data = {1.1, 2.1, 0.7, 0.3, 1.2, 3.1, 4.1, 0.0, 1.0, 4.4}; + + for (auto storage : env_.GetSupportedStorages()) { + for (auto precision : env_.GetSupportedPrecisions()) { + const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-3f; + OperationDef op_def; + op_def.precision = precision; + auto data_type = DeduceDataTypeFromPrecision(precision); + op_def.src_tensors.push_back({data_type, storage, Layout::HWC}); + op_def.dst_tensors.push_back({data_type, storage, Layout::HWC}); + TensorFloat32 dst_tensor; + GPUOperation operation = CreateReduce(op_def, OperationType::ADD); + ASSERT_OK(ExecuteGPUOperation(src_tensor, creation_context_, &operation, + BHWC(1, 2, 1, 1), &dst_tensor)); + EXPECT_THAT(dst_tensor.data, Pointwise(FloatNear(eps), {5.4f, 12.6f})); + } + } +} + +TEST_F(OpenCLOperationTest, ReduceProductChannels) { + TensorFloat32 src_tensor; + src_tensor.shape = BHWC(1, 2, 1, 2); + src_tensor.data = {1.1, 2.0, 3.1, 4.0}; + + for (auto storage : env_.GetSupportedStorages()) { + for (auto precision : env_.GetSupportedPrecisions()) { + const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-3f; + OperationDef op_def; + op_def.precision = precision; + auto data_type = DeduceDataTypeFromPrecision(precision); + op_def.src_tensors.push_back({data_type, storage, Layout::HWC}); + op_def.dst_tensors.push_back({data_type, storage, Layout::HWC}); + TensorFloat32 dst_tensor; + GPUOperation operation = CreateReduce(op_def, OperationType::MUL); + ASSERT_OK(ExecuteGPUOperation(src_tensor, creation_context_, &operation, + BHWC(1, 2, 1, 1), &dst_tensor)); + EXPECT_THAT(dst_tensor.data, Pointwise(FloatNear(eps), {2.2f, 12.4f})); + } + } +} + +TEST_F(OpenCLOperationTest, ReduceMaxChannels) { + TensorFloat32 src_tensor; + src_tensor.shape = BHWC(1, 2, 1, 6); + src_tensor.data = {1.1, 2.0, -0.3, -100.0, 32.6, 1.1, + -3.1, -4.0, -5.0, -7.0, -2.0, -100.0}; + + for (auto storage : env_.GetSupportedStorages()) { + for (auto precision : env_.GetSupportedPrecisions()) { + const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-3f; + OperationDef op_def; + op_def.precision = precision; + auto data_type = DeduceDataTypeFromPrecision(precision); + op_def.src_tensors.push_back({data_type, storage, Layout::HWC}); + op_def.dst_tensors.push_back({data_type, storage, Layout::HWC}); + TensorFloat32 dst_tensor; + GPUOperation operation = CreateReduce(op_def, OperationType::MAXIMUM); + ASSERT_OK(ExecuteGPUOperation(src_tensor, creation_context_, &operation, + BHWC(1, 2, 1, 1), &dst_tensor)); + EXPECT_THAT(dst_tensor.data, Pointwise(FloatNear(eps), {32.6f, -2.0f})); + } + } +} + +TEST_F(OpenCLOperationTest, ReduceMinChannels) { + TensorFloat32 src_tensor; + src_tensor.shape = BHWC(1, 2, 1, 6); + src_tensor.data = {1.1, 2.0, -0.3, -100.0, 32.6, 1.1, + -3.1, -4.0, -5.0, -7.0, -2.0, 100.0}; + + for (auto storage : env_.GetSupportedStorages()) { + for (auto precision : env_.GetSupportedPrecisions()) { + const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-3f; + OperationDef op_def; + op_def.precision = precision; + auto data_type = DeduceDataTypeFromPrecision(precision); + op_def.src_tensors.push_back({data_type, storage, Layout::HWC}); + op_def.dst_tensors.push_back({data_type, storage, Layout::HWC}); + TensorFloat32 dst_tensor; + GPUOperation operation = CreateReduce(op_def, OperationType::MINIMUM); + ASSERT_OK(ExecuteGPUOperation(src_tensor, creation_context_, &operation, + BHWC(1, 2, 1, 1), &dst_tensor)); + EXPECT_THAT(dst_tensor.data, Pointwise(FloatNear(eps), {-100.0f, -7.0f})); + } + } +} + +} // namespace +} // namespace cl +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/cl/testing/run_performance_profiling.sh b/tensorflow/lite/delegates/gpu/cl/testing/run_performance_profiling.sh index 0fd2d33de14..56d1e1010ed 100755 --- a/tensorflow/lite/delegates/gpu/cl/testing/run_performance_profiling.sh +++ b/tensorflow/lite/delegates/gpu/cl/testing/run_performance_profiling.sh @@ -83,11 +83,17 @@ ADB push "$model_path" "$OPENCL_DIR" declare -a BUILD_CONFIG abi_version=$(ADB shell getprop ro.product.cpu.abi | tr -d '\r') if [[ "$abi_version" == "armeabi-v7a" ]]; then -#"32 bit" +#"32 bit ARM" BUILD_CONFIG=( --config=android_arm -c opt --copt=-fPIE --linkopt=-pie ) -else -#"64 bit" +elif [[ "$abi_version" == "arm64-v8a" ]]; then +#"64 bit ARM" BUILD_CONFIG=( --config=android_arm64 -c opt ) +elif [[ "$abi_version" == "x86_64" ]]; then +# x86_64 +BUILD_CONFIG=( --config=android_x86_64 -c opt ) +else +echo "Error: Unknown processor ABI" +exit 1 fi bazel build "${BUILD_CONFIG[@]}" //$SHELL_DIR:$BINARY_NAME diff --git a/tensorflow/lite/delegates/gpu/common/model_builder.cc b/tensorflow/lite/delegates/gpu/common/model_builder.cc index d7ddef23374..04503c88439 100644 --- a/tensorflow/lite/delegates/gpu/common/model_builder.cc +++ b/tensorflow/lite/delegates/gpu/common/model_builder.cc @@ -798,6 +798,7 @@ class ElementwiseOperationParser : public TFLiteOperationParser { case OperationType::ABS: case OperationType::COPY: case OperationType::COS: + case OperationType::ELU: case OperationType::EXP: case OperationType::LOG: case OperationType::RSQRT: @@ -815,6 +816,8 @@ class ElementwiseOperationParser : public TFLiteOperationParser { bool IsTwoArgumentOperation() const { switch (operation_type_) { case OperationType::DIV: + case OperationType::MAXIMUM: + case OperationType::MINIMUM: case OperationType::POW: case OperationType::SQUARED_DIFF: case OperationType::SUB: @@ -826,8 +829,11 @@ class ElementwiseOperationParser : public TFLiteOperationParser { bool IsTwoArgumentOperationWithConst() const { switch (operation_type_) { - case OperationType::MINIMUM: + case OperationType::DIV: case OperationType::MAXIMUM: + case OperationType::MINIMUM: + case OperationType::POW: + case OperationType::SQUARED_DIFF: case OperationType::SUB: return true; default: @@ -1125,6 +1131,17 @@ class MulOperationParser : public TFLiteOperationParser { // The "larger" input tensor must be bound to 1st input and the "smaller" // input tensor ("mask") must be bound to 2nd input. if (runtime_tensor0 && runtime_tensor1) { + if (input0 == input1) { + // replace MUL(A, A) with POW(A, 2.0) + // TODO(b/166831113): Support the same inputs for operations. + node->operation.type = ToString(OperationType::POW); + ElementwiseAttributes attr; + attr.param = 2.0f; + node->operation.attributes = std::move(attr); + RETURN_IF_ERROR(reader->AddInput(node, 0)); + return reader->AddOutputs(node); + } + BHWC shape0; RETURN_IF_ERROR(ExtractTensorShape(*input0, &shape0)); BHWC shape1; @@ -2390,6 +2407,7 @@ class Landmarks2TransformMatrixOperationParser : public TFLiteOperationParser { absl::Status IsSupported(const TfLiteContext* context, const TfLiteNode* tflite_node, const TfLiteRegistration* registration) final { + RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2)); return CheckInputsOutputs(context, tflite_node, /*runtime_inputs=*/1, /*outputs=*/1); } @@ -2415,37 +2433,6 @@ class Landmarks2TransformMatrixOperationParser : public TFLiteOperationParser { } }; -class Landmarks2TransformMatrixV2OperationParser - : public TFLiteOperationParser { - public: - absl::Status IsSupported(const TfLiteContext* context, - const TfLiteNode* tflite_node, - const TfLiteRegistration* registration) final { - return CheckInputsOutputs(context, tflite_node, /*runtime_inputs=*/1, - /*outputs=*/1); - } - - absl::Status Parse(const TfLiteNode* tflite_node, - const TfLiteRegistration* registration, - GraphFloat32* graph, ObjectReader* reader) final { - Node* node = graph->NewNode(); - RETURN_IF_ERROR(reader->AddInput(node, 0)); // landmarks - RETURN_IF_ERROR(reader->AddOutputs(node)); // transform matrix - - const std::string op_name = "landmarks_to_transform_matrix_v2"; - node->operation.type = op_name; - BHWC output_shape; - RETURN_IF_ERROR(ParseCustomAttributes( - op_name, registration->version, tflite_node->custom_initial_data, - tflite_node->custom_initial_data_size, &(node->operation.attributes), - &output_shape)); - - auto output_value = graph->FindOutputs(node->id)[0]; - output_value->tensor.shape = output_shape; - return absl::OkStatus(); - } -}; - class AlignmentPointsToTransformMatrixOperationParser : public TFLiteOperationParser { public: @@ -2689,12 +2676,10 @@ std::unique_ptr NewOperationParser( if (custom_name == "TransformLandmarksV2") { return std::make_unique(); } - if (custom_name == "Landmarks2TransformMatrix") { + if (custom_name == "Landmarks2TransformMatrix" || + custom_name == "Landmarks2TransformMatrixV2") { return std::make_unique(); } - if (custom_name == "Landmarks2TransformMatrixV2") { - return std::make_unique(); - } if (custom_name == "AlignmentPointsToTransformMatrix") { return std::make_unique< AlignmentPointsToTransformMatrixOperationParser>(); diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/BUILD b/tensorflow/lite/delegates/gpu/metal/kernels/BUILD index f4f4c180976..e90f8a41c8b 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/BUILD +++ b/tensorflow/lite/delegates/gpu/metal/kernels/BUILD @@ -122,6 +122,7 @@ cc_library( srcs = ["conv.cc"], hdrs = ["conv.h"], deps = [ + "//tensorflow/lite/delegates/gpu/common:data_type", "//tensorflow/lite/delegates/gpu/common:model", "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:shape", diff --git a/tensorflow/lite/g3doc/performance/best_practices.md b/tensorflow/lite/g3doc/performance/best_practices.md index e4abb564b26..9df0ace4db0 100644 --- a/tensorflow/lite/g3doc/performance/best_practices.md +++ b/tensorflow/lite/g3doc/performance/best_practices.md @@ -38,6 +38,12 @@ has a built-in profiler that shows per operator profiling statistics. This can help in understanding performance bottlenecks and which operators dominate the computation time. +You can also use +[TensrFlow Lite tracing](measurement.md#trace_tensorflow_lite_internals_in_android) +to profile the model in your Android application, using standard Android system +tracing, and to visualize the operator invocations by time with GUI based +profiling tools. + ## Profile and optimize operators in the graph If a particular operator appears frequently in the model and, based on @@ -116,7 +122,7 @@ interpreter execution. TensorFlow Lite can use delegates by: Be aware that some accelerators work better for different types of models. Some delegates only support float models or models optimized in a specific way. It is -important to [benchmark](benchmarks.md) each delegate to see if it is a good +important to [benchmark](measurement.md) each delegate to see if it is a good choice for your application. For example, if you have a very small model, it may not be worth delegating the model to either the NN API or the GPU. Conversely, accelerators are a great choice for large models that have high arithmetic diff --git a/tensorflow/lite/g3doc/performance/measurement.md b/tensorflow/lite/g3doc/performance/measurement.md index 179406f517e..9d2f7247ac7 100644 --- a/tensorflow/lite/g3doc/performance/measurement.md +++ b/tensorflow/lite/g3doc/performance/measurement.md @@ -451,6 +451,25 @@ help notice where the inference call is made. ``` +### Enable TensorFlow Lite tracing + +To enable TensorFlow Lite tracing, set the Android system property +`debug.tflite.tracing` to 1 before starting the Android app. + +```shell +adb shell setprop debug.tflite.trace 1 +``` + +If this property has been set when TensorFlow Lite interpreter is initialized, +key events (e.g., operator invocation) from the interpreter will be traced. + +After you captured all the traces, disable tracing by setting the property value +to 0. + +```shell +adb shell setprop debug.tflite.trace 0 +``` + ### Android Studio CPU Profiler Capture traces with the diff --git a/tensorflow/lite/interpreter_builder.cc b/tensorflow/lite/interpreter_builder.cc index 0765f00faf3..7743bc732fb 100644 --- a/tensorflow/lite/interpreter_builder.cc +++ b/tensorflow/lite/interpreter_builder.cc @@ -27,15 +27,12 @@ limitations under the License. #include "tensorflow/lite/core/api/error_reporter.h" #include "tensorflow/lite/core/api/flatbuffer_conversions.h" #include "tensorflow/lite/kernels/internal/compatibility.h" +#include "tensorflow/lite/profiling/platform_profiler.h" #include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/shared_library.h" #include "tensorflow/lite/util.h" #include "tensorflow/lite/version.h" -#if defined(TFLITE_ENABLE_DEFAULT_PROFILER) -#include "tensorflow/lite/profiling/platform_profiler.h" -#endif - // aligned_alloc is available (via cstdlib/stdlib.h) with C++17/C11. #if __cplusplus >= 201703L || __STDC_VERSION__ >= 201112L #if !defined(__ANDROID__) || __ANDROID_API__ >= 28 @@ -630,9 +627,7 @@ TfLiteStatus InterpreterBuilder::operator()( (*interpreter)->AddSubgraphs(subgraphs->size() - 1); } -#if defined(TFLITE_ENABLE_DEFAULT_PROFILER) - (*interpreter)->SetProfiler(tflite::profiling::CreatePlatformProfiler()); -#endif + (*interpreter)->SetProfiler(tflite::profiling::MaybeCreatePlatformProfiler()); for (int subgraph_index = 0; subgraph_index < subgraphs->size(); ++subgraph_index) { diff --git a/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java b/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java index 59afc0c3608..ba0f569bc34 100644 --- a/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java +++ b/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java @@ -152,7 +152,7 @@ public final class Interpreter implements AutoCloseable { *
    *
  • Startup time and resize time may increase. *
  • Baseline memory consumption may increase. - *
  • Compatibility with other delegates (e.g., GPU) has not been fully validated. + *
  • May be ignored if another delegate (eg NNAPI) have been applied. *
  • Quantized models will not see any benefit. *
* diff --git a/tensorflow/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc b/tensorflow/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc index 2d1844fbd39..fc0857fdf43 100644 --- a/tensorflow/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc +++ b/tensorflow/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc @@ -367,8 +367,14 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_useXNNPACK( } tflite_api_dispatcher::Interpreter::TfLiteDelegatePtr delegate( xnnpack_create(&options), xnnpack_delete); - if (interpreter->ModifyGraphWithDelegate(std::move(delegate)) != - kTfLiteOk) { + auto delegation_status = + interpreter->ModifyGraphWithDelegate(std::move(delegate)); + // kTfLiteApplicationError occurs in cases where delegation fails but + // the runtime is invokable (eg. another delegate has already been applied). + // We don't throw an Exception in that case. + // TODO(b/166483905): Add support for multiple delegates when model allows. + if (delegation_status != kTfLiteOk && + delegation_status != kTfLiteApplicationError) { ThrowException(env, kIllegalArgumentException, "Internal error: Failed to apply XNNPACK delegate: %s", error_reporter->CachedErrorMessage()); diff --git a/tensorflow/lite/java/src/test/java/org/tensorflow/lite/nnapi/NnApiDelegateTest.java b/tensorflow/lite/java/src/test/java/org/tensorflow/lite/nnapi/NnApiDelegateTest.java index 45d66e24d35..fc9038c4de0 100644 --- a/tensorflow/lite/java/src/test/java/org/tensorflow/lite/nnapi/NnApiDelegateTest.java +++ b/tensorflow/lite/java/src/test/java/org/tensorflow/lite/nnapi/NnApiDelegateTest.java @@ -56,6 +56,25 @@ public final class NnApiDelegateTest { } } + @Test + public void testInterpreterWithNnApiAndXNNPack() throws Exception { + Interpreter.Options options = new Interpreter.Options(); + options.setUseXNNPACK(true); + + try (NnApiDelegate delegate = new NnApiDelegate(); + Interpreter interpreter = new Interpreter(MODEL_BUFFER, options.addDelegate(delegate))) { + float[] oneD = {1.23f, 6.54f, 7.81f}; + float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD}; + float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; + float[][][][] fourD = {threeD, threeD}; + float[][][][] parsedOutputs = new float[2][8][8][3]; + interpreter.run(fourD, parsedOutputs); + float[] outputOneD = parsedOutputs[0][0][0]; + float[] expected = {3.69f, 19.62f, 23.43f}; + assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder(); + } + } + @Test public void testInterpreterWithNnApiAllowFp16() throws Exception { Interpreter.Options options = new Interpreter.Options(); diff --git a/tensorflow/lite/kernels/internal/BUILD b/tensorflow/lite/kernels/internal/BUILD index 1190b2edc6a..b109e8ed78a 100644 --- a/tensorflow/lite/kernels/internal/BUILD +++ b/tensorflow/lite/kernels/internal/BUILD @@ -464,6 +464,7 @@ cc_library( "reference/integer_ops/fully_connected.h", "reference/integer_ops/l2normalization.h", "reference/integer_ops/logistic.h", + "reference/integer_ops/mean.h", "reference/integer_ops/mul.h", "reference/integer_ops/pooling.h", "reference/integer_ops/tanh.h", @@ -492,7 +493,6 @@ cc_library( "//conditions:default": [ "reference/integer_ops/dequantize.h", "reference/integer_ops/log_softmax.h", - "reference/integer_ops/mean.h", "reference/integer_ops/transpose_conv.h", "reference/reference_ops.h", "reference/string_comparisons.h", diff --git a/tensorflow/lite/kernels/internal/common.h b/tensorflow/lite/kernels/internal/common.h index 66a2d977f39..92bb01a2900 100644 --- a/tensorflow/lite/kernels/internal/common.h +++ b/tensorflow/lite/kernels/internal/common.h @@ -263,6 +263,30 @@ inline void gen_lut(const std::function& func, double min, std::min(std::max(TfLiteRound(func(max) * 32768.0), -32768.0), 32767.0); } +// generate INT16 LUT for function(), e.g., table exp(x) and 1/(1+x) used in +// softmax +inline void gen_lut(const std::function& func, float min, + float max, int16_t* table, const int num) { + // size of table should equal to num + 1 + // last element only for slope calculation + float step = (max - min) / (num - 1); + float half_step = step / 2.0f; + for (int i = 0; i < num - 1; i++) { + float sample_val = TfLiteRound(func(min + i * step) * 32768.0f); + float midpoint_interp_val = + TfLiteRound((func(min + (i + 1) * step) * 32768.0f + + TfLiteRound(func(min + i * step) * 32768.0f)) / + 2.0f); + float midpoint_val = + TfLiteRound(func(min + i * step + half_step) * 32768.0f); + float midpoint_err = midpoint_interp_val - midpoint_val; + float bias = TfLiteRound(midpoint_err / 2.0f); + table[i] = std::min(std::max(sample_val - bias, -32768.0f), 32767.0f); + } + table[num - 1] = std::min( + std::max(TfLiteRound(func(max) * 32768.0f), -32768.0f), 32767.0f); +} + // int16_t func table lookup, e.g., lookup exp() and 1/(1+x) used in softmax inline int16_t generic_int16_table_lookup(int16_t value, const int16_t* lut) { // 512 base value, lut[513] only for calculate slope diff --git a/tensorflow/lite/kernels/internal/reference/integer_ops/conv.h b/tensorflow/lite/kernels/internal/reference/integer_ops/conv.h index f4bcb2bd06e..3e9cd0caa51 100644 --- a/tensorflow/lite/kernels/internal/reference/integer_ops/conv.h +++ b/tensorflow/lite/kernels/internal/reference/integer_ops/conv.h @@ -63,45 +63,47 @@ inline void ConvPerChannel( const int output_width = output_shape.Dims(2); for (int batch = 0; batch < batches; ++batch) { for (int out_y = 0; out_y < output_height; ++out_y) { + const int in_y_origin = (out_y * stride_height) - pad_height; for (int out_x = 0; out_x < output_width; ++out_x) { + const int in_x_origin = (out_x * stride_width) - pad_width; for (int out_channel = 0; out_channel < output_depth; ++out_channel) { - const int in_x_origin = (out_x * stride_width) - pad_width; - const int in_y_origin = (out_y * stride_height) - pad_height; int32_t acc = 0; for (int filter_y = 0; filter_y < filter_height; ++filter_y) { + const int in_y = in_y_origin + dilation_height_factor * filter_y; for (int filter_x = 0; filter_x < filter_width; ++filter_x) { + const int in_x = in_x_origin + dilation_width_factor * filter_x; + + // Zero padding by omitting the areas outside the image. + const bool is_point_inside_image = + (in_x >= 0) && (in_x < input_width) && (in_y >= 0) && + (in_y < input_height); + + if (!is_point_inside_image) { + continue; + } + for (int in_channel = 0; in_channel < input_depth; ++in_channel) { - const int in_x = in_x_origin + dilation_width_factor * filter_x; - const int in_y = - in_y_origin + dilation_height_factor * filter_y; - // Zero padding by omitting the areas outside the image. - const bool is_point_inside_image = - (in_x >= 0) && (in_x < input_width) && (in_y >= 0) && - (in_y < input_height); - if (is_point_inside_image) { - int32_t input_val = input_data[Offset( - input_shape, batch, in_y, in_x, in_channel)]; - int32_t filter_val = - filter_data[Offset(filter_shape, out_channel, filter_y, - filter_x, in_channel)]; - // Accumulate with 32 bits accumulator. - // In the nudging process during model quantization, we force - // real value of 0.0 be represented by a quantized value. This - // guarantees that the input_offset is a int8_t, even though - // it is represented using int32_t. int32_t += int8_t * - // (int8_t - int8_t) so the highest value we can get from each - // accumulation is [-127, 127] * ([-128, 127] - - // [-128, 127]), which is [-32512, 32512]. log2(32512) - // = 14.98, which means we can accumulate at least 2^16 - // multiplications without overflow. The accumulator is - // applied to a filter so the accumulation logic will hold as - // long as the filter size (filter_y * filter_x * in_channel) - // does not exceed 2^16, which is the case in all the models - // we have seen so far. - // TODO(jianlijianli): Add a check to make sure the - // accumulator depth is smaller than 2^16. - acc += filter_val * (input_val + input_offset); - } + int32_t input_val = input_data[Offset(input_shape, batch, in_y, + in_x, in_channel)]; + int32_t filter_val = filter_data[Offset( + filter_shape, out_channel, filter_y, filter_x, in_channel)]; + // Accumulate with 32 bits accumulator. + // In the nudging process during model quantization, we force + // real value of 0.0 be represented by a quantized value. This + // guarantees that the input_offset is a int8_t, even though + // it is represented using int32_t. int32_t += int8_t * + // (int8_t - int8_t) so the highest value we can get from each + // accumulation is [-127, 127] * ([-128, 127] - + // [-128, 127]), which is [-32512, 32512]. log2(32512) + // = 14.98, which means we can accumulate at least 2^16 + // multiplications without overflow. The accumulator is + // applied to a filter so the accumulation logic will hold as + // long as the filter size (filter_y * filter_x * in_channel) + // does not exceed 2^16, which is the case in all the models + // we have seen so far. + // TODO(jianlijianli): Add a check to make sure the + // accumulator depth is smaller than 2^16. + acc += filter_val * (input_val + input_offset); } } } @@ -164,35 +166,37 @@ inline void ConvPerChannel( const int output_width = output_shape.Dims(2); for (int batch = 0; batch < batches; ++batch) { for (int out_y = 0; out_y < output_height; ++out_y) { + const int in_y_origin = (out_y * stride_height) - pad_height; for (int out_x = 0; out_x < output_width; ++out_x) { + const int in_x_origin = (out_x * stride_width) - pad_width; for (int out_channel = 0; out_channel < output_depth; ++out_channel) { - const int in_x_origin = (out_x * stride_width) - pad_width; - const int in_y_origin = (out_y * stride_height) - pad_height; std::int64_t acc = 0; for (int filter_y = 0; filter_y < filter_height; ++filter_y) { + const int in_y = in_y_origin + dilation_height_factor * filter_y; for (int filter_x = 0; filter_x < filter_width; ++filter_x) { + const int in_x = in_x_origin + dilation_width_factor * filter_x; + + // Zero padding by omitting the areas outside the image. + const bool is_point_inside_image = + (in_x >= 0) && (in_x < input_width) && (in_y >= 0) && + (in_y < input_height); + + if (!is_point_inside_image) { + continue; + } + for (int in_channel = 0; in_channel < input_depth; ++in_channel) { - const int in_x = in_x_origin + dilation_width_factor * filter_x; - const int in_y = - in_y_origin + dilation_height_factor * filter_y; - // Zero padding by omitting the areas outside the image. - const bool is_point_inside_image = - (in_x >= 0) && (in_x < input_width) && (in_y >= 0) && - (in_y < input_height); - if (is_point_inside_image) { - int32_t input_val = input_data[Offset( - input_shape, batch, in_y, in_x, in_channel)]; - int32_t filter_val = - filter_data[Offset(filter_shape, out_channel, filter_y, - filter_x, in_channel)]; - // Accumulate with 64 bits accumulator. - // int64_t += int8_t * int16_t so the highest value we can - // get from each accumulation is [-127, 127] * ([-32768, - // 32767] - - // [-32768, 32767]), which is [-8322945, 8322945]. - // log2(8322945) = 22.99. - acc += filter_val * input_val; - } + int32_t input_val = input_data[Offset(input_shape, batch, in_y, + in_x, in_channel)]; + int32_t filter_val = filter_data[Offset( + filter_shape, out_channel, filter_y, filter_x, in_channel)]; + // Accumulate with 64 bits accumulator. + // int64_t += int8_t * int16_t so the highest value we can + // get from each accumulation is [-127, 127] * ([-32768, + // 32767] - + // [-32768, 32767]), which is [-8322945, 8322945]. + // log2(8322945) = 22.99. + acc += filter_val * input_val; } } } diff --git a/tensorflow/lite/kernels/internal/reference/integer_ops/mean.h b/tensorflow/lite/kernels/internal/reference/integer_ops/mean.h index 1e29f8c61a7..bd484270012 100644 --- a/tensorflow/lite/kernels/internal/reference/integer_ops/mean.h +++ b/tensorflow/lite/kernels/internal/reference/integer_ops/mean.h @@ -23,9 +23,9 @@ namespace reference_integer_ops { template inline void Mean(const tflite::MeanParams& op_params, int32_t multiplier, int32_t shift, const RuntimeShape& unextended_input_shape, - const integer_type* input_data, int32 input_zero_point, + const integer_type* input_data, int32_t input_zero_point, const RuntimeShape& unextended_output_shape, - integer_type* output_data, int32 output_zero_point) { + integer_type* output_data, int32_t output_zero_point) { // Current implementation only supports dimension equals 4 and simultaneous // reduction over width and height. TFLITE_CHECK_EQ(unextended_input_shape.DimensionsCount(), 4); @@ -53,7 +53,7 @@ inline void Mean(const tflite::MeanParams& op_params, int32_t multiplier, for (int out_b = 0; out_b < output_batch; ++out_b) { for (int out_d = 0; out_d < output_depth; ++out_d) { - int32 acc = 0; + int32_t acc = 0; for (int in_h = 0; in_h < input_height; ++in_h) { for (int in_w = 0; in_w < input_width; ++in_w) { acc += input_data[Offset(input_shape, out_b, in_h, in_w, out_d)] - diff --git a/tensorflow/lite/kernels/internal/reference/reduce.h b/tensorflow/lite/kernels/internal/reference/reduce.h index 597d015d0b1..7953b4347c6 100644 --- a/tensorflow/lite/kernels/internal/reference/reduce.h +++ b/tensorflow/lite/kernels/internal/reference/reduce.h @@ -186,11 +186,11 @@ inline bool Mean(const T* input_data, const int* input_dims, } // Calculate mean by dividing output_data by num of aggregated element. - U num_elements_in_axis = 1; + size_t num_elements_in_axis = 1; for (int idx = 0; idx < num_resolved_axis; ++idx) { size_t current = static_cast(input_dims[resolved_axis[idx]]); // Overflow prevention. - if (current > (std::numeric_limits::max() / num_elements_in_axis)) { + if (current > (std::numeric_limits::max() / num_elements_in_axis)) { return false; } num_elements_in_axis *= current; @@ -359,11 +359,11 @@ inline bool QuantizedMeanOrSum(const T* input_data, int32_t input_zero_point, } // Calculate mean by dividing output_data by num of aggregated element. - U num_elements_in_axis = 1; + size_t num_elements_in_axis = 1; for (int idx = 0; idx < num_resolved_axis; ++idx) { size_t current = static_cast(input_dims[resolved_axis[idx]]); // Overflow prevention. - if (current > (std::numeric_limits::max() / num_elements_in_axis)) { + if (current > (std::numeric_limits::max() / num_elements_in_axis)) { return false; } num_elements_in_axis *= current; diff --git a/tensorflow/lite/kernels/internal/types.h b/tensorflow/lite/kernels/internal/types.h index 9db742ddf03..0164f82f19e 100644 --- a/tensorflow/lite/kernels/internal/types.h +++ b/tensorflow/lite/kernels/internal/types.h @@ -1044,7 +1044,9 @@ struct SoftmaxParams { int32_t zero_point; float scale; float* table; + // int16 LUT for exp(x), where x uniform distributed between [-10.0 , 0.0] int16_t* exp_lut; + // int16 LUT for 1 / (1 + x), where x uniform distributed between [0.0 , 1.0] int16_t* one_over_one_plus_x_lut; uint8_t* uint8_table1; uint8_t* uint8_table2; diff --git a/tensorflow/lite/micro/CONTRIBUTING.md b/tensorflow/lite/micro/CONTRIBUTING.md index 9d6fc83463b..f5a974b60af 100644 --- a/tensorflow/lite/micro/CONTRIBUTING.md +++ b/tensorflow/lite/micro/CONTRIBUTING.md @@ -72,6 +72,12 @@ We strongly recommend that contributors: * We will be adding internal checks that automate this requirement by matching the PR description to the regexp: `(Fixes|Issue) #` +1. Unit tests are critical to a healthy codebase. PRs without tests should be + the exception rather than the norm. And contributions to improve, simplify, + or make the unit tests more exhaustive are welcome! Please refer to + [this guideline](https://google.github.io/eng-practices/review/developer/small-cls.html#test_code) + on how test code and writing small PRs should be reconciled. + ## Guidelines for Specific Contribution Categories We provide some additional guidelines for different categories of contributions. @@ -86,6 +92,9 @@ fixing a bug needs a bigger architectural change. [TF Lite Micro Github issue](https://github.com/tensorflow/tensorflow/issues/new?labels=comp%3Amicro&template=70-tflite-micro-issue.md) to determine the scope of the bug fix. 1. Send a PR (if that is determined to be the best path forward). +1. Bugfix PRs should be accompanied by a test case that fails prior to the fix + and passes with the fix. This validates that the fix works as expected, and + helps prevent future regressions. ### Reference Kernel Implementations diff --git a/tensorflow/lite/micro/benchmarks/Makefile.inc b/tensorflow/lite/micro/benchmarks/Makefile.inc index 4a57ef39d69..d9dfba265ed 100644 --- a/tensorflow/lite/micro/benchmarks/Makefile.inc +++ b/tensorflow/lite/micro/benchmarks/Makefile.inc @@ -1,7 +1,3 @@ -$(eval $(call add_third_party_download,$(PERSON_MODEL_URL),$(PERSON_MODEL_MD5),person_model_grayscale,)) -$(eval $(call add_third_party_download,$(PERSON_MODEL_INT8_URL),$(PERSON_MODEL_INT8_MD5),person_model_int8,)) - - KEYWORD_BENCHMARK_SRCS := \ tensorflow/lite/micro/benchmarks/keyword_benchmark.cc \ tensorflow/lite/micro/benchmarks/keyword_scrambled_model_data.cc diff --git a/tensorflow/lite/micro/examples/hello_world/sparkfun_edge/output_handler.cc b/tensorflow/lite/micro/examples/hello_world/sparkfun_edge/output_handler.cc index 2e727095a5c..87f2cdff104 100644 --- a/tensorflow/lite/micro/examples/hello_world/sparkfun_edge/output_handler.cc +++ b/tensorflow/lite/micro/examples/hello_world/sparkfun_edge/output_handler.cc @@ -55,7 +55,7 @@ void HandleOutput(tflite::ErrorReporter* error_reporter, float x_value, // The blue LED is lit for all negative values am_devices_led_on(am_bsp_psLEDs, AM_BSP_LED_BLUE); // The red LED is lit in only some cases - if (y_value <= -0.75) { + if (y_value <= -0.75f) { am_devices_led_on(am_bsp_psLEDs, AM_BSP_LED_RED); } else { am_devices_led_off(am_bsp_psLEDs, AM_BSP_LED_RED); @@ -68,13 +68,14 @@ void HandleOutput(tflite::ErrorReporter* error_reporter, float x_value, // The green LED is lit for all positive values am_devices_led_on(am_bsp_psLEDs, AM_BSP_LED_GREEN); // The yellow LED is lit in only some cases - if (y_value >= 0.75) { + if (y_value >= 0.75f) { am_devices_led_on(am_bsp_psLEDs, AM_BSP_LED_YELLOW); } else { am_devices_led_off(am_bsp_psLEDs, AM_BSP_LED_YELLOW); } } // Log the current X and Y values - TF_LITE_REPORT_ERROR(error_reporter, "x_value: %f, y_value: %f\n", x_value, - y_value); + TF_LITE_REPORT_ERROR(error_reporter, "x_value: %f, y_value: %f\n", + static_cast(x_value), + static_cast(y_value)); } diff --git a/tensorflow/lite/micro/examples/person_detection/Makefile.inc b/tensorflow/lite/micro/examples/person_detection/Makefile.inc index a295bb83f71..304dd95d874 100644 --- a/tensorflow/lite/micro/examples/person_detection/Makefile.inc +++ b/tensorflow/lite/micro/examples/person_detection/Makefile.inc @@ -1,6 +1,3 @@ -$(eval $(call add_third_party_download,$(PERSON_MODEL_URL),$(PERSON_MODEL_MD5),person_model_grayscale,)) -$(eval $(call add_third_party_download,$(RUY_URL),$(RUY_MD5),ruy,)) - person_detection_MODEL_SRCS := \ tensorflow/lite/micro/examples/person_detection/model_settings.cc \ $(MAKEFILE_DIR)/downloads/person_model_grayscale/person_detect_model_data.cc diff --git a/tensorflow/lite/micro/kernels/add.cc b/tensorflow/lite/micro/kernels/add.cc index 79a04875def..7c63eeaba98 100644 --- a/tensorflow/lite/micro/kernels/add.cc +++ b/tensorflow/lite/micro/kernels/add.cc @@ -110,19 +110,22 @@ void EvalAdd(TfLiteContext* context, TfLiteNode* node, TfLiteAddParams* params, tflite::ArithmeticParams op_params; SetActivationParams(data->output_activation_min_f32, data->output_activation_max_f32, &op_params); -#define TF_LITE_ADD(opname) \ - reference_ops::opname(op_params, tflite::micro::GetTensorShape(input1), \ - tflite::micro::GetTensorData(input1), \ - tflite::micro::GetTensorShape(input2), \ - tflite::micro::GetTensorData(input2), \ - tflite::micro::GetTensorShape(output), \ - tflite::micro::GetTensorData(output)) if (data->requires_broadcast) { - TF_LITE_ADD(BroadcastAdd4DSlow); + reference_ops::BroadcastAdd4DSlow( + op_params, tflite::micro::GetTensorShape(input1), + tflite::micro::GetTensorData(input1), + tflite::micro::GetTensorShape(input2), + tflite::micro::GetTensorData(input2), + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData(output)); } else { - TF_LITE_ADD(Add); + reference_ops::Add(op_params, tflite::micro::GetTensorShape(input1), + tflite::micro::GetTensorData(input1), + tflite::micro::GetTensorShape(input2), + tflite::micro::GetTensorData(input2), + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData(output)); } -#undef TF_LITE_ADD } TfLiteStatus EvalAddQuantized(TfLiteContext* context, TfLiteNode* node, @@ -147,27 +150,42 @@ TfLiteStatus EvalAddQuantized(TfLiteContext* context, TfLiteNode* node, bool need_broadcast = reference_ops::ProcessBroadcastShapes( tflite::micro::GetTensorShape(input1), tflite::micro::GetTensorShape(input2), &op_params); -#define TF_LITE_ADD(type, opname, dtype) \ - type::opname(op_params, tflite::micro::GetTensorShape(input1), \ - tflite::micro::GetTensorData(input1), \ - tflite::micro::GetTensorShape(input2), \ - tflite::micro::GetTensorData(input2), \ - tflite::micro::GetTensorShape(output), \ - tflite::micro::GetTensorData(output)); if (output->type == kTfLiteInt8) { if (need_broadcast) { - TF_LITE_ADD(reference_integer_ops, BroadcastAdd4DSlow, int8_t); + reference_integer_ops::BroadcastAdd4DSlow( + op_params, tflite::micro::GetTensorShape(input1), + tflite::micro::GetTensorData(input1), + tflite::micro::GetTensorShape(input2), + tflite::micro::GetTensorData(input2), + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData(output)); } else { - TF_LITE_ADD(reference_integer_ops, Add, int8_t); + reference_integer_ops::Add( + op_params, tflite::micro::GetTensorShape(input1), + tflite::micro::GetTensorData(input1), + tflite::micro::GetTensorShape(input2), + tflite::micro::GetTensorData(input2), + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData(output)); } } else { if (need_broadcast) { - TF_LITE_ADD(reference_ops, BroadcastAdd4DSlow, uint8_t); + reference_ops::BroadcastAdd4DSlow( + op_params, tflite::micro::GetTensorShape(input1), + tflite::micro::GetTensorData(input1), + tflite::micro::GetTensorShape(input2), + tflite::micro::GetTensorData(input2), + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData(output)); } else { - TF_LITE_ADD(reference_ops, Add, uint8_t); + reference_ops::Add(op_params, tflite::micro::GetTensorShape(input1), + tflite::micro::GetTensorData(input1), + tflite::micro::GetTensorShape(input2), + tflite::micro::GetTensorData(input2), + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData(output)); } } -#undef TF_LITE_ADD } return kTfLiteOk; diff --git a/tensorflow/lite/micro/kernels/ethos-u/ethosu.cc b/tensorflow/lite/micro/kernels/ethos-u/ethosu.cc index ddb144406bb..afdc564d808 100644 --- a/tensorflow/lite/micro/kernels/ethos-u/ethosu.cc +++ b/tensorflow/lite/micro/kernels/ethos-u/ethosu.cc @@ -16,6 +16,7 @@ limitations under the License. #include #include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/micro/kernels/kernel_util.h" #include "tensorflow/lite/micro/tools/make/downloads/flatbuffers/include/flatbuffers/flexbuffers.h" namespace tflite { @@ -26,30 +27,51 @@ namespace ethosu { constexpr uint8_t CO_TYPE_ETHOSU = 1; +struct OpData { + int cms_data_size; + int buffer_idx; +}; + void* Init(TfLiteContext* context, const char* buffer, size_t length) { - return nullptr; + TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr); + return context->AllocatePersistentBuffer(context, sizeof(OpData)); } void Free(TfLiteContext* context, void* buffer) {} TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TFLITE_DCHECK(context != nullptr); TF_LITE_ENSURE(context, node->inputs->size > 0); - TF_LITE_ENSURE(context, context->tensors); + TFLITE_DCHECK(node->user_data != nullptr); TF_LITE_ENSURE(context, node->custom_initial_data_size > 0); + + OpData* data = static_cast(node->user_data); + int num_base_addr = node->inputs->size + node->outputs->size; + TF_LITE_ENSURE_STATUS(context->RequestScratchBufferInArena( + context, num_base_addr * sizeof(uint64_t), &data->buffer_idx)); + + // Get command stream data size + TfLiteTensor* tensor = context->GetTensor(context, node->inputs->data[0]); + data->cms_data_size = tensor->bytes; + return kTfLiteOk; } TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + TFLITE_DCHECK(node->user_data != nullptr); + TFLITE_DCHECK(context != nullptr); + TFLITE_DCHECK(context->GetScratchBuffer != nullptr); + // Get base addresses - TfLiteTensor* tensor; - int num_base_addr = node->inputs->size + node->outputs->size; + TfLiteEvalTensor* tensor; int i = 0; int num_tensors = 0; - uint64_t base_addrs[num_base_addr]; void* cms_data; - int cms_data_size; uint8_t co_type; int result; + const OpData* data = static_cast(node->user_data); + uint64_t* base_addrs = static_cast( + context->GetScratchBuffer(context, data->buffer_idx)); const uint8_t* custom_data = static_cast(node->custom_initial_data); @@ -60,26 +82,30 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { return kTfLiteError; } - // Get command stream data address and size - tensor = &(context->tensors[node->inputs->data[0]]); + // Get command stream data address + tensor = context->GetEvalTensor(context, node->inputs->data[0]); cms_data = reinterpret_cast(tensor->data.uint8); - cms_data_size = tensor->bytes; // Get adresses to weights/scratch/input data for (i = 1; i < node->inputs->size; ++i) { - tensor = &(context->tensors[node->inputs->data[i]]); + tensor = context->GetEvalTensor(context, node->inputs->data[i]); base_addrs[num_tensors] = reinterpret_cast(tensor->data.uint8); num_tensors++; } // Get adresses to output data for (i = 0; i < node->outputs->size; ++i) { - tensor = &(context->tensors[node->outputs->data[i]]); + tensor = context->GetEvalTensor(context, node->outputs->data[i]); base_addrs[num_tensors] = reinterpret_cast(tensor->data.uint8); num_tensors++; } - result = ethosu_invoke(cms_data, cms_data_size, base_addrs, num_tensors); + // Ethos-U guarantees that the tensors that require a base pointer are among + // the 8 first tensors + num_tensors = std::min(num_tensors, 8); + + result = + ethosu_invoke(cms_data, data->cms_data_size, base_addrs, num_tensors); if (-1 == result) { return kTfLiteError; } else { @@ -89,8 +115,16 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } // namespace ethosu -TfLiteRegistration Register_ETHOSU() { - return {ethosu::Init, ethosu::Free, ethosu::Prepare, ethosu::Eval}; +TfLiteRegistration* Register_ETHOSU() { + static TfLiteRegistration r = {ethosu::Init, + ethosu::Free, + ethosu::Prepare, + ethosu::Eval, + /*profiling_string=*/nullptr, + /*builtin_code=*/0, + /*custom_name=*/nullptr, + /*version=*/0}; + return &r; } const char* GetString_ETHOSU() { return "ethos-u"; } diff --git a/tensorflow/lite/micro/kernels/reduce.cc b/tensorflow/lite/micro/kernels/reduce.cc index 5cae782482e..2f9e6398591 100644 --- a/tensorflow/lite/micro/kernels/reduce.cc +++ b/tensorflow/lite/micro/kernels/reduce.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/kernels/internal/quantization_util.h" +#include "tensorflow/lite/kernels/internal/reference/integer_ops/mean.h" #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" #include "tensorflow/lite/kernels/internal/types.h" #include "tensorflow/lite/kernels/kernel_util.h" @@ -32,6 +33,20 @@ namespace reduce { constexpr int kMaxNumberOfAxis = 4; constexpr int kMaxNumberOfReducedAxis = 2; +struct OpData { + int32_t multiplier; + int shift; + int temp_buffer_idx; + int input_zp; + float input_scale; + int output_zp; + float output_scale; +}; + +void* InitMean(TfLiteContext* context, const char* buffer, size_t length) { + return context->AllocatePersistentBuffer(context, sizeof(OpData)); +} + TfLiteStatus PrepareSimple(TfLiteContext* context, TfLiteNode* node) { // Inputs Tensor (dtype depends on quantization): // [0] = Input @@ -51,6 +66,25 @@ TfLiteStatus PrepareSimple(TfLiteContext* context, TfLiteNode* node) { } TfLiteStatus PrepareMeanOrSum(TfLiteContext* context, TfLiteNode* node) { + const TfLiteTensor* input = GetInput(context, node, 0); + OpData* op_data = reinterpret_cast(node->user_data); + const TfLiteTensor* output = GetOutput(context, node, 0); + if (input->type == kTfLiteInt8) { + const double real_multiplier = static_cast(input->params.scale) / + static_cast(output->params.scale); + QuantizeMultiplier(real_multiplier, &op_data->multiplier, &op_data->shift); + } + + int output_size = NumElements(output); + if (input->type == kTfLiteInt8 || input->type == kTfLiteUInt8) { + context->RequestScratchBufferInArena(context, output_size * sizeof(int32_t), + &op_data->temp_buffer_idx); + op_data->input_zp = input->params.zero_point; + op_data->input_scale = input->params.scale; + op_data->output_zp = output->params.zero_point; + op_data->output_scale = output->params.scale; + } + TF_LITE_ENSURE_OK(context, PrepareSimple(context, node)); // TODO(b/144955155): Support uint8_t(b/144955155) and int8_t(b/144955018) return kTfLiteOk; @@ -74,26 +108,25 @@ TfLiteStatus EvalMean(TfLiteContext* context, TfLiteNode* node) { TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, 0); TfLiteReducerParams* params = reinterpret_cast(node->builtin_data); + OpData* op_data = reinterpret_cast(node->user_data); int num_axis = static_cast(ElementCount(*axis->dims)); int temp_index[kMaxNumberOfAxis]; int resolved_axis[kMaxNumberOfReducedAxis]; + tflite::MeanParams op_params; + ResolveAxis(tflite::micro::GetTensorData(axis), num_axis, &op_params); + // TODO(b/146571391): Support only 4D Input and 2D Axis for Mean until + // scratch tensor allocation has been implemented in (b/132070898) + bool is_valid_inputs = (input->dims->size == 4 && op_params.axis_count == 2 && + ((op_params.axis[0] == 1 && op_params.axis[1] == 2) || + (op_params.axis[0] == 2 && op_params.axis[1] == 1))); + TF_LITE_ENSURE_MSG( + context, is_valid_inputs == true, + "Number of Input " + "dimensions != 4 OR the Axis is not either [1, 2] or [2, 1]"); switch (input->type) { case kTfLiteFloat32: { - tflite::MeanParams op_params; - ResolveAxis(tflite::micro::GetTensorData(axis), num_axis, - &op_params); - // TODO(b/146571391): Support only 4D Input and 2D Axis for Mean until - // scratch tensor allocation has been implemented in (b/132070898) - bool is_valid_inputs = - (input->dims->size == 4 && op_params.axis_count == 2 && - ((op_params.axis[0] == 1 && op_params.axis[1] == 2) || - (op_params.axis[0] == 2 && op_params.axis[1] == 1))); - TF_LITE_ENSURE_MSG( - context, is_valid_inputs == true, - "Number of Input " - "dimensions != 4 OR the Axis is not either [1, 2] or [2, 1]"); // TODO(b/139102329): Handle the below special case in the combined // reference method. // Defer to specialized implementation for 4D Mean across axes 1 & 2. @@ -114,10 +147,81 @@ TfLiteStatus EvalMean(TfLiteContext* context, TfLiteNode* node) { tflite::micro::GetTensorData(output))); } } break; + case kTfLiteInt8: { + if (params->keep_dims) { + reference_integer_ops::Mean( + op_params, op_data->multiplier, op_data->shift, + tflite::micro::GetTensorShape(input), + tflite::micro::GetTensorData(input), op_data->input_zp, + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData(output), op_data->output_zp); + } else if (op_data->input_zp == op_data->output_zp && + op_data->input_scale == op_data->output_scale) { + int32_t* temp_buffer = static_cast( + context->GetScratchBuffer(context, op_data->temp_buffer_idx)); + TF_LITE_ENSURE( + context, + reference_ops::Mean( + tflite::micro::GetTensorData(input), input->dims->data, + input->dims->size, tflite::micro::GetTensorData(output), + output->dims->data, output->dims->size, + tflite::micro::GetTensorData(axis), num_axis, + params->keep_dims, temp_index, resolved_axis, temp_buffer)); + } else { + int32_t* temp_buffer = static_cast( + context->GetScratchBuffer(context, op_data->temp_buffer_idx)); + TF_LITE_ENSURE( + context, + reference_ops::QuantizedMeanOrSum( + tflite::micro::GetTensorData(input), op_data->input_zp, + op_data->input_scale, input->dims->data, input->dims->size, + tflite::micro::GetTensorData(output), + op_data->output_zp, op_data->output_scale, output->dims->data, + output->dims->size, tflite::micro::GetTensorData(axis), + num_axis, params->keep_dims, temp_index, resolved_axis, + temp_buffer, false)); + } + } break; + case kTfLiteUInt8: { + if (params->keep_dims) { + reference_ops::Mean(op_params, tflite::micro::GetTensorShape(input), + tflite::micro::GetTensorData(input), + op_data->input_zp, op_data->input_scale, + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData(output), + op_data->output_zp, op_data->output_scale); + } else if (op_data->input_zp == op_data->output_zp && + op_data->input_scale == op_data->output_scale) { + uint32_t* temp_buffer = static_cast( + context->GetScratchBuffer(context, op_data->temp_buffer_idx)); + TF_LITE_ENSURE( + context, + reference_ops::Mean(tflite::micro::GetTensorData(input), + input->dims->data, input->dims->size, + tflite::micro::GetTensorData(output), + output->dims->data, output->dims->size, + tflite::micro::GetTensorData(axis), + num_axis, params->keep_dims, temp_index, + resolved_axis, temp_buffer)); + } else { + uint32_t* temp_buffer = static_cast( + context->GetScratchBuffer(context, op_data->temp_buffer_idx)); + TF_LITE_ENSURE( + context, + reference_ops::QuantizedMeanOrSum( + tflite::micro::GetTensorData(input), op_data->input_zp, + op_data->input_scale, input->dims->data, input->dims->size, + tflite::micro::GetTensorData(output), + op_data->output_zp, op_data->output_scale, output->dims->data, + output->dims->size, tflite::micro::GetTensorData(axis), + num_axis, params->keep_dims, temp_index, resolved_axis, + temp_buffer, false)); + } + } break; default: // TODO(b/144955155): Support uint8_t(b/144955155) and int8_t(b/144955018) TF_LITE_ENSURE_MSG(context, false, - "Currently, only float32 input type " + "Currently, only float32, int8 or uint8 input type " "is supported."); } return kTfLiteOk; @@ -125,7 +229,7 @@ TfLiteStatus EvalMean(TfLiteContext* context, TfLiteNode* node) { } // namespace reduce TfLiteRegistration Register_MEAN() { - return {/*init=*/nullptr, + return {/*init=*/reduce::InitMean, /*free=*/nullptr, /*prepare=*/reduce::PrepareMeanOrSum, /*invoke=*/reduce::EvalMean, diff --git a/tensorflow/lite/micro/kernels/reduce_test.cc b/tensorflow/lite/micro/kernels/reduce_test.cc index 1e3ded2bd77..3207063b46b 100644 --- a/tensorflow/lite/micro/kernels/reduce_test.cc +++ b/tensorflow/lite/micro/kernels/reduce_test.cc @@ -25,7 +25,7 @@ namespace testing { namespace { // Common inputs and outputs. -// static const int kInputElements4D = 24; +static const int kInputElements4D = 24; static const int kInputShape4D[] = {4, 2, 2, 3, 2}; static const float kInputData4D[] = { 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, @@ -90,6 +90,44 @@ void TestMeanFloatInput4D(const int* input_dims_data, const float* input_data, output_data, output_dims_count, params, tolerance)); } +template +void TestMeanOpQuantized(const int* input_dims_data, const float* input_data, + T* input_data_quant, float input_scale, + int input_zero_point, const int* axis_dims_data, + const int32_t* axis_data, const int* output_dims_data, + const float* expected_output_data, + T* output_data_quant, T* expected_output_data_quant, + float output_scale, int output_zero_point, + TfLiteReducerParams* params) { + // Convert dimesion arguments to TfLiteArrays + TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data); + TfLiteIntArray* axis_dims = IntArrayFromInts(axis_dims_data); + TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data); + + // Get number of elements in input and output tensors + const int output_dims_count = ElementCount(*output_dims); + + // Initialize tensors + constexpr int tensors_size = 3; + TfLiteTensor tensors[] = { + CreateQuantizedTensor(input_data, input_data_quant, input_dims, + input_scale, input_zero_point), + CreateInt32Tensor(axis_data, axis_dims), + CreateQuantizedTensor(output_data_quant, output_dims, output_scale, + output_zero_point), + }; + + // Quantize expected output + tflite::AsymmetricQuantize(expected_output_data, expected_output_data_quant, + output_dims_count, output_scale, + output_zero_point); + + TF_LITE_MICRO_EXPECT_EQ( + kTfLiteOk, + ValidateReduceGoldens(tensors, tensors_size, expected_output_data_quant, + output_data_quant, output_dims_count, params, 1.0)); +} + } // namespace } // namespace testing } // namespace tflite @@ -110,10 +148,55 @@ TF_LITE_MICRO_TEST(MeanFloat4DKeepDims) { ¶ms); } +TF_LITE_MICRO_TEST(MeanInt84DKeepDims) { + int8_t expected_output_data_quant[tflite::testing::kOutputElements]; + int8_t output_data_quant[tflite::testing::kOutputElements]; + int8_t input_data_quant[tflite::testing::kInputElements4D]; + + float input_scale = 0.5f; + int input_zero_point = 0; + float output_scale = 0.5f; + int output_zero_point = 0; + + TfLiteReducerParams params = { + true // keep_dims + }; + + tflite::testing::TestMeanOpQuantized( + tflite::testing::kInputShape4D, tflite::testing::kInputData4D, + input_data_quant, input_scale, input_zero_point, + tflite::testing::kAxisShape, tflite::testing::kAxisData, + tflite::testing::kOutputShape, tflite::testing::kGoldenData, + output_data_quant, expected_output_data_quant, output_scale, + output_zero_point, ¶ms); +} + +TF_LITE_MICRO_TEST(MeanUInt84DKeepDims) { + uint8_t expected_output_data_quant[tflite::testing::kOutputElements]; + uint8_t output_data_quant[tflite::testing::kOutputElements]; + uint8_t input_data_quant[tflite::testing::kInputElements4D]; + + float input_scale = 0.5f; + int input_zero_point = 128; + float output_scale = 0.5f; + int output_zero_point = 128; + + TfLiteReducerParams params = { + true // keep_dims + }; + + tflite::testing::TestMeanOpQuantized( + tflite::testing::kInputShape4D, tflite::testing::kInputData4D, + input_data_quant, input_scale, input_zero_point, + tflite::testing::kAxisShape, tflite::testing::kAxisData, + tflite::testing::kOutputShape, tflite::testing::kGoldenData, + output_data_quant, expected_output_data_quant, output_scale, + output_zero_point, ¶ms); +} + TF_LITE_MICRO_TEST(MeanFloat4DWithoutKeepDims) { const int kOutputShape[] = {2, 2, 2}; float output_data[tflite::testing::kOutputElements]; - TfLiteReducerParams params = { false // keep_dims }; @@ -124,6 +207,50 @@ TF_LITE_MICRO_TEST(MeanFloat4DWithoutKeepDims) { tflite::testing::kGoldenData, output_data, ¶ms); } +TF_LITE_MICRO_TEST(MeanInt84DWithoutKeepDims) { + int8_t expected_output_data_quant[tflite::testing::kOutputElements]; + int8_t output_data_quant[tflite::testing::kOutputElements]; + int8_t input_data_quant[tflite::testing::kInputElements4D]; + + const int kOutputShape[] = {2, 2, 2}; + TfLiteReducerParams params = { + false // keep_dims + }; + float input_scale = 0.5f; + int input_zero_point = 0; + float output_scale = 0.5f; + int output_zero_point = 0; + + tflite::testing::TestMeanOpQuantized( + tflite::testing::kInputShape4D, tflite::testing::kInputData4D, + input_data_quant, input_scale, input_zero_point, + tflite::testing::kAxisShape, tflite::testing::kAxisData, kOutputShape, + tflite::testing::kGoldenData, output_data_quant, + expected_output_data_quant, output_scale, output_zero_point, ¶ms); +} + +TF_LITE_MICRO_TEST(MeanUInt84DWithoutKeepDims) { + uint8_t expected_output_data_quant[tflite::testing::kOutputElements]; + uint8_t output_data_quant[tflite::testing::kOutputElements]; + uint8_t input_data_quant[tflite::testing::kInputElements4D]; + + const int kOutputShape[] = {2, 2, 2}; + TfLiteReducerParams params = { + false // keep_dims + }; + float input_scale = 0.5f; + int input_zero_point = 128; + float output_scale = 0.5f; + int output_zero_point = 128; + + tflite::testing::TestMeanOpQuantized( + tflite::testing::kInputShape4D, tflite::testing::kInputData4D, + input_data_quant, input_scale, input_zero_point, + tflite::testing::kAxisShape, tflite::testing::kAxisData, kOutputShape, + tflite::testing::kGoldenData, output_data_quant, + expected_output_data_quant, output_scale, output_zero_point, ¶ms); +} + TF_LITE_MICRO_TEST(MeanFloat4DWithoutKeepDimsWithPrecision) { const int kInputShape4D[] = {4, 2, 2, 3, 1}; const float kInputData4D[] = {1.0, 24.0, 13.0, 3.0, 9.0, 17.0, @@ -132,7 +259,6 @@ TF_LITE_MICRO_TEST(MeanFloat4DWithoutKeepDimsWithPrecision) { const int kOutputShape[] = {2, 2, 1}; const float kGoldenData[] = {11.166667, 19.833334}; float output_data[kOutputElements]; - TfLiteReducerParams params = { false // keep_dims }; @@ -143,4 +269,54 @@ TF_LITE_MICRO_TEST(MeanFloat4DWithoutKeepDimsWithPrecision) { ¶ms); } +TF_LITE_MICRO_TEST(MeanInt84DWithoutKeepDimsWithPrecision) { + const int kInputShape4D[] = {4, 2, 2, 3, 1}; + const float kInputData4D[] = {1.0, 24.0, 13.0, 3.0, 9.0, 17.0, + 11.0, 36.0, 14.0, 19.0, 17.0, 22.0}; + const int kOutputShape[] = {2, 2, 1}; + const float kGoldenData[] = {11.166667, 19.833334}; + TfLiteReducerParams params = { + false // keep_dims + }; + float input_scale = 0.5f; + int input_zero_point = 0; + float output_scale = 0.5f; + int output_zero_point = 0; + + int8_t output_data_quant[2]; + int8_t expected_output_data_quant[2]; + int8_t input_data_quant[12]; + + tflite::testing::TestMeanOpQuantized( + kInputShape4D, kInputData4D, input_data_quant, input_scale, + input_zero_point, tflite::testing::kAxisShape, tflite::testing::kAxisData, + kOutputShape, kGoldenData, output_data_quant, expected_output_data_quant, + output_scale, output_zero_point, ¶ms); +} + +TF_LITE_MICRO_TEST(MeanUInt84DWithoutKeepDimsWithPrecision) { + const int kInputShape4D[] = {4, 2, 2, 3, 1}; + const float kInputData4D[] = {1.0, 24.0, 13.0, 3.0, 9.0, 17.0, + 11.0, 36.0, 14.0, 19.0, 17.0, 22.0}; + const int kOutputShape[] = {2, 2, 1}; + const float kGoldenData[] = {11.166667, 19.833334}; + TfLiteReducerParams params = { + false // keep_dims + }; + + float input_scale = 0.5f; + int input_zero_point = 128; + float output_scale = 0.5f; + int output_zero_point = 128; + + uint8_t output_data_quant[2]; + uint8_t expected_output_data_quant[2]; + uint8_t input_data_quant[12]; + + tflite::testing::TestMeanOpQuantized( + kInputShape4D, kInputData4D, input_data_quant, input_scale, + input_zero_point, tflite::testing::kAxisShape, tflite::testing::kAxisData, + kOutputShape, kGoldenData, output_data_quant, expected_output_data_quant, + output_scale, output_zero_point, ¶ms); +} TF_LITE_MICRO_TESTS_END diff --git a/tensorflow/lite/micro/kernels/softmax.cc b/tensorflow/lite/micro/kernels/softmax.cc index e85c1a4a306..fa1b9caf077 100644 --- a/tensorflow/lite/micro/kernels/softmax.cc +++ b/tensorflow/lite/micro/kernels/softmax.cc @@ -30,23 +30,30 @@ namespace micro { namespace activations { namespace { +// Softmax parameter data that persists in user_data +static constexpr int kInt16LUTArraySize = 513; + TfLiteStatus CalculateSoftmaxParams(TfLiteContext* context, const TfLiteTensor* input, TfLiteTensor* output, const TfLiteSoftmaxParams* params, SoftmaxParams* op_data) { - if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8) { + if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8 || + input->type == kTfLiteInt16) { if (input->type == kTfLiteUInt8) { TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteUInt8); TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0); - } else { + } else if (input->type == kTfLiteInt16) { + TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0); + TF_LITE_ENSURE_NEAR(context, output->params.scale, 1.f / 32768, + (0.001f * 1.f / 32768)); + } else { // input->type == kTfLiteInt8 TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteInt8); if (output->type == kTfLiteInt16) { TF_LITE_ENSURE_EQ(context, output->params.zero_point, -32768); - // NOTE: Current int16_t softmax output does not require symmetric - // scaling - // - so no need to verify scale here. - } else { + TF_LITE_ENSURE_NEAR(context, output->params.scale, 1.f / 65536, + (0.001f * 1.f / 65536)); + } else { // output->type == kTfLiteint8 TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteInt8); TF_LITE_ENSURE_EQ(context, output->params.zero_point, -128); TF_LITE_ENSURE(context, output->params.scale == 1.f / 256); @@ -55,15 +62,28 @@ TfLiteStatus CalculateSoftmaxParams(TfLiteContext* context, static const int kScaledDiffIntegerBits = 5; - int input_left_shift; - tflite::PreprocessSoftmaxScaling( - static_cast(params->beta), - static_cast(input->params.scale), kScaledDiffIntegerBits, - &op_data->input_multiplier, &input_left_shift); - op_data->input_left_shift = input_left_shift; - op_data->diff_min = - -1.0 * tflite::CalculateInputRadius(kScaledDiffIntegerBits, - op_data->input_left_shift); + // Calculate input_multiplier and input_left_shift + if (input->type == kTfLiteInt16) { + int input_left_shift; + double input_scale_beta_rescale = + static_cast(input->params.scale) * + static_cast(params->beta) / + (10.0 / 65535.0); // scale the input_diff such that [-65535, 0] + // correspond to [-10.0, 0.0] + QuantizeMultiplier(input_scale_beta_rescale, &op_data->input_multiplier, + &input_left_shift); + op_data->input_left_shift = input_left_shift; + } else { + int input_left_shift; + tflite::PreprocessSoftmaxScaling( + static_cast(params->beta), + static_cast(input->params.scale), kScaledDiffIntegerBits, + &op_data->input_multiplier, &input_left_shift); + op_data->input_left_shift = input_left_shift; + op_data->diff_min = + -1.0 * tflite::CalculateInputRadius(kScaledDiffIntegerBits, + op_data->input_left_shift); + } } else { TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteFloat32); TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteFloat32); @@ -91,7 +111,7 @@ void SoftmaxQuantized(const TfLiteEvalTensor* input, TfLiteEvalTensor* output, tflite::micro::GetTensorData(input), tflite::micro::GetTensorShape(output), tflite::micro::GetTensorData(output)); - } else { + } else if (input->type == kTfLiteInt8) { if (output->type == kTfLiteInt16) { tflite::reference_ops::Softmax( op_data, tflite::micro::GetTensorShape(input), @@ -105,6 +125,12 @@ void SoftmaxQuantized(const TfLiteEvalTensor* input, TfLiteEvalTensor* output, tflite::micro::GetTensorShape(output), tflite::micro::GetTensorData(output)); } + } else { + tflite::reference_ops::SoftmaxInt16( + op_data, tflite::micro::GetTensorShape(input), + tflite::micro::GetTensorData(input), + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData(output)); } } @@ -114,18 +140,50 @@ void* SoftmaxInit(TfLiteContext* context, const char* buffer, size_t length) { } TfLiteStatus SoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) { - auto* params = static_cast(node->builtin_data); - TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); const TfLiteTensor* input = GetInput(context, node, 0); TF_LITE_ENSURE(context, NumDimensions(input) >= 1); - TfLiteTensor* output = GetOutput(context, node, 0); - TFLITE_DCHECK(node->user_data != nullptr); - SoftmaxParams* data = static_cast(node->user_data); - return CalculateSoftmaxParams(context, input, output, params, data); + TF_LITE_ENSURE(context, node->user_data != nullptr); + SoftmaxParams* op_data = static_cast(node->user_data); + // Only allocate LUTs for KTfLiteInt16 data type + if (input->type == kTfLiteInt16) { + void* raw_exp_lut = context->AllocatePersistentBuffer( + context, sizeof(int16_t) * kInt16LUTArraySize); + TF_LITE_ENSURE(context, raw_exp_lut != nullptr); + op_data->exp_lut = reinterpret_cast(raw_exp_lut); + void* one_over_one_plus_x_lut = context->AllocatePersistentBuffer( + context, sizeof(int16_t) * kInt16LUTArraySize); + TF_LITE_ENSURE(context, one_over_one_plus_x_lut != nullptr); + op_data->one_over_one_plus_x_lut = + reinterpret_cast(one_over_one_plus_x_lut); + } + + if (output->type == kTfLiteInt16) { + TF_LITE_ENSURE(context, input->type == kTfLiteInt8 || + input->type == kTfLiteUInt8 || + input->type == kTfLiteInt16); + } else { + TF_LITE_ENSURE_EQ(context, input->type, output->type); + } + + // Populate LUT if required + if (input->type == kTfLiteInt16) { + TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0); + // exp LUT only used on negative values + // we consider exp(-10.0) is insignificant to accumulation + gen_lut([](float value) { return std::exp(value); }, -10.0f, 0.0f, + op_data->exp_lut, kInt16LUTArraySize); + gen_lut([](float value) { return 1.0f / (1.0f + value); }, 0.0f, 1.0f, + op_data->one_over_one_plus_x_lut, kInt16LUTArraySize); + op_data->zero_point = output->params.zero_point; + op_data->scale = output->params.scale; + } + + auto* params = static_cast(node->builtin_data); + return CalculateSoftmaxParams(context, input, output, params, op_data); } TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) { @@ -133,16 +191,17 @@ TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) { TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, 0); TFLITE_DCHECK(node->user_data != nullptr); - SoftmaxParams* data = static_cast(node->user_data); + SoftmaxParams op_data = *static_cast(node->user_data); switch (input->type) { case kTfLiteFloat32: { - SoftmaxFloat(input, output, *data); + SoftmaxFloat(input, output, op_data); return kTfLiteOk; } case kTfLiteInt8: - case kTfLiteUInt8: { - SoftmaxQuantized(input, output, *data); + case kTfLiteUInt8: + case kTfLiteInt16: { + SoftmaxQuantized(input, output, op_data); return kTfLiteOk; } default: diff --git a/tensorflow/lite/micro/kernels/softmax_test.cc b/tensorflow/lite/micro/kernels/softmax_test.cc index 27828d2de34..808ea9396ba 100644 --- a/tensorflow/lite/micro/kernels/softmax_test.cc +++ b/tensorflow/lite/micro/kernels/softmax_test.cc @@ -28,8 +28,13 @@ namespace { // quantization parameters. const float output_scale_int8 = 1.0f / 256.0f; const float output_scale_uint8 = 1.0f / 256.0f; +const float output_scale_int16 = 1.0f / 32768.0f; const int output_zero_point_int8 = -128; const int output_zero_point_uint8 = 0; +const int output_zero_point_int16 = 0; + +// Empirical tolerance in quantization space +const float tolerance_int16 = 7.0; // 1-dimensional test data. const int flat_size_1d = 5; @@ -291,7 +296,7 @@ void TestSoftmaxQuantized(const int* input_dims_data, const float* input_data, int input_zero_point, const int* output_dims_data, const float* golden, T* golden_quantized, float output_scale, int output_zero_point, - T* output_data) { + T* output_data, float tolerance = 1.0) { TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data); TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data); const int output_dims_count = ElementCount(*output_dims); @@ -310,7 +315,7 @@ void TestSoftmaxQuantized(const int* input_dims_data, const float* input_data, output_zero_point); ValidateSoftmaxGoldens(tensors, tensors_size, output_data, golden_quantized, - output_dims_count, 1.0); + output_dims_count, tolerance); } } // namespace @@ -356,6 +361,21 @@ TF_LITE_MICRO_TEST(Softmax1DQuantizedInt8ShouldMatchGolden) { tflite::testing::output_zero_point_int8, output_data); } +TF_LITE_MICRO_TEST(Softmax1DQuantizedInt16ShouldMatchGolden) { + const float input_scale = 0.1f; + const int input_zero_point = 0; + + int16_t input_quantized[tflite::testing::flat_size_1d]; + int16_t golden_quantized[tflite::testing::flat_size_1d]; + int16_t output_data[tflite::testing::flat_size_1d]; + tflite::testing::TestSoftmaxQuantized( + tflite::testing::shape_1d, tflite::testing::input_data_1d, + input_quantized, input_scale, input_zero_point, tflite::testing::shape_1d, + tflite::testing::golden_1d, golden_quantized, + tflite::testing::output_scale_int16, + tflite::testing::output_zero_point_int16, output_data); +} + TF_LITE_MICRO_TEST(Softmax2DFloatShouldMatchGolden) { float output_data[tflite::testing::flat_size_2d]; tflite::testing::TestSoftmaxFloat( @@ -393,6 +413,21 @@ TF_LITE_MICRO_TEST(Softmax2DQuantizedInt8ShouldMatchGolden) { tflite::testing::output_zero_point_int8, output_data); } +TF_LITE_MICRO_TEST(Softmax2DQuantizedInt16ShouldMatchGolden) { + const float input_scale = 0.1f; + const int input_zero_point = 0; + + int16_t input_quantized[tflite::testing::flat_size_2d]; + int16_t golden_quantized[tflite::testing::flat_size_2d]; + int16_t output_data[tflite::testing::flat_size_2d]; + tflite::testing::TestSoftmaxQuantized( + tflite::testing::shape_2d, tflite::testing::input_data_2d, + input_quantized, input_scale, input_zero_point, tflite::testing::shape_2d, + tflite::testing::golden_2d, golden_quantized, + tflite::testing::output_scale_int16, + tflite::testing::output_zero_point_int16, output_data); +} + TF_LITE_MICRO_TEST(Softmax3DFloatShouldMatchGolden) { float output_data[tflite::testing::flat_size_3d]; tflite::testing::TestSoftmaxFloat( @@ -430,6 +465,22 @@ TF_LITE_MICRO_TEST(Softmax3DQuantizedInt8ShouldMatchGolden) { tflite::testing::output_zero_point_int8, output_data); } +TF_LITE_MICRO_TEST(Softmax3DQuantizedInt16ShouldMatchGolden) { + const float input_scale = 0.1f; + const int input_zero_point = 0; + + int16_t input_quantized[tflite::testing::flat_size_3d]; + int16_t golden_quantized[tflite::testing::flat_size_3d]; + int16_t output_data[tflite::testing::flat_size_3d]; + tflite::testing::TestSoftmaxQuantized( + tflite::testing::shape_3d, tflite::testing::input_data_3d, + input_quantized, input_scale, input_zero_point, tflite::testing::shape_3d, + tflite::testing::golden_3d, golden_quantized, + tflite::testing::output_scale_int16, + tflite::testing::output_zero_point_int16, output_data, + tflite::testing::tolerance_int16); +} + TF_LITE_MICRO_TEST(Softmax4DFloatShouldMatchGolden) { float output_data[tflite::testing::flat_size_4d]; tflite::testing::TestSoftmaxFloat( @@ -467,4 +518,19 @@ TF_LITE_MICRO_TEST(Softmax4DQuantizedInt8ShouldMatchGolden) { tflite::testing::output_zero_point_int8, output_data); } +TF_LITE_MICRO_TEST(Softmax4DQuantizedInt16ShouldMatchGolden) { + const float input_scale = 0.1f; + const int input_zero_point = 0; + + int16_t input_quantized[tflite::testing::flat_size_4d]; + int16_t golden_quantized[tflite::testing::flat_size_4d]; + int16_t output_data[tflite::testing::flat_size_4d]; + tflite::testing::TestSoftmaxQuantized( + tflite::testing::shape_4d, tflite::testing::input_data_4d, + input_quantized, input_scale, input_zero_point, tflite::testing::shape_4d, + tflite::testing::golden_4d, golden_quantized, + tflite::testing::output_scale_int16, + tflite::testing::output_zero_point_int16, output_data, + tflite::testing::tolerance_int16); +} TF_LITE_MICRO_TESTS_END diff --git a/tensorflow/lite/micro/kernels/vexriscv/README.md b/tensorflow/lite/micro/kernels/vexriscv/README.md new file mode 100644 index 00000000000..bba47df4a6d --- /dev/null +++ b/tensorflow/lite/micro/kernels/vexriscv/README.md @@ -0,0 +1,39 @@ +# VexRISC-V + +## Maintainers + +* [danielyou0230](https://github.com/danielyou0230) +* [tal-x](https://github.com/tcal-x) + +## Background + +The optimized kernels for +[VexRISC-V](https://github.com/SpinalHDL/VexRiscv)/[Litex](https://github.com/enjoy-digital/litex) +are used to run Tensorflow Lite Micro in Zephyr on either + +* Digilent Arty board (e.g. Arty A7) +* [Renode](https://github.com/renode/renode): Open source simulation framework + (no hardware required) + +To run on Digilent Arty board (FPGA,) you'll also need a soft-CPU gateware for +the FPGA, please see +[Tensorflow lite demo running in Zephyr on Litex/VexRiscv SoC](https://github.com/antmicro/litex-vexriscv-tensorflow-lite-demo) +by Antmicro for more details. + +## Info + +To use VexRISC-V optimized kernels instead of reference kernel add +`TAGS=vexriscv` to the make command. The kernels that doesn't have optimization +for a certain micro architecture fallback to use TFLM reference kernels. + +# Example + +To compile the binary file with VexRISC-V optimizations, one can use the +following command + +``` +make -f tensorflow/lite/micro/tools/make/Makefile \ +TAGS=vexriscv \ +TARGET=zephyr_vexriscv \ +person_detection_int8_bin +``` diff --git a/tensorflow/lite/micro/micro_allocator.cc b/tensorflow/lite/micro/micro_allocator.cc index 881b9b9abb0..f279450efb8 100644 --- a/tensorflow/lite/micro/micro_allocator.cc +++ b/tensorflow/lite/micro/micro_allocator.cc @@ -337,8 +337,8 @@ TfLiteStatus AllocationInfoBuilder::AddScratchBuffers( current->bytes = handle->bytes; current->first_created = handle->node_idx; current->last_used = handle->node_idx; - current->needs_allocating = true; current->offline_offset = kOnlinePlannedBuffer; + current->needs_allocating = true; } return kTfLiteOk; } @@ -655,6 +655,7 @@ TfLiteStatus MicroAllocator::StartModelAllocation( model_is_allocating_ = true; + TF_LITE_ENSURE_STATUS(InitScratchBufferHandles()); TF_LITE_ENSURE_STATUS(AllocateTfLiteEvalTensors(model, eval_tensors)); TF_LITE_ENSURE_STATUS( AllocateNodeAndRegistrations(model, node_and_registrations)); @@ -665,7 +666,8 @@ TfLiteStatus MicroAllocator::StartModelAllocation( } TfLiteStatus MicroAllocator::FinishModelAllocation( - const Model* model, TfLiteEvalTensor* eval_tensors) { + const Model* model, TfLiteEvalTensor* eval_tensors, + void** scratch_buffer_handles) { if (!model_is_allocating_) { TF_LITE_REPORT_ERROR(error_reporter_, "MicroAllocator: Model allocation finished before " @@ -676,9 +678,13 @@ TfLiteStatus MicroAllocator::FinishModelAllocation( const SubGraph* subgraph = GetSubGraphFromModel(model); TFLITE_DCHECK(subgraph != nullptr); + TF_LITE_ENSURE_STATUS(MoveScratchBufferHandlesToTail()); TF_LITE_ENSURE_STATUS(CommitStaticMemoryPlan(model, subgraph, eval_tensors)); TF_LITE_ENSURE_STATUS(AllocateVariables(subgraph, eval_tensors)); + if (scratch_buffer_handles != nullptr) { + *scratch_buffer_handles = scratch_buffer_handles_; + } model_is_allocating_ = false; return kTfLiteOk; } @@ -690,49 +696,39 @@ void* MicroAllocator::AllocatePersistentBuffer(size_t bytes) { TfLiteStatus MicroAllocator::RequestScratchBufferInArena(int node_id, size_t bytes, int* buffer_idx) { - // A consistency check to make sure scratch_buffer_handles_ is contiguous i.e. - // scratch_buffer_handles_ is pointing to the last allocation from memory - // allocator. - if (scratch_buffer_handles_ != nullptr && - reinterpret_cast(scratch_buffer_handles_) != - memory_allocator_->GetTail()) { - TF_LITE_REPORT_ERROR(error_reporter_, - "Internal error: AllocateFromTail can not be called " - "between two RequestScratchBufferInArena calls."); - return kTfLiteError; + // This method is only called during Prepare stage, when the scratch buffer + // handles are placed in the head. + + // Allocate space for the new scratch buffer handle. + TF_LITE_ENSURE_STATUS(memory_allocator_->EnsureHeadSize( + sizeof(internal::ScratchBufferHandle) * (scratch_buffer_count_ + 1), + alignof(internal::ScratchBufferHandle))); + + if (scratch_buffer_handles_ == nullptr) { + // If this is the first scratch buffer handle, place it in the buffer head. + scratch_buffer_handles_ = reinterpret_cast( + memory_allocator_->GetBufferHead()); } + // Initialize the handle. `data` field will be set during memory planning. internal::ScratchBufferHandle* handle = - reinterpret_cast( - memory_allocator_->AllocateFromTail( - sizeof(internal::ScratchBufferHandle), - alignof(internal::ScratchBufferHandle))); - if (handle == nullptr) { - TF_LITE_REPORT_ERROR(error_reporter_, - "Failed to register scratch buffer handle for node %s", - node_id); - return kTfLiteError; - } + scratch_buffer_handles_ + scratch_buffer_count_; *handle = {}; handle->bytes = bytes; handle->node_idx = node_id; + + // Buffer idx starts from 0 in this implementation. *buffer_idx = scratch_buffer_count_; scratch_buffer_count_ += 1; - // scratch_buffer_handles_ is in reverse order. The following code ensures - // that scratch_buffers[0] is pointing to the newly allocated handle. - scratch_buffer_handles_ = handle; return kTfLiteOk; } -void* MicroAllocator::GetScratchBuffer(int buffer_idx) const { - if (static_cast(buffer_idx) >= scratch_buffer_count_) { - TF_LITE_REPORT_ERROR(error_reporter_, - "Buffer %d not found. %d buffers available.", - buffer_idx, scratch_buffer_count_); - return nullptr; - } - // scratch_buffer_handles_ is in reverse order. - return scratch_buffer_handles_[scratch_buffer_count_ - buffer_idx - 1].data; +void* MicroAllocator::GetScratchBuffer(void* scratch_buffer_handles, + int buffer_idx) { + internal::ScratchBufferHandle* handle = + reinterpret_cast(scratch_buffer_handles) + + buffer_idx; + return handle->data; } size_t MicroAllocator::used_bytes() const { @@ -1035,7 +1031,6 @@ TfLiteStatus MicroAllocator::CommitStaticMemoryPlan( builder.GetOfflinePlannedOffsets(model, &offline_planner_offsets)); TF_LITE_ENSURE_STATUS( builder.AddTensors(subgraph, offline_planner_offsets, eval_tensors)); - TF_LITE_ENSURE_STATUS(builder.AddScratchBuffers(scratch_buffer_handles_)); const AllocationInfo* allocation_info = builder.Finish(); @@ -1051,16 +1046,16 @@ TfLiteStatus MicroAllocator::CommitStaticMemoryPlan( size_t actual_available_arena_size = memory_allocator_->GetAvailableMemory(kBufferAlignment); + // Make sure we have enough arena size. if (planner.GetMaximumMemorySize() > actual_available_arena_size) { TF_LITE_REPORT_ERROR( error_reporter_, - "Arena size is too small for activation buffers. Needed %d but only " - "%d was available.", + "Arena size is too small for all buffers. Needed %u but only " + "%u was available.", planner.GetMaximumMemorySize(), actual_available_arena_size); return kTfLiteError; } - // Commit the plan. TF_LITE_ENSURE_STATUS(CommitPlan(error_reporter_, &planner, memory_allocator_->GetBufferHead(), @@ -1073,4 +1068,27 @@ TfLiteStatus MicroAllocator::CommitStaticMemoryPlan( return kTfLiteOk; } +TfLiteStatus MicroAllocator::InitScratchBufferHandles() { + scratch_buffer_count_ = 0; + scratch_buffer_handles_ = nullptr; + return kTfLiteOk; +} + +TfLiteStatus MicroAllocator::MoveScratchBufferHandlesToTail() { + if (scratch_buffer_count_ == 0) { + return kTfLiteOk; + } + auto src = scratch_buffer_handles_; + internal::ScratchBufferHandle* dest = + reinterpret_cast( + memory_allocator_->AllocateFromTail( + sizeof(internal::ScratchBufferHandle) * scratch_buffer_count_, + alignof(internal::ScratchBufferHandle))); + for (size_t i = 0; i < scratch_buffer_count_; i++) { + *(dest + i) = *(src + i); + } + scratch_buffer_handles_ = dest; + return kTfLiteOk; +} + } // namespace tflite diff --git a/tensorflow/lite/micro/micro_allocator.h b/tensorflow/lite/micro/micro_allocator.h index efd11b8b230..5b478832d3d 100644 --- a/tensorflow/lite/micro/micro_allocator.h +++ b/tensorflow/lite/micro/micro_allocator.h @@ -123,9 +123,12 @@ class MicroAllocator { // the 'head' section of the memory arena. All variable tensor data will also // be allocated. This method should be called after assigning model resources // in StartModelAllocation(). The eval_tensors pointer should be the value - // passed into this class during StartModelAllocation(). + // passed into this class during StartModelAllocation(). Scratch buffer + // handles are stored in the out-param `scratch_buffer_handles`. This value + // will be used in `GetScratchBuffer` call to retrieve scratch buffers. TfLiteStatus FinishModelAllocation(const Model* model, - TfLiteEvalTensor* eval_tensors); + TfLiteEvalTensor* eval_tensors, + void** scratch_buffer_handles = nullptr); // Allocates a TfLiteTensor struct and populates the returned value with // properties from the model flatbuffer. This struct is allocated from @@ -160,12 +163,18 @@ class MicroAllocator { // This method only allocates a BufferHandle holding information for memory // planning. The buffer ptr is ready after `FinishModelAllocation` and can // be retrieved by `GetScratchBuffer` method using the returned buffer_idx. - // Note that there should be no tail allocation between two consecutive - // `RequestScratchBufferInArena` calls. + // Note that this method should only be called in the Prepare stage. TfLiteStatus RequestScratchBufferInArena(int node_id, size_t bytes, int* buffer_idx); - // Returns the pointer to the planned scratch buffer. - void* GetScratchBuffer(int buffer_idx) const; + + // Return the number of scratch buffers in the allocator. + size_t GetScratchBufferCount() const { return scratch_buffer_count_; } + + // Return the pointer to the planned scratch buffer. `scratch_buffer_handles` + // should be the corresponding value returned in `FinishModelAllocation`. + // `scratch_buffer_handles` is intentionally desigend as void*. The actual + // data type is an implementation detail, and is only visible in this class. + static void* GetScratchBuffer(void* scratch_buffer_handles, int buffer_idx); // Returns the arena usage in bytes, only available after // `FinishModelAllocation`. Otherwise, it will return 0. @@ -236,13 +245,16 @@ class MicroAllocator { ErrorReporter* error_reporter_; bool model_is_allocating_; - // In reverse order for efficiency. - // i.e. scratch_buffer_handles_[0] is the handle for the last buffer, - // corresponding to the last RequestScratchBufferInArena call. + // Points to the first allocated scratch buffer handle. + // Scratch buffer handles are placed in the head during `Prepare` stage and + // then moved to the tail for static memory plan. internal::ScratchBufferHandle* scratch_buffer_handles_ = nullptr; // How many scratch buffers have been allocated. size_t scratch_buffer_count_ = 0; + virtual TfLiteStatus InitScratchBufferHandles(); + virtual TfLiteStatus MoveScratchBufferHandlesToTail(); + TF_LITE_REMOVE_VIRTUAL_DELETE }; diff --git a/tensorflow/lite/micro/micro_interpreter.cc b/tensorflow/lite/micro/micro_interpreter.cc index 8c2f8e031d8..f4f591b5ad7 100644 --- a/tensorflow/lite/micro/micro_interpreter.cc +++ b/tensorflow/lite/micro/micro_interpreter.cc @@ -59,13 +59,31 @@ TfLiteStatus ContextHelper::RequestScratchBufferInArena(TfLiteContext* ctx, size_t bytes, int* buffer_idx) { ContextHelper* helper = reinterpret_cast(ctx->impl_); - return helper->allocator_->RequestScratchBufferInArena( - helper->current_node_idx_, bytes, buffer_idx); + + // We can not forward the scratch buffer request to the allocator yet, + // otherwise the scratch buffer handles will ruin the data in `temp` section. + // These requests will be processed once the `temp` section is deallocated, + // i.e. after a node has been prepared. + + if (helper->scratch_buffer_count_ >= kMaxScratchBuffersPerOp) { + TF_LITE_REPORT_ERROR( + helper->error_reporter_, + "Node %d is allocating too many scratch buffers per op, max=%d", + helper->current_node_idx_, helper->scratch_buffer_count_); + } + helper->scrach_buffer_sizes_[helper->scratch_buffer_count_] = bytes; + // buffer_idx is 0 indexed. + *buffer_idx = helper->scratch_buffer_count_ + + helper->allocator_->GetScratchBufferCount(); + helper->scratch_buffer_count_++; + return kTfLiteOk; } void* ContextHelper::GetScratchBuffer(TfLiteContext* ctx, int buffer_idx) { - return reinterpret_cast(ctx->impl_) - ->allocator_->GetScratchBuffer(buffer_idx); + ContextHelper* helper = reinterpret_cast(ctx->impl_); + + return helper->allocator_->GetScratchBuffer(helper->scratch_buffer_handles_, + buffer_idx); } void ContextHelper::ReportOpError(struct TfLiteContext* context, @@ -92,12 +110,39 @@ TfLiteEvalTensor* ContextHelper::GetEvalTensor( return &helper->eval_tensors_[tensor_idx]; } -void ContextHelper::SetNodeIndex(int idx) { current_node_idx_ = idx; } +void ContextHelper::SetNodeIndex(int idx) { + if (scratch_buffer_count_ != 0) { + TF_LITE_REPORT_ERROR(error_reporter_, + "Internal error: Please commit scratch buffers " + "befrore moving to the next node"); + } + current_node_idx_ = idx; +} void ContextHelper::SetTfLiteEvalTensors(TfLiteEvalTensor* eval_tensors) { eval_tensors_ = eval_tensors; } +void ContextHelper::SetScratchBufferHandles(void* scratch_buffer_handle) { + scratch_buffer_handles_ = scratch_buffer_handle; +} + +TfLiteStatus ContextHelper::CommitScratchBuffers() { + size_t initial_buffer_count = allocator_->GetScratchBufferCount(); + for (size_t i = 0; i < scratch_buffer_count_; i++) { + int buffer_id; + allocator_->RequestScratchBufferInArena( + current_node_idx_, scrach_buffer_sizes_[i], &buffer_id); + if (static_cast(buffer_id) != initial_buffer_count + i) { + TF_LITE_REPORT_ERROR( + error_reporter_, + "Internal error. Scratch buffers are not contiguous.\n"); + } + } + scratch_buffer_count_ = 0; + return kTfLiteOk; +} + } // namespace internal MicroInterpreter::MicroInterpreter(const Model* model, @@ -297,6 +342,7 @@ TfLiteStatus MicroInterpreter::AllocateTensors() { } } allocator_.ResetTempAllocations(); + context_helper_.CommitScratchBuffers(); } context_helper_.SetNodeIndex(-1); @@ -306,8 +352,12 @@ TfLiteStatus MicroInterpreter::AllocateTensors() { context_.RequestScratchBufferInArena = nullptr; context_.GetScratchBuffer = context_helper_.GetScratchBuffer; + void* scratch_buffer_handles = nullptr; + TF_LITE_ENSURE_OK(&context_, - allocator_.FinishModelAllocation(model_, eval_tensors_)); + allocator_.FinishModelAllocation(model_, eval_tensors_, + &scratch_buffer_handles)); + context_helper_.SetScratchBufferHandles(scratch_buffer_handles); TF_LITE_ENSURE_STATUS(ResetVariableTensors()); tensors_allocated_ = true; diff --git a/tensorflow/lite/micro/micro_interpreter.h b/tensorflow/lite/micro/micro_interpreter.h index 0983a007011..f36d9d80f96 100644 --- a/tensorflow/lite/micro/micro_interpreter.h +++ b/tensorflow/lite/micro/micro_interpreter.h @@ -32,6 +32,8 @@ namespace tflite { namespace internal { +constexpr size_t kMaxScratchBuffersPerOp = 8; + // A helper class to encapsulate the implementation of APIs in Context. // context->impl_ points to an instance of this class. // Check tensorflow/lite/c/common.h for detailed descriptions. @@ -53,19 +55,28 @@ class ContextHelper { int tensor_idx); static TfLiteEvalTensor* GetEvalTensor(const struct TfLiteContext* context, int tensor_idx); + // Commits all scratch buffer allocations to MicroAllocator. + TfLiteStatus CommitScratchBuffers(); // Sets the current node index to assist with scratch buffer allocations: void SetNodeIndex(int idx); // Sets the pointer to a list of TfLiteEvalTensor instances. void SetTfLiteEvalTensors(TfLiteEvalTensor* eval_tensors); + // Sets the pointer to scratch buffer handle, which is needed by + // `GetScratchBuffer`. + void SetScratchBufferHandles(void* scratch_buffer_handle); private: - MicroAllocator* allocator_; - ErrorReporter* error_reporter_; - const Model* model_; - TfLiteEvalTensor* eval_tensors_; + MicroAllocator* allocator_ = nullptr; + ErrorReporter* error_reporter_ = nullptr; + const Model* model_ = nullptr; + TfLiteEvalTensor* eval_tensors_ = nullptr; + void* scratch_buffer_handles_ = nullptr; int current_node_idx_ = -1; + + size_t scrach_buffer_sizes_[kMaxScratchBuffersPerOp]; + size_t scratch_buffer_count_ = 0; }; } // namespace internal diff --git a/tensorflow/lite/micro/micro_interpreter_test.cc b/tensorflow/lite/micro/micro_interpreter_test.cc index 150dbead337..a4a4143a2ae 100644 --- a/tensorflow/lite/micro/micro_interpreter_test.cc +++ b/tensorflow/lite/micro/micro_interpreter_test.cc @@ -220,38 +220,45 @@ TF_LITE_MICRO_TEST(TestKernelMemoryPlanning) { tflite::AllOpsResolver op_resolver = tflite::testing::GetOpResolver(); - constexpr size_t allocator_buffer_size = 2048; + constexpr size_t allocator_buffer_size = 4096; uint8_t allocator_buffer[allocator_buffer_size]; - tflite::MicroInterpreter interpreter(model, op_resolver, allocator_buffer, - allocator_buffer_size, - micro_test::reporter); - TF_LITE_MICRO_EXPECT_EQ(interpreter.AllocateTensors(), kTfLiteOk); - TF_LITE_MICRO_EXPECT_EQ(static_cast(1), interpreter.inputs_size()); - TF_LITE_MICRO_EXPECT_EQ(static_cast(2), interpreter.outputs_size()); - TfLiteTensor* input = interpreter.input(0); - TF_LITE_MICRO_EXPECT_EQ(1, input->dims->size); - TF_LITE_MICRO_EXPECT_EQ(3, input->dims->data[0]); - input->data.uint8[0] = 2; - input->data.uint8[1] = 3; - input->data.uint8[2] = 1; + tflite::RecordingMicroAllocator* allocator = + tflite::RecordingMicroAllocator::Create( + allocator_buffer, allocator_buffer_size, micro_test::reporter); - uint8_t expected_median = 2; + // Make sure kernel memory planning works in multi-tenant context. + for (int i = 0; i < 3; i++) { + tflite::MicroInterpreter interpreter(model, op_resolver, allocator, + micro_test::reporter); + TF_LITE_MICRO_EXPECT_EQ(interpreter.AllocateTensors(), kTfLiteOk); + TF_LITE_MICRO_EXPECT_EQ(static_cast(1), interpreter.inputs_size()); + TF_LITE_MICRO_EXPECT_EQ(static_cast(2), interpreter.outputs_size()); - { - TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, interpreter.Invoke()); - TfLiteTensor* median = interpreter.output(0); - TF_LITE_MICRO_EXPECT_EQ(expected_median, median->data.uint8[0]); - TfLiteTensor* invoke_count = interpreter.output(1); - TF_LITE_MICRO_EXPECT_EQ(1, invoke_count->data.i32[0]); - } + TfLiteTensor* input = interpreter.input(0); + TF_LITE_MICRO_EXPECT_EQ(1, input->dims->size); + TF_LITE_MICRO_EXPECT_EQ(3, input->dims->data[0]); + input->data.uint8[0] = 2; + input->data.uint8[1] = 3; + input->data.uint8[2] = 1; - { - TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, interpreter.Invoke()); - TfLiteTensor* median = interpreter.output(0); - TF_LITE_MICRO_EXPECT_EQ(expected_median, median->data.uint8[0]); - TfLiteTensor* invoke_count = interpreter.output(1); - TF_LITE_MICRO_EXPECT_EQ(2, invoke_count->data.i32[0]); + uint8_t expected_median = 2; + + { + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, interpreter.Invoke()); + TfLiteTensor* median = interpreter.output(0); + TF_LITE_MICRO_EXPECT_EQ(expected_median, median->data.uint8[0]); + TfLiteTensor* invoke_count = interpreter.output(1); + TF_LITE_MICRO_EXPECT_EQ(1, invoke_count->data.i32[0]); + } + + { + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, interpreter.Invoke()); + TfLiteTensor* median = interpreter.output(0); + TF_LITE_MICRO_EXPECT_EQ(expected_median, median->data.uint8[0]); + TfLiteTensor* invoke_count = interpreter.output(1); + TF_LITE_MICRO_EXPECT_EQ(2, invoke_count->data.i32[0]); + } } } diff --git a/tensorflow/lite/micro/test_helpers.cc b/tensorflow/lite/micro/test_helpers.cc index 23c7ca96408..6a2d981dd34 100644 --- a/tensorflow/lite/micro/test_helpers.cc +++ b/tensorflow/lite/micro/test_helpers.cc @@ -593,13 +593,18 @@ TfLiteStatus SimpleStatefulOp::Prepare(TfLiteContext* context, TF_LITE_ENSURE_STATUS(context->RequestScratchBufferInArena( context, sizeof(uint8_t) * NumElements(input->dims), &data->sorting_buffer)); + // We can interleave scratch / persistent buffer allocation. + data->invoke_count = reinterpret_cast( + context->AllocatePersistentBuffer(context, sizeof(int))); + *data->invoke_count = 0; + return kTfLiteOk; } TfLiteStatus SimpleStatefulOp::Invoke(TfLiteContext* context, TfLiteNode* node) { OpData* data = reinterpret_cast(node->user_data); - data->invoke_count += 1; + *data->invoke_count += 1; const TfLiteTensor* input = GetInput(context, node, kInputTensor); const uint8_t* input_data = GetTensorData(input); @@ -626,7 +631,7 @@ TfLiteStatus SimpleStatefulOp::Invoke(TfLiteContext* context, int32_t* invoke_count_data = GetTensorData(invoke_count); median_data[0] = sorting_buffer[size / 2]; - invoke_count_data[0] = data->invoke_count; + invoke_count_data[0] = *data->invoke_count; return kTfLiteOk; } diff --git a/tensorflow/lite/micro/test_helpers.h b/tensorflow/lite/micro/test_helpers.h index a7897145d26..e41beb9fbb0 100644 --- a/tensorflow/lite/micro/test_helpers.h +++ b/tensorflow/lite/micro/test_helpers.h @@ -49,7 +49,7 @@ class SimpleStatefulOp { static constexpr int kMedianTensor = 0; static constexpr int kInvokeCount = 1; struct OpData { - int invoke_count = 0; + int* invoke_count = nullptr; int sorting_buffer = kBufferNotAllocated; }; diff --git a/tensorflow/lite/micro/tools/make/Makefile b/tensorflow/lite/micro/tools/make/Makefile index 418da265f08..b896820d03c 100644 --- a/tensorflow/lite/micro/tools/make/Makefile +++ b/tensorflow/lite/micro/tools/make/Makefile @@ -106,7 +106,11 @@ MICROLITE_LIB_NAME := libtensorflow-microlite.a # These two must be defined before we include the target specific Makefile.inc # because we filter out the examples that are not supported for those targets. # See targets/xtensa_xpg_makefile.inc for an example. -MICRO_LITE_EXAMPLE_TESTS := $(shell find tensorflow/lite/micro/examples/ -name Makefile.inc) +# We limit max depth of directories to search to do not include +# target specific Makefiles that are included directly by the +# main example Makefile. +# See examples/micro_speech/Makefile.inc for an example. +MICRO_LITE_EXAMPLE_TESTS := $(shell find tensorflow/lite/micro/examples/ -maxdepth 2 -name Makefile.inc) MICRO_LITE_BENCHMARKS := $(wildcard tensorflow/lite/micro/benchmarks/Makefile.inc) MICROLITE_TEST_SRCS := \ @@ -176,6 +180,7 @@ tensorflow/lite/kernels/internal/reference/integer_ops/depthwise_conv.h \ tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h \ tensorflow/lite/kernels/internal/reference/integer_ops/logistic.h \ tensorflow/lite/kernels/internal/reference/integer_ops/l2normalization.h \ +tensorflow/lite/kernels/internal/reference/integer_ops/mean.h \ tensorflow/lite/kernels/internal/reference/integer_ops/mul.h \ tensorflow/lite/kernels/internal/reference/integer_ops/pooling.h \ tensorflow/lite/kernels/internal/reference/integer_ops/tanh.h \ @@ -259,6 +264,8 @@ THIRD_PARTY_DOWNLOADS := $(eval $(call add_third_party_download,$(GEMMLOWP_URL),$(GEMMLOWP_MD5),gemmlowp,)) $(eval $(call add_third_party_download,$(FLATBUFFERS_URL),$(FLATBUFFERS_MD5),flatbuffers,)) $(eval $(call add_third_party_download,$(RUY_URL),$(RUY_MD5),ruy,)) +$(eval $(call add_third_party_download,$(PERSON_MODEL_URL),$(PERSON_MODEL_MD5),person_model_grayscale,)) +$(eval $(call add_third_party_download,$(PERSON_MODEL_INT8_URL),$(PERSON_MODEL_INT8_MD5),person_model_int8,)) # These target-specific makefiles should modify or replace options like # CXXFLAGS or LIBS to work for a specific targeted architecture. All logic diff --git a/tensorflow/lite/micro/tools/make/targets/hexagon_makefile.inc b/tensorflow/lite/micro/tools/make/targets/hexagon_makefile.inc index e46ca0717a4..3bbe6f9aeb9 100644 --- a/tensorflow/lite/micro/tools/make/targets/hexagon_makefile.inc +++ b/tensorflow/lite/micro/tools/make/targets/hexagon_makefile.inc @@ -1,10 +1,19 @@ # Settings for Hexagon toolchain. # REQUIRED: -# - Hexagon SDK 3.5 Toolkit (for hexagon-clang++, hexagon-sim). -# - HEXAGON_SDK_PREFIX environment variable must be set to location of +# - Hexagon SDK 3.5 Toolkit (for qurt, posix libs). +# HEXAGON_SDK_ROOT environment variable must be set to location of # Hexagon_SDK// on your machine. +# - Hexagon Tools root (for hexagon-clang++, hexagon-sim). +# The tool folder may be a part of the Hexagon SDK +# (e.g. $(HEXAGON_SDK_ROOT)/tools/HEXAGON_Tools) or installed +# separately. +# HEXAGON_ROOT environment variable must be set to location of +# HEXAGON_Tools on your machine. +# - HEXAGON_TOOL_VER: The Hexagon tool version (installed under HEXAGON_ROOT). +# For example: 8.3.07 # - HEXAGON_CPU_VER: The CPU version to use, will cause a compiler exception -# without providing a version. Acceptable values: v55-v67 +# without providing a version. Valid values may vary depending on tools +# version, but generally in the range: v55-v67 # # Unlike other targets, there is not currently a way to automatically download # the Hexagon SDK. For this reason, users are required to manually download @@ -12,8 +21,16 @@ ifeq ($(TARGET), hexagon) TARGET_ARCH := hexagon - ifndef HEXAGON_SDK_PREFIX - $(error HEXAGON_SDK_PREFIX is undefined) + ifndef HEXAGON_SDK_ROOT + $(error HEXAGON_SDK_ROOT is undefined) + endif + + ifndef HEXAGON_TOOL_VER + $(error HEXAGON_TOOL_VER is undefined) + endif + + ifndef HEXAGON_ROOT + $(error HEXAGON_ROOT is undefined) endif ifndef HEXAGON_CPU_VER @@ -55,6 +72,7 @@ ifeq ($(TARGET), hexagon) -mcpu=$(HEXAGON_CPU_VER) \ -m$(HEXAGON_CPU_VER) + export PATH := $(HEXAGON_ROOT)/$(HEXAGON_TOOL_VER)/Tools/bin:$(PATH) TARGET_TOOLCHAIN_PREFIX := hexagon- CXX_TOOL := clang++ CC_TOOL := clang @@ -63,11 +81,11 @@ ifeq ($(TARGET), hexagon) CCFLAGS += $(PLATFORM_ARGS) LDFLAGS += \ -Wl,--gc-sections -lhexagon \ - $(HEXAGON_SDK_PREFIX)/tools/HEXAGON_Tools/8.3.07/Tools/target/hexagon/lib/v66/libstdc++.a + $(HEXAGON_ROOT)/$(HEXAGON_TOOL_VER)/Tools/target/hexagon/lib/v66/libstdc++.a INCLUDES += \ - -I$(HEXAGON_SDK_PREFIX)/libs/common/qurt/computev66/include/posix \ - -I$(HEXAGON_SDK_PREFIX)/libs/common/qurt/computev66/include/qurt + -I$(HEXAGON_SDK_ROOT)/libs/common/qurt/computev66/include/posix \ + -I$(HEXAGON_SDK_ROOT)/libs/common/qurt/computev66/include/qurt TEST_SCRIPT := tensorflow/lite/micro/testing/test_hexagon_binary.sh endif diff --git a/tensorflow/lite/profiling/BUILD b/tensorflow/lite/profiling/BUILD index ac957590c21..b54e742e4b5 100644 --- a/tensorflow/lite/profiling/BUILD +++ b/tensorflow/lite/profiling/BUILD @@ -1,13 +1,14 @@ +load("//tensorflow:tensorflow.bzl", "if_not_windows") +load("//tensorflow:tensorflow.bzl", "get_compatible_with_portable") load("//tensorflow/lite:build_def.bzl", "tflite_copts") +load("//tensorflow/lite:special_rules.bzl", "tflite_portable_test_suite_combined") package( default_visibility = ["//visibility:public"], licenses = ["notice"], # Apache 2.0 ) -common_copts = [ - "-Wall", -] + tflite_copts() +common_copts = tflite_copts() + if_not_windows(["-Wall"]) cc_library( name = "profiler", @@ -23,6 +24,16 @@ cc_library( ], ) +cc_test( + name = "profiler_test", + srcs = ["profiler_test.cc"], + deps = [ + ":profiler", + ":test_main", + "@com_google_googletest//:gtest", + ], +) + cc_library( name = "atrace_profiler", srcs = ["atrace_profiler.cc"], @@ -35,10 +46,21 @@ cc_library( ], ) +cc_test( + name = "atrace_profiler_test", + srcs = ["atrace_profiler_test.cc"], + deps = [ + ":atrace_profiler", + ":test_main", + "@com_google_googletest//:gtest", + ], +) + cc_library( name = "platform_profiler", srcs = ["platform_profiler.cc"], hdrs = ["platform_profiler.h"], + compatible_with = get_compatible_with_portable(), copts = common_copts, deps = [ "//tensorflow/lite/core/api", @@ -48,16 +70,6 @@ cc_library( }), ) -cc_test( - name = "profiler_test", - srcs = ["profiler_test.cc"], - deps = [ - ":profiler", - "//tensorflow/lite/testing:util", - "@com_google_googletest//:gtest", - ], -) - cc_library( name = "profile_buffer", hdrs = ["profile_buffer.h"], @@ -69,6 +81,16 @@ cc_library( ], ) +cc_test( + name = "profile_buffer_test", + srcs = ["profile_buffer_test.cc"], + deps = [ + ":profile_buffer", + ":test_main", + "@com_google_googletest//:gtest", + ], +) + cc_library( name = "time", srcs = ["time.cc"], @@ -76,6 +98,16 @@ cc_library( copts = common_copts, ) +cc_test( + name = "time_test", + srcs = ["time_test.cc"], + deps = [ + ":test_main", + ":time", + "@com_google_googletest//:gtest", + ], +) + cc_library( name = "memory_info", srcs = ["memory_info.cc"], @@ -83,31 +115,21 @@ cc_library( copts = common_copts, ) -cc_test( - name = "time_test", - srcs = ["time_test.cc"], - copts = common_copts, - deps = [ - ":time", - "//tensorflow/lite/testing:util", - "@com_google_googletest//:gtest", - ], -) - cc_test( name = "memory_info_test", srcs = ["memory_info_test.cc"], - copts = common_copts, tags = [ # Some low-level checks, like heap size check, may break in asan, msan # and tsan. So, disable such tests. "noasan", "nomsan", "notsan", + # TODO(b/166227284): Fix the test for Android. + "tflite_not_portable_android", ], deps = [ ":memory_info", - "//tensorflow/lite/testing:util", + ":test_main", "@com_google_googletest//:gtest", ], ) @@ -125,10 +147,9 @@ cc_library( cc_test( name = "profile_summary_formatter_test", srcs = ["profile_summary_formatter_test.cc"], - copts = common_copts, deps = [ ":profile_summary_formatter", - "//tensorflow/lite/testing:util", + ":test_main", "@com_google_googletest//:gtest", ], ) @@ -151,26 +172,28 @@ cc_library( cc_test( name = "profile_summarizer_test", srcs = ["profile_summarizer_test.cc"], - copts = common_copts, deps = [ ":profile_summarizer", ":profiler", + ":test_main", "//tensorflow/lite:framework", "//tensorflow/lite:schema_fbs_version", "//tensorflow/lite/kernels:kernel_util", "//tensorflow/lite/kernels:subgraph_test_util", "//tensorflow/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +cc_library( + name = "test_main", + testonly = 1, + srcs = ["test_main.cc"], + visibility = ["//visibility:private"], + deps = [ "//tensorflow/lite/testing:util", "@com_google_googletest//:gtest", ], ) -cc_test( - name = "profile_buffer_test", - srcs = ["profile_buffer_test.cc"], - deps = [ - ":profile_buffer", - "//tensorflow/lite/testing:util", - "@com_google_googletest//:gtest", - ], -) +tflite_portable_test_suite_combined(combine_conditions = {"deps": [":test_main"]}) diff --git a/tensorflow/lite/profiling/atrace_profiler.cc b/tensorflow/lite/profiling/atrace_profiler.cc index 4bdaf9d9e06..cc29c2df34a 100644 --- a/tensorflow/lite/profiling/atrace_profiler.cc +++ b/tensorflow/lite/profiling/atrace_profiler.cc @@ -15,6 +15,9 @@ limitations under the License. #include "tensorflow/lite/profiling/atrace_profiler.h" #include +#if defined(__ANDROID__) +#include +#endif #include @@ -89,8 +92,16 @@ class ATraceProfiler : public tflite::Profiler { FpEndSection atrace_end_section_; }; -std::unique_ptr CreateATraceProfiler() { - return std::unique_ptr(new ATraceProfiler()); +std::unique_ptr MaybeCreateATraceProfiler() { +#if defined(__ANDROID__) + constexpr char kTraceProp[] = "debug.tflite.trace"; + char trace_enabled[PROP_VALUE_MAX] = ""; + int length = __system_property_get(kTraceProp, trace_enabled); + if (length == 1 && trace_enabled[0] == '1') { + return std::unique_ptr(new ATraceProfiler()); + } +#endif // __ANDROID__ + return nullptr; } } // namespace profiling diff --git a/tensorflow/lite/profiling/atrace_profiler.h b/tensorflow/lite/profiling/atrace_profiler.h index d103cbc8536..044db1cd6cc 100644 --- a/tensorflow/lite/profiling/atrace_profiler.h +++ b/tensorflow/lite/profiling/atrace_profiler.h @@ -22,7 +22,7 @@ limitations under the License. namespace tflite { namespace profiling { -std::unique_ptr CreateATraceProfiler(); +std::unique_ptr MaybeCreateATraceProfiler(); } // namespace profiling } // namespace tflite diff --git a/tensorflow/lite/profiling/atrace_profiler_test.cc b/tensorflow/lite/profiling/atrace_profiler_test.cc new file mode 100644 index 00000000000..d2a5c5264b5 --- /dev/null +++ b/tensorflow/lite/profiling/atrace_profiler_test.cc @@ -0,0 +1,48 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/profiling/atrace_profiler.h" + +#if defined(__ANDROID__) +#include +#endif + +#include +#include + +namespace tflite { +namespace profiling { + +namespace { + +TEST(ATraceProfilerTest, MaybeCreateATraceProfiler) { + auto default_profiler = MaybeCreateATraceProfiler(); + EXPECT_EQ(nullptr, default_profiler.get()); + +#if defined(__ANDROID__) + if (__system_property_set("debug.tflite.trace", "1") == 0) { + auto profiler = MaybeCreateATraceProfiler(); + EXPECT_NE(nullptr, profiler.get()); + } + + if (__system_property_set("debug.tflite.trace", "0") == 0) { + auto no_profiler = MaybeCreateATraceProfiler(); + EXPECT_EQ(nullptr, no_profiler.get()); + } +#endif // __ANDROID__ +} + +} // namespace +} // namespace profiling +} // namespace tflite diff --git a/tensorflow/lite/profiling/memory_info_test.cc b/tensorflow/lite/profiling/memory_info_test.cc index a6bd2e4a667..9b580b75adf 100644 --- a/tensorflow/lite/profiling/memory_info_test.cc +++ b/tensorflow/lite/profiling/memory_info_test.cc @@ -15,7 +15,6 @@ limitations under the License. #include "tensorflow/lite/profiling/memory_info.h" #include -#include "tensorflow/lite/testing/util.h" namespace tflite { namespace profiling { @@ -71,9 +70,3 @@ TEST(MemoryUsage, IsSupported) { } // namespace memory } // namespace profiling } // namespace tflite - -int main(int argc, char** argv) { - ::tflite::LogToStderr(); - ::testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); -} diff --git a/tensorflow/lite/profiling/platform_profiler.cc b/tensorflow/lite/profiling/platform_profiler.cc index cd0770c2348..6ee290cb982 100644 --- a/tensorflow/lite/profiling/platform_profiler.cc +++ b/tensorflow/lite/profiling/platform_profiler.cc @@ -25,11 +25,11 @@ limitations under the License. namespace tflite { namespace profiling { -std::unique_ptr CreatePlatformProfiler() { +std::unique_ptr MaybeCreatePlatformProfiler() { #if defined(__ANDROID__) - return CreateATraceProfiler(); + return MaybeCreateATraceProfiler(); #else - return std::unique_ptr(nullptr); + return nullptr; #endif } diff --git a/tensorflow/lite/profiling/platform_profiler.h b/tensorflow/lite/profiling/platform_profiler.h index 87361b30b50..52a51f87634 100644 --- a/tensorflow/lite/profiling/platform_profiler.h +++ b/tensorflow/lite/profiling/platform_profiler.h @@ -22,7 +22,7 @@ limitations under the License. namespace tflite { namespace profiling { -std::unique_ptr CreatePlatformProfiler(); +std::unique_ptr MaybeCreatePlatformProfiler(); } // namespace profiling } // namespace tflite diff --git a/tensorflow/lite/profiling/profile_buffer_test.cc b/tensorflow/lite/profiling/profile_buffer_test.cc index ab98cbb0d13..457b6ff2aba 100644 --- a/tensorflow/lite/profiling/profile_buffer_test.cc +++ b/tensorflow/lite/profiling/profile_buffer_test.cc @@ -20,7 +20,6 @@ limitations under the License. #include #include -#include "tensorflow/lite/testing/util.h" namespace tflite { namespace profiling { @@ -121,9 +120,3 @@ TEST(ProfileBufferTest, Enable) { } // namespace } // namespace profiling } // namespace tflite - -int main(int argc, char** argv) { - ::tflite::LogToStderr(); - ::testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); -} diff --git a/tensorflow/lite/profiling/profile_summarizer_test.cc b/tensorflow/lite/profiling/profile_summarizer_test.cc index 98d26196b75..fd81c00e603 100644 --- a/tensorflow/lite/profiling/profile_summarizer_test.cc +++ b/tensorflow/lite/profiling/profile_summarizer_test.cc @@ -26,7 +26,6 @@ limitations under the License. #include "tensorflow/lite/kernels/test_util.h" #include "tensorflow/lite/model.h" #include "tensorflow/lite/profiling/buffered_profiler.h" -#include "tensorflow/lite/testing/util.h" #include "tensorflow/lite/version.h" namespace tflite { @@ -224,9 +223,3 @@ TEST_F(ProfileSummarizerIfOpTest, TestIfFalse) { } // namespace } // namespace profiling } // namespace tflite - -int main(int argc, char** argv) { - ::tflite::LogToStderr(); - ::testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); -} diff --git a/tensorflow/lite/profiling/profile_summary_formatter_test.cc b/tensorflow/lite/profiling/profile_summary_formatter_test.cc index 78d46aae1ea..0de0e733842 100644 --- a/tensorflow/lite/profiling/profile_summary_formatter_test.cc +++ b/tensorflow/lite/profiling/profile_summary_formatter_test.cc @@ -19,7 +19,6 @@ limitations under the License. #include #include -#include "tensorflow/lite/testing/util.h" namespace tflite { namespace profiling { @@ -156,9 +155,3 @@ TEST(SummaryWriterTest, DelegationShortSummary) { } // namespace } // namespace profiling } // namespace tflite - -int main(int argc, char** argv) { - ::tflite::LogToStderr(); - ::testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); -} diff --git a/tensorflow/lite/profiling/profiler_test.cc b/tensorflow/lite/profiling/profiler_test.cc index 1d8455e3647..c59dca9738e 100644 --- a/tensorflow/lite/profiling/profiler_test.cc +++ b/tensorflow/lite/profiling/profiler_test.cc @@ -22,7 +22,6 @@ limitations under the License. #include #include -#include "tensorflow/lite/testing/util.h" namespace tflite { namespace profiling { @@ -136,9 +135,3 @@ TEST(ProfilingTest, NoopProfiler) { } // namespace } // namespace profiling } // namespace tflite - -int main(int argc, char** argv) { - ::tflite::LogToStderr(); - ::testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); -} diff --git a/tensorflow/lite/profiling/test_main.cc b/tensorflow/lite/profiling/test_main.cc new file mode 100644 index 00000000000..df6b8cb0477 --- /dev/null +++ b/tensorflow/lite/profiling/test_main.cc @@ -0,0 +1,23 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include "tensorflow/lite/testing/util.h" + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/lite/profiling/time_test.cc b/tensorflow/lite/profiling/time_test.cc index 6f08479adeb..8a85de9fe51 100644 --- a/tensorflow/lite/profiling/time_test.cc +++ b/tensorflow/lite/profiling/time_test.cc @@ -15,7 +15,6 @@ limitations under the License. #include "tensorflow/lite/profiling/time.h" #include -#include "tensorflow/lite/testing/util.h" namespace tflite { namespace profiling { @@ -48,9 +47,3 @@ TEST(TimeTest, SleepForMicros) { } // namespace time } // namespace profiling } // namespace tflite - -int main(int argc, char** argv) { - ::tflite::LogToStderr(); - ::testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); -} diff --git a/tensorflow/lite/python/lite_v2_test.py b/tensorflow/lite/python/lite_v2_test.py index fa046706e52..170db3f6bce 100644 --- a/tensorflow/lite/python/lite_v2_test.py +++ b/tensorflow/lite/python/lite_v2_test.py @@ -85,8 +85,6 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest): # Convert model. converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func]) - # We don't support integer types as we don't have statistical information - # to quantize (only supported for post training integer quantization). with self.assertRaises(ValueError) as error: converter.inference_input_type = inference_input_output_type converter.inference_output_type = inference_input_output_type @@ -212,8 +210,6 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest): # Convert quantized model. quantized_converter = lite.TFLiteConverterV2.from_concrete_functions([func]) quantized_converter.optimizations = [lite.Optimize.DEFAULT] - # We don't support integer types as we don't have statistical information - # to quantize (only supported for post training integer quantization). with self.assertRaises(ValueError) as error: quantized_converter.inference_input_type = inference_input_output_type quantized_converter.inference_output_type = inference_input_output_type @@ -223,11 +219,20 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest): 'must be tf.float32.', str(error.exception)) @parameterized.named_parameters( - ('_DefaultFLOAT32InputOutput', lite.constants.FLOAT), - ('_INT8InputOutput', lite.constants.INT8), - ('_UINT8InputOutput', lite.constants.QUANTIZED_UINT8)) - def testPostTrainingIntegerAllowFloatQuantization( - self, inference_input_output_type): + ('_Default', False, False, lite.constants.FLOAT), + ('_INT8InputOutput', False, False, lite.constants.INT8), + ('_UINT8InputOutput', False, False, lite.constants.QUANTIZED_UINT8), + ('_INT16Quantize', False, True, lite.constants.FLOAT), + ('_INT16Quantize_INT16InputOutput', False, True, lite.constants.INT16), + ('_IntOnly', True, False, lite.constants.FLOAT), + ('_IntOnly_INT8InputOutput', True, False, lite.constants.INT8), + ('_IntOnly_UINT8InputOutput', True, False, + lite.constants.QUANTIZED_UINT8), + ('_IntOnly_INT16Quantize', True, True, lite.constants.FLOAT), + ('_IntOnly_INT16Quantize_INT16InputOutput', True, True, + lite.constants.INT16)) + def testIntegerQuantization(self, is_int_only, is_int16_quantize, + inference_input_output_type): func, calibration_gen = self._getIntegerQuantizeModel() # Convert float model. @@ -239,111 +244,8 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest): quantized_converter = lite.TFLiteConverterV2.from_concrete_functions([func]) quantized_converter.optimizations = [lite.Optimize.DEFAULT] quantized_converter.representative_dataset = calibration_gen - quantized_converter.inference_input_type = inference_input_output_type - quantized_converter.inference_output_type = inference_input_output_type - quantized_tflite_model = quantized_converter.convert() - self.assertIsNotNone(quantized_tflite_model) - - interpreter = Interpreter(model_content=quantized_tflite_model) - interpreter.allocate_tensors() - input_details = interpreter.get_input_details() - self.assertLen(input_details, 1) - self.assertEqual(inference_input_output_type.as_numpy_dtype, - input_details[0]['dtype']) - output_details = interpreter.get_output_details() - self.assertLen(output_details, 1) - self.assertEqual(inference_input_output_type.as_numpy_dtype, - output_details[0]['dtype']) - - # Ensure that the quantized tflite model is smaller. - self.assertLess(len(quantized_tflite_model), len(tflite_model)) - - def testPostTrainingIntegerAllowFloatQuantizationINT16InputOutput(self): - func, calibration_gen = self._getIntegerQuantizeModel() - - # Convert float model. - converter = lite.TFLiteConverterV2.from_concrete_functions([func]) - tflite_model = converter.convert() - self.assertTrue(tflite_model) - - # Post-training quantization 16x8 with float fallback allowed. - quantized_converter = lite.TFLiteConverterV2.from_concrete_functions([func]) - quantized_converter.optimizations = [lite.Optimize.DEFAULT] - quantized_converter.representative_dataset = calibration_gen - quantized_converter.target_spec.supported_ops = [ - lite.OpsSet.\ - EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8, - lite.OpsSet.TFLITE_BUILTINS - ] - inference_input_output_type = lite.constants.INT16 - quantized_converter.inference_input_type = inference_input_output_type - quantized_converter.inference_output_type = inference_input_output_type - quantized_tflite_model = quantized_converter.convert() - self.assertIsNotNone(quantized_tflite_model) - - interpreter = Interpreter(model_content=quantized_tflite_model) - interpreter.allocate_tensors() - input_details = interpreter.get_input_details() - self.assertLen(input_details, 1) - self.assertEqual(inference_input_output_type.as_numpy_dtype, - input_details[0]['dtype']) - output_details = interpreter.get_output_details() - self.assertLen(output_details, 1) - self.assertEqual(inference_input_output_type.as_numpy_dtype, - output_details[0]['dtype']) - - # Ensure that the quantized tflite model is smaller. - self.assertLess(len(quantized_tflite_model), len(tflite_model)) - - def testPostTrainingIntegerQuant16x8MismatchInferenceParams(self): - # In this test we check that when we do 16x8 post-training - # quantization and set inference_input(output)_type to - # constants.INT8, we have an error. - func, calibration_gen = self._getIntegerQuantizeModel() - - # Convert quantized model. - quantized_converter = lite.TFLiteConverterV2.from_concrete_functions([func]) - quantized_converter.optimizations = [lite.Optimize.DEFAULT] - quantized_converter.representative_dataset = calibration_gen - quantized_converter.target_spec.supported_ops = [ - lite.OpsSet.\ - EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8 - ] - - with self.assertRaises(ValueError) as error: - quantized_converter.inference_input_type = lite.constants.INT8 - quantized_converter.inference_output_type = lite.constants.INT8 - quantized_converter.convert() - self.assertEqual( - "The inference_input_type and inference_output_type " - "must be in ['tf.float32', 'tf.int16'].", str(error.exception)) - - @parameterized.named_parameters( - ('_DefaultFLOAT32InputOutput_UseTargetTypesFlag', lite.constants.FLOAT, - False, False), - ('_DefaultFLOAT32InputOutput', lite.constants.FLOAT, True, False), - ('_INT8InputOutput', lite.constants.INT8, True, False), - ('_UINT8InputOutput', lite.constants.QUANTIZED_UINT8, True, False), - ('_INT16InputOutput', lite.constants.INT16, True, True)) - @test_util.run_v2_only - def testPostTrainingIntegerNoFloatQuantization(self, - inference_input_output_type, - use_target_ops_flag, - quantization_16x8): - func, calibration_gen = self._getIntegerQuantizeModel() - - # Convert float model. - converter = lite.TFLiteConverterV2.from_concrete_functions([func]) - tflite_model = converter.convert() - self.assertTrue(tflite_model) - - # Convert model by specifying target spec (instead of optimizations), since - # when targeting an integer only backend, quantization is mandatory. - quantized_converter = lite.TFLiteConverterV2.from_concrete_functions([func]) - quantized_converter.optimizations = [lite.Optimize.DEFAULT] - quantized_converter.representative_dataset = calibration_gen - if use_target_ops_flag: - if quantization_16x8: + if is_int_only: + if is_int16_quantize: quantized_converter.target_spec.supported_ops = [ lite.OpsSet.\ EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8 @@ -353,7 +255,12 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest): lite.OpsSet.TFLITE_BUILTINS_INT8 ] else: - quantized_converter.target_spec.supported_types = [lite.constants.INT8] + if is_int16_quantize: + quantized_converter.target_spec.supported_ops = [ + lite.OpsSet.\ + EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8, + lite.OpsSet.TFLITE_BUILTINS + ] quantized_converter.inference_input_type = inference_input_output_type quantized_converter.inference_output_type = inference_input_output_type quantized_tflite_model = quantized_converter.convert() @@ -373,6 +280,30 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest): # Ensure that the quantized tflite model is smaller. self.assertLess(len(quantized_tflite_model), len(tflite_model)) + @parameterized.named_parameters( + ('_INT16Quantize_INT8InputOutput', True, lite.constants.INT8)) + def testInvalidIntegerQuantization(self, is_int16_quantize, + inference_input_output_type): + func, calibration_gen = self._getIntegerQuantizeModel() + + # Convert quantized model. + quantized_converter = lite.TFLiteConverterV2.from_concrete_functions([func]) + quantized_converter.optimizations = [lite.Optimize.DEFAULT] + quantized_converter.representative_dataset = calibration_gen + if is_int16_quantize: + quantized_converter.target_spec.supported_ops = [ + lite.OpsSet.\ + EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8, + lite.OpsSet.TFLITE_BUILTINS + ] + with self.assertRaises(ValueError) as error: + quantized_converter.inference_input_type = lite.constants.INT8 + quantized_converter.inference_output_type = lite.constants.INT8 + quantized_converter.convert() + self.assertEqual( + "The inference_input_type and inference_output_type " + "must be in ['tf.float32', 'tf.int16'].", str(error.exception)) + def testCalibrateAndQuantizeBuiltinInt16(self): func, calibration_gen = self._getIntegerQuantizeModel() diff --git a/tensorflow/lite/testing/op_tests/leaky_relu.py b/tensorflow/lite/testing/op_tests/leaky_relu.py index e37df7722f5..0d2ec384917 100644 --- a/tensorflow/lite/testing/op_tests/leaky_relu.py +++ b/tensorflow/lite/testing/op_tests/leaky_relu.py @@ -28,12 +28,13 @@ from tensorflow.lite.testing.zip_test_utils import register_make_test_function def make_leaky_relu_tests(options): """Make a set of tests to do LeakyRelu.""" - test_parameters = [ - { - "input_shape": [[], [1], [5], [1, 10, 10, 3], [3, 3, 3, 3]], - "alpha": [0.1, 1.0, 2.0, -0.1, -1.0, -2.0], - }, - ] + test_parameters = [{ + "input_shape": [[], [1], [5], [1, 10, 10, 3], [3, 3, 3, 3]], + "alpha": [0.1, 1.0, 2.0, -0.1, -1.0, -2.0], + "fully_quantize": [False, True], + "input_range": [(-3, 10)], + "quant_16x8": [False, True], + }] def build_graph(parameters): """Build the graph for the test case.""" diff --git a/tensorflow/lite/tools/benchmark/BUILD b/tensorflow/lite/tools/benchmark/BUILD index 774e7ed7088..eb3f37aef58 100644 --- a/tensorflow/lite/tools/benchmark/BUILD +++ b/tensorflow/lite/tools/benchmark/BUILD @@ -158,7 +158,6 @@ cc_library( "//tensorflow/lite/c:common", "//tensorflow/lite/kernels:builtin_ops", "//tensorflow/lite/kernels:cpu_backend_context", - "//tensorflow/lite/profiling:platform_profiler", "//tensorflow/lite/profiling:profile_summary_formatter", "//tensorflow/lite/profiling:profiler", "//tensorflow/lite/tools:logging", diff --git a/tensorflow/lite/tools/benchmark/README.md b/tensorflow/lite/tools/benchmark/README.md index 453ea5b986a..df432daa2e2 100644 --- a/tensorflow/lite/tools/benchmark/README.md +++ b/tensorflow/lite/tools/benchmark/README.md @@ -36,11 +36,6 @@ and the following optional parameters: mean use no delay. * `enable_op_profiling`: `bool` (default=false) \ Whether to enable per-operator profiling measurement. -* `enable_platform_tracing`: `bool` (default=false) \ - Whether to enable platform-wide tracing. Needs to be combined with - 'enable_op_profiling'. Note, the platform-wide tracing might not work if the - tool runs as a commandline native binary. For example, on Android, the - ATrace-based tracing only works when the tool is launched as an APK. * `profiling_output_csv_file`: `str` (default="") \ File path to export profile data to as CSV. The results are printed to `stdout` if option is not set. Requires `enable_op_profiling` to be `true` diff --git a/tensorflow/lite/tools/benchmark/android/README.md b/tensorflow/lite/tools/benchmark/android/README.md index 3475d47632a..d41090d9515 100644 --- a/tensorflow/lite/tools/benchmark/android/README.md +++ b/tensorflow/lite/tools/benchmark/android/README.md @@ -96,7 +96,13 @@ page for more detailed information. (0)-(3) Follow the steps (0)-(3) of [build/install/run](#to-buildinstallrun) section. -(4) Set up Quick Settings tile for System Tracing app on your device. Follow the +(4) Enable platform tracing. + +``` +adb shell setprop debug.tflite.trace 1 +``` + +(5) Set up Quick Settings tile for System Tracing app on your device. Follow the [instruction](https://developer.android.com/topic/performance/tracing/on-device#set-up-tile). The System Tracing tile will be added to the Quick Settings panel. @@ -105,20 +111,20 @@ Refer to the [guide](https://developer.android.com/topic/performance/tracing/on-device#app-menu) for more information. -(5) Tap the System Tracing tile, which has the label "Record trace". The tile +(6) Tap the System Tracing tile, which has the label "Record trace". The tile becomes enabled, and a persistent notification appears to notify you that the system is now recording a trace. -(6) Run the benchmark with platform tracing enabled. +(7) Run the benchmark with platform tracing enabled. ``` adb shell am start -S \ -n org.tensorflow.lite.benchmark/.BenchmarkModelActivity \ --es args '"--graph=/data/local/tmp/mobilenet_quant_v1_224.tflite \ - --num_threads=4 --enable_op_profiling=true --enable_platform_tracing=true"' + --num_threads=4"' ``` -(7) Wait until the benchmark finishes. It can be checked from Android log +(8) Wait until the benchmark finishes. It can be checked from Android log messages, e.g., ``` @@ -127,14 +133,14 @@ adb logcat | grep "Average inference" ... tflite : Average inference timings in us: Warmup: 91471, Init: 4108, Inference: 80660.1 ``` -(8) Stop tracing by tapping either the System Tracing tile in the Quick Settings +(9) Stop tracing by tapping either the System Tracing tile in the Quick Settings panel or on the System Tracing notification. The system displays a new notification that contains the message "Saving trace". When saving is complete, the system dismisses the notification and displays a third notification "Trace saved", confirming that your trace has been saved and that you're ready to share the system trace. -(9) +(10) [Share](https://developer.android.com/topic/performance/tracing/on-device#share-trace) a trace file, [convert](https://developer.android.com/topic/performance/tracing/on-device#converting_between_trace_formats) @@ -143,3 +149,9 @@ between tracing formats and an HTML report. Note that, the captured tracing file format is either in Perfetto format or in Systrace format depending on the Android version of your device. Select the appropriate method to handle the generated file. + +(11) Disable platform tracing. + +``` +adb shell setprop debug.tflite.trace 0 +``` diff --git a/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc b/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc index ef9742eaac7..511244cee88 100644 --- a/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc +++ b/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc @@ -34,7 +34,6 @@ limitations under the License. #include "tensorflow/lite/kernels/register.h" #include "tensorflow/lite/model.h" #include "tensorflow/lite/op_resolver.h" -#include "tensorflow/lite/profiling/platform_profiler.h" #include "tensorflow/lite/profiling/profile_summary_formatter.h" #include "tensorflow/lite/string_util.h" #include "tensorflow/lite/tools/benchmark/benchmark_utils.h" @@ -61,20 +60,6 @@ constexpr int kOpProfilingEnabledDefault = true; constexpr int kOpProfilingEnabledDefault = false; #endif -// Dumps platform-wide tracing files via a platform-based profiler that's built -// upon platform tracing tools, like ATrace on Android etc. -class PlatformProfilingListener : public BenchmarkListener { - public: - explicit PlatformProfilingListener(Interpreter* interpreter) { - TFLITE_TOOLS_CHECK(interpreter); - platform_profiler_ = profiling::CreatePlatformProfiler(); - interpreter->SetProfiler(platform_profiler_.get()); - } - - private: - std::unique_ptr platform_profiler_; -}; - // Dumps ruy profiling events if the ruy profiler is enabled. class RuyProfileListener : public BenchmarkListener { public: @@ -269,8 +254,6 @@ BenchmarkParams BenchmarkTfLiteModel::DefaultParams() { BenchmarkParam::Create(1024)); default_params.AddParam("profiling_output_csv_file", BenchmarkParam::Create("")); - default_params.AddParam("enable_platform_tracing", - BenchmarkParam::Create(false)); for (const auto& delegate_provider : tools::GetRegisteredDelegateProviders()) { @@ -331,10 +314,7 @@ std::vector BenchmarkTfLiteModel::GetFlags() { CreateFlag( "profiling_output_csv_file", ¶ms_, "File path to export profile data as CSV, if not set " - "prints to stdout."), - CreateFlag("enable_platform_tracing", ¶ms_, - "enable platform-wide tracing, only meaningful when " - "--enable_op_profiling is set to true.")}; + "prints to stdout.")}; flags.insert(flags.end(), specific_flags.begin(), specific_flags.end()); @@ -369,8 +349,6 @@ void BenchmarkTfLiteModel::LogParams() { "Max profiling buffer entries", verbose); LOG_BENCHMARK_PARAM(std::string, "profiling_output_csv_file", "CSV File to export profiling data to", verbose); - LOG_BENCHMARK_PARAM(bool, "enable_platform_tracing", - "Enable platform-wide tracing", verbose); for (const auto& delegate_provider : tools::GetRegisteredDelegateProviders()) { @@ -746,11 +724,6 @@ std::unique_ptr BenchmarkTfLiteModel::MayCreateProfilingListener() const { if (!params_.Get("enable_op_profiling")) return nullptr; - if (params_.Get("enable_platform_tracing")) { - return std::unique_ptr( - new PlatformProfilingListener(interpreter_.get())); - } - return std::unique_ptr(new ProfilingListener( interpreter_.get(), params_.Get("max_profiling_buffer_entries"), params_.Get("profiling_output_csv_file"), diff --git a/tensorflow/lite/tools/benchmark/experimental/c/c_api_types.h b/tensorflow/lite/tools/benchmark/experimental/c/c_api_types.h index d320a90d005..31405dfb998 100644 --- a/tensorflow/lite/tools/benchmark/experimental/c/c_api_types.h +++ b/tensorflow/lite/tools/benchmark/experimental/c/c_api_types.h @@ -226,6 +226,17 @@ void TfLiteFloatArrayFree(TfLiteFloatArray* a); } \ } while (0) +#define TF_LITE_ENSURE_NEAR(context, a, b, epsilon) \ + do { \ + auto delta = ((a) > (b)) ? ((a) - (b)) : ((b) - (a)); \ + if (delta > epsilon) { \ + TF_LITE_KERNEL_LOG((context), "%s:%d %s not near %s (%f != %f)", \ + __FILE__, __LINE__, #a, #b, static_cast(a), \ + static_cast(b)); \ + return kTfLiteError; \ + } \ + } while (0) + #define TF_LITE_ENSURE_OK(context, status) \ do { \ const TfLiteStatus s = (status); \ diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index c8be168d814..f39797f8158 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -678,6 +678,7 @@ tf_python_pybind_extension( "client/tf_session_helper.h", "lib/core/numpy.h", "lib/core/safe_ptr.h", + "lib/core/safe_pyobject_ptr.h", "//tensorflow/c:headers", "//tensorflow/c/eager:headers", "//tensorflow/c/eager:pywrap_required_hdrs", @@ -880,6 +881,7 @@ tf_python_pybind_extension( hdrs = [ "lib/core/ndarray_tensor.h", "lib/core/safe_ptr.h", + "lib/core/safe_pyobject_ptr.h", ":py_exception_registry_hdr", "//tensorflow/c:checkpoint_reader_hdrs", "//tensorflow/c:headers", @@ -940,12 +942,17 @@ tf_python_pybind_extension( ], ) +# TODO(edloper): Remove unused dependency on safe_ptr. (Blocker: there are +# targets that depend are relying on cpp_python_util to pull in safe_ptr's +# third_party/tensorflow/c:c_api_no_xla dependency, which registers +# ops/gradients, rather than depending on it themselves.) cc_library( name = "cpp_python_util", srcs = ["util/util.cc"], hdrs = ["util/util.h"], deps = [ ":safe_ptr", + ":safe_pyobject_ptr", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//third_party/python_runtime:headers", @@ -1000,6 +1007,15 @@ tf_python_pybind_extension( ], ) +cc_library( + name = "safe_pyobject_ptr", + srcs = ["lib/core/safe_pyobject_ptr.cc"], + hdrs = ["lib/core/safe_pyobject_ptr.h"], + deps = [ + "//third_party/python_runtime:headers", + ], +) + cc_library( name = "safe_ptr", srcs = [ @@ -1008,6 +1024,7 @@ cc_library( ], hdrs = ["lib/core/safe_ptr.h"], deps = [ + ":safe_pyobject_ptr", "//tensorflow/c:c_api_no_xla", "//third_party/python_runtime:headers", ], @@ -1021,6 +1038,7 @@ cc_library( "lib/core/ndarray_tensor_bridge.h", "lib/core/numpy.h", "lib/core/safe_ptr.h", + "lib/core/safe_pyobject_ptr.h", "//tensorflow/c:headers", "//tensorflow/c/eager:headers", ], @@ -1627,13 +1645,47 @@ py_library( ], ) +cc_library( + name = "py_context_manager", + srcs = ["framework/py_context_manager.cc"], + hdrs = ["framework/py_context_manager.h"], + deps = [ + ":safe_pyobject_ptr", + "//tensorflow/core:lib", # for core/platform/logging.h + "//third_party/python_runtime:headers", + ], +) + +# Pybind extension used by py_context_manager_test. +tf_python_pybind_extension( + name = "_py_context_manager", + srcs = ["framework/py_context_manager_pybind.cc"], + module_name = "_py_context_manager", + deps = [ + ":py_context_manager", + "//third_party/python_runtime:headers", + "@pybind11", + ], +) + +tf_py_test( + name = "py_context_manager_test", + srcs = ["framework/py_context_manager_test.py"], + python_version = "PY3", + tags = ["no_pip"], + tfrt_enabled = True, + deps = [ + ":_py_context_manager", + ], +) + cc_library( name = "op_def_util_cc", srcs = ["framework/op_def_util.cc"], hdrs = ["framework/op_def_util.h"], deps = [ ":cpp_python_util", - ":safe_ptr", + ":safe_pyobject_ptr", "//tensorflow/core:protos_all_cc", "@com_google_absl//absl/strings", ], @@ -1644,6 +1696,8 @@ cc_library( # depending on that target adds dependencies that register objects; and since the # extension is built as a shared object in some kokoro tests, this causes those objects # to get registered multiple times (which fails). +# TODO(edloper): Simplify this, once cpp_python_util is changed to not depend on +# safe_ptr (which transitively depends on third_party/tensorflow/c:c_api_no_xla). tf_python_pybind_extension( name = "_op_def_util", srcs = [ @@ -1653,6 +1707,7 @@ tf_python_pybind_extension( hdrs = [ "framework/op_def_util.h", "lib/core/safe_ptr.h", + "lib/core/safe_pyobject_ptr.h", "util/util.h", "//tensorflow/c:headers", "//tensorflow/c/eager:headers", @@ -3421,6 +3476,25 @@ tf_py_test( ], ) +tf_py_test( + name = "collective_ops_multi_worker_test", + size = "medium", + srcs = ["ops/collective_ops_multi_worker_test.py"], + python_version = "PY3", + tags = ["no_rocm"], + tfrt_enabled = False, + deps = [ + ":collective_ops", + ":constant_op", + ":errors", + "//tensorflow/python/distribute:multi_process_runner", + "//tensorflow/python/distribute:multi_worker_test_base", + "//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib", + "//tensorflow/python/eager:context", + "//tensorflow/python/eager:test", + ], +) + tf_py_test( name = "collective_ops_xla_test", size = "small", @@ -8311,6 +8385,7 @@ tf_python_pybind_extension( srcs = ["mlir_wrapper.cc"], hdrs = [ "lib/core/safe_ptr.h", + "lib/core/safe_pyobject_ptr.h", "//tensorflow/c:headers", "//tensorflow/c/eager:headers", "//tensorflow/compiler/mlir/python:pywrap_mlir_hdrs", @@ -8342,6 +8417,7 @@ tf_python_pybind_extension( srcs = ["tfe_wrapper.cc"], hdrs = [ "lib/core/safe_ptr.h", + "lib/core/safe_pyobject_ptr.h", "util/util.h", ":py_exception_registry_hdr", "//tensorflow/c:headers", @@ -8413,6 +8489,7 @@ tf_python_pybind_extension( name = "_pywrap_parallel_device", srcs = [ "lib/core/safe_ptr.h", + "lib/core/safe_pyobject_ptr.h", "//tensorflow/c:headers", "//tensorflow/c/eager:headers", "//tensorflow/c/eager/parallel_device:headers", diff --git a/tensorflow/python/__init__.py b/tensorflow/python/__init__.py index b5acf23ba79..ce26271f236 100644 --- a/tensorflow/python/__init__.py +++ b/tensorflow/python/__init__.py @@ -36,6 +36,7 @@ import traceback # go/tf-wildcard-import # pylint: disable=wildcard-import,g-bad-import-order,g-import-not-at-top +from tensorflow.python import pywrap_tensorflow as _pywrap_tensorflow from tensorflow.python.eager import context diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index b5729c2408e..4467677f0d1 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -33,7 +33,7 @@ from tensorflow.python.util.tf_export import tf_export # This value changes every day with an automatic CL. It can be modified in code # via `forward_compatibility_horizon()` or with the environment variable # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2020, 8, 26) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2020, 8, 30) _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DATE_NUMBER = None diff --git a/tensorflow/python/compiler/tensorrt/BUILD b/tensorflow/python/compiler/tensorrt/BUILD index 387d379e601..6570e511cbe 100644 --- a/tensorflow/python/compiler/tensorrt/BUILD +++ b/tensorflow/python/compiler/tensorrt/BUILD @@ -145,9 +145,8 @@ cuda_py_tests( ], python_version = "PY3", tags = [ - "no_cuda11", # TODO(b/165611343): Need to address the failures. "no_cuda_on_cpu_tap", - "no_oss", + "no_oss", # TODO(b/165611343): Need to address the failures for CUDA 11 in OSS build. "no_rocm", "no_windows", "nomac", @@ -170,6 +169,7 @@ cuda_py_test( ], python_version = "PY3", tags = [ + "no_cuda11", # TODO(b/166308253): enable the test for CUDA 11. "no_cuda_on_cpu_tap", "no_oss", # TODO(b/125290478): allow running in at least some OSS configurations. "no_pip", diff --git a/tensorflow/python/compiler/tensorrt/test/base_test.py b/tensorflow/python/compiler/tensorrt/test/base_test.py index 195382cd8ed..9d2d3abd4fb 100644 --- a/tensorflow/python/compiler/tensorrt/test/base_test.py +++ b/tensorflow/python/compiler/tensorrt/test/base_test.py @@ -70,12 +70,6 @@ class SimpleSingleEngineTest(trt_test.TfTrtIntegrationTestBase): ] } - def ShouldRunTest(self, run_params): - # TODO(b/162448349): Enable the test for TRT 7.1.3. - if trt_test.IsTensorRTVersionGreaterEqual(7, 1, 3): - return (False, "Skip test due to b/162448349") - return super().ShouldRunTest(run_params) - class SimpleMultiEnginesTest(trt_test.TfTrtIntegrationTestBase): @@ -136,12 +130,6 @@ class SimpleMultiEnginesTest(trt_test.TfTrtIntegrationTestBase): return conversion_params._replace( rewriter_config_template=rewrite_config_with_trt) - def ShouldRunTest(self, run_params): - # TODO(b/162448349): Enable the test for TRT 7.1.3. - if trt_test.IsTensorRTVersionGreaterEqual(7, 1, 3): - return (False, "Skip test due to b/162448349") - return super().ShouldRunTest(run_params) - class SimpleMultiEnginesTest2(trt_test.TfTrtIntegrationTestBase): diff --git a/tensorflow/python/compiler/tensorrt/test/combined_nms_test.py b/tensorflow/python/compiler/tensorrt/test/combined_nms_test.py index 26e911e3b0b..3f2a5469ae6 100644 --- a/tensorflow/python/compiler/tensorrt/test/combined_nms_test.py +++ b/tensorflow/python/compiler/tensorrt/test/combined_nms_test.py @@ -89,9 +89,6 @@ class CombinedNmsTest(trt_test.TfTrtIntegrationTestBase): } def ShouldRunTest(self, run_params): - # TODO(b/162447069): Enable the test for TRT 7.1.3. - if trt_test.IsTensorRTVersionGreaterEqual(7, 1, 3): - return (False, 'Skip test due to b/162447069') # There is no CombinedNonMaxSuppression op for GPU at the moment, so # calibration will fail. # TODO(laigd): fix this. diff --git a/tensorflow/python/compiler/tensorrt/test/const_broadcast_test.py b/tensorflow/python/compiler/tensorrt/test/const_broadcast_test.py index 9e71b9e3f75..ccbaf9e52fa 100644 --- a/tensorflow/python/compiler/tensorrt/test/const_broadcast_test.py +++ b/tensorflow/python/compiler/tensorrt/test/const_broadcast_test.py @@ -60,12 +60,6 @@ class ConstBroadcastTest(trt_test.TfTrtIntegrationTestBase): """The relative tolerance to compare floating point results.""" return 1.e-04 if run_params.precision_mode == 'FP32' else 1.e-02 - def ShouldRunTest(self, run_params): - # TODO(b/162448349): Enable the test for TRT 7.1.3. - if trt_test.IsTensorRTVersionGreaterEqual(7, 1, 3): - return (False, 'Skip test due to b/162448349') - return super().ShouldRunTest(run_params) - if __name__ == '__main__': test.main() diff --git a/tensorflow/python/compiler/tensorrt/test/conv2d_test.py b/tensorflow/python/compiler/tensorrt/test/conv2d_test.py index 400c17b343e..df1adce2178 100644 --- a/tensorflow/python/compiler/tensorrt/test/conv2d_test.py +++ b/tensorflow/python/compiler/tensorrt/test/conv2d_test.py @@ -114,12 +114,6 @@ class Conv2DNCHWTest(trt_test.TfTrtIntegrationTestBase): return 4e-02 return super(Conv2DNCHWTest, self).ExpectedRelativeTolerance(run_params) - def ShouldRunTest(self, run_params): - # TODO(b/162448349): Enable the test for TRT 7.1.3. - if trt_test.IsTensorRTVersionGreaterEqual(7, 1, 3): - return (False, "Skip test due to b/162448349") - return super().ShouldRunTest(run_params) - class Conv2DNHWCTest(trt_test.TfTrtIntegrationTestBase): """Testing conversion of Conv2D (data_format=NCHW) in TF-TRT conversion.""" @@ -143,12 +137,6 @@ class Conv2DNHWCTest(trt_test.TfTrtIntegrationTestBase): """Return the expected engines to build.""" return ["TRTEngineOp_0"] - def ShouldRunTest(self, run_params): - # TODO(b/162448349): Enable the test for TRT 7.1.3. - if trt_test.IsTensorRTVersionGreaterEqual(7, 1, 3): - return (False, "Skip test due to b/162448349") - return super().ShouldRunTest(run_params) - class Conv2DStridedNCHWTest(trt_test.TfTrtIntegrationTestBase): """Testing conversion of strided Conv2D (data_format=NCHW).""" @@ -180,12 +168,6 @@ class Conv2DStridedNCHWTest(trt_test.TfTrtIntegrationTestBase): """Return the expected engines to build.""" return ["TRTEngineOp_0"] - def ShouldRunTest(self, run_params): - # TODO(b/162448349): Enable the test for TRT 7.1.3. - if trt_test.IsTensorRTVersionGreaterEqual(7, 1, 3): - return (False, "Skip test due to b/162448349") - return super().ShouldRunTest(run_params) - class Conv2DTranposeTest(trt_test.TfTrtIntegrationTestBase): """Testing conversion of conv2d_transpose (AKA Conv2DBackpropInput)""" diff --git a/tensorflow/python/compiler/tensorrt/test/dynamic_input_shapes_test.py b/tensorflow/python/compiler/tensorrt/test/dynamic_input_shapes_test.py index f02ad08777e..95dbe727ac3 100644 --- a/tensorflow/python/compiler/tensorrt/test/dynamic_input_shapes_test.py +++ b/tensorflow/python/compiler/tensorrt/test/dynamic_input_shapes_test.py @@ -98,9 +98,6 @@ class DynamicInputShapesTest(trt_test.TfTrtIntegrationTestBase): return ["TRTEngineOp_0"] def ShouldRunTest(self, run_params): - # TODO(b/162448349): Enable the test for TRT 7.1.3. - if trt_test.IsTensorRTVersionGreaterEqual(7, 1, 3): - return (False, "Skip test due to b/162448349") return (run_params.dynamic_engine and not trt_test.IsQuantizationMode( run_params.precision_mode)), "test dynamic engine and non-INT8" diff --git a/tensorflow/python/compiler/tensorrt/test/memory_alignment_test.py b/tensorflow/python/compiler/tensorrt/test/memory_alignment_test.py index c1f0a007bf8..056edc3e4d4 100644 --- a/tensorflow/python/compiler/tensorrt/test/memory_alignment_test.py +++ b/tensorflow/python/compiler/tensorrt/test/memory_alignment_test.py @@ -67,12 +67,6 @@ class MemoryAlignmentTest(trt_test.TfTrtIntegrationTestBase): """The relative tolerance to compare floating point results.""" return 0.1 - def ShouldRunTest(self, run_params): - # TODO(b/162448349): Enable the test for TRT 7.1.3. - if trt_test.IsTensorRTVersionGreaterEqual(7, 1, 3): - return (False, "Skip test due to b/162448349") - return super().ShouldRunTest(run_params) - if __name__ == "__main__": test.main() diff --git a/tensorflow/python/compiler/tensorrt/test/multi_connection_neighbor_engine_test.py b/tensorflow/python/compiler/tensorrt/test/multi_connection_neighbor_engine_test.py index 687a12486b7..b57bee6c5d7 100644 --- a/tensorflow/python/compiler/tensorrt/test/multi_connection_neighbor_engine_test.py +++ b/tensorflow/python/compiler/tensorrt/test/multi_connection_neighbor_engine_test.py @@ -72,12 +72,6 @@ class MultiConnectionNeighborEngineTest(trt_test.TfTrtIntegrationTestBase): """Return the expected engines to build.""" return ["TRTEngineOp_0", "TRTEngineOp_1"] - def ShouldRunTest(self, run_params): - # TODO(b/162447069): Enable the test for TRT 7.1.3. - if trt_test.IsTensorRTVersionGreaterEqual(7, 1, 3): - return (False, "Skip test due to b/162447069") - return super().ShouldRunTest(run_params) - if __name__ == "__main__": test.main() diff --git a/tensorflow/python/compiler/tensorrt/test/neighboring_engine_test.py b/tensorflow/python/compiler/tensorrt/test/neighboring_engine_test.py index 39fee5cba5d..f377fe8dceb 100644 --- a/tensorflow/python/compiler/tensorrt/test/neighboring_engine_test.py +++ b/tensorflow/python/compiler/tensorrt/test/neighboring_engine_test.py @@ -61,12 +61,6 @@ class NeighboringEngineTest(trt_test.TfTrtIntegrationTestBase): "TRTEngineOp_1": ["weights", "conv"] } - def ShouldRunTest(self, run_params): - # TODO(b/162447069): Enable the test for TRT 7.1.3. - if trt_test.IsTensorRTVersionGreaterEqual(7, 1, 3): - return (False, "Skip test due to b/162447069") - return super().ShouldRunTest(run_params) - if __name__ == "__main__": test.main() diff --git a/tensorflow/python/compiler/tensorrt/test/quantization_mnist_test.py b/tensorflow/python/compiler/tensorrt/test/quantization_mnist_test.py index d859407f1f7..000b231a61a 100644 --- a/tensorflow/python/compiler/tensorrt/test/quantization_mnist_test.py +++ b/tensorflow/python/compiler/tensorrt/test/quantization_mnist_test.py @@ -261,10 +261,6 @@ class QuantizationAwareTrainingMNISTTest(test_util.TensorFlowTestCase): if not is_tensorrt_enabled(): return - # TODO(b/162447069): Enable the test for TRT 7.1.3. - if trt_test.IsTensorRTVersionGreaterEqual(7, 1, 3): - return - model_dir = test.test_src_dir_path( 'python/compiler/tensorrt/test/testdata/mnist') diff --git a/tensorflow/python/compiler/tensorrt/test/vgg_block_nchw_test.py b/tensorflow/python/compiler/tensorrt/test/vgg_block_nchw_test.py index 43034e8b31e..8fd9606812d 100644 --- a/tensorflow/python/compiler/tensorrt/test/vgg_block_nchw_test.py +++ b/tensorflow/python/compiler/tensorrt/test/vgg_block_nchw_test.py @@ -76,12 +76,6 @@ class VGGBlockNCHWTest(trt_test.TfTrtIntegrationTestBase): super(trt_test.TfTrtIntegrationTestBase, self).setUp() os.environ["TF_TRT_ALLOW_ENGINE_NATIVE_SEGMENT_EXECUTION"] = "True" - def ShouldRunTest(self, run_params): - # TODO(b/162448349): Enable the test for TRT 7.1.3. - if trt_test.IsTensorRTVersionGreaterEqual(7, 1, 3): - return (False, "Skip test due to b/162448349") - return super().ShouldRunTest(run_params) - if __name__ == "__main__": test.main() diff --git a/tensorflow/python/compiler/tensorrt/test/vgg_block_test.py b/tensorflow/python/compiler/tensorrt/test/vgg_block_test.py index 7b1f7e062d7..9d81cd6dcc3 100644 --- a/tensorflow/python/compiler/tensorrt/test/vgg_block_test.py +++ b/tensorflow/python/compiler/tensorrt/test/vgg_block_test.py @@ -67,12 +67,6 @@ class VGGBlockTest(trt_test.TfTrtIntegrationTestBase): super(trt_test.TfTrtIntegrationTestBase, self).setUp() os.environ["TF_TRT_ALLOW_ENGINE_NATIVE_SEGMENT_EXECUTION"] = "True" - def ShouldRunTest(self, run_params): - # TODO(b/162448349): Enable the test for TRT 7.1.3. - if trt_test.IsTensorRTVersionGreaterEqual(7, 1, 3): - return (False, "Skip test due to b/162448349") - return super().ShouldRunTest(run_params) - if __name__ == "__main__": test.main() diff --git a/tensorflow/python/data/experimental/kernel_tests/compression_ops_test.py b/tensorflow/python/data/experimental/kernel_tests/compression_ops_test.py index a091bdca8b9..e7c84ee5d60 100644 --- a/tensorflow/python/data/experimental/kernel_tests/compression_ops_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/compression_ops_test.py @@ -17,6 +17,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from collections import namedtuple from absl.testing import parameterized from tensorflow.python.data.experimental.ops import compression_ops @@ -25,14 +26,24 @@ from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import structure from tensorflow.python.framework import combinations from tensorflow.python.framework import sparse_tensor +from tensorflow.python.ops.ragged import ragged_factory_ops from tensorflow.python.platform import test def _test_objects(): + + Item = namedtuple("Item", "id name") + return [ combinations.NamedObject("int", 1), combinations.NamedObject("string", "dog"), combinations.NamedObject("tuple", (1, 1)), + combinations.NamedObject("nested_tuple", ((1, 1), (2, 2))), + combinations.NamedObject("named_tuple", Item(id=1, name="item1")), + combinations.NamedObject("unicode", "アヒル"), + combinations.NamedObject( + "nested_named_tuple", + (Item(id=1, name="item1"), Item(id=2, name="item2"))), combinations.NamedObject("int_string_tuple", (1, "dog")), combinations.NamedObject( "sparse", @@ -50,11 +61,32 @@ def _test_objects(): ] +def _test_v2_eager_only_objects(): + return [ + combinations.NamedObject( + "ragged", + ragged_factory_ops.constant([[0, 1, 2, 3], [4, 5], [6, 7, 8], [9]])), + combinations.NamedObject( + "sparse_ragged_structured", { + "sparse": + sparse_tensor.SparseTensorValue( + indices=[[0, 0], [1, 2]], + values=[1, 2], + dense_shape=[3, 4]), + "ragged": + ragged_factory_ops.constant([[0, 1, 2, 3], [9]]) + }) + ] + + class CompressionOpsTest(test_base.DatasetTestBase, parameterized.TestCase): @combinations.generate( combinations.times(test_base.default_test_combinations(), - combinations.combine(element=_test_objects()))) + combinations.combine(element=_test_objects())) + + combinations.times( + test_base.v2_eager_only_combinations(), + combinations.combine(element=_test_v2_eager_only_objects()))) def testCompression(self, element): element = element._obj @@ -65,7 +97,10 @@ class CompressionOpsTest(test_base.DatasetTestBase, parameterized.TestCase): @combinations.generate( combinations.times(test_base.default_test_combinations(), - combinations.combine(element=_test_objects()))) + combinations.combine(element=_test_objects())) + + combinations.times( + test_base.v2_eager_only_combinations(), + combinations.combine(element=_test_v2_eager_only_objects()))) def testDatasetCompression(self, element): element = element._obj diff --git a/tensorflow/python/data/experimental/kernel_tests/optimize_dataset_test.py b/tensorflow/python/data/experimental/kernel_tests/optimize_dataset_test.py index 05d0968ae5a..cab41268583 100644 --- a/tensorflow/python/data/experimental/kernel_tests/optimize_dataset_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/optimize_dataset_test.py @@ -104,6 +104,28 @@ def _captured_refvar_test_combinations(): return functools.reduce(reduce_fn, cases, []) +def _disable_intra_op_parallelism_test_combinations(): + + def make_tensor_dataset(): + return dataset_ops.Dataset.from_tensors(42) + + def make_map_dataset(): + return dataset_ops.Dataset.from_tensors(42).map(lambda x: x + 1) + + cases = [ + ("FromTensors", make_tensor_dataset, [42]), + ("Map", make_map_dataset, [43]), + ] + + def reduce_fn(x, y): + name, dataset_fn, expected_output = y + return x + combinations.combine( + dataset_fn=combinations.NamedObject(name, dataset_fn), + expected_output=[expected_output]) + + return functools.reduce(reduce_fn, cases, []) + + class OptimizeDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): @combinations.generate(test_base.default_test_combinations()) @@ -186,15 +208,18 @@ class OptimizeDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): dataset = dataset.with_options(options) self.assertDatasetProduces(dataset, expected_output=[[0]]) - @combinations.generate(test_base.default_test_combinations()) - def testOptimizationDisableIntraOpParallelism(self): + @combinations.generate( + combinations.times(test_base.default_test_combinations(), + _disable_intra_op_parallelism_test_combinations())) + def testOptimizationDisableIntraOpParallelism(self, dataset_fn, + expected_output): os.environ["TF_DATA_EXPERIMENT_OPT_IN"] = "disable_intra_op_parallelism" os.environ["TF_JOB_NAME"] = "test_job" - dataset = dataset_ops.Dataset.range(10).map(lambda x: x+1) + dataset = dataset_fn() dataset = dataset.apply(testing.assert_next(["MaxIntraOpParallelism"])) - self.assertDatasetProduces(dataset, expected_output=list(range(1, 11))) + self.assertDatasetProduces(dataset, expected_output=expected_output) del os.environ["TF_DATA_EXPERIMENT_OPT_IN"] del os.environ["TF_JOB_NAME"] diff --git a/tensorflow/python/data/experimental/kernel_tests/prefetch_with_slack_test.py b/tensorflow/python/data/experimental/kernel_tests/prefetch_with_slack_test.py index cbff39b90e5..e9a4d52599a 100644 --- a/tensorflow/python/data/experimental/kernel_tests/prefetch_with_slack_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/prefetch_with_slack_test.py @@ -86,18 +86,16 @@ class PrefetchWithSlackTest(test_base.DatasetTestBase, parameterized.TestCase): self.assertDatasetProduces(dataset, range(1, 11)) @combinations.generate(test_base.default_test_combinations()) - def testErrorWithoutPrefetch(self): - """The rewrite fails if there is no prefetch() in the pipeline.""" + def testNoErrorWithoutPrefetch(self): + """The rewrite should not fail if there is no prefetch() in the pipeline.""" dataset = dataset_ops.Dataset.range(10) options = dataset_ops.Options() options.experimental_slack = True dataset = dataset.with_options(options) - with self.assertRaises(errors.InvalidArgumentError): - get_next = self.getNext(dataset) - self.evaluate(get_next()) + self.assertDatasetProduces(dataset, range(10)) @combinations.generate(test_base.default_test_combinations()) - def testErrorWithInvalidDataset(self): + def testNoErrorWithInvalidDataset(self): """With a nested dataset op after prefetch, the rewrite should fail.""" dataset = dataset_ops.Dataset.range(10) dataset = dataset.prefetch(1) @@ -105,9 +103,7 @@ class PrefetchWithSlackTest(test_base.DatasetTestBase, parameterized.TestCase): options = dataset_ops.Options() options.experimental_slack = True dataset = dataset.with_options(options) - with self.assertRaises(errors.InvalidArgumentError): - get_next = self.getNext(dataset) - self.evaluate(get_next()) + self.assertDatasetProduces(dataset, range(10)) if __name__ == "__main__": diff --git a/tensorflow/python/data/experimental/ops/distribute.py b/tensorflow/python/data/experimental/ops/distribute.py index 568c01646de..5105f30fd07 100644 --- a/tensorflow/python/data/experimental/ops/distribute.py +++ b/tensorflow/python/data/experimental/ops/distribute.py @@ -85,9 +85,9 @@ class _AutoShardDataset(dataset_ops.UnaryDataset): return self._element_spec -def _AutoShardDatasetV1(input_dataset, num_workers, index, num_replicas=None): # pylint: disable=invalid-name +def _AutoShardDatasetV1(input_dataset, num_workers, index): # pylint: disable=invalid-name return dataset_ops.DatasetV1Adapter( - _AutoShardDataset(input_dataset, num_workers, index, num_replicas)) + _AutoShardDataset(input_dataset, num_workers, index)) class _RebatchDataset(dataset_ops.UnaryDataset): diff --git a/tensorflow/python/data/experimental/service/server_lib_wrapper.cc b/tensorflow/python/data/experimental/service/server_lib_wrapper.cc index b268ba2403a..8ce904eecba 100644 --- a/tensorflow/python/data/experimental/service/server_lib_wrapper.cc +++ b/tensorflow/python/data/experimental/service/server_lib_wrapper.cc @@ -63,7 +63,7 @@ PYBIND11_MODULE(_pywrap_server_lib, m) { } std::unique_ptr server; tensorflow::Status status = - tensorflow::data::NewDispatchServer(config, &server); + tensorflow::data::NewDispatchServer(config, server); tensorflow::MaybeRaiseFromStatus(status); return server; }, @@ -80,7 +80,7 @@ PYBIND11_MODULE(_pywrap_server_lib, m) { } std::unique_ptr server; tensorflow::Status status = - tensorflow::data::NewWorkerServer(config, &server); + tensorflow::data::NewWorkerServer(config, server); tensorflow::MaybeRaiseFromStatus(status); return server; }, diff --git a/tensorflow/python/data/kernel_tests/data_service_ops_test.py b/tensorflow/python/data/kernel_tests/data_service_ops_test.py index 310a60b8114..6bc64ddae15 100644 --- a/tensorflow/python/data/kernel_tests/data_service_ops_test.py +++ b/tensorflow/python/data/kernel_tests/data_service_ops_test.py @@ -600,10 +600,8 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase): _make_distributed_dataset(dataset, dispatcher) return dataset - with self.assertRaisesRegex( - errors.InvalidArgumentError, r"The `.distribute\(...\)` dataset " - "transformation is not supported within tf.data functions"): - ds = ds.interleave(interleave_fn, cycle_length=2) + ds = ds.interleave(interleave_fn, cycle_length=2) + self.assertDatasetProduces(ds, [0, 0, 1, 1]) @combinations.generate(test_base.eager_only_combinations()) def testDistributeNonStringAddresses(self): diff --git a/tensorflow/python/data/kernel_tests/test_base.py b/tensorflow/python/data/kernel_tests/test_base.py index 31595363bd5..5da633a9ee2 100644 --- a/tensorflow/python/data/kernel_tests/test_base.py +++ b/tensorflow/python/data/kernel_tests/test_base.py @@ -52,10 +52,15 @@ def graph_only_combinations(): def v2_only_combinations(): - """Returns the default test combinations for v1 only tf.data tests.""" + """Returns the default test combinations for v2 only tf.data tests.""" return combinations.combine(tf_api_version=2, mode=["eager", "graph"]) +def v2_eager_only_combinations(): + """Returns the default test combinations for v2 eager only tf.data tests.""" + return combinations.combine(tf_api_version=2, mode="eager") + + class DatasetTestBase(test.TestCase): """Base class for dataset tests.""" diff --git a/tensorflow/python/data/ops/iterator_ops.py b/tensorflow/python/data/ops/iterator_ops.py index f6f2da0939e..479c8d337a0 100644 --- a/tensorflow/python/data/ops/iterator_ops.py +++ b/tensorflow/python/data/ops/iterator_ops.py @@ -36,7 +36,6 @@ from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import type_spec from tensorflow.python.ops import gen_dataset_ops -from tensorflow.python.ops import gen_experimental_dataset_ops from tensorflow.python.training.saver import BaseSaverBuilder from tensorflow.python.training.tracking import base as trackable from tensorflow.python.util import deprecation @@ -656,11 +655,7 @@ class OwnedIterator(IteratorBase): in eager mode and inside of tf.functions. """ - def __init__(self, - dataset=None, - components=None, - element_spec=None, - job_token=None): + def __init__(self, dataset=None, components=None, element_spec=None): """Creates a new iterator from the given dataset. If `dataset` is not specified, the iterator will be created from the given @@ -673,20 +668,17 @@ class OwnedIterator(IteratorBase): components: Tensor components to construct the iterator from. element_spec: A nested structure of `TypeSpec` objects that represents the type specification of elements of the iterator. - job_token: A token to use for reading from a tf.data service job. Data - will be partitioned among all iterators using the same token. If `None`, - the iterator will not read from the tf.data service. Raises: ValueError: If `dataset` is not provided and either `components` or `element_spec` is not provided. Or `dataset` is provided and either `components` and `element_spec` is provided. """ + super(OwnedIterator, self).__init__() error_message = ("Either `dataset` or both `components` and " "`element_spec` need to be provided.") self._device = context.context().device_name - self._job_token = job_token if dataset is None: if (components is None or element_spec is None): @@ -729,11 +721,7 @@ class OwnedIterator(IteratorBase): gen_dataset_ops.anonymous_iterator_v2( output_types=self._flat_output_types, output_shapes=self._flat_output_shapes)) - if self._job_token is None: - gen_dataset_ops.make_iterator(ds_variant, self._iterator_resource) - else: - gen_experimental_dataset_ops.make_data_service_iterator( - ds_variant, self._job_token, self._iterator_resource) + gen_dataset_ops.make_iterator(ds_variant, self._iterator_resource) # Delete the resource when this object is deleted self._resource_deleter = IteratorResourceDeleter( handle=self._iterator_resource, diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD index 19c7e98b5a5..2b03b9d52a8 100644 --- a/tensorflow/python/distribute/BUILD +++ b/tensorflow/python/distribute/BUILD @@ -13,6 +13,7 @@ exports_files(["LICENSE"]) py_library( name = "distribute_test_lib_pip", deps = [ + ":all_reduce", ":combinations", ":multi_worker_test_base", ":single_loss_example", @@ -89,7 +90,6 @@ py_library( srcs = ["cross_device_utils.py"], srcs_version = "PY2AND3", deps = [ - ":all_reduce", ":values", "//tensorflow/python:array_ops", "//tensorflow/python:collective_ops", @@ -141,6 +141,7 @@ py_library( ], srcs_version = "PY2AND3", deps = [ + ":all_reduce", ":cross_device_ops", ":distribute_lib", ":mirrored_strategy", diff --git a/tensorflow/python/distribute/client/client.py b/tensorflow/python/distribute/client/client.py index e431c19459e..7f3559caecd 100644 --- a/tensorflow/python/distribute/client/client.py +++ b/tensorflow/python/distribute/client/client.py @@ -32,6 +32,8 @@ import threading import weakref from absl import logging from six.moves import queue + +from tensorflow.python.data.ops import iterator_ops from tensorflow.python.distribute import input_lib from tensorflow.python.distribute import parameter_server_strategy_v2 from tensorflow.python.distribute.client import metric_utils @@ -1145,15 +1147,12 @@ class _PerWorkerDistributedDataset(object): per_worker_iterator = self._client._create_per_worker_resources( _create_per_worker_iterator) - # Create an iterator, so the consumer function of this iterator can start - # tracing using this iterator without needing to wait for the completion of - # the iterater creation. Note: the iterator shouldn't use memory until it is - # consumed. - # TODO(b/154675763): get rid of this workaround once we can make input_fn a - # tf.function. - iterator = _create_per_worker_iterator() + # Setting type_spec of each RemoteValue so that functions taking these + # RemoteValues as inputs can be traced. for iterator_remote_value in per_worker_iterator._values: - iterator_remote_value._set_type_spec(iterator._type_spec) + iterator_remote_value._set_type_spec( + iterator_ops.IteratorSpec( + self._dataset_fn.structured_outputs.element_spec)) return _PerWorkerDistributedIterator(per_worker_iterator._values) @property diff --git a/tensorflow/python/distribute/client/utils.py b/tensorflow/python/distribute/client/utils.py index 6c595579863..51d82630d6e 100644 --- a/tensorflow/python/distribute/client/utils.py +++ b/tensorflow/python/distribute/client/utils.py @@ -28,8 +28,11 @@ def start_server(cluster_resolver, protocol): """Start a server and block the process from exiting.""" # This function is for multi-processing test or users who would like to have # every job run the same binary for simplicity. - assert (cluster_resolver.task_type == 'worker' or - cluster_resolver.task_type == 'ps') + if not (cluster_resolver.task_type == 'worker' or + cluster_resolver.task_type == 'ps'): + raise ValueError('Unexpected task_type to start a server: {}'.format( + cluster_resolver.task_type)) + server = server_lib.Server( cluster_resolver.cluster_spec().as_cluster_def(), job_name=cluster_resolver.task_type, diff --git a/tensorflow/python/distribute/collective_all_reduce_strategy.py b/tensorflow/python/distribute/collective_all_reduce_strategy.py index d0eed9b8cff..49b6a93678c 100644 --- a/tensorflow/python/distribute/collective_all_reduce_strategy.py +++ b/tensorflow/python/distribute/collective_all_reduce_strategy.py @@ -25,6 +25,7 @@ import weakref from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.core.protobuf import tensorflow_server_pb2 +from tensorflow.python.distribute import collective_util from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib from tensorflow.python.distribute import cross_device_utils from tensorflow.python.distribute import device_util @@ -39,7 +40,6 @@ from tensorflow.python.distribute import values from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver from tensorflow.python.distribute.cluster_resolver import TFConfigClusterResolver from tensorflow.python.eager import context -from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops @@ -188,6 +188,9 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended): _check_health_interval = 30 # Timeout in seconds for the first check health. The first check health needs # to wait for cluster, which may make a longer time. + # + # TODO(b/151232436): now the inital barrier may hang in a rare case, so we + # need a finite timeout. _check_health_initial_timeout = 1200 def __init__(self, @@ -473,7 +476,7 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended): dataset, self._input_workers_with_options(options), self._container_strategy(), - num_replicas_in_sync=self._num_replicas_in_sync, + split_batch_by=self._num_replicas_in_sync, input_context=input_context) def _experimental_distribute_datasets_from_function(self, dataset_fn, @@ -503,7 +506,7 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended): dataset, self._input_workers, self._container_strategy(), - num_replicas_in_sync=self._num_replicas_in_sync, + split_batch_by=self._num_replicas_in_sync, input_context=input_context) def _make_input_fn_iterator( @@ -629,55 +632,20 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended): destinations=destinations, experimental_hints=experimental_hints) - def _check_health(self, device, group_key, instance_key): - first = True - # We need to use a large enough value so that the all-reduce forms a - # complete RING. In RING implementation, when value is too small, the - # all-reduce may degrade into broadcasts. This means that some worker - # failure may not be detected. - value = array_ops.ones((32, 32), dtype=dtypes.float32) + def _check_health(self): while True: if self._check_health_thread_should_stop.is_set(): return - timeout = None - if first: - # For the first check health we set timeout since it may need to do - # group resolution, which may hang if the cluster is never healthy. - timeout = self._check_health_initial_timeout - first = False try: - # We use an dummy all-reduce as a way to check the health of a cluster. - # For RING it should be able to detect failed workers in the cluster if - # the values are large enough. - # - # We're not using CrossDeviceOps because we need to run it with - # pre-allocated group and instance keys. - # - # TODO(b/151232436): Replace the reduce with a check health op once we - # add that. - with ops.device(device): - collective_ops.all_reduce( - value, - group_size=self._num_workers, - group_key=group_key, - instance_key=instance_key, - merge_op="Add", - final_op="Id", - subdiv_offsets=[0], - communication_hint="ring", - timeout=timeout) - if context.is_async(): - context.async_wait() - except (errors.UnavailableError, errors.DeadlineExceededError, - errors.FailedPreconditionError, errors.CancelledError) as e: + for job in self._cluster_spec.jobs: + for task_id in range(self._cluster_spec.num_tasks(job)): + context.context().check_collective_ops_peer_health( + "/job:{}/replica:0/task:{}".format(job, task_id)) + except (errors.UnavailableError, errors.FailedPreconditionError) as e: # TODO(b/151232436): Always raise UnavailableError when a peer fails. # Now there could be many kinds of errors: # - Unavailable: when the peer is not reachable, e.g. it's down. # - FailedPrecondition: when the peer has restarted. - # - DeadlineExceeded: when the first check health exceeds the deadline, - # e.g. the peers take too long to be ready. - # - Cancelled: when failures in organic collectives aborts first, - # outgoing RPCs may be aborted with Cancelled. logging.error("Cluster check alive failed, aborting collectives") context.context().abort_collective_ops( errors.UNAVAILABLE, "cluster check alive failed: %s" % e) @@ -689,20 +657,32 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended): time.sleep(self._check_health_interval) def _start_check_health_thread(self): - # Allocate group and instance key before starting the thread to avoid - # indeterminism. There can only be one thread that assigns group keys and - # instance keys, otherwise different workers may end up with unmatched keys - # since execution order between threads are arbitrary. - device = device_util.canonicalize(self._worker_device) - group_key = self._collective_keys.get_group_key([device]) - instance_key = self._collective_keys.get_op_instance_key() + # Use a dummy all-reduce as a barrier to wait for all workers to be up, + # otherwise the check health may fail immediately. + # + # TODO(b/151232436): change to an explicit barrier if we have it. + dummy_value = ops.convert_to_tensor([]) + logging.info("Waiting for the cluster, timeout = %d", + self._check_health_initial_timeout) + try: + self._host_cross_device_ops.reduce( + reduce_util.ReduceOp.SUM, + dummy_value, + dummy_value, + experimental_hints=collective_util.Hints( + timeout_seconds=self._check_health_initial_timeout)) + if context.is_async(): + context.async_wait() + except errors.DeadlineExceededError: + raise RuntimeError( + "Timeout waiting for the cluster, timeout is %d seconds" % + self._check_health_initial_timeout) self._check_health_thread_should_stop = threading.Event() # Start the thread as daemon to avoid it blocking the program from exiting. # We try best to shutdown the thread but __del__ is not guaranteed to be # called when program exists. self._check_health_thread = threading.Thread( target=self._check_health, - args=(device, group_key, instance_key), daemon=True) self._check_health_thread.start() diff --git a/tensorflow/python/distribute/cross_device_ops.py b/tensorflow/python/distribute/cross_device_ops.py index ed3e2d5d951..9c554bc882c 100644 --- a/tensorflow/python/distribute/cross_device_ops.py +++ b/tensorflow/python/distribute/cross_device_ops.py @@ -81,7 +81,7 @@ def reduce_non_distributed_value( reduce_op, value, destinations, num_replicas_in_graph): """Reduce a non-DistributedValue `value` to `destinations`.""" if isinstance(value, value_lib.DistributedValues): - raise ValueError("You are passing a `DistributedValue` to " + raise ValueError("You are passing a `DistributedValues` to " "`reduce_non_distributed_value`, which is not allowed.") # If the same value is present on all replicas then the PerReplica value will @@ -216,7 +216,18 @@ def _simple_reduce(per_replica_value, reduce_to_device, accumulation_fn, @tf_export("distribute.CrossDeviceOps") class CrossDeviceOps(object): - """Base class for cross-device reduction and broadcasting algorithms.""" + """Base class for cross-device reduction and broadcasting algorithms. + + The main purpose of this class is to be passed to + `tf.distribute.MirroredStrategy` in order to choose among different cross + device communication implementations. Prefer using the methods of + `tf.distribute.Strategy` instead of the ones of this class. + + Implementations: + * `tf.distribute.ReductionToOneDevice` + * `tf.distribute.NcclAllReduce` + * `tf.distribute.HierarchicalCopyAllReduce` + """ def __init__(self): pass @@ -233,24 +244,30 @@ class CrossDeviceOps(object): experimental_hints=None): """Reduce `per_replica_value` to `destinations`. - It runs the reduction operation defined by `reduce_op` and put the - result on `destinations`. + See `tf.distribute.StrategyExtended.reduce_to`. This can only be called in + the cross-replica context. Args: - reduce_op: An instance of `tf.distribute.ReduceOp` that indicates how - per_replica_value will be reduced. - per_replica_value: A `tf.distribute.DistributedValues` object or a tensor - with device set. - destinations: the reduction destinations. - experimental_hints: A `tf.distrbute.experimental.CollectiveHints`. Hints - to perform collective operations. + reduce_op: a `tf.distribute.ReduceOp` specifying how values should be + combined. + per_replica_value: a `tf.distribute.DistributedValues`, or a `tf.Tensor` + like object. + destinations: a `tf.distribute.DistributedValues`, a `tf.Variable`, a + `tf.Tensor` alike object, or a device string. It specifies the devices + to reduce to. To perform an all-reduce, pass the same to `value` and + `destinations`. Note that if it's a `tf.Variable`, the value is reduced + to the devices of that variable, and this method doesn't update the + variable. + experimental_hints: a `tf.distribute.experimental.CollectiveHints`. See + `tf.distribute.experimental.CollectiveHints` for details. Returns: - a Mirrored object. + A `tf.Tensor` or `tf.distribute.DistributedValues`. Raises: - ValueError: if per_replica_value can't be converted to a PerReplica - object or if destinations aren't strings, Variables or DistributedValues + ValueError: if per_replica_value can't be converted to a + `tf.distribute.DistributedValues` or if destinations is not a string, + `tf.Variable` or `tf.distribute.DistributedValues`. """ if not isinstance(per_replica_value, value_lib.DistributedValues): per_replica_value = _make_tensor_into_per_replica(per_replica_value) @@ -274,28 +291,26 @@ class CrossDeviceOps(object): reduce_op, value_destination_pairs, experimental_hints=None): - """Reduce PerReplica objects in a batch. + """Reduce values to destinations in batches. - Reduce each first element in `value_destination_pairs` to each second - element which indicates the destinations. - - This can be faster than multiple individual `reduce`s because we can - fuse several tensors into one or multiple packs before reduction. + See `tf.distribute.StrategyExtended.batch_reduce_to`. This can only be + called in the cross-replica context. Args: - reduce_op: An instance of `tf.distribute.ReduceOp` that indicates how the - `per_replica_value` will be reduced. - value_destination_pairs: A list or a tuple of PerReplica objects (or - tensors with device set if there is one device) and destinations. - experimental_hints: A `tf.distrbute.experimental.CollectiveHints`. Hints - to perform collective operations. + reduce_op: a `tf.distribute.ReduceOp` specifying how values should be + combined. + value_destination_pairs: a sequence of (value, destinations) pairs. See + `tf.distribute.CrossDeviceOps.reduce` for descriptions. + experimental_hints: a `tf.distribute.experimental.CollectiveHints`. See + `tf.distribute.experimental.CollectiveHints` for details. Returns: - a list of Mirrored objects. + A list of `tf.Tensor` or `tf.distribute.DistributedValues`, one per pair + in `value_destination_pairs`. Raises: ValueError: if `value_destination_pairs` is not an iterable of - tuples of PerReplica objects and destinations. + tuples of `tf.distribute.DistributedValues` and destinations. """ # TODO(yuefengz): if destinations are different, split into several # `_batch_reduce` invocations. @@ -323,14 +338,20 @@ class CrossDeviceOps(object): experimental_hints) def broadcast(self, tensor, destinations): - """Broadcast the `tensor` to destinations. + """Broadcast `tensor` to `destinations`. + + This can only be called in the cross-replica context. Args: - tensor: the tensor to broadcast. - destinations: the broadcast destinations. + tensor: a `tf.Tensor` like object. The value to broadcast. + destinations: a `tf.distribute.DistributedValues`, a `tf.Variable`, a + `tf.Tensor` alike object, or a device string. It specifies the devices + to broadcast to. Note that if it's a `tf.Variable`, the value is + broadcasted to the devices of that variable, this method doesn't update + the variable. Returns: - a Mirrored object. + A `tf.Tensor` or `tf.distribute.DistributedValues`. """ validate_destinations(destinations) return self.broadcast_implementation(tensor, destinations) @@ -338,27 +359,31 @@ class CrossDeviceOps(object): @doc_controls.for_subclass_implementers def reduce_implementation(self, reduce_op, per_replica_value, destinations, experimental_hints): - """The implementation of reduce of `per_replica_value` to `destinations`. + """Implementation of `reduce`. Overriding this method is useful for subclass implementers. - It runs the reduction operation defined by `reduce_op` and put the - result on `destinations`. - Args: - reduce_op: An instance `tf.distribute.ReduceOp` that indicates of how - per_replica_value will be reduced. - per_replica_value: A PerReplica object or a tensor with device set. - destinations: the reduction destinations. - experimental_hints: A `tf.distrbute.experimental.CollectiveHints`. Hints - to perform collective operations. + reduce_op: a `tf.distribute.ReduceOp` specifying how values should be + combined. + per_replica_value: a `tf.distribute.DistributedValues`, or a `tf.Tensor` + like object. + destinations: a `tf.distribute.DistributedValues`, a `tf.Variable`, a + `tf.Tensor` alike object, or a device string. It specifies the devices + to reduce to. To perform an all-reduce, pass the same to `value` and + `destinations`. Note that if it's a `tf.Variable`, the value is reduced + to the devices of that variable, this method doesn't update the + variable. + experimental_hints: a `tf.distribute.experimental.CollectiveHints`. See + `tf.distribute.experimental.CollectiveHints` for details. Returns: - a Mirrored object. + A `tf.Tensor` or `tf.distribute.DistributedValues`. Raises: - ValueError: if per_replica_value can't be converted to a PerReplica - object. + ValueError: if per_replica_value can't be converted to a + `tf.distribute.DistributedValues` or if destinations is not a string, + `tf.Variable` or `tf.distribute.DistributedValues`. """ raise NotImplementedError( "_reduce method must be implemented in descendants.") @@ -366,27 +391,25 @@ class CrossDeviceOps(object): @doc_controls.for_subclass_implementers def batch_reduce_implementation(self, reduce_op, value_destination_pairs, experimental_hints): - """Implementation of reduce PerReplica objects in a batch. + """Implementation of `batch_reduce`. Overriding this method is useful for subclass implementers. - Reduce each first element in `value_destination_pairs` to each second - element which indicates the destinations. - Args: - reduce_op: An instance of `tf.distribute.ReduceOp` that indicates how - per_replica_value will be reduced. - value_destination_pairs: An iterable of tuples of PerReplica objects - (or tensors with device set if there is one device) and destinations. - experimental_hints: A `tf.distrbute.experimental.CollectiveHints`. Hints + reduce_op: a `tf.distribute.ReduceOp` specifying how values should be + combined. + value_destination_pairs: a sequence of (value, destinations) pairs. See + `reduce` for descriptions. + experimental_hints: a `tf.distribute.experimental.CollectiveHints`. Hints to perform collective operations. Returns: - a list of Mirrored objects. + A list of `tf.Tensor` or `tf.distribute.DistributedValues`, one per pair + in `value_destination_pairs`. Raises: ValueError: if `value_destination_pairs` is not an iterable of - tuples of PerReplica objects and destinations + tuples of `tf.distribute.DistributedValues` and destinations. """ raise NotImplementedError( "batch_reduce_implementation method must be implemented in descendants." @@ -394,26 +417,36 @@ class CrossDeviceOps(object): @doc_controls.for_subclass_implementers def broadcast_implementation(self, tensor, destinations): - """Implementation of broadcast the `tensor` to destinations. + """Implementation of `broadcast`. Args: - tensor: the tensor to broadcast. - destinations: the broadcast destinations. + tensor: a `tf.Tensor` like object. The value to broadcast. + destinations: a `tf.distribute.DistributedValues`, a `tf.Variable`, a + `tf.Tensor` alike object, or a device string. It specifies the devices + to broadcast to. + `destinations`. Note that if it's a `tf.Variable`, the value is + broadcasted to the devices of that variable, this method doesn't update + the variable. Returns: - a Mirrored object. + A `tf.Tensor` or `tf.distribute.DistributedValues`. """ return simple_broadcast(tensor, destinations, always_mirrored=True) @tf_export("distribute.ReductionToOneDevice") class ReductionToOneDevice(CrossDeviceOps): - """Always do reduction to one device first and then do broadcasting. + """A CrossDeviceOps implementation that copies values to one device to reduce. - Batch reduction is done by reduction on each element one by one. + This implementation always copies values to one device to reduce them, then + broadcast reduced values to the destinations. It doesn't support efficient + batching. + + Here is how you can use `ReductionToOneDevice` in + `tf.distribute.MirroredStrategy`: ``` - mirrored_strategy = tf.distribute.MirroredStrategy( + strategy = tf.distribute.MirroredStrategy( cross_device_ops=tf.distribute.ReductionToOneDevice()) ``` """ @@ -423,8 +456,8 @@ class ReductionToOneDevice(CrossDeviceOps): Args: reduce_to_device: the intermediate device to reduce to. If None, reduce - to the first device in `destinations` of the `reduce()` method. - accumulation_fn: a function that does accumulation. If None, then + to the first device in `destinations` of the `reduce` method. + accumulation_fn: a function that does accumulation. If None, `tf.math.add_n` is used. """ self.reduce_to_device = reduce_to_device @@ -641,18 +674,24 @@ def _unpack_tensors(reduced, tensor_packer=None): class AllReduceCrossDeviceOps(CrossDeviceOps): - """Reduction using all-reduce.""" + """All-reduce implementation of CrossDeviceOps. + + It performs all-reduce when applicable using NCCL or hierarchical copy. For + the batch API, tensors will be repacked or aggregated for more efficient + cross-device transportation. + + For reduces that are not all-reduce, it falls back to + `tf.distribute.ReductionToOneDevice`. + """ def __init__(self, all_reduce_alg="nccl", num_packs=1): - """All-reduce implementation of CrossDeviceOps. - - Before performing all-reduce, tensors will be packed for more efficient - cross-device transportation. + """Initializes the object. Args: all_reduce_alg: the all-reduce algorithm to use, currently only "nccl" or "hierarchical_copy" are supported. - num_packs: If non-zero, pack values into `num_packs` splits. + num_packs: a non-negative integer. The number of packs to split values + into. If zero, no packing will be done. """ self._all_reduce_alg = all_reduce_alg self._num_packs = num_packs @@ -746,21 +785,32 @@ AllReduceSpecTuple = collections.namedtuple("AllReduceSpecTuple", @tf_export("distribute.NcclAllReduce") class NcclAllReduce(AllReduceCrossDeviceOps): - """Reduction using NCCL all-reduce.""" + """NCCL all-reduce implementation of CrossDeviceOps. + + It uses Nvidia NCCL for all-reduce. For the batch API, tensors will be + repacked or aggregated for more efficient cross-device transportation. + + For reduces that are not all-reduce, it falls back to + `tf.distribute.ReductionToOneDevice`. + + Here is how you can use `NcclAllReduce` in `tf.distribute.MirroredStrategy`: + + + ``` + strategy = tf.distribute.MirroredStrategy( + cross_device_ops=tf.distribute.NcclAllReduce()) + ``` + """ def __init__(self, num_packs=1): - """NCCL all-reduce implementation of CrossDeviceOps. - - It uses Nvidia NCCL for all-reduce. Before performing all-reduce, tensors - will be repacked or aggregated for more efficient cross-device - transportation. + """Initializes the object. Args: - num_packs: values will be packed in this many splits. `num_packs` should - be greater than or equals 0. When it is zero, no packing will be done. + num_packs: a non-negative integer. The number of packs to split values + into. If zero, no packing will be done. Raises: - ValueError if `num_packs` is negative. + ValueError: if `num_packs` is negative. """ if num_packs < 0: raise ValueError( @@ -772,23 +822,34 @@ class NcclAllReduce(AllReduceCrossDeviceOps): @tf_export("distribute.HierarchicalCopyAllReduce") class HierarchicalCopyAllReduce(AllReduceCrossDeviceOps): - """Reduction using hierarchical copy all-reduce. + """Hierarchical copy all-reduce implementation of CrossDeviceOps. It reduces to one GPU along edges in some hierarchy and broadcasts back to - each GPU along the same path. Before performing all-reduce, tensors will be - repacked or aggregated for more efficient cross-device transportation. + each GPU along the same path. For the batch API, tensors will be repacked or + aggregated for more efficient cross-device transportation. This is a reduction created for Nvidia DGX-1 which assumes GPUs connects like that on DGX-1 machine. If you have different GPU inter-connections, it is likely that it would be slower than `tf.distribute.ReductionToOneDevice`. + + For reduces that are not all-reduce, it falls back to + `tf.distribute.ReductionToOneDevice`. + + Here is how you can use `HierarchicalCopyAllReduce` in + `tf.distribute.MirroredStrategy`: + + ``` + strategy = tf.distribute.MirroredStrategy( + cross_device_ops=tf.distribute.HierarchicalCopyAllReduce()) + ``` """ def __init__(self, num_packs=1): """Initializes the object. Args: - num_packs: values will be packed in this many splits. `num_packs` should - be greater than or equals 0. When it is zero, no packing will be done. + num_packs: a non-negative integer. The number of packs to split values + into. If zero, no packing will be done. Raises: ValueError if `num_packs` is negative. @@ -802,117 +863,6 @@ class HierarchicalCopyAllReduce(AllReduceCrossDeviceOps): num_packs=num_packs) -class MultiWorkerAllReduce(AllReduceCrossDeviceOps): - """All-reduce algorithms for distributed TensorFlow.""" - - def __init__(self, - worker_devices, - num_gpus_per_worker, - all_reduce_spec=("pscpu/pscpu", 2, -1), - num_packs=0): - """Initialize the all-reduce algorithm. - - Args: - worker_devices: a list of device strings for workers participating in - all-reduce. - num_gpus_per_worker: number of GPU devices per worker. - all_reduce_spec: a tuple or a named tuple or a list of tuples specifying - the all-reduce algorithm. - 1. The first element of a tuple is the name of the all-reduce algorithm. - Valid algorithm names are: "nccl", "nccl/xring", "nccl/rechd", - "nccl/pscpu", "xring", "pscpu", "psgpu", "pscpu/pscpu". Algorithms with - a "/" are hierarchical, so two all-reduces are executed, the first one - aggregates tensors within a worker and the second aggregates across - workers. - 2. The second element of a tuple is the number of shards when doing - all-reduce. Let's say its values is M, each tensor after packing will be - split into M shards and then M parallel all-reduces would be performed - before finally they are concatenated backed into a complete tensor. - 3. The third element is the maximum size of tensors that will be - applicable for the algorithm specified by the first element. For - example, if all_reduce_spec=[("nccl", 2, 1024), ("pscpu/pscpu", 2, -1)], - tensors with size not larger than 1024 bytes will be applied a 2-shard - "nccl" all-reduce and other tensors will be applied a 2-shard - "pscpu/pscpu" algorithm. The third elements should be in increasing - order across tuples and end with -1 which indicates infinity. - num_packs: see AllReduceCrossDeviceOps. - """ - self._worker_devices = worker_devices - self._num_gpus_per_worker = num_gpus_per_worker - super(MultiWorkerAllReduce, self).__init__(num_packs=num_packs) - - def validate_and_complete_spec(spec): - """Validate and complete the all-reduce spec.""" - # TODO(yuefengz): support namedtuple. - if not isinstance(spec, tuple): - raise ValueError( - "A tuple is expected for all-reduce spec: %r" % all_reduce_spec) - if not spec or len(spec) > 3: - raise ValueError( - "Too many elements in the all-reduce spec tuple: %r" % spec) - if len(spec) == 1: - return AllReduceSpecTuple(spec[0], 1, -1) - elif len(spec) == 2: - return AllReduceSpecTuple(spec[0], spec[1], -1) - else: - return AllReduceSpecTuple(*spec) - - self._all_reduce_spec = [] - if isinstance(all_reduce_spec, six.string_types): - self._all_reduce_spec.append(AllReduceSpecTuple(all_reduce_spec, 1, -1)) - elif isinstance(all_reduce_spec, tuple): - self._all_reduce_spec.append(validate_and_complete_spec(all_reduce_spec)) - elif isinstance(all_reduce_spec, list): - self._all_reduce_spec = [ - validate_and_complete_spec(spec) for spec in all_reduce_spec - ] - - def _batch_all_reduce(self, reduce_op, per_replica_values): - """All-reduce algorithm in a batch.""" - logging.log_first_n( - logging.INFO, "Distributed batch_all_reduce: %d all-reduces with " - "allreduce_spec = %r, num_packs = %d" % - (len(per_replica_values), self._all_reduce_spec, self._num_packs), 10) - - device_grads = _group_value_by_device(per_replica_values) - - # The all-reduce library requires fully defined shapes. - # TODO(yuefengz): when tensor sharding is not needed, static shapes are not - # required as well. - for device_grad in device_grads: - for grad, _ in device_grad: - if not grad.shape.is_fully_defined(): - raise ValueError("Shape is unknown for node %r" % grad) - - remaining_grads = device_grads - aggregated_grads = [] - for spec_tuple in self._all_reduce_spec: - if spec_tuple.limit < 0: - this_grads = remaining_grads - remaining_grads = [] - else: - (this_grads, remaining_grads) = cross_device_utils.split_grads_by_size( - spec_tuple.limit, remaining_grads) - if this_grads: - device_grad_packs, tensor_packer = _pack_tensors( - this_grads, self._num_packs) - range_agg_grads = cross_device_utils.sum_gradients_all_reduce( - self._worker_devices, device_grad_packs, len(self._worker_devices), - spec_tuple.alg, spec_tuple.shards, range(self._num_gpus_per_worker)) - range_agg_grads = _unpack_tensors(range_agg_grads, tensor_packer) - - if not aggregated_grads: - aggregated_grads = range_agg_grads - else: - assert len(aggregated_grads) == len(range_agg_grads) - for i, range_agg_grad in enumerate(range_agg_grads): - aggregated_grads[i] += range_agg_grad - assert not remaining_grads - - return _ungroup_and_make_mirrored(aggregated_grads, per_replica_values[0], - reduce_op) - - @tf_export("distribute.experimental.CollectiveCommunication") class CollectiveCommunication(enum.Enum): """Communication choices for CollectiveOps. diff --git a/tensorflow/python/distribute/cross_device_ops_test.py b/tensorflow/python/distribute/cross_device_ops_test.py index 967de7d8426..557c601e0cd 100644 --- a/tensorflow/python/distribute/cross_device_ops_test.py +++ b/tensorflow/python/distribute/cross_device_ops_test.py @@ -433,55 +433,6 @@ class SingleWorkerCrossDeviceOpsTest(CrossDeviceOpsTestBase): self.assertAllEqual(self.evaluate(result.values), [1.0, 1.0]) -class MultiWorkerCrossDeviceOpsTest(multi_worker_test_base.MultiWorkerTestBase, - CrossDeviceOpsTestBase): - - worker_devices = [ - "/job:worker/replica:0/task:0", "/job:worker/replica:0/task:1" - ] - multi_worker_allreduce_combinations = combinations.combine( - cross_device_ops=[ - combinations.NamedObject( - "MultiWorkerAllReduce", - cross_device_ops_lib.MultiWorkerAllReduce(worker_devices, 2, - ("pscpu/pscpu", 2, -1), - 0)), - combinations.NamedObject( - "MultiWorkerAllReducePack", - cross_device_ops_lib.MultiWorkerAllReduce(worker_devices, 2, - ("pscpu/pscpu", 2, -1), - 1)), - combinations.NamedObject( - "MultiWorkerAllReduceMultipleSpecs", - cross_device_ops_lib.MultiWorkerAllReduce( - worker_devices, 2, [("pscpu/pscpu", 2, 100), - ("xring", 2, -1)], 0)), - ], - devices=[ - [ - "/job:worker/replica:0/task:0/device:CPU:0", - "/job:worker/replica:0/task:1/device:CPU:0" - ], - [ - "/job:worker/replica:0/task:0/device:GPU:0", - "/job:worker/replica:0/task:1/device:GPU:0" - ], - [ - "/job:worker/replica:0/task:0/device:GPU:0", - "/job:worker/replica:0/task:0/device:GPU:1", - "/job:worker/replica:0/task:1/device:GPU:0", - "/job:worker/replica:0/task:1/device:GPU:1" - ], - ], - mode=["graph"]) - - @combinations.generate(multi_worker_allreduce_combinations) - def testReductionAndBroadcast(self, cross_device_ops, devices): - # Mimic the default device of multi-worker strategies. - with ops.device("/job:worker/replica:0/task:0"): - self._testReductionAndBroadcast(cross_device_ops, devices) - - NUM_WORKERS = 3 CollectiveCommunication = cross_device_ops_lib.CollectiveCommunication diff --git a/tensorflow/python/distribute/cross_device_utils.py b/tensorflow/python/distribute/cross_device_utils.py index a8d4d176ab9..bf7b41368f0 100644 --- a/tensorflow/python/distribute/cross_device_utils.py +++ b/tensorflow/python/distribute/cross_device_utils.py @@ -18,16 +18,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import collections as pycoll import copy import threading -from tensorflow.python.distribute import all_reduce from tensorflow.python.distribute import values as value_lib from tensorflow.python.eager import backprop from tensorflow.python.eager import context from tensorflow.python.framework import device as pydev -from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import collective_ops @@ -171,65 +168,6 @@ def aggregate_single_gradient_using_copy(grad_and_vars, use_mean, return (grad, v), None -def group_device_names(devices, group_size): - """Group device names into groups of group_size. - - Args: - devices: a list of canonical device strings. - group_size: integer which is equal to or greater than 1. - - Returns: - list of lists of devices, where each inner list is group_size long, - and each device appears at least once in an inner list. If - len(devices) % group_size == 0 then each device will appear exactly once. - - Raises: - ValueError: if group_size > len(devices) - """ - num_devices = len(devices) - if group_size > num_devices: - raise ValueError( - 'only %d devices, but group_size=%d' % (num_devices, group_size)) - num_groups = ( - num_devices // group_size + (1 if (num_devices % group_size != 0) else 0)) - groups = [[] for i in range(num_groups)] - for i in range(num_groups * group_size): - groups[i % num_groups].append(devices[i % num_devices]) - return groups - - -def split_grads_by_size(threshold_size, device_grads): - """Break gradients into two sets according to tensor size. - - Args: - threshold_size: int size cutoff for small vs large tensor. - device_grads: List of lists of (gradient, variable) tuples. The outer - list is over devices. The inner list is over individual gradients. - - Returns: - small_grads: Subset of device_grads where shape is <= threshold_size - elements. - large_grads: Subset of device_grads where shape is > threshold_size - elements. - """ - small_grads = [] - large_grads = [] - for dl in device_grads: - small_dl = [] - large_dl = [] - for (g, v) in dl: - tensor_size = g.get_shape().num_elements() - if tensor_size <= threshold_size: - small_dl.append([g, v]) - else: - large_dl.append([g, v]) - if small_dl: - small_grads.append(small_dl) - if large_dl: - large_grads.append(large_dl) - return small_grads, large_grads - - # TODO(yuefengz): use random key starts to avoid reusing keys? class CollectiveKeys(object): """Class that manages collective keys. @@ -580,272 +518,6 @@ def build_collective_gather_indexed_slices(input_slices_list, return out_slices_list -def sum_grad_and_var_all_reduce(grad_and_vars, - num_workers, - alg, - gpu_indices, - aux_devices=None, - num_shards=1): - """Apply all-reduce algorithm over specified gradient tensors.""" - with ops.name_scope('allreduce'): - # Note that each grad_and_vars looks like the following: - # ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN)) - scaled_grads = [g for g, _ in grad_and_vars] - if alg == 'nccl': - summed_grads = nccl_ops.all_sum(scaled_grads) - elif alg == 'xring': - summed_grads = all_reduce.build_ring_all_reduce( - scaled_grads, num_workers, num_shards, gpu_indices, math_ops.add) - elif alg == 'nccl/xring': - summed_grads = all_reduce.build_nccl_then_ring(scaled_grads, num_shards, - math_ops.add) - elif alg == 'nccl/rechd': - summed_grads = all_reduce.build_nccl_then_recursive_hd( - scaled_grads, math_ops.add) - elif alg == 'nccl/pscpu': - summed_grads = all_reduce.build_nccl_then_shuffle( - scaled_grads, aux_devices, math_ops.add, math_ops.add_n) - elif alg == 'pscpu/pscpu': - second_gather_devices = aux_devices[:num_shards] - summed_grads = all_reduce.build_shuffle_then_shuffle( - scaled_grads, aux_devices, second_gather_devices, math_ops.add_n) - elif alg in ['pscpu', 'psgpu']: - summed_grads = all_reduce.build_shuffle_all_reduce( - scaled_grads, aux_devices, math_ops.add_n) - else: - raise ValueError('unsupported all_reduce alg: ', alg) - - result = [] - for (_, v), g in zip(grad_and_vars, summed_grads): - result.append([g, v]) - return result - - -def sum_gradients_all_reduce(dev_prefixes, replica_grads, num_workers, alg, - num_shards, gpu_indices): - """Apply all-reduce algorithm over specified gradient tensors. - - Args: - dev_prefixes: list of prefix strings to use to generate PS device names. - replica_grads: the gradients to reduce. - num_workers: number of worker processes across entire job. - alg: the all-reduce algorithm to apply. - num_shards: alg-specific sharding factor. - gpu_indices: indices of local GPUs in order usable for ring-reduce. - - Returns: - list of reduced tensors - """ - alg_contains_shuffle = any(n in alg for n in ['pscpu', 'psgpu']) - is_hierarchical = '/' in alg - if 'pscpu' in alg: - aux_devices = [prefix + '/cpu:0' for prefix in dev_prefixes] - elif 'psgpu' in alg: - aux_devices = [ - prefix + '/gpu:%d' % i - for i in range(len(gpu_indices)) - for prefix in dev_prefixes - ] - else: - aux_devices = ['/job:localhost/cpu:0'] - # Auxiliary devices for hierarchical all-reduces. - aux_device_groups = group_device_names( - aux_devices, num_shards if alg_contains_shuffle else 1) - group_index = 0 - reduced_gv_list = [] - for grad_and_vars in zip(*replica_grads): - reduced_gv_list.append( - sum_grad_and_var_all_reduce( - grad_and_vars, num_workers, alg, gpu_indices, aux_devices - if is_hierarchical else aux_device_groups[group_index], num_shards)) - group_index = (group_index + 1) % len(aux_device_groups) - new_replica_grads = [list(x) for x in zip(*reduced_gv_list)] - return new_replica_grads - - -def extract_ranges(index_list, range_size_limit=32): - """Extract consecutive ranges and singles from index_list. - - Args: - index_list: List of monotone increasing non-negative integers. - range_size_limit: Largest size range to return. If a larger - consecutive range exists, it will be returned as multiple - ranges. - - Returns: - (ranges, singles) where ranges is a list of [first, last] pairs of - consecutive elements in index_list, and singles is all of the - other elements, in original order. - """ - if not index_list: - return [], [] - first = index_list[0] - last = first - ranges = [] - singles = [] - for i in index_list[1:]: - if i == last + 1 and (last - first) <= range_size_limit: - last = i - else: - if last > first: - ranges.append([first, last]) - else: - singles.append(first) - first = i - last = i - if last > first: - ranges.append([first, last]) - else: - singles.append(first) - return ranges, singles - - -GradPackTuple = pycoll.namedtuple('GradPackTuple', 'indices vars shapes') - - -def pack_range(key, packing, grad_vars, rng): - """Form the concatenation of a specified range of gradient tensors. - - Args: - key: Value under which to store meta-data in packing that will be used - later to restore the grad_var list structure. - packing: Dict holding data describing packed ranges of small tensors. - grad_vars: List of (grad, var) pairs for one replica. - rng: A pair of integers giving the first, last indices of a consecutive - range of tensors to be packed. - - Returns: - A tensor that is the concatenation of all the specified small tensors. - """ - to_pack = grad_vars[rng[0]:rng[1] + 1] - members = [] - variables = [] - restore_shapes = [] - with ops.name_scope('pack'): - for g, v in to_pack: - variables.append(v) - restore_shapes.append(g.shape) - with ops.device(g.device): - members.append(array_ops.reshape(g, [-1])) - packing[key] = GradPackTuple( - indices=range(rng[0], rng[1] + 1), - vars=variables, - shapes=restore_shapes) - with ops.device(members[0].device): - return array_ops.concat(members, 0) - - -def unpack_grad_tuple(gv, gpt): - """Unpack a previously packed collection of gradient tensors. - - Args: - gv: A (grad, var) pair to be unpacked. - gpt: A GradPackTuple describing the packing operation that produced gv. - - Returns: - A list of (grad, var) pairs corresponding to the values that were - originally packed into gv, maybe following subsequent operations like - reduction. - """ - elt_widths = [x.num_elements() for x in gpt.shapes] - with ops.device(gv[0].device): - with ops.name_scope('unpack'): - splits = array_ops.split(gv[0], elt_widths) - unpacked_gv = [] - for idx, s in enumerate(splits): - unpacked_gv.append((array_ops.reshape(s, gpt.shapes[idx]), - gpt.vars[idx])) - return unpacked_gv - - -def pack_small_tensors(replica_grads, max_bytes=0, max_group=0): - """Concatenate small gradient tensors together for reduction. - - Args: - replica_grads: List of lists of (gradient, variable) tuples. - max_bytes: Int giving max number of bytes in a tensor that - may be considered small. - max_group: Int giving max number of small tensors that may be - concatenated into one new tensor. - - Returns: - new_replica_grads, packing where new_replica_grads is identical to - replica_grads except that all feasible small_tensors have been removed - from their places and concatenated into larger tensors that are - now in the front of the list for each replica, and packing contains - the data necessary to restore the replica_grads structure. - - Look through the first replica for gradients of the same type (float), - and small size, that are all sequential. For each such group, - replace by a new tensor that is a flattened concatenation. Note - that the corresponding variable will be absent, which doesn't matter - because it isn't used during all-reduce. - - Requires: - Every gv_list in replicas must have isomorphic structure including identical - tensor sizes and types. - """ - small_indices = [] - large_indices = [] - for idx, (g, _) in enumerate(replica_grads[0]): - if g.dtype == dtypes.float32 and (4 * g.shape.num_elements()) <= max_bytes: - small_indices.append(idx) - else: - large_indices.append(idx) - small_ranges, small_singles = extract_ranges( - small_indices, range_size_limit=max_group) - large_indices = sorted(large_indices + small_singles) - num_gv = len(replica_grads[0]) - packing = {} - if small_ranges: - new_replica_grads = [] - for dev_idx, gv_list in enumerate(replica_grads): - assert len(gv_list) == num_gv - new_gv_list = [] - for r in small_ranges: - key = '%d:%d' % (dev_idx, len(new_gv_list)) - new_gv_list.append((pack_range(key, packing, gv_list, r), - 'packing_var_placeholder')) - for i in large_indices: - new_gv_list.append(gv_list[i]) - new_replica_grads.append(new_gv_list) - return new_replica_grads, packing - else: - return replica_grads, None - - -def unpack_small_tensors(replica_grads, packing): - """Undo the structure alterations to replica_grads done by pack_small_tensors. - - Args: - replica_grads: List of List of (grad, var) tuples. - packing: A dict generated by pack_small_tensors describing the changes - it made to replica_grads. - - Returns: - new_replica_grads: identical to replica_grads except that concatenations - of small tensors have been split apart and returned to their original - positions, paired with their original variables. - """ - if not packing: - return replica_grads - new_replica_grads = [] - num_devices = len(replica_grads) - num_packed = len(packing.keys()) // num_devices - for dev_idx, gv_list in enumerate(replica_grads): - gv_list = list(gv_list) - new_gv_list = gv_list[num_packed:] - for i in range(num_packed): - k = '%d:%d' % (dev_idx, i) - gpt = packing[k] - gv = unpack_grad_tuple(gv_list[i], gpt) - for gi, idx in enumerate(gpt.indices): - assert idx == gpt.indices[gi] - new_gv_list.insert(idx, gv[gi]) - new_replica_grads.append(new_gv_list) - return new_replica_grads - - def aggregate_tensors_or_indexed_slices(values, accumulation_fn=math_ops.add_n): """Aggregate tensors using `accumulation_fn` and IndexedSlices via concat.""" if any(isinstance(v, ops.IndexedSlices) for v in values): @@ -875,18 +547,6 @@ def copy_tensor_or_indexed_slices_to_device(value, device): return result -def contains_indexed_slices(value): - """Check whether the value is `IndexedSlices` or contains `IndexedSlices`.""" - if isinstance(value, ops.IndexedSlices): - return True - elif isinstance(value, (list, tuple)) and value: - return any(contains_indexed_slices(v) for v in value) - elif isinstance(value, value_lib.DistributedValues): - return contains_indexed_slices(value.values) - else: - return False - - def is_indexed_slices(value): if isinstance(value, ops.IndexedSlices): return True diff --git a/tensorflow/python/distribute/cross_device_utils_test.py b/tensorflow/python/distribute/cross_device_utils_test.py index 9781bf67566..626ec5cfd60 100644 --- a/tensorflow/python/distribute/cross_device_utils_test.py +++ b/tensorflow/python/distribute/cross_device_utils_test.py @@ -81,32 +81,7 @@ class IndexedSlicesUtilsTest(test.TestCase, parameterized.TestCase): def testIsIndexedSlices(self): t = math_ops._as_indexed_slices( constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) - self.assertTrue(cross_device_utils.contains_indexed_slices(t)) - - @test_util.run_in_graph_and_eager_modes - def testContainsIndexedSlices_List(self): - t0 = math_ops._as_indexed_slices( - constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) - t1 = math_ops._as_indexed_slices( - constant_op.constant([[0., 0.], [5, 6], [7., 8.]])) - self.assertTrue(cross_device_utils.contains_indexed_slices([t0, t1])) - - @test_util.run_in_graph_and_eager_modes - def testContainsIndexedSlices_Tuple(self): - t0 = math_ops._as_indexed_slices( - constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) - t1 = math_ops._as_indexed_slices( - constant_op.constant([[0., 0.], [5, 6], [7., 8.]])) - self.assertTrue(cross_device_utils.contains_indexed_slices((t0, t1))) - - @test_util.run_in_graph_and_eager_modes - def testContainsIndexedSlices_PerReplica(self): - t0 = math_ops._as_indexed_slices( - constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) - t1 = math_ops._as_indexed_slices( - constant_op.constant([[0., 0.], [5, 6], [7., 8.]])) - per_replica = value_lib.PerReplica((t0, t1)) - self.assertTrue(cross_device_utils.contains_indexed_slices(per_replica)) + self.assertTrue(cross_device_utils.is_indexed_slices(t)) @combinations.generate(combinations.combine( mode=["graph", "eager"], diff --git a/tensorflow/python/distribute/distribute_lib.py b/tensorflow/python/distribute/distribute_lib.py index 173caa364a9..bd24ec0145d 100644 --- a/tensorflow/python/distribute/distribute_lib.py +++ b/tensorflow/python/distribute/distribute_lib.py @@ -2332,16 +2332,18 @@ class StrategyExtendedV2(object): Args: - reduce_op: a `tf.distribute.ReduceOp` or string. How to reduce the value. - value: a `tf.distribute.DistributedValue`, or a `tf.Tensor` like object. - destinations: a `tf.distribute.DistributedValue`, a `tf.Variable`, a + reduce_op: a `tf.distribute.ReduceOp` value specifying how values should + be combined. Allows using string representation of the enum such as + "SUM", "MEAN". + value: a `tf.distribute.DistributedValues`, or a `tf.Tensor` like object. + destinations: a `tf.distribute.DistributedValues`, a `tf.Variable`, a `tf.Tensor` alike object, or a device string. It specifies the devices to reduce to. To perform an all-reduce, pass the same to `value` and `destinations`. Note that if it's a `tf.Variable`, the value is reduced - to the devices of that variable, this method doesn't update the variable. - experimental_hints: a `tf.distrbute.experimental.CollectiveHints`. Hints - to perform collective operations. See - `tf.distrbute.experimental.CollectiveHints` for details. + to the devices of that variable, and this method doesn't update the + variable. + experimental_hints: a `tf.distribute.experimental.CollectiveHints`. See + `tf.distribute.experimental.CollectiveHints` for details. Returns: A tensor or value reduced to `destinations`. @@ -2413,11 +2415,13 @@ class StrategyExtendedV2(object): Args: - reduce_op: a `tf.distribute.ReduceOp`. How to reduce the value. + reduce_op: a `tf.distribute.ReduceOp` value specifying how values should + be combined. Allows using string representation of the enum such as + "SUM", "MEAN". value_destination_pairs: a sequence of (value, destinations) pairs. See - `reduce_to()` for descriptions. - experimental_hints: a `tf.distrbute.experimental.CollectiveHints`. Hints - to perform collective operations. + `tf.distribute.Strategy.reduce_to` for descriptions. + experimental_hints: a `tf.distribute.experimental.CollectiveHints`. See + `tf.distribute.experimental.CollectiveHints` for details. Returns: A list of reduced values, one per pair in `value_destination_pairs`. @@ -3010,32 +3014,64 @@ class ReplicaContext(object): return (device_util.current(),) def all_reduce(self, reduce_op, value, experimental_hints=None): - """All-reduces the given `value Tensor` nest across replicas. + """All-reduces `value` across all replicas. - If `all_reduce` is called in any replica, it must be called in all replicas. - The nested structure and `Tensor` shapes must be identical in all replicas. + >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) + >>> def step_fn(): + ... ctx = tf.distribute.get_replica_context() + ... value = tf.identity(1.) + ... return ctx.all_reduce(tf.distribute.ReduceOp.SUM, value) + >>> strategy.experimental_local_results(strategy.run(step_fn)) + (, + ) - IMPORTANT: The ordering of communications must be identical in all replicas. + It supports batched operations. You can pass a list of values and it + attempts to batch them when possible. You can also specify `experimental_hints` + to indicate the desired batching behavior, e.g. batch the values into + multiple packs so that they can better overlap with computations. - Example with two replicas: - Replica 0 `value`: {'a': 1, 'b': [40, 1]} - Replica 1 `value`: {'a': 3, 'b': [ 2, 98]} + >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) + >>> def step_fn(): + ... ctx = tf.distribute.get_replica_context() + ... value1 = tf.identity(1.) + ... value2 = tf.identity(2.) + ... return ctx.all_reduce(tf.distribute.ReduceOp.SUM, [value1, value2]) + >>> strategy.experimental_local_results(strategy.run(step_fn)) + ([PerReplica:{ + 0: , + 1: + }, PerReplica:{ + 0: , + 1: + }],) - If `reduce_op` == `SUM`: - Result (on all replicas): {'a': 4, 'b': [42, 99]} + Note that all replicas need to participate in the all-reduce, otherwise this + operation hangs. Note that if there're multiple all-reduces, they need to + execute in the same order on all replicas. Dispatching all-reduce based on + conditions is usually error-prone. - If `reduce_op` == `MEAN`: - Result (on all replicas): {'a': 2, 'b': [21, 49.5]} + This API currently can only be called in the replica context. Other + variants to reduce values across replicas are: + * `tf.distribute.StrategyExtended.reduce_to`: the reduce and all-reduce API + in the cross-replica context. + * `tf.distribute.StrategyExtended.batch_reduce_to`: the batched reduce and + all-reduce API in the cross-replica context. + * `tf.distribute.Strategy.reduce`: a more convenient method to reduce + to the host in cross-replica context. Args: - reduce_op: Reduction type, an instance of `tf.distribute.ReduceOp` enum. - value: The nested structure of `Tensor`s to all-reduce. The structure must - be compatible with `tf.nest`. - experimental_hints: A `tf.distrbute.experimental.CollectiveHints`. Hints + reduce_op: a `tf.distribute.ReduceOp` value specifying how values should + be combined. Allows using string representation of the enum such as + "SUM", "MEAN". + value: a nested structure of `tf.Tensor` which `tf.nest.flatten` accepts. + The structure and the shapes of the `tf.Tensor` need to be same on all + replicas. + experimental_hints: a `tf.distrbute.experimental.CollectiveHints`. Hints to perform collective operations. Returns: - A `Tensor` nest with the reduced `value`s from each replica. + A nested structure of `tf.Tensor` with the reduced values. The structure + is the same as `value`. """ if isinstance(reduce_op, six.string_types): reduce_op = reduce_util.ReduceOp(reduce_op.upper()) diff --git a/tensorflow/python/distribute/input_lib.py b/tensorflow/python/distribute/input_lib.py index 1c79f9552fe..d689346870e 100644 --- a/tensorflow/python/distribute/input_lib.py +++ b/tensorflow/python/distribute/input_lib.py @@ -61,7 +61,7 @@ from tensorflow.tools.docs import doc_controls def get_distributed_dataset(dataset, input_workers, strategy, - num_replicas_in_sync=None, + split_batch_by=None, input_context=None): """Returns a distributed dataset from the given tf.data.Dataset instance. @@ -77,10 +77,8 @@ def get_distributed_dataset(dataset, iterators should be created. strategy: a `tf.distribute.Strategy` object, used to run all-reduce to handle last partial batch. - num_replicas_in_sync: Optional integer. If this is not None, the value is - used to decide how to rebatch datasets into smaller batches so that - the total batch size for each step (across all workers and replicas) - adds up to `dataset`'s batch size. + split_batch_by: Optional integer. If present, we "split" each batch of the + dataset by `split_batch_by` value. input_context: `InputContext` for sharding. Only pass this in for between graph multi-worker cases where there is only one `input_worker`. In these cases, we will shard based on the `input_pipeline_id` and @@ -94,14 +92,14 @@ def get_distributed_dataset(dataset, dataset, input_workers, strategy, - num_replicas_in_sync=num_replicas_in_sync, + split_batch_by=split_batch_by, input_context=input_context) else: return DistributedDatasetV1( dataset, input_workers, strategy, - num_replicas_in_sync=num_replicas_in_sync, + split_batch_by=split_batch_by, input_context=input_context) @@ -501,21 +499,27 @@ class InputWorkers(object): return InputWorkers(worker_device_pairs) -def _get_next_as_optional(iterator, strategy, name=None): - """Returns an empty dataset indicator and the next input from the iterator.""" +def _get_next_as_optional(iterator, strategy, return_per_replica=False): + """Returns an empty dataset indicator and the next input from the iterator. + + Args: + iterator: a DistributedIterator object. + strategy: the `tf.distribute.Strategy` instance. + return_per_replica: a boolean. If True, the returned data will be wrapped + with `PerReplica` structure. Otherwise it is a 2D + num_input_workers*num_replicas_per_worker list. + + Returns: + A tuple (a boolean tensor indicating whether the next batch has value + globally, data from all replicas). + """ replicas = [] worker_has_values = [] worker_devices = [] for i, worker in enumerate(iterator._input_workers.worker_devices): # pylint: disable=protected-access - if name is not None: - d = tf_device.DeviceSpec.from_string(worker) - new_name = "%s_%s_%d" % (name, d.job, d.task) - else: - new_name = None - with ops.device(worker): worker_has_value, next_element = ( - iterator._iterators[i].get_next_as_list(new_name)) # pylint: disable=protected-access + iterator._iterators[i].get_next_as_list()) # pylint: disable=protected-access # Collective all-reduce requires explicit devices for inputs. with ops.device("/cpu:0"): # Converting to integers for all-reduce. @@ -525,6 +529,12 @@ def _get_next_as_optional(iterator, strategy, name=None): # Make `replicas` a flat list of values across all replicas. replicas.append(next_element) + if return_per_replica: + flattened_data = [] + for per_worker_data in replicas: + flattened_data.extend(per_worker_data) + replicas = distribute_utils.regroup(flattened_data) + # Run an all-reduce to see whether any worker has values. # TODO(b/131423105): we should be able to short-cut the all-reduce in some # cases. @@ -624,29 +634,15 @@ class DistributedIteratorBase(DistributedIteratorInterface): return self def get_next_as_optional(self): - global_has_value, replicas = _get_next_as_optional(self, self._strategy) + global_has_value, replicas = _get_next_as_optional( + self, self._strategy, return_per_replica=True) def return_none(): return optional_ops.Optional.empty(self._element_spec) - def return_value(replicas): - """Wraps the inputs for replicas in an `tf.experimental.Optional`.""" - results = [] - for i, worker in enumerate(self._input_workers.worker_devices): - with ops.device(worker): - devices = self._input_workers.compute_devices_for_worker(i) - for j, device in enumerate(devices): - with ops.device(device): - result = replicas[i][j] - results.append(result) - replicas = results - - return optional_ops.Optional.from_value( - distribute_utils.regroup(replicas)) - - return control_flow_ops.cond(global_has_value, - lambda: return_value(replicas), - lambda: return_none()) # pylint: disable=unnecessary-lambda + return control_flow_ops.cond( + global_has_value, lambda: optional_ops.Optional.from_value(replicas), + return_none) def get_next(self, name=None): """Returns the next input from the iterator for all replicas.""" @@ -673,7 +669,8 @@ class DistributedIteratorBase(DistributedIteratorInterface): out_of_range_replicas.append(data) return data - global_has_value, replicas = _get_next_as_optional(self, self._strategy) + global_has_value, replicas = _get_next_as_optional( + self, self._strategy, return_per_replica=False) results = [] for i, worker in enumerate(self._input_workers.worker_devices): with ops.device(worker): @@ -908,7 +905,8 @@ class _IterableInput(DistributedDatasetInterface): def reduce(self, initial_state, reduce_fn): """Execute a `reduce_fn` over all the elements of the input.""" iterator = iter(self) - has_data, data = _get_next_as_optional(iterator, self._strategy) + has_data, data = _get_next_as_optional( + iterator, self._strategy, return_per_replica=True) def cond(has_data, data, state): del data, state # Unused. @@ -917,16 +915,9 @@ class _IterableInput(DistributedDatasetInterface): def loop_body(has_data, data, state): """Executes `reduce_fn` in a loop till the dataset is empty.""" del has_data # Unused. - # data is list of lists here. where each list corresponds to one worker. - # TODO(b/130570614): Add support for the multiworker and TPU pods use - # case. - if self._input_workers.num_workers == 1: - data = data[0] - else: - raise ValueError("Dataset iteration within a tf.function is" - " not supported for multiple workers.") - state = reduce_fn(state, distribute_utils.regroup(data)) - has_data, data = _get_next_as_optional(iterator, self._strategy) + state = reduce_fn(state, data) + has_data, data = _get_next_as_optional( + iterator, self._strategy, return_per_replica=True) return has_data, data, state has_data, data, final_state = control_flow_ops.while_loop( @@ -941,59 +932,61 @@ class DistributedDataset(_IterableInput): dataset, input_workers, strategy, - num_replicas_in_sync=None, + split_batch_by=None, input_context=None): """Distribute the dataset on all workers. - If `num_replicas_in_sync` is not None, we split each batch of the dataset - into `num_replicas_in_sync` smaller batches, to be distributed among that - worker's replicas, so that the batch size for a global step (across all - workers and replicas) is as expected. + If `split_batch_by` is not None, we "split" each batch of the dataset by + `split_batch_by` value. Args: dataset: `tf.data.Dataset` that will be used as the input source. input_workers: an `InputWorkers` object. strategy: a `tf.distribute.Strategy` object, used to run all-reduce to handle last partial batch. - num_replicas_in_sync: Optional integer. If this is not None, the value - is used to decide how to rebatch datasets into smaller batches so that - the total batch size for each step (across all workers and replicas) - adds up to `dataset`'s batch size. + split_batch_by: Optional integer. If present, we "split" each batch of the + dataset by `split_batch_by` value. input_context: `InputContext` for sharding. Only pass this in for between graph multi-worker cases where there is only one `input_worker`. In these cases, we will shard based on the `input_pipeline_id` and `num_input_pipelines` in the `InputContext`. """ super(DistributedDataset, self).__init__(input_workers=input_workers) - # We clone and shard the dataset on each worker. The current setup tries to # shard the dataset by files if possible so that each worker sees a # different subset of files. If that is not possible, will attempt to shard # the final input such that each worker will run the entire preprocessing # pipeline and only receive its own shard of the dataset. - - # Additionally, we rebatch the dataset on each worker into - # `num_replicas_in_sync` smaller batches to be distributed among that - # worker's replicas, so that the batch size for a global step (across all - # workers and replicas) adds up to the original dataset's batch size. - if num_replicas_in_sync is not None: - num_workers = input_context.num_input_pipelines if input_context else len( - input_workers.worker_devices) - rebatch_fn = self._make_rebatch_fn(dataset, num_workers, - num_replicas_in_sync) - else: - rebatch_fn = None + if split_batch_by: + try: + # pylint: disable=protected-access + with ops.colocate_with(dataset._variant_tensor): + dataset = distribute._LegacyRebatchDataset(dataset, split_batch_by) + # Add a prefetch to pipeline rebatching for performance. + # TODO(rachelim): Instead of inserting an extra prefetch stage here, + # leverage static graph rewrites to insert _RebatchDataset before + # the final `prefetch` if it exists. + dataset = dataset.prefetch(split_batch_by) + except errors.InvalidArgumentError as e: + if "without encountering a batch" in str(e): + six.reraise( + ValueError, + ValueError( + "Call the `batch` method on the input Dataset in order to be " + "able to split your input across {} replicas.\n Please " + "the tf.distribute.Strategy guide. {}".format( + split_batch_by, e)), + sys.exc_info()[2]) + else: + raise self._cloned_datasets = [] if input_context: # Between-graph where we rely on the input_context for sharding assert input_workers.num_workers == 1 - if rebatch_fn is not None: - dataset = rebatch_fn(dataset, input_context.input_pipeline_id) dataset = input_ops.auto_shard_dataset(dataset, input_context.num_input_pipelines, - input_context.input_pipeline_id, - num_replicas_in_sync) + input_context.input_pipeline_id) self._cloned_datasets.append(dataset) else: replicated_ds = distribute.replicate(dataset, @@ -1002,71 +995,14 @@ class DistributedDataset(_IterableInput): with ops.device(worker): cloned_dataset = replicated_ds[worker] cloned_dataset = cloned_dataset.with_options(dataset.options()) - if rebatch_fn is not None: - cloned_dataset = rebatch_fn(cloned_dataset, i) cloned_dataset = input_ops.auto_shard_dataset( - cloned_dataset, len(input_workers.worker_devices), i, - num_replicas_in_sync) + cloned_dataset, len(input_workers.worker_devices), i) self._cloned_datasets.append(cloned_dataset) self._input_workers = input_workers self._strategy = strategy - self._element_spec = _create_distributed_tensor_spec( - self._strategy, self._cloned_datasets[0].element_spec) - - def _make_rebatch_fn(self, dataset, num_workers, num_replicas_in_sync): - """Returns a callable that rebatches the input dataset. - - Args: - dataset: A `tf.data.Dataset` representing the dataset to be distributed. - num_workers: An integer representing the number of workers to distribute - `dataset` among. - num_replicas_in_sync: An integer representing the number of replicas in - sync across all workers. - """ - if num_replicas_in_sync % num_workers: - raise ValueError( - "tf.distribute expects every worker to have the same number of " - "replicas. However, encountered `num_replicas_in_sync` ({}) that " - "cannot be divided by `num_workers` ({})".format( - num_replicas_in_sync, num_workers)) - - num_replicas_per_worker = num_replicas_in_sync // num_workers - with ops.colocate_with(dataset._variant_tensor): # pylint: disable=protected-access - batch_size = distribute.compute_batch_size(dataset) - - def rebatch_fn(dataset, worker_index): - try: - # pylint: disable=protected-access - def apply_rebatch(): - batch_sizes = distribute.batch_sizes_for_worker( - batch_size, num_workers, num_replicas_per_worker, worker_index) - return distribute._RebatchDataset( - dataset, batch_sizes).prefetch(num_replicas_per_worker) - - def apply_legacy_rebatch(): - return distribute._LegacyRebatchDataset( - dataset, num_replicas_in_sync).prefetch(num_replicas_per_worker) - - with ops.colocate_with(dataset._variant_tensor): - return control_flow_ops.cond( - math_ops.not_equal(batch_size, -1), - true_fn=apply_rebatch, - false_fn=apply_legacy_rebatch) - except errors.InvalidArgumentError as e: - if "without encountering a batch" in str(e): - six.reraise( - ValueError, - ValueError( - "Call the `batch` method on the input Dataset in order to be " - "able to split your input across {} replicas.\n Please see " - "the tf.distribute.Strategy guide. {}".format( - num_replicas_in_sync, e)), - sys.exc_info()[2]) - else: - raise - - return rebatch_fn + self._element_spec = _create_distributed_tensor_spec(self._strategy, + dataset.element_spec) # pylint: disable=protected-access def __iter__(self): if not (context.executing_eagerly() or @@ -1111,14 +1047,14 @@ class DistributedDatasetV1(DistributedDataset): dataset, input_workers, strategy, - num_replicas_in_sync=None, + split_batch_by=None, input_context=None): self._input_workers = input_workers super(DistributedDatasetV1, self).__init__( dataset, input_workers, strategy, - num_replicas_in_sync=num_replicas_in_sync, + split_batch_by=split_batch_by, input_context=input_context) def make_one_shot_iterator(self): @@ -1367,24 +1303,20 @@ class DatasetIterator(DistributedIteratorV1): dataset, input_workers, strategy, - num_replicas_in_sync=None, + split_batch_by=None, input_context=None): """Make an iterator for the dataset on given devices. - If `num_replicas_in_sync` is not None, we split each batch of the dataset - into `num_replicas_in_sync` smaller batches, to be distributed among that - worker's replicas, so that the batch size for a global step (across all - workers and replicas) is as expected. + If `split_batch_by` is not None, we "split" each batch of the + dataset by `split_batch_by` value. Args: dataset: `tf.data.Dataset` that will be used as the input source. input_workers: an `InputWorkers` object. strategy: a `tf.distribute.Strategy` object, used to run all-reduce to handle last partial batch. - num_replicas_in_sync: Optional integer. If this is not None, the value is - used to decide how to rebatch datasets into smaller batches so that the - total batch size for each step (across all workers and replicas) adds up - to `dataset`'s batch size. + split_batch_by: Optional integer. If present, we "split" each batch of the + dataset by `split_batch_by` value. input_context: `InputContext` for sharding. Only pass this in for between graph multi-worker cases where there is only one `input_worker`. In these cases, we will shard based on the `input_pipeline_id` and @@ -1394,7 +1326,7 @@ class DatasetIterator(DistributedIteratorV1): dataset, input_workers, strategy, - num_replicas_in_sync=num_replicas_in_sync, + split_batch_by=split_batch_by, input_context=input_context) worker_iterators = _create_iterators_per_worker( dist_dataset._cloned_datasets, input_workers, True) # pylint: disable=protected-access diff --git a/tensorflow/python/distribute/input_lib_test.py b/tensorflow/python/distribute/input_lib_test.py index ea18ff77ab8..ea411542a51 100644 --- a/tensorflow/python/distribute/input_lib_test.py +++ b/tensorflow/python/distribute/input_lib_test.py @@ -38,7 +38,6 @@ from tensorflow.python.distribute import distribute_utils from tensorflow.python.distribute import input_lib from tensorflow.python.distribute import mirrored_strategy from tensorflow.python.distribute import multi_worker_test_base -from tensorflow.python.distribute import multi_worker_util from tensorflow.python.distribute import parameter_server_strategy from tensorflow.python.distribute import reduce_util from tensorflow.python.distribute import strategy_combinations @@ -67,7 +66,7 @@ class DistributedIteratorTestBase(test.TestCase): dataset_or_input_fn, input_workers, devices, - num_replicas_in_sync, + split_batch_by, strategy, input_context=None): # The `input_context` passed in is to shard dataset for @@ -99,7 +98,7 @@ class DistributedIteratorTestBase(test.TestCase): dataset_or_input_fn, input_workers, strategy, - num_replicas_in_sync=num_replicas_in_sync, + split_batch_by=split_batch_by, input_context=input_context) return iterator @@ -107,7 +106,7 @@ class DistributedIteratorTestBase(test.TestCase): input_type, dataset, input_workers, - num_replicas_in_sync, + split_batch_by, strategy, input_context=None): if input_type == "dataset": @@ -116,14 +115,14 @@ class DistributedIteratorTestBase(test.TestCase): dataset, input_workers, strategy, - num_replicas_in_sync=num_replicas_in_sync, + split_batch_by=split_batch_by, input_context=input_context) else: return input_lib.DistributedDatasetV1( dataset, input_workers, strategy, - num_replicas_in_sync=num_replicas_in_sync, + split_batch_by=split_batch_by, input_context=input_context) else: return strategy.experimental_distribute_datasets_from_function(dataset) @@ -137,7 +136,7 @@ class DistributedIteratorTestBase(test.TestCase): expected_values, strategy, sess=None, - num_replicas_in_sync=None, + split_batch_by=None, input_context=None): if iteration_type == "for_loop" and not context.executing_eagerly(): self.skipTest("unsupported test combination.") @@ -157,7 +156,7 @@ class DistributedIteratorTestBase(test.TestCase): dataset_or_input_fn, input_workers, devices, - num_replicas_in_sync, + split_batch_by, strategy, input_context=input_context) else: @@ -166,7 +165,7 @@ class DistributedIteratorTestBase(test.TestCase): input_type, dataset_or_input_fn, input_workers, - num_replicas_in_sync, + split_batch_by, strategy, input_context=input_context) @@ -268,9 +267,7 @@ class DistributedIteratorTestBase(test.TestCase): for i, expected_value in enumerate(expected_values): self.assertEqual(len(expected_value), len(actual_values[i])) for j in range(len(expected_value)): - self.assertAllEqual( - expected_value[j], actual_values[i][j], - "%s vs %s" % (expected_value[j], actual_values[i][j])) + self.assertAllEqual(expected_value[j], actual_values[i][j]) def _create_dataset_or_input_fn(self, input_type, input_fn): if input_type == "input_fn": @@ -364,7 +361,10 @@ class DistributedIteratorSingleWorkerTest(DistributedIteratorTestBase, def testOneDeviceCPU(self, input_type, api_type, iteration_type, distribution, enable_get_next_as_optional): worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])] - dataset_fn = lambda _: dataset_ops.Dataset.range(10) + if tf2.enabled(): + dataset_fn = lambda _: dataset_ops.DatasetV2.range(10) + else: + dataset_fn = lambda _: dataset_ops.DatasetV1.range(10) dataset_or_input_fn = self._create_dataset_or_input_fn( input_type, dataset_fn) @@ -396,7 +396,10 @@ class DistributedIteratorSingleWorkerTest(DistributedIteratorTestBase, distribution, enable_get_next_as_optional): worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0", "/device:CPU:0"])] - dataset_fn = lambda _: dataset_ops.Dataset.range(10) + if tf2.enabled(): + dataset_fn = lambda _: dataset_ops.DatasetV2.range(10) + else: + dataset_fn = lambda _: dataset_ops.Dataset.range(10) dataset_or_input_fn = self._create_dataset_or_input_fn( input_type, dataset_fn) @@ -429,7 +432,10 @@ class DistributedIteratorSingleWorkerTest(DistributedIteratorTestBase, worker_device_pairs.setdefault(host_device, []) worker_device_pairs[host_device].append(tpu_device) worker_device_pairs = worker_device_pairs.items() - dataset_fn = lambda _: dataset_ops.Dataset.range(10) + if tf2.enabled(): + dataset_fn = lambda _: dataset_ops.DatasetV2.range(10) + else: + dataset_fn = lambda _: dataset_ops.Dataset.range(10) dataset_or_input_fn = self._create_dataset_or_input_fn( input_type, dataset_fn) @@ -464,10 +470,14 @@ class DistributedIteratorSingleWorkerTest(DistributedIteratorTestBase, def dataset_fn(ctx): del ctx - dataset1 = dataset_ops.Dataset.range(10) - dataset2 = dataset_ops.Dataset.range(10).map(lambda x: x**2) - return dataset_ops.Dataset.zip((dataset1, dataset2)) - + if tf2.enabled(): + dataset1 = dataset_ops.DatasetV2.range(10) + dataset2 = dataset_ops.DatasetV2.range(10).map(lambda x: x**2) + return dataset_ops.DatasetV2.zip((dataset1, dataset2)) + else: + dataset1 = dataset_ops.Dataset.range(10) + dataset2 = dataset_ops.Dataset.range(10).map(lambda x: x**2) + return dataset_ops.Dataset.zip((dataset1, dataset2)) dataset_or_input_fn = self._create_dataset_or_input_fn( input_type, dataset_fn) @@ -495,7 +505,7 @@ class DistributedIteratorSingleWorkerTest(DistributedIteratorTestBase, worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])] input_workers = input_lib.InputWorkers(worker_device_pairs) - dataset = dataset_ops.Dataset.range(10) + dataset = dataset_ops.DatasetV2.range(10) dist_dataset = input_lib.get_distributed_dataset(dataset, input_workers, distribution) @@ -518,12 +528,12 @@ class DistributedIteratorSingleWorkerTest(DistributedIteratorTestBase, drop_remainder, distribution): worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0", "/device:CPU:0"])] - - def dataset_fn(ctx): - del ctx - return dataset_ops.Dataset.range(9).batch( + if tf2.enabled(): + dataset_fn = lambda _: dataset_ops.DatasetV2.range(9).batch( # pylint: disable=g-long-lambda + 2, drop_remainder=drop_remainder) + else: + dataset_fn = lambda _: dataset_ops.Dataset.range(9).batch( # pylint: disable=g-long-lambda 2, drop_remainder=drop_remainder) - dataset_or_input_fn = self._create_dataset_or_input_fn( input_type, dataset_fn) @@ -548,25 +558,27 @@ class DistributedIteratorSingleWorkerTest(DistributedIteratorTestBase, input_type=["dataset"], api_type=["wrap_into_iterator", "wrap_into_dataset"], iteration_type=["get_next", "for_loop"], - num_replicas_in_sync=[None, 2], + split_batch_by=[None, 2], distribution=[ strategy_combinations.mirrored_strategy_with_gpu_and_cpu, strategy_combinations.central_storage_strategy_with_gpu_and_cpu ], enable_get_next_as_optional=[True, False])) def testBatchSplitting(self, input_type, api_type, iteration_type, - num_replicas_in_sync, distribution, + split_batch_by, distribution, enable_get_next_as_optional): worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0", "/device:CPU:0"])] batch_size = 10 - dataset_fn = lambda _: dataset_ops.Dataset.range(100).batch(batch_size) + if tf2.enabled(): + dataset_fn = lambda _: dataset_ops.DatasetV2.range(100).batch(batch_size) + else: + dataset_fn = lambda _: dataset_ops.Dataset.range(100).batch(batch_size) dataset_or_input_fn = self._create_dataset_or_input_fn( input_type, dataset_fn) updated_batch_size = ( - batch_size // - num_replicas_in_sync if num_replicas_in_sync else batch_size) + batch_size // split_batch_by if split_batch_by else batch_size) expected_values = [[range(i, i+updated_batch_size), range(i+updated_batch_size, i+2*updated_batch_size)] for i in range(0, 100, updated_batch_size*2)] @@ -582,7 +594,7 @@ class DistributedIteratorSingleWorkerTest(DistributedIteratorTestBase, expected_values, distribution, sess=None, - num_replicas_in_sync=num_replicas_in_sync) + split_batch_by=split_batch_by) @combinations.generate( combinations.combine( @@ -969,15 +981,18 @@ class DistributedIteratorMultiWorkerTest( auto_shard_policy): ds_option = dataset_ops.Options() ds_option.experimental_distribute.auto_shard_policy = auto_shard_policy - dataset_fn = ( - lambda _: dataset_ops.Dataset.range(4).with_options(ds_option)) + if tf2.enabled(): + dataset_fn = ( + lambda _: dataset_ops.DatasetV2.range(4).with_options(ds_option)) + else: + dataset_fn = ( + lambda _: dataset_ops.Dataset.range(4).with_options(ds_option)) dataset_or_input_fn = self._create_dataset_or_input_fn( input_type, dataset_fn) strategy = mirrored_strategy.MirroredStrategy( devices=(self._cpu_devices()[0][1] + self._cpu_devices()[1][1]), - cross_device_ops=cross_device_ops_lib.MultiWorkerAllReduce( - ["/job:worker/task:0", "/job:worker/task:1"], 1)) + cross_device_ops=cross_device_ops_lib.ReductionToOneDevice()) worker_devices = self._cpu_devices() with context.graph_mode(), self.cached_session() as sess: if auto_shard_policy == AutoShardPolicy.AUTO: @@ -997,14 +1012,16 @@ class DistributedIteratorMultiWorkerTest( enable_get_next_as_optional=[True, False])) def testOneDevicePerWorker(self, input_type, api_type, iteration_type, enable_get_next_as_optional): - dataset_fn = lambda _: dataset_ops.Dataset.range(4) + if tf2.enabled(): + dataset_fn = lambda _: dataset_ops.DatasetV2.range(4) + else: + dataset_fn = lambda _: dataset_ops.Dataset.range(4) dataset_or_input_fn = self._create_dataset_or_input_fn( input_type, dataset_fn) strategy = mirrored_strategy.MirroredStrategy( devices=(self._cpu_devices()[0][1] + self._cpu_devices()[1][1]), - cross_device_ops=cross_device_ops_lib.MultiWorkerAllReduce( - ["/job:worker/task:0", "/job:worker/task:1"], 1)) + cross_device_ops=cross_device_ops_lib.ReductionToOneDevice()) worker_devices = self._cpu_devices() with context.graph_mode(), strategy.scope(), self.cached_session() as sess: @@ -1035,15 +1052,17 @@ class DistributedIteratorMultiWorkerTest( required_gpus=1)) def testTwoDevicesPerWorker(self, input_type, api_type, iteration_type, enable_get_next_as_optional): - dataset_fn = lambda _: dataset_ops.Dataset.range(4) + if tf2.enabled(): + dataset_fn = lambda _: dataset_ops.DatasetV2.range(4) + else: + dataset_fn = lambda _: dataset_ops.Dataset.range(4) dataset_or_input_fn = self._create_dataset_or_input_fn( input_type, dataset_fn) strategy = mirrored_strategy.MirroredStrategy( devices=(self._cpu_and_one_gpu_devices()[0][1] + self._cpu_and_one_gpu_devices()[1][1]), - cross_device_ops=cross_device_ops_lib.MultiWorkerAllReduce( - ["/job:worker/task:0", "/job:worker/task:1"], 2)) + cross_device_ops=cross_device_ops_lib.ReductionToOneDevice()) worker_devices = self._cpu_and_one_gpu_devices() with context.graph_mode(), strategy.scope(), self.cached_session() as sess: @@ -1075,16 +1094,19 @@ class DistributedIteratorMultiWorkerTest( enable_get_next_as_optional): strategy = mirrored_strategy.MirroredStrategy( devices=(self._cpu_devices()[0][1] + self._cpu_devices()[1][1]), - cross_device_ops=cross_device_ops_lib.MultiWorkerAllReduce( - ["/job:worker/task:0", "/job:worker/task:1"], 1)) + cross_device_ops=cross_device_ops_lib.ReductionToOneDevice()) worker_devices = self._cpu_devices() def dataset_fn(ctx): del ctx - dataset1 = dataset_ops.Dataset.range(4) - dataset2 = dataset_ops.Dataset.range(4).map(lambda x: x**2) - return dataset_ops.Dataset.zip((dataset1, dataset2)) - + if tf2.enabled(): + dataset1 = dataset_ops.DatasetV2.range(4) + dataset2 = dataset_ops.DatasetV2.range(4).map(lambda x: x**2) + return dataset_ops.DatasetV2.zip((dataset1, dataset2)) + else: + dataset1 = dataset_ops.Dataset.range(4) + dataset2 = dataset_ops.Dataset.range(4).map(lambda x: x**2) + return dataset_ops.Dataset.zip((dataset1, dataset2)) dataset_or_input_fn = self._create_dataset_or_input_fn( input_type, dataset_fn) @@ -1118,9 +1140,11 @@ class DistributedIteratorMultiWorkerTest( strategy = mirrored_strategy.MirroredStrategy( devices=(self._cpu_and_one_gpu_devices()[0][1] + self._cpu_and_one_gpu_devices()[1][1]), - cross_device_ops=cross_device_ops_lib.MultiWorkerAllReduce( - ["/job:worker/task:0", "/job:worker/task:1"], 2)) - dataset_fn = lambda _: dataset_ops.Dataset.range(9).batch(2) + cross_device_ops=cross_device_ops_lib.ReductionToOneDevice()) + if tf2.enabled(): + dataset_fn = lambda _: dataset_ops.DatasetV2.range(9).batch(2) + else: + dataset_fn = lambda _: dataset_ops.Dataset.range(9).batch(2) dataset_or_input_fn = self._create_dataset_or_input_fn( input_type, dataset_fn) @@ -1179,7 +1203,10 @@ class DistributedIteratorMultiWorkerTest( strategy = strategy_cls() with context.graph_mode(), strategy.scope(), self.cached_session( target="grpc://" + self._cluster_spec[task_type][task_id]) as sess: - dataset_fn = lambda _: dataset_ops.Dataset.range(5).batch(2) + if tf2.enabled(): + dataset_fn = lambda _: dataset_ops.DatasetV2.range(5).batch(2) + else: + dataset_fn = lambda _: dataset_ops.Dataset.range(5).batch(2) dataset_or_input_fn = self._create_dataset_or_input_fn( input_type, dataset_fn) if (input_type == "dataset" and strategy_cls is @@ -1231,8 +1258,7 @@ class DistributedIteratorMultiWorkerTest( strategy = mirrored_strategy.MirroredStrategy( devices=(self._cpu_and_one_gpu_devices()[0][1] + self._cpu_and_one_gpu_devices()[1][1]), - cross_device_ops=cross_device_ops_lib.MultiWorkerAllReduce( - ["/job:worker/task:0", "/job:worker/task:1"], 2)) + cross_device_ops=cross_device_ops_lib.ReductionToOneDevice()) worker_devices = self._cpu_and_one_gpu_devices() with context.graph_mode(), strategy.scope(), self.cached_session() as sess: @@ -1249,195 +1275,5 @@ class DistributedIteratorMultiWorkerTest( strategy, sess=sess) - -# TODO(yuefengz): Refactor this into TF2 multi worker tests when those changes -# have landed. -class MultiWorkerRebatchingBehaviorTest(DistributedIteratorTestBase, - parameterized.TestCase): - - @combinations.generate( - combinations.combine( - mode=["eager"], - input_type=["dataset"], - api_type=["wrap_into_iterator", "wrap_into_dataset"], - iteration_type=["get_next", "for_loop"], - distribution=[ - strategy_combinations.multi_worker_mirrored_2x1_cpu, - strategy_combinations.multi_worker_mirrored_2x1_gpu, - ])) - def testPartialBatchWithFileSharding(self, input_type, api_type, - iteration_type, distribution): - # Test case: 2 workers, 1 replica each. - # This test simulates the sharded behavior when we have two files each with - # 12 elements and a global batch size of 8. When we consider the dataset in - # aggregate (non-distributed), there are 24 elements divided into 3 batches - # of size 8. Hence, the correct distributed behavior is for each replica to - # see sub-batches of size 4, over three steps. - def dataset_fn(ctx): - del ctx - dataset = dataset_ops.Dataset.range(12).batch(8) - - # Set the sharding behavior to OFF for simplicity of test setup; namely, - # `dataset` defines the per-worker dataset and will not be further - # sharded. Each worker will see a dataset that is - # tf.data.Dataset.range(12).batch(8).rebatch(...). - options = dataset_ops.Options() - options.experimental_distribute.auto_shard_policy = AutoShardPolicy.OFF - dataset = dataset.with_options(options) - return dataset - - dataset = self._create_dataset_or_input_fn(input_type, dataset_fn) - - # Actual devices don't matter in this test as long as there is 1 local - # replica. - worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])] - - # Each test runs individually on each worker, so we compare the - # values on each worker. Each worker should rebatch its dataset into - # smaller batches of size 4. - expected_values = [[[0, 1, 2, 3]], [[4, 5, 6, 7]], [[8, 9, 10, 11]]] - self._test_input_iteration( - input_type, - api_type, - iteration_type, - dataset, - worker_device_pairs, - expected_values, - distribution, - num_replicas_in_sync=distribution.num_replicas_in_sync, - input_context=distribution.extended._make_input_context()) - - @combinations.generate( - combinations.combine( - mode=["eager"], - input_type=["dataset"], - api_type=["wrap_into_iterator", "wrap_into_dataset"], - iteration_type=["get_next", "for_loop"], - distribution=[ - strategy_combinations.multi_worker_mirrored_2x1_cpu, - strategy_combinations.multi_worker_mirrored_2x1_gpu, - ])) - def testPartialBatchWithFileShardingWithLegacyRebatch(self, input_type, - api_type, - iteration_type, - distribution): - # Test case: 2 workers, 1 replica each. - # This test simulates the sharded behavior when we have two files each with - # 12 elements and a global batch size of 8. When we consider the dataset in - # aggregate (non-distributed), there are 24 elements divided into 3 batches - # of size 8. Hence, the correct distributed behavior is for each replica to - # see sub-batches of size 4, over three steps. However, when we create a - # DistributedDataset and cannot statically infer the intended global batch - # size (e.g. if the user does not use a batching dataset), each worker will - # rebatch based on the dynamic batch size of the data encountered, even when - # it encounters partial batches. The last per-worker partial batch (size 4) - # ends up being split into two replicas, resulting in 4 steps in total, of - # (global) batch sizes 8, 8, 4, 4. - def dataset_fn(ctx): - del ctx - # The following dataset is equivalent to - # tf.data.Dataset.range(12).batch(8), but does not use a batching dataset. - # This causes DistributedDataset to use LegacyRebatch instead. - batch_sizes = dataset_ops.Dataset.from_tensor_slices([8, 4]) - offsets = dataset_ops.Dataset.from_tensor_slices([0, 8]) - dataset = dataset_ops.Dataset.zip((offsets, batch_sizes)) - - def map_fn(offset, batch_size): - return math_ops.range(offset, offset + batch_size) - - dataset = dataset.map(map_fn) - - # Set the sharding behavior to OFF for simplicity of test setup; namely, - # `dataset` defines the per-worker dataset and will not be further - # sharded. Each worker will see a dataset that is equivalent to - # tf.data.Dataset.range(12).batch(8).rebatch(...). - options = dataset_ops.Options() - options.experimental_distribute.auto_shard_policy = AutoShardPolicy.OFF - dataset = dataset.with_options(options) - return dataset - - dataset = self._create_dataset_or_input_fn(input_type, dataset_fn) - - # Actual devices don't matter in this test as long as the number of global - # replicas is 2. - worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])] - - # Each test runs individually on each worker, so we compare the - # values on each worker. Each worker should rebatch its dataset into - # smaller batches of size 4. - expected_values = [[[0, 1, 2, 3]], [[4, 5, 6, 7]], [[8, 9]], [[10, 11]]] - self._test_input_iteration( - input_type, - api_type, - iteration_type, - dataset, - worker_device_pairs, - expected_values, - distribution, - num_replicas_in_sync=distribution.num_replicas_in_sync, - input_context=distribution.extended._make_input_context()) - - @combinations.generate( - combinations.combine( - mode=["eager"], - input_type=["dataset"], - api_type=["wrap_into_iterator", "wrap_into_dataset"], - iteration_type=["get_next", "for_loop"], - distribution=[ - strategy_combinations.multi_worker_mirrored_2x1_cpu, - strategy_combinations.multi_worker_mirrored_2x1_gpu, - ], - auto_shard_policy=[AutoShardPolicy.AUTO, AutoShardPolicy.DATA])) - def testWithDataSharding(self, input_type, api_type, iteration_type, - distribution, auto_shard_policy): - # Test case: 2 workers, 1 replica each. - # This test simulates the sharded behavior the dataset is sharded by data - # and the batch size is indivisible by the number of replicas. This checks - # that the elements are as expected and the batch size across all workers - # adds up to 3. This test will only pass if the autoshard rewrite rewrites - # RebatchDatasetV2 to legacy RebatchDataset when sharding by data. - def dataset_fn(ctx): - del ctx - dataset = dataset_ops.Dataset.range(8).batch(3) - - # Set the sharding behavior to OFF for simplicity of test setup; namely, - # `dataset` defines the per-worker dataset and will not be further - # sharded. Each worker will see a dataset that is - # tf.data.Dataset.range(12).batch(8).rebatch(...). - options = dataset_ops.Options() - options.experimental_distribute.auto_shard_policy = auto_shard_policy - dataset = dataset.with_options(options) - return dataset - - dataset = self._create_dataset_or_input_fn(input_type, dataset_fn) - - # Actual devices don't matter in this test as long as there is 1 local - # replica. - worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])] - - # Each test runs individually on each worker, so we compare the - # values on each worker. We expect each worker to see different shards of - # data. - cr = distribution.cluster_resolver - worker_id = multi_worker_util.id_in_cluster(cr.cluster_spec(), cr.task_type, - cr.task_id) - - if worker_id == 0: - expected_values = [[[0, 1]], [[3, 4]], [[6]]] - elif worker_id == 1: - expected_values = [[[2]], [[5]], [[7]]] - - self._test_input_iteration( - input_type, - api_type, - iteration_type, - dataset, - worker_device_pairs, - expected_values, - distribution, - num_replicas_in_sync=distribution.num_replicas_in_sync, - input_context=distribution.extended._make_input_context()) - - if __name__ == "__main__": - combinations.main() + test.main() diff --git a/tensorflow/python/distribute/input_ops.py b/tensorflow/python/distribute/input_ops.py index de828f4bcd9..37a7ed642d0 100644 --- a/tensorflow/python/distribute/input_ops.py +++ b/tensorflow/python/distribute/input_ops.py @@ -27,7 +27,7 @@ from tensorflow.python.framework import ops # pylint: disable=protected-access -def auto_shard_dataset(dataset, num_shards, index, num_replicas_in_sync=None): +def auto_shard_dataset(dataset, num_shards, index): """Shard the input pipeline by sharding the underlying list of files. Args: @@ -37,8 +37,6 @@ def auto_shard_dataset(dataset, num_shards, index, num_replicas_in_sync=None): shards operating in parallel. Same usage as in `tf.data.Dataset.shard`. index: A `tf.int64` scalar `tf.Tensor`, representing the worker index. Same usage as in `tf.data.Dataset.shard`. - num_replicas_in_sync: An integer representing the total number of replicas - across all workers. This is used in the rewrite when sharding by data. Returns: A modified `Dataset` obtained by updating the pipeline sharded by the @@ -47,14 +45,10 @@ def auto_shard_dataset(dataset, num_shards, index, num_replicas_in_sync=None): """ if (dataset.options().experimental_distribute.auto_shard_policy != AutoShardPolicy.OFF): - if num_replicas_in_sync is None: - num_replicas_in_sync = 1 if isinstance(dataset, dataset_ops.DatasetV1): - return distribute._AutoShardDatasetV1(dataset, num_shards, index, - num_replicas_in_sync) + return distribute._AutoShardDatasetV1(dataset, num_shards, index) else: - return distribute._AutoShardDataset(dataset, num_shards, index, - num_replicas_in_sync) + return distribute._AutoShardDataset(dataset, num_shards, index) else: return dataset diff --git a/tensorflow/python/distribute/mirrored_strategy.py b/tensorflow/python/distribute/mirrored_strategy.py index ad13c14f218..79a563680ea 100644 --- a/tensorflow/python/distribute/mirrored_strategy.py +++ b/tensorflow/python/distribute/mirrored_strategy.py @@ -86,7 +86,7 @@ def _cluster_spec_to_device_list(cluster_spec, num_gpus_per_worker): for task_type in ("chief", "worker"): for task_id in range(len(cluster_spec.as_dict().get(task_type, []))): if num_gpus_per_worker == 0: - devices.append("/job:%s/task:%d" % (task_type, task_id)) + devices.append("/job:%s/task:%d/device:CPU:0" % (task_type, task_id)) else: devices.extend([ "/job:%s/task:%d/device:GPU:%i" % (task_type, task_id, gpu_id) @@ -378,8 +378,10 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1): self._is_multi_worker_training = True if len(workers) > 1: - if not isinstance(self._cross_device_ops, - cross_device_ops_lib.MultiWorkerAllReduce): + # Grandfather usage in the legacy tests if they're configured properly. + if (not isinstance(self._cross_device_ops, + cross_device_ops_lib.ReductionToOneDevice) or + self._cross_device_ops._num_between_graph_workers > 1): # pylint: disable=protected-access raise ValueError( "In-graph multi-worker training with `MirroredStrategy` is not " "supported.") @@ -477,7 +479,7 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1): dataset, self._input_workers, self._container_strategy(), - num_replicas_in_sync=self._num_replicas_in_sync) + split_batch_by=self._num_replicas_in_sync) def _make_input_fn_iterator( self, @@ -499,7 +501,7 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1): dataset, self._input_workers_with_options(options), self._container_strategy(), - num_replicas_in_sync=self._num_replicas_in_sync) + split_batch_by=self._num_replicas_in_sync) def _experimental_make_numpy_dataset(self, numpy_input, session): return numpy_dataset.one_host_numpy_dataset( diff --git a/tensorflow/python/distribute/mirrored_strategy_test.py b/tensorflow/python/distribute/mirrored_strategy_test.py index 5c86cbea1a4..acdfdbb3788 100644 --- a/tensorflow/python/distribute/mirrored_strategy_test.py +++ b/tensorflow/python/distribute/mirrored_strategy_test.py @@ -1148,9 +1148,9 @@ class MirroredStrategyDefunTest(test.TestCase): # pylint: disable=g-long-lambda lambda: mirrored_strategy.MirroredStrategy( devices=mirrored_strategy.all_local_devices(), - cross_device_ops=cross_device_ops_lib.MultiWorkerAllReduce([ - "/job:worker/task:0", "/job:worker/task:1" - ], context.num_gpus())), + cross_device_ops=cross_device_ops_lib.ReductionToOneDevice( + ), + ), required_gpus=1) ], mode=["graph"])) @@ -1288,9 +1288,7 @@ class MultiWorkerMirroredStrategyTestWithChief( cls._default_target = "grpc://" + cls._cluster_spec["chief"][0] def _make_cross_device_ops(self): - return cross_device_ops_lib.MultiWorkerAllReduce( - ["/job:chief/task:0", "/job:worker/task:0", "/job:worker/task:1"], - context.num_gpus()) + return cross_device_ops_lib.ReductionToOneDevice() def testMinimizeLossGraph(self): with context.graph_mode(): diff --git a/tensorflow/python/distribute/parameter_server_strategy.py b/tensorflow/python/distribute/parameter_server_strategy.py index 16f4856ba16..1d4c593d48b 100644 --- a/tensorflow/python/distribute/parameter_server_strategy.py +++ b/tensorflow/python/distribute/parameter_server_strategy.py @@ -353,14 +353,14 @@ class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1): dataset, self._input_workers_with_options(options), self._container_strategy(), - num_replicas_in_sync=self._num_replicas_in_sync) + split_batch_by=self._num_replicas_in_sync) def _make_dataset_iterator(self, dataset): return input_lib.DatasetIterator( dataset, self._input_workers, self._container_strategy(), - num_replicas_in_sync=self._num_replicas_in_sync) + split_batch_by=self._num_replicas_in_sync) def _make_input_fn_iterator( self, diff --git a/tensorflow/python/distribute/tpu_strategy.py b/tensorflow/python/distribute/tpu_strategy.py index d30bca0cebc..22aeb37ff7c 100644 --- a/tensorflow/python/distribute/tpu_strategy.py +++ b/tensorflow/python/distribute/tpu_strategy.py @@ -558,7 +558,7 @@ class TPUExtended(distribute_lib.StrategyExtendedV1): dataset, input_workers, self._container_strategy(), - num_replicas_in_sync=self._num_replicas_in_sync) + split_batch_by=self._num_replicas_in_sync) def _make_input_fn_iterator( self, @@ -615,7 +615,7 @@ class TPUExtended(distribute_lib.StrategyExtendedV1): dataset, self._get_input_workers(options), self._container_strategy(), - num_replicas_in_sync=self._num_replicas_in_sync) + split_batch_by=self._num_replicas_in_sync) def _experimental_distribute_datasets_from_function(self, dataset_fn, options): diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py index a19d578614e..567a34d2c23 100644 --- a/tensorflow/python/eager/context.py +++ b/tensorflow/python/eager/context.py @@ -769,6 +769,26 @@ class Context(object): self.ensure_initialized() pywrap_tfe.TFE_AbortCollectiveOps(self._handle, code, message) + def check_collective_ops_peer_health(self, task): + """Check collective peer health. + + This probes each task to see if they're still alive. Note that restarted + tasks are considered a different one, and they're considered not healthy. + + This should only be used in multi client multi worker training. + + Args: + task: a task string, must be in the format of /job:xxx/replica:0/task:N. + + Raises: + tf.errors.UnavailableError: when a peer is down. + tf.errors.FailedPreconditionError: when a peer is a different one from the + one this task has talked to, e.g. the peer has restarted. + tf.errors.InvalidArgumentError: when the task string is invalid. + """ + self.ensure_initialized() + pywrap_tfe.TFE_CollectiveOpsCheckPeerHealth(self._handle, task) + @property def _handle(self): if self._context_handle is None: diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index 45b408cd3e6..4a5b0288857 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -3073,7 +3073,11 @@ class Function(object): # Return the cached `Function` for the instance return self._descriptor_cache[instance] - def _cache_key(self, args, kwargs, include_tensor_ranks_only=False): + def _cache_key(self, + args, + kwargs, + cache_key_context, + include_tensor_ranks_only=False): """Computes the cache key given inputs and execution context.""" if self.input_signature is None: inputs = (args, kwargs) if kwargs else args @@ -3085,6 +3089,15 @@ class Function(object): assert not include_tensor_ranks_only hashable_input_signature = self._hashable_input_signature + (parent_graph, device_functions, colocation_stack, in_cross_replica_context, + variable_policy, xla_context_id) = cache_key_context + + return CacheKey(hashable_input_signature, parent_graph, device_functions, + colocation_stack, in_cross_replica_context, variable_policy, + xla_context_id) + + def _cache_key_context(self): + """Returns execution context.""" ctx = context.context() # Don't need to open an init_scope if the _cache_key call is in eager mode @@ -3153,9 +3166,8 @@ class Function(object): else: variable_policy = save_options.VariablePolicy.EXPAND_DISTRIBUTED_VARIABLES - return CacheKey(hashable_input_signature, parent_graph, device_functions, - colocation_stack, in_cross_replica_context, variable_policy, - xla_context_id) + return (parent_graph, device_functions, colocation_stack, + in_cross_replica_context, variable_policy, xla_context_id) def _create_graph_function(self, args, kwargs, override_flat_arg_shapes=None): """Create a `ConcreteFunction` from `args` and `kwargs`.""" @@ -3196,7 +3208,8 @@ class Function(object): return graph_function def _define_function_with_shape_relaxation(self, args, kwargs, flat_args, - filtered_flat_args): + filtered_flat_args, + cache_key_context): """Define a function, relaxing arg shapes to avoid unnecessary retracing.""" flat_no_comp = nest.flatten((args, kwargs), expand_composites=False) @@ -3207,14 +3220,17 @@ class Function(object): # not information about the size of each dimension). if not any_composite_args: rank_only_cache_key = self._cache_key( - args, kwargs, include_tensor_ranks_only=True) + args, kwargs, cache_key_context, include_tensor_ranks_only=True) else: # For the rank-only cache key, replace any composite tensors with # shape-relaxed TypeSpecs. (cache_key_args, cache_key_kwargs) = nest.map_structure( _shape_relaxed_type_for_composite_tensor, (args, kwargs)) rank_only_cache_key = self._cache_key( - cache_key_args, cache_key_kwargs, include_tensor_ranks_only=True) + cache_key_args, + cache_key_kwargs, + cache_key_context, + include_tensor_ranks_only=True) arg_specs = [_type_spec_for(x) for x in flat_no_comp] relaxed_arg_specs = self._function_cache.arg_relaxed_specs.get( @@ -3293,7 +3309,8 @@ class Function(object): else: flat_args, filtered_flat_args = [None], [] - cache_key = self._cache_key(args, kwargs) + cache_key_context = self._cache_key_context() + cache_key = self._cache_key(args, kwargs, cache_key_context) try: hash(cache_key) @@ -3307,37 +3324,41 @@ class Function(object): return graph_function, filtered_flat_args with monitoring.MonitoredTimer(_graph_building_time_counter.get_cell()): - logging.vlog(1, "Creating new FuncGraph for Python function %r (key: %r)", - self._python_function, cache_key) - logging.vlog(2, "Python function signature [args: %s] [kwargs: %s]", args, - kwargs) + with trace.Trace("tf.function-graph_building"): + logging.vlog(1, + "Creating new FuncGraph for Python function %r (key: %r)", + self._python_function, cache_key) + logging.vlog(2, "Python function signature [args: %s] [kwargs: %s]", + args, kwargs) - # pylint: disable=protected-access - call_context_key = cache_key._replace(input_signature=None) - # pylint: disable=protected-access + # pylint: disable=protected-access + call_context_key = cache_key._replace(input_signature=None) + # pylint: disable=protected-access - ag_status = ( - ag_ctx.Status.ENABLED if self._autograph else ag_ctx.Status.DISABLED) - with ag_ctx.ControlStatusCtx( - status=ag_status, options=self._autograph_options): + ag_status = ( + ag_ctx.Status.ENABLED + if self._autograph else ag_ctx.Status.DISABLED) + with ag_ctx.ControlStatusCtx( + status=ag_status, options=self._autograph_options): - # Build a function with shape relaxation retracing if: - # 1. shape relaxation is explicitly enabled - # and 2. there's no provided input signature - # and 3. there's been a cache miss for this calling context - if (self._experimental_relax_shapes and self.input_signature is None and - call_context_key in self._function_cache.missed): - return self._define_function_with_shape_relaxation( - args, kwargs, flat_args, filtered_flat_args) + # Build a function with shape relaxation retracing if: + # 1. shape relaxation is explicitly enabled + # and 2. there's no provided input signature + # and 3. there's been a cache miss for this calling context + if (self._experimental_relax_shapes and + self.input_signature is None and + call_context_key in self._function_cache.missed): + return self._define_function_with_shape_relaxation( + args, kwargs, flat_args, filtered_flat_args, cache_key_context) - self._function_cache.missed.add(call_context_key) - graph_function = self._create_graph_function(args, kwargs) - self._function_cache.primary[cache_key] = graph_function + self._function_cache.missed.add(call_context_key) + graph_function = self._create_graph_function(args, kwargs) + self._function_cache.primary[cache_key] = graph_function - if ops.get_default_graph()._distribution_strategy_stack: - self._traced_with_distribution_strategy = True + if ops.get_default_graph()._distribution_strategy_stack: + self._traced_with_distribution_strategy = True - return graph_function, filtered_flat_args + return graph_function, filtered_flat_args def register(func, *args, **kwargs): diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py index 65b23401431..57e5e173a8a 100644 --- a/tensorflow/python/eager/function_test.py +++ b/tensorflow/python/eager/function_test.py @@ -56,6 +56,7 @@ from tensorflow.python.framework import test_ops from tensorflow.python.framework import test_util from tensorflow.python.framework import type_spec from tensorflow.python.layers import convolutional +from tensorflow.python.module import module from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import clip_ops @@ -4220,6 +4221,25 @@ class FunctionTest(test.TestCase, parameterized.TestCase): enabled(1, 2, 3, 4, kwonly=5, kwarg1=600, kwarg2=700) # No retrace self.assertEqual(trace_count[0], 4) + def testWithModuleNameScope(self): + self.skipTest('b/166158748:function does not handle this case correctly.') + + class Foo(module.Module): + + def __init__(self): + super().__init__() + self.var = None + + @def_function.function + @module.Module.with_name_scope + def bar(self, x, y): + if self.var is None: + return x + + foo = Foo() + with self.assertRaisesRegex(TypeError, 'got two values for argument'): + foo.bar(2, x=3) # pylint: disable=redundant-keyword-arg + class MultiDeviceTest(test.TestCase, parameterized.TestCase): diff --git a/tensorflow/python/framework/config.py b/tensorflow/python/framework/config.py index 0962b9a8a70..6593c0596c1 100644 --- a/tensorflow/python/framework/config.py +++ b/tensorflow/python/framework/config.py @@ -24,34 +24,73 @@ from tensorflow.python.util import deprecation from tensorflow.python.util.tf_export import tf_export -# No tf_export until TF is built against CUDA11 which is required for TF32. -def tensor_float_32_execution_allowed(): - """Get if TensorFloat-32 operations are enabled on supported hardware. +@tf_export('config.experimental.tensor_float_32_execution_enabled') +def tensor_float_32_execution_enabled(): + """Returns whether TensorFloat-32 is enabled. + + By default, TensorFloat-32 is enabled, but this can be changed with + `tf.config.experimental.enable_tensor_float_32_execution`. Returns: - True if TensorFloat-32 execution is enabled and False otherwise. + True if TensorFloat-32 is enabled (the default) and False otherwise """ return _pywrap_tf32_execution.is_allowed() -# No tf_export until TF is built against CUDA11 which is required for TF32. -def allow_tensor_float_32_execution(allowed): - """Allow use of TensorFloat-32 with float32 ops on supported hardware. +@tf_export('config.experimental.enable_tensor_float_32_execution') +def enable_tensor_float_32_execution(enabled): + """Enable or disable the use of TensorFloat-32 on supported hardware. - TensorFloat-32 is a math mode introduced with the NVIDIA Ampere architecture. - TensorFloat-32 kernels take float32 inputs and produce float32 outputs. - Internally, the inputs are cast to a custom representation with 10-bit - mantissa (similar to float16) and 8-bit exponent (similar to float32) and are - executed using TensorCores with float32 accumulation. For more information, - see https://blogs.nvidia.com/blog/2020/05/14/tensorfloat-32-precision-format/. + [TensorFloat-32](https://blogs.nvidia.com/blog/2020/05/14/tensorfloat-32-precision-format), + or TF32 for short, is a math mode for NVIDIA Ampere GPUs. TensorFloat-32 + execution causes certain float32 ops, such as matrix multiplications and + convolutions, to run much faster on Ampere GPUs but with reduced precision. + This reduced precision should not impact convergence of deep learning models + in practice. - TensorFloat-32 execution is disabled by default, but this may change in a - future version. + TensorFloat-32 is enabled by default in the nightly versions of TensorFlow. We + expect it will remain enabled by default in the first stable version that + TensorFloat-32 is available, which is TensorFlow 2.4, as it increases + performance and does not reduce model quality in practice. If you want to use + the full float32 precision, you can disable TensorFloat-32 execution with this + function. For example: + + ```python + x = tf.fill((2, 2), 1.0001) + y = tf.fill((2, 2), 1.) + # TensorFloat-32 is enabled, so matmul is run with reduced precision + print(tf.linalg.matmul(x, y)) # [[2., 2.], [2., 2.]] + tf.config.experimental.enable_tensor_float_32_execution(False) + # Matmul is run with full precision + print(tf.linalg.matmul(x, y)) # [[2.0002, 2.0002], [2.0002, 2.0002]] + ``` + + We soon will create an RFC proposing that TensorFloat-32 remain enabled by + default in stable versions of TensorFlow. We expect the RFC to be accepted, + but if it isn't, TensorFloat-32 will be disabled by default in TensorFlow + 2.4. + + To check whether TensorFloat-32 execution is currently enabled, use + `tf.config.experimental.tensor_float_32_execution_enabled`. + + Enabling TensorFloat-32 causes float32 inputs of supported ops, such as + `tf.linalg.matmul`, to be rounded from 23 bits of precision to 10 bits of + precision in most cases. This allows the ops to execute much faster by + utilizing the GPU's tensor cores. TensorFloat-32 has the same dynamic range as + float32, meaning it is no more likely to underflow or overflow than float32. + Ops still use float32 accumulation when TensorFloat-32 is enabled. Enabling + TensorFloat-32 only affects Ampere GPUs and subsequent GPUs that support + TensorFloat-32. + + Note TensorFloat-32 is not always used in supported ops, as only inputs of + certain shapes are supported. Support for more input shapes and more ops may + be added in the future. As a result, precision of float32 ops may decrease in + minor versions of TensorFlow. Args: - allowed: whether to allow TensorFloat-32 execution + enabled: Bool indicating whether to enable TensorFloat-32 execution. """ - _pywrap_tf32_execution.allow(allowed) + _pywrap_tf32_execution.allow(enabled) @tf_export('config.threading.get_intra_op_parallelism_threads') diff --git a/tensorflow/python/framework/config_test.py b/tensorflow/python/framework/config_test.py index ee7e111f6b0..d2314f10abe 100644 --- a/tensorflow/python/framework/config_test.py +++ b/tensorflow/python/framework/config_test.py @@ -750,18 +750,18 @@ class DeviceTest(test.TestCase): class TensorFloat32Test(test.TestCase): def setUp(self): + super(TensorFloat32Test, self).setUp() if not test_util.is_gpu_available( cuda_only=True, min_cuda_compute_capability=(8, 0)): self.skipTest('TensorFloat-32 requires an NVIDIA GPU with compute ' 'capability of at least 8.0') def tearDown(self): - config.allow_tensor_float_32_execution(False) + super(TensorFloat32Test, self).tearDown() + config.enable_tensor_float_32_execution(True) def test_tf32_enabled(self): - self.assertFalse(config.tensor_float_32_execution_allowed()) - config.allow_tensor_float_32_execution(True) - self.assertTrue(config.tensor_float_32_execution_allowed()) + self.assertTrue(config.tensor_float_32_execution_enabled()) x = array_ops.fill((8, 8), 1 + 2**-20) y = array_ops.ones((8, 8)) @@ -771,19 +771,16 @@ class TensorFloat32Test(test.TestCase): self.assertAllEqual(out, expected) def test_tf32_disabled(self): + self.assertTrue(config.tensor_float_32_execution_enabled()) + config.enable_tensor_float_32_execution(False) + self.assertFalse(config.tensor_float_32_execution_enabled()) + x = array_ops.fill((8, 8), 1 + 2**-20) y = array_ops.ones((8, 8)) out = math_ops.matmul(x, y) expected = array_ops.fill((8, 8), 8 * (1 + 2**-20)) self.assertAllEqual(out, expected) - # Test disabling tf32 after enabling it works correctly - config.allow_tensor_float_32_execution(True) - config.allow_tensor_float_32_execution(False) - self.assertFalse(config.tensor_float_32_execution_allowed()) - out = math_ops.matmul(x, y) - self.assertAllEqual(out, expected) - if __name__ == '__main__': ops.enable_eager_execution() diff --git a/tensorflow/python/framework/function_test.py b/tensorflow/python/framework/function_test.py index 596b93227bf..298d41a995c 100644 --- a/tensorflow/python/framework/function_test.py +++ b/tensorflow/python/framework/function_test.py @@ -1592,6 +1592,7 @@ class UnrollLSTMTest(test.TestCase): self.assertAllClose(mv0, mv2, rtol=1e-4) self.assertAllClose(mv0, mv3, rtol=1e-4) + @test_util.run_without_tensor_float_32("Calls matmul in custom LSTM function") def testUnrollLSTMGrad(self): # Run one step of the unrolled lstm graph. def RunForwardBackward(mode, cfg=None): diff --git a/tensorflow/python/framework/indexed_slices.py b/tensorflow/python/framework/indexed_slices.py index 45f6e254b0e..b1e1e20fc2e 100644 --- a/tensorflow/python/framework/indexed_slices.py +++ b/tensorflow/python/framework/indexed_slices.py @@ -429,9 +429,12 @@ def _indexed_slices_to_tensor(value, dtype=None, name=None, as_ref=False): "elements. This may consume a large amount of memory." % num_elements) else: - warnings.warn( - "Converting sparse IndexedSlices to a dense Tensor of unknown shape. " - "This may consume a large amount of memory.") + if value.dense_shape.op.type != "VariableShape": + # VariableShape may hide static shapes behind a resource handle + # producing a warning that isn't that useful to users. + warnings.warn( + "Converting sparse IndexedSlices(%s) to a dense Tensor of unknown " + "shape. This may consume a large amount of memory." % value) return math_ops.unsorted_segment_sum( value.values, value.indices, value.dense_shape[0], name=name) diff --git a/tensorflow/python/framework/op_def_library.py b/tensorflow/python/framework/op_def_library.py index 53d092787f6..016af65fc0a 100644 --- a/tensorflow/python/framework/op_def_library.py +++ b/tensorflow/python/framework/op_def_library.py @@ -63,6 +63,27 @@ def _SatisfiesTypeConstraint(dtype, attr_def, param_name): ", ".join(dtypes.as_dtype(x).name for x in allowed_list))) +def _SatisfiesLengthConstraint(length, attr_def, param_name, op_type_name): + if attr_def.has_minimum and length < attr_def.minimum: + raise ValueError("Attr '%s' of '%s' Op passed list of length %d " + "less than minimum %d." % + (param_name, op_type_name, length, attr_def.minimum)) + + +def _SatisfiesAllowedStringsConstraint(value, attr_def, arg_name, op_type_name): + if value not in attr_def.allowed_values.list.s: + raise ValueError( + "Attr '%s' of '%s' Op passed string '%s' not in: \"%s\"." % + (arg_name, op_type_name, compat.as_text(value), '", "'.join( + map(compat.as_text, attr_def.allowed_values.list.s)))) + + +def _SatisfiesIntMinimumConstraint(value, attr_def, arg_name, op_type_name): + if value < attr_def.minimum: + raise ValueError("Attr '%s' of '%s' Op passed %d less than minimum %d." % + (arg_name, op_type_name, value, attr_def.minimum)) + + def _IsListParameter(arg): if arg.number_attr: return True @@ -172,15 +193,13 @@ def _MakeBool(v, arg_name): return v -def _MakeType(v, attr_def): +def _MakeType(v, arg_name): try: v = dtypes.as_dtype(v).base_dtype except TypeError: raise TypeError("Expected DataType for argument '%s' not %s." % - (attr_def.name, repr(v))) - i = v.as_datatype_enum - _SatisfiesTypeConstraint(i, attr_def, param_name=attr_def.name) - return i + (arg_name, repr(v))) + return v.as_datatype_enum def _MakeShape(v, arg_name): @@ -670,78 +689,32 @@ def _apply_op_helper(op_type_name, name=None, **keywords): # pylint: disable=in for attr_def in op_def.attr: key = attr_def.name value = attrs[key] - attr_value = attr_value_pb2.AttrValue() + if attr_def.HasField("default_value") and value is None: + attr_value = attr_value_pb2.AttrValue() attr_value.CopyFrom(attr_def.default_value) attr_protos[key] = attr_value continue + + attr_value = value_to_attr_value(value, attr_def.type, key) if attr_def.type.startswith("list("): - if not _IsListValue(value): - raise TypeError("Expected list for attr " + key) - if attr_def.has_minimum: - if len(value) < attr_def.minimum: - raise ValueError("Attr '%s' of '%s' Op passed list of length %d " - "less than minimum %d." % - (key, op_type_name, len(value), - attr_def.minimum)) - attr_value.list.SetInParent() - if attr_def.type == "string": - attr_value.s = _MakeStr(value, key) - if attr_def.HasField("allowed_values"): - if attr_value.s not in attr_def.allowed_values.list.s: - raise ValueError( - "Attr '%s' of '%s' Op passed string '%s' not in: \"%s\"." % - (key, op_type_name, compat.as_text(attr_value.s), - '", "'.join(map(compat.as_text, - attr_def.allowed_values.list.s)))) - elif attr_def.type == "list(string)": - attr_value.list.s.extend([_MakeStr(x, key) for x in value]) - if attr_def.HasField("allowed_values"): - for x in attr_value.list.s: - if x not in attr_def.allowed_values.list.s: - raise ValueError( - "Attr '%s' of '%s' Op passed string '%s' not in: \"%s\"." % - (key, op_type_name, compat.as_text(x), - '", "'.join(map(compat.as_text, - attr_def.allowed_values.list.s)))) - elif attr_def.type == "int": - attr_value.i = _MakeInt(value, key) - if attr_def.has_minimum: - if attr_value.i < attr_def.minimum: - raise ValueError( - "Attr '%s' of '%s' Op passed %d less than minimum %d." % - (key, op_type_name, attr_value.i, attr_def.minimum)) - elif attr_def.type == "list(int)": - attr_value.list.i.extend([_MakeInt(x, key) for x in value]) - elif attr_def.type == "float": - attr_value.f = _MakeFloat(value, key) - elif attr_def.type == "list(float)": - attr_value.list.f.extend([_MakeFloat(x, key) for x in value]) - elif attr_def.type == "bool": - attr_value.b = _MakeBool(value, key) - elif attr_def.type == "list(bool)": - attr_value.list.b.extend([_MakeBool(x, key) for x in value]) - elif attr_def.type == "type": - attr_value.type = _MakeType(value, attr_def) - elif attr_def.type == "list(type)": - attr_value.list.type.extend( - [_MakeType(x, attr_def) for x in value]) - elif attr_def.type == "shape": - attr_value.shape.CopyFrom(_MakeShape(value, key)) - elif attr_def.type == "list(shape)": - attr_value.list.shape.extend( - [_MakeShape(x, key) for x in value]) - elif attr_def.type == "tensor": - attr_value.tensor.CopyFrom(_MakeTensor(value, key)) - elif attr_def.type == "list(tensor)": - attr_value.list.tensor.extend( - [_MakeTensor(x, key) for x in value]) - elif attr_def.type == "func": - attr_value.func.CopyFrom(_MakeFunc(value, key)) - elif attr_def.type == "list(func)": - attr_value.list.func.extend([_MakeFunc(x, key) for x in value]) - else: - raise TypeError("Unrecognized Attr type " + attr_def.type) + _SatisfiesLengthConstraint(len(value), attr_def, key, op_type_name) + if attr_def.HasField("allowed_values"): + if attr_def.type == "string": + _SatisfiesAllowedStringsConstraint(attr_value.s, attr_def, key, + op_type_name) + elif attr_def.type == "list(string)": + for value in attr_value.list.s: + _SatisfiesAllowedStringsConstraint(value, attr_def, key, + op_type_name) + if attr_def.has_minimum and attr_def.type == "int": + _SatisfiesIntMinimumConstraint(attr_value.i, attr_def, key, + op_type_name) + if attr_def.type == "type": + _SatisfiesTypeConstraint(attr_value.type, attr_def, key) + if attr_def.type == "list(type)": + for value in attr_value.list.type: + _SatisfiesTypeConstraint(value, attr_def, key) attr_protos[key] = attr_value del attrs # attrs is no longer authoritative, use attr_protos instead @@ -792,6 +765,61 @@ def _apply_op_helper(op_type_name, name=None, **keywords): # pylint: disable=in return output_structure, op_def.is_stateful, op, outputs +def value_to_attr_value(value, attr_type, arg_name): # pylint: disable=invalid-name + """Encodes a Python value as an `AttrValue` proto message. + + Args: + value: The value to convert. + attr_type: The value type (string) -- see the AttrValue proto definition for + valid strings. + arg_name: Argument name (for error messages). + + Returns: + An AttrValue proto message that encodes `value`. + """ + attr_value = attr_value_pb2.AttrValue() + + if attr_type.startswith("list("): + if not _IsListValue(value): + raise TypeError("Expected list for attr " + arg_name) + + if attr_type == "string": + attr_value.s = _MakeStr(value, arg_name) + elif attr_type == "list(string)": + attr_value.list.s.extend([_MakeStr(x, arg_name) for x in value]) + elif attr_type == "int": + attr_value.i = _MakeInt(value, arg_name) + elif attr_type == "list(int)": + attr_value.list.i.extend([_MakeInt(x, arg_name) for x in value]) + elif attr_type == "float": + attr_value.f = _MakeFloat(value, arg_name) + elif attr_type == "list(float)": + attr_value.list.f.extend([_MakeFloat(x, arg_name) for x in value]) + elif attr_type == "bool": + attr_value.b = _MakeBool(value, arg_name) + elif attr_type == "list(bool)": + attr_value.list.b.extend([_MakeBool(x, arg_name) for x in value]) + elif attr_type == "type": + attr_value.type = _MakeType(value, arg_name) + elif attr_type == "list(type)": + attr_value.list.type.extend([_MakeType(x, arg_name) for x in value]) + elif attr_type == "shape": + attr_value.shape.CopyFrom(_MakeShape(value, arg_name)) + elif attr_type == "list(shape)": + attr_value.list.shape.extend([_MakeShape(x, arg_name) for x in value]) + elif attr_type == "tensor": + attr_value.tensor.CopyFrom(_MakeTensor(value, arg_name)) + elif attr_type == "list(tensor)": + attr_value.list.tensor.extend([_MakeTensor(x, arg_name) for x in value]) + elif attr_type == "func": + attr_value.func.CopyFrom(_MakeFunc(value, arg_name)) + elif attr_type == "list(func)": + attr_value.list.func.extend([_MakeFunc(x, arg_name) for x in value]) + else: + raise TypeError("Unrecognized Attr type " + attr_type) + return attr_value + + # The following symbols are used by op_def_util.cc. _pywrap_utils.RegisterPyObject("tf.dtypes.DType", dtypes.DType) _pywrap_utils.RegisterPyObject("tf.dtypes.as_dtype", dtypes.as_dtype) diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index f07bca17061..7e51d3a330d 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -1506,6 +1506,13 @@ def convert_to_tensor(value, if preferred_dtype is not None: preferred_dtype = dtypes.as_dtype(preferred_dtype) + + # See below for the reason why it's `type(value)` and not just `value`. + # https://docs.python.org/3.8/reference/datamodel.html#special-lookup + overload = getattr(type(value), "__tf_tensor__", None) + if overload is not None: + return overload(value, dtype, name) + for base_type, conversion_func in tensor_conversion_registry.get(type(value)): # If dtype is None but preferred_dtype is not None, we try to # cast to preferred_dtype first. @@ -2333,6 +2340,10 @@ class Operation(object): def __repr__(self): return "" % (self.name, self.type) + def __tf_tensor__(self, dtype=None, name=None): + """Raises a helpful error.""" + raise TypeError("can't convert Operation '{}' to Tensor".format(self.name)) + @property def outputs(self): """The list of `Tensor` objects representing the outputs of this op.""" @@ -6833,13 +6844,6 @@ def get_from_proto_function(collection_name): return None -def _operation_conversion_error(op, dtype=None, name=None, as_ref=False): - """Produce a nice error if someone converts an Operation to a Tensor.""" - raise TypeError(("Can't convert Operation '%s' to Tensor " - "(target dtype=%r, name=%r, as_ref=%r)") % - (op.name, dtype, name, as_ref)) - - def _op_to_colocate_with(v, graph): """Operation object corresponding to v to use for colocation constraints.""" if v is None: @@ -6873,10 +6877,6 @@ def _is_keras_symbolic_tensor(x): return hasattr(x, "graph") and getattr(x.graph, "name", None) == "keras_graph" -tensor_conversion_registry.register_tensor_conversion_function( - Operation, _operation_conversion_error) - - # These symbols were originally defined in this module; import them for # backwards compatibility until all references have been updated to access # them from the indexed_slices.py module. diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py index 4129b55e3fd..58e3f650c44 100644 --- a/tensorflow/python/framework/ops_test.py +++ b/tensorflow/python/framework/ops_test.py @@ -858,12 +858,25 @@ class OperationTest(test_util.TensorFlowTestCase): with self.assertRaises(ValueError): ops.convert_to_tensor(tensor, dtype=dtypes.int32) + @test_util.run_in_graph_and_eager_modes + def testConvertToTensorProtocol(self): + class TensorCompatible: + + def __tf_tensor__(self, dtype=None, name=None): + return constant_op.constant((1, 2, 3), dtype=dtype, name=name) + + tc = TensorCompatible() + + tensor = ops.convert_to_tensor(tc, dtype=dtypes.int32) + self.assertEqual(tensor.dtype, dtypes.int32) + self.assertAllEqual((1, 2, 3), self.evaluate(tensor)) + @test_util.run_deprecated_v1 def testNoConvert(self): # Operation cannot be converted to Tensor. op = control_flow_ops.no_op() with self.assertRaisesRegex(TypeError, - r"Can't convert Operation '.*' to Tensor"): + "can't convert Operation '.+' to Tensor"): ops.convert_to_tensor(op) def testStr(self): diff --git a/tensorflow/python/framework/py_context_manager.cc b/tensorflow/python/framework/py_context_manager.cc new file mode 100644 index 00000000000..b895701d84f --- /dev/null +++ b/tensorflow/python/framework/py_context_manager.cc @@ -0,0 +1,74 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/python/framework/py_context_manager.h" + +#include + +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { + +bool PyContextManager::Enter(PyObject* py_context_manager) { + if (context_manager_) { + PyErr_SetString( + PyExc_ValueError, + "tensorflow::PyContextManager::Enter must be called at most once."); + } + if (!py_context_manager) return false; + context_manager_.reset(py_context_manager); + static char _enter[] = "__enter__"; + var_.reset(PyObject_CallMethod(context_manager_.get(), _enter, nullptr)); + return var_ != nullptr; +} + +PyContextManager::~PyContextManager() { + if (var_) { + static char _exit[] = "__exit__"; + static char _ooo[] = "OOO"; + if (PyErr_Occurred()) { + PyObject *type, *value, *traceback; + PyErr_Fetch(&type, &value, &traceback); + value = value ? value : Py_None; + traceback = traceback ? traceback : Py_None; + Safe_PyObjectPtr result(PyObject_CallMethod( + context_manager_.get(), _exit, _ooo, type, value, traceback)); + if (result) { + if (PyObject_IsTrue(result.get())) { + PyErr_SetString( + PyExc_ValueError, + "tensorflow::PyContextManager::Enter does not support " + "context managers that suppress exceptions."); + } else { + PyErr_Restore(type, value, traceback); + } + } + } else { + PyObject* result = PyObject_CallMethod(context_manager_.get(), _exit, + _ooo, Py_None, Py_None, Py_None); + if (result) { + Py_DECREF(result); + } else { + LOG(ERROR) + << "A context manager wrapped by tensorflow::PyContextManager " + "raised a new exception from its __new__ method. This behavior " + "is not supported by PyContextManager, and the exception is " + "being suppressed."; + PyErr_Clear(); + } + } + } +} + +} // namespace tensorflow diff --git a/tensorflow/python/framework/py_context_manager.h b/tensorflow/python/framework/py_context_manager.h new file mode 100644 index 00000000000..6c15fccaf07 --- /dev/null +++ b/tensorflow/python/framework/py_context_manager.h @@ -0,0 +1,78 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_PYTHON_FRAMEWORK_PY_CONTEXT_MANAGER_H_ +#define TENSORFLOW_PYTHON_FRAMEWORK_PY_CONTEXT_MANAGER_H_ + +#include + +#include + +#include "tensorflow/python/lib/core/safe_pyobject_ptr.h" + +namespace tensorflow { + +// Class that wraps a Python context manager, and calls the `__enter__` and +// `__exit__` methods at appropriate times: +// +// * When `PyContextManager::Enter(cm)` is called, the context manager `cm` +// is stored, and `cm.__enter__` is called. The result can be retrieved +// with `PyContextManager::var()`. +// * When the `PyContextManager` is destroyed, then `cm.__exit__` is called +// (with information about any active exception). +// * `PyContextManager::Enter(cm)` may be called at most once. If +// `PyContextManager::Enter()` is never called, then the destructor is a +// no-op (i.e., `__exit__` is not called). +// +// PyContextManager places two restrictons on the wrapped context managers: +// +// 1. The context manager may not suppress exceptions -- i.e., `__exit__` +// may not return a True value. If it does, then a new exception will be +// set, indicating that this is unuspported. +// 2. The context manager may not raise an exception from `__exit__` if the +// an exception is not active when it is called. If it does, then an error +// message will be logged, indicating that this is unsupported, and the +// exception will be suppressed. +// +// These restrictions are both intended to ensure that the state of +// PyErr_Occured is unchanged by PyContextManager's destructor. This is +// important, because changing the state of PyErr_Occurred in the destructor +// would mean that we are returning a nullptr with no exception set, or +// returning a non-null value with an exception set (both of which are invalid). +class PyContextManager { + public: + // Calls `py_context_manager.__enter__()`, and stores the result in `var`. + // Return true if `__enter__` succeeds, or false if `__enter__` raises an + // exception. (Also returns false if `py_context_manager` is nullptr.) + // + // Steals a reference to `py_context_manager`. (This reference is deleted + // when the destructor is called.) + bool Enter(PyObject* py_context_manager); + + // Calls `py_context_manager.__exit__()`. + ~PyContextManager(); + + // Returns the variable returned by `context_manager.__enter__()`. + // (This is the `var` bound by `with context_manager as var`.) + // Returns a borrowed reference. + PyObject* var() { return var_.get(); } + + protected: + Safe_PyObjectPtr context_manager_; + Safe_PyObjectPtr var_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_PYTHON_FRAMEWORK_PY_CONTEXT_MANAGER_H_ diff --git a/tensorflow/python/framework/py_context_manager_pybind.cc b/tensorflow/python/framework/py_context_manager_pybind.cc new file mode 100644 index 00000000000..34565145444 --- /dev/null +++ b/tensorflow/python/framework/py_context_manager_pybind.cc @@ -0,0 +1,51 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "pybind11/pybind11.h" +#include "pybind11/stl.h" +#include "tensorflow/python/framework/py_context_manager.h" + +namespace py = pybind11; + +namespace { + +// Test harness for PyContextManager. Creates a PyContextManager `cm` that +// wraps `context_manager`, calls `cm.Enter()`, and then calls `body_func` +// with `cm.var()`. Returns the result of the function. +py::handle TestPyContextManager(py::handle context_manager, + py::handle body_func) { + tensorflow::Safe_PyObjectPtr result; + { + tensorflow::PyContextManager cm; + Py_INCREF(context_manager.ptr()); // cm.Enter steals a reference. + if (!cm.Enter(context_manager.ptr())) { + throw py::error_already_set(); + } + result.reset( + PyObject_CallFunctionObjArgs(body_func.ptr(), cm.var(), nullptr)); + } + // cm gets destroyed here. + + if (result) { + return result.release(); + } else { + throw py::error_already_set(); + } +} + +} // namespace + +PYBIND11_MODULE(_py_context_manager, m) { + m.def("test_py_context_manager", TestPyContextManager); +} diff --git a/tensorflow/python/framework/py_context_manager_test.py b/tensorflow/python/framework/py_context_manager_test.py new file mode 100644 index 00000000000..60c72a806ae --- /dev/null +++ b/tensorflow/python/framework/py_context_manager_test.py @@ -0,0 +1,118 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tensorflow.python.framework._py_context_manager.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python import _py_context_manager +from tensorflow.python.framework import test_util +from tensorflow.python.platform import googletest + + +class TestContextManager(object): + + def __init__(self, behavior="basic"): + self.log = [] + self.behavior = behavior + + def __enter__(self): + self.log.append("__enter__()") + if self.behavior == "raise_from_enter": + raise ValueError("exception in __enter__") + return "var" + + def __exit__(self, ex_type, ex_value, ex_tb): + self.log.append("__exit__(%s, %s, %s)" % (ex_type, ex_value, ex_tb)) + if self.behavior == "raise_from_exit": + raise ValueError("exception in __exit__") + if self.behavior == "suppress_exception": + return True + + +# Expected log when the body doesn't raise an exception. +NO_EXCEPTION_LOG = """\ +__enter__() +body('var') +__exit__(None, None, None)""" + +# Expected log when the body does raise an exception. (Regular expression.) +EXCEPTION_LOG = """\ +__enter__\\(\\) +body\\('var'\\) +__exit__\\(, Foo, \\)""" + + +class OpDefUtilTest(test_util.TensorFlowTestCase): + + def testBasic(self): + cm = TestContextManager() + + def body(var): + cm.log.append("body(%r)" % var) + + _py_context_manager.test_py_context_manager(cm, body) + self.assertEqual("\n".join(cm.log), NO_EXCEPTION_LOG) + + def testBodyRaisesException(self): + cm = TestContextManager() + + def body(var): + cm.log.append("body(%r)" % var) + raise ValueError("Foo") + + with self.assertRaisesRegexp(ValueError, "Foo"): + _py_context_manager.test_py_context_manager(cm, body) + self.assertRegex("\n".join(cm.log), EXCEPTION_LOG) + + def testEnterRaisesException(self): + cm = TestContextManager("raise_from_enter") + + def body(var): + cm.log.append("body(%r)" % var) + + with self.assertRaisesRegexp(ValueError, "exception in __enter__"): + _py_context_manager.test_py_context_manager(cm, body) + self.assertEqual("\n".join(cm.log), "__enter__()") + + # Test behavior in unsupported case where __exit__ raises an exception. + def testExitRaisesException(self): + cm = TestContextManager("raise_from_exit") + + def body(var): + cm.log.append("body(%r)" % var) + + # Note: this does *not* raise an exception (but does log a warning): + _py_context_manager.test_py_context_manager(cm, body) + self.assertEqual("\n".join(cm.log), NO_EXCEPTION_LOG) + + # Test behavior in unsupported case where __exit__ suppresses exception. + def testExitSuppressesException(self): + cm = TestContextManager("suppress_exception") + + def body(var): + cm.log.append("body(%r)" % var) + raise ValueError("Foo") + + with self.assertRaisesRegexp( + ValueError, "tensorflow::PyContextManager::Enter does not support " + "context managers that suppress exception"): + _py_context_manager.test_py_context_manager(cm, body) + self.assertRegex("\n".join(cm.log), EXCEPTION_LOG) + + +if __name__ == "__main__": + googletest.main() diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py index 4d7b7746b9c..47639c99a4a 100644 --- a/tensorflow/python/framework/test_util.py +++ b/tensorflow/python/framework/test_util.py @@ -54,6 +54,7 @@ from tensorflow.python.eager import backprop from tensorflow.python.eager import context from tensorflow.python.eager import def_function from tensorflow.python.eager import tape +from tensorflow.python.framework import config from tensorflow.python.framework import device as pydev from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -70,6 +71,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_util from tensorflow.python.ops import control_flow_util_v2 from tensorflow.python.ops import gradients_impl +from tensorflow.python.ops import math_ops from tensorflow.python.ops import script_ops from tensorflow.python.ops import summary_ops_v2 from tensorflow.python.ops import variables @@ -1908,6 +1910,75 @@ def xla_allow_fallback(description): # pylint: disable=unused-argument return xla_allow_fallback_impl +# The description is just for documentation purposes. +def run_without_tensor_float_32(description): # pylint: disable=unused-argument + """Execute test with TensorFloat-32 disabled. + + While almost every real-world deep learning model runs fine with + TensorFloat-32, many tests use assertAllClose or similar methods. + TensorFloat-32 matmuls typically will cause such methods to fail with the + default tolerances. + + Args: + description: A description used for documentation purposes, describing why + the test requires TensorFloat-32 to be disabled. + + Returns: + Decorator which runs a test with TensorFloat-32 disabled. + """ + + def decorator(f): + + @functools.wraps(f) + def decorated(self, *args, **kwargs): + allowed = config.tensor_float_32_execution_enabled() + try: + config.enable_tensor_float_32_execution(False) + f(self, *args, **kwargs) + finally: + config.enable_tensor_float_32_execution(allowed) + + return decorated + + return decorator + + +# The description is just for documentation purposes. +def run_all_without_tensor_float_32(description): # pylint: disable=unused-argument + """Execute all tests in a class with TensorFloat-32 disabled.""" + return for_all_test_methods(run_without_tensor_float_32, description) + + +def matmul_without_tf32(a, b, *args, **kwargs): + """Run matmul but cast float32 inputs to float64 if TensorFloat-32 is enabled. + + This effectively runs matmul without TensorFloat-32. It should only be used in + tests when verifying some other op or functions works correctly, e.g. to test + `tf.linalg.sqrtm` by matrix multiplying the output of the op by itself. In + such cases, the matmul itself is not being tested so it's OK to run it with + higher precision. + + If a matmul itself is being tested, or some other op which uses matmul, use + `run_without_tensor_float_32` instead. + + Args: + a: First input to tf.linalg.matmul + b: Second input to tf.linalg.matmul + args: Other positional arguments to tf.linalg.matmul + **kwargs: Other keyword arguments to tf.linalg.matmul + + Returns: + A tensor with the same type as `a`. + """ + if config.tensor_float_32_execution_enabled() and a.dtype == "float32": + a = math_ops.cast(a, "float64") + b = math_ops.cast(b, "float64") + ret = math_ops.matmul(a, b, *args, **kwargs) + return math_ops.cast(ret, a.dtype) + else: + return math_ops.matmul(a, b, *args, **kwargs) + + class EagerSessionWarner(object): def __getattr__(self, attr): diff --git a/tensorflow/python/keras/BUILD b/tensorflow/python/keras/BUILD index d8eff0f2260..95a192af51d 100755 --- a/tensorflow/python/keras/BUILD +++ b/tensorflow/python/keras/BUILD @@ -307,6 +307,7 @@ py_library( deps = [ ":backend", ":models", + "//tensorflow/python:config", "//tensorflow/python:framework_test_lib", "//tensorflow/python:tensor_shape", "//tensorflow/python:tensor_spec", diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py index 651acbfeac4..28bbe429b5f 100644 --- a/tensorflow/python/keras/backend.py +++ b/tensorflow/python/keras/backend.py @@ -840,7 +840,7 @@ def _to_tensor(x, dtype): Returns: A tensor. """ - return ops.convert_to_tensor_v2(x, dtype=dtype) + return ops.convert_to_tensor_v2_with_dispatch(x, dtype=dtype) @keras_export('keras.backend.is_sparse') @@ -4766,7 +4766,7 @@ def categorical_crossentropy(target, output, from_logits=False, axis=-1): from_logits: Boolean, whether `output` is the result of a softmax, or is a tensor of logits. axis: Int specifying the channels axis. `axis=-1` corresponds to data - format `channels_last', and `axis=1` corresponds to data format + format `channels_last`, and `axis=1` corresponds to data format `channels_first`. Returns: @@ -4797,8 +4797,8 @@ def categorical_crossentropy(target, output, from_logits=False, axis=-1): [0. 0. 0.] """ - target = ops.convert_to_tensor_v2(target) - output = ops.convert_to_tensor_v2(output) + target = ops.convert_to_tensor_v2_with_dispatch(target) + output = ops.convert_to_tensor_v2_with_dispatch(output) target.shape.assert_is_compatible_with(output.shape) if from_logits: @@ -4838,7 +4838,7 @@ def sparse_categorical_crossentropy(target, output, from_logits=False, axis=-1): from_logits: Boolean, whether `output` is the result of a softmax, or is a tensor of logits. axis: Int specifying the channels axis. `axis=-1` corresponds to data - format `channels_last', and `axis=1` corresponds to data format + format `channels_last`, and `axis=1` corresponds to data format `channels_first`. Returns: @@ -4847,8 +4847,8 @@ def sparse_categorical_crossentropy(target, output, from_logits=False, axis=-1): Raises: ValueError: if `axis` is neither -1 nor one of the axes of `output`. """ - target = ops.convert_to_tensor_v2(target) - output = ops.convert_to_tensor_v2(output) + target = ops.convert_to_tensor_v2_with_dispatch(target) + output = ops.convert_to_tensor_v2_with_dispatch(output) if (not from_logits and not isinstance(output, (ops.EagerTensor, variables_module.Variable)) and @@ -4925,8 +4925,8 @@ def binary_crossentropy(target, output, from_logits=False): Returns: A tensor. """ - target = ops.convert_to_tensor_v2(target) - output = ops.convert_to_tensor_v2(output) + target = ops.convert_to_tensor_v2_with_dispatch(target) + output = ops.convert_to_tensor_v2_with_dispatch(output) if from_logits: return nn.sigmoid_cross_entropy_with_logits(labels=target, logits=output) diff --git a/tensorflow/python/keras/backend_test.py b/tensorflow/python/keras/backend_test.py index 2e0274a509b..bbbed2f524a 100644 --- a/tensorflow/python/keras/backend_test.py +++ b/tensorflow/python/keras/backend_test.py @@ -491,7 +491,7 @@ class BackendLinearAlgebraTest(test.TestCase, parameterized.TestCase): input_shape_b=(4, 7)) def test_relu(self): - x = ops.convert_to_tensor_v2([[-4, 0], [2, 7]], 'float32') + x = ops.convert_to_tensor_v2_with_dispatch([[-4, 0], [2, 7]], 'float32') # standard relu relu_op = backend.relu(x) @@ -1310,7 +1310,7 @@ class BackendNNOpsTest(test.TestCase, parameterized.TestCase): inputs = backend.variable(input_val) initial_states = [ backend.variable(init_state_val), - ops.convert_to_tensor_v2( + ops.convert_to_tensor_v2_with_dispatch( np.concatenate([init_state_val, init_state_val], axis=-1)) ] mask = backend.variable(np_mask) @@ -1617,9 +1617,11 @@ class BackendCrossEntropyLossesTest(test.TestCase, parameterized.TestCase): p = backend.placeholder() o = backend.categorical_crossentropy(t, p) - t_val = ops.convert_to_tensor_v2([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]]) - p_val = ops.convert_to_tensor_v2([[.9, .05, .05], [.05, .89, .06], - [.05, .01, .94]]) + t_val = ops.convert_to_tensor_v2_with_dispatch([[1., 0., 0.], [0., 1., 0.], + [0., 0., 1.]]) + p_val = ops.convert_to_tensor_v2_with_dispatch([[.9, .05, .05], + [.05, .89, .06], + [.05, .01, .94]]) f = backend.function([t, p], o) result = f([t_val, p_val]) @@ -1633,7 +1635,8 @@ class BackendCrossEntropyLossesTest(test.TestCase, parameterized.TestCase): self.assertArrayNear(result, [.105, .065, .111], 1e-3) # from logits - p_val = ops.convert_to_tensor_v2([[8., 1., 1.], [0., 9., 1.], [2., 3., 5.]]) + p_val = ops.convert_to_tensor_v2_with_dispatch([[8., 1., 1.], [0., 9., 1.], + [2., 3., 5.]]) o = backend.categorical_crossentropy(t, p, from_logits=True) f = backend.function([t, p], o) @@ -1685,9 +1688,10 @@ class BackendCrossEntropyLossesTest(test.TestCase, parameterized.TestCase): p = backend.placeholder() o = backend.sparse_categorical_crossentropy(t, p) - t_val = ops.convert_to_tensor_v2([0, 1, 2]) - p_val = ops.convert_to_tensor_v2([[.9, .05, .05], [.05, .89, .06], - [.05, .01, .94]]) + t_val = ops.convert_to_tensor_v2_with_dispatch([0, 1, 2]) + p_val = ops.convert_to_tensor_v2_with_dispatch([[.9, .05, .05], + [.05, .89, .06], + [.05, .01, .94]]) f = backend.function([t, p], o) result = f([t_val, p_val]) @@ -1703,7 +1707,8 @@ class BackendCrossEntropyLossesTest(test.TestCase, parameterized.TestCase): _ = f([t_val, p_val]) # from logits - p_val = ops.convert_to_tensor_v2([[8., 1., 1.], [0., 9., 1.], [2., 3., 5.]]) + p_val = ops.convert_to_tensor_v2_with_dispatch([[8., 1., 1.], [0., 9., 1.], + [2., 3., 5.]]) o = backend.sparse_categorical_crossentropy(t, p, from_logits=True) f = backend.function([t, p], o) @@ -2124,9 +2129,10 @@ class ControlOpsTests(test.TestCase): self.assertEqual(backend.eval(tensor), [9.0]) def test_unequal_rank(self): - x = ops.convert_to_tensor_v2( + x = ops.convert_to_tensor_v2_with_dispatch( np.array([[1, 2, 3], [4, 5, 6]]), dtype='float32') - y = ops.convert_to_tensor_v2(np.array([1, 2, 3]), dtype='float32') + y = ops.convert_to_tensor_v2_with_dispatch( + np.array([1, 2, 3]), dtype='float32') def true_func(): return x diff --git a/tensorflow/python/keras/callbacks.py b/tensorflow/python/keras/callbacks.py index 21215db2c6f..7b1cc291be6 100644 --- a/tensorflow/python/keras/callbacks.py +++ b/tensorflow/python/keras/callbacks.py @@ -675,6 +675,10 @@ class Callback(object): Subclasses should override for any actions to run. + Note that if the `steps_per_execution` argument to `compile` in + `tf.keras.Model` is set to `N`, this method will only be called every `N` + batches. + Arguments: batch: Integer, index of batch within the current epoch. logs: Dict, contains the return value of `model.train_step`. Typically, @@ -691,6 +695,10 @@ class Callback(object): Subclasses should override for any actions to run. + Note that if the `steps_per_execution` argument to `compile` in + `tf.keras.Model` is set to `N`, this method will only be called every `N` + batches. + Arguments: batch: Integer, index of batch within the current epoch. logs: Dict. Aggregated metric results up until this batch. @@ -708,6 +716,10 @@ class Callback(object): Subclasses should override for any actions to run. + Note that if the `steps_per_execution` argument to `compile` in + `tf.keras.Model` is set to `N`, this method will only be called every `N` + batches. + Arguments: batch: Integer, index of batch within the current epoch. logs: Dict, contains the return value of `model.test_step`. Typically, @@ -725,6 +737,10 @@ class Callback(object): Subclasses should override for any actions to run. + Note that if the `steps_per_execution` argument to `compile` in + `tf.keras.Model` is set to `N`, this method will only be called every `N` + batches. + Arguments: batch: Integer, index of batch within the current epoch. logs: Dict. Aggregated metric results up until this batch. @@ -737,6 +753,10 @@ class Callback(object): Subclasses should override for any actions to run. + Note that if the `steps_per_execution` argument to `compile` in + `tf.keras.Model` is set to `N`, this method will only be called every `N` + batches. + Arguments: batch: Integer, index of batch within the current epoch. logs: Dict, contains the return value of `model.predict_step`, @@ -751,6 +771,10 @@ class Callback(object): Subclasses should override for any actions to run. + Note that if the `steps_per_execution` argument to `compile` in + `tf.keras.Model` is set to `N`, this method will only be called every `N` + batches. + Arguments: batch: Integer, index of batch within the current epoch. logs: Dict. Aggregated metric results up until this batch. @@ -896,10 +920,15 @@ class TerminateOnNaN(Callback): """Callback that terminates training when a NaN loss is encountered. """ + def __init__(self): + super(TerminateOnNaN, self).__init__() + self._supports_tf_logs = True + def on_batch_end(self, batch, logs=None): logs = logs or {} loss = logs.get('loss') if loss is not None: + loss = tf_utils.to_numpy_or_python_type(loss) if np.isnan(loss) or np.isinf(loss): print('Batch %d: Invalid loss, terminating training' % (batch)) self.model.stop_training = True @@ -1156,7 +1185,7 @@ class ModelCheckpoint(Callback): save_freq: `'epoch'` or integer. When using `'epoch'`, the callback saves the model after each epoch. When using integer, the callback saves the model at end of this many batches. If the `Model` is compiled with - `experimental_steps_per_execution=N`, then the saving criteria will be + `steps_per_execution=N`, then the saving criteria will be checked every Nth batch. Note that if the saving isn't aligned to epochs, the monitored metric may potentially be less reliable (it could reflect as little as 1 batch, since the metrics get reset every @@ -1259,16 +1288,6 @@ class ModelCheckpoint(Callback): self.save_weights_only = True def on_train_begin(self, logs=None): - # pylint: disable=protected-access - if self.model._in_multi_worker_mode: - logging.warning( - 'Automatic model reloading for interrupted job was removed from ' - 'the `ModelCheckpoint` callback in multi-worker mode, please use the ' - '`keras.callbacks.experimental.BackupAndRestore` callback instead. ' - 'See this tutorial for details: ' - 'https://www.tensorflow.org/tutorials/distribute/' - 'multi_worker_with_keras#backupandrestore_callback.' - ) if self.load_weights_on_restart: filepath_to_load = ( self._get_most_recently_modified_file_matching_pattern(self.filepath)) @@ -2417,7 +2436,7 @@ class ReduceLROnPlateau(Callback): """Resets wait counter and cooldown counter. """ if self.mode not in ['auto', 'min', 'max']: - logging.warning('Learning Rate Plateau Reducing mode %s is unknown, ' + logging.warning('Learning rate reduction mode %s is unknown, ' 'fallback to auto mode.', self.mode) self.mode = 'auto' if (self.mode == 'min' or @@ -2438,7 +2457,7 @@ class ReduceLROnPlateau(Callback): logs['lr'] = K.get_value(self.model.optimizer.lr) current = logs.get(self.monitor) if current is None: - logging.warning('Reduce LR on plateau conditioned on metric `%s` ' + logging.warning('Learning rate reduction is conditioned on metric `%s` ' 'which is not available. Available metrics are: %s', self.monitor, ','.join(list(logs.keys()))) diff --git a/tensorflow/python/keras/callbacks_test.py b/tensorflow/python/keras/callbacks_test.py index 9fd8bf86609..1eaa3dd4052 100644 --- a/tensorflow/python/keras/callbacks_test.py +++ b/tensorflow/python/keras/callbacks_test.py @@ -935,7 +935,7 @@ class KerasCallbacksTest(keras_parameterized.TestCase): verbose=0) with context.eager_mode(): - tensor = ops.convert_to_tensor(1.) + tensor = ops.convert_to_tensor_v2_with_dispatch(1.) def mock_numpy(): raise RuntimeError( @@ -975,7 +975,7 @@ class KerasCallbacksTest(keras_parameterized.TestCase): verbose=2) with context.eager_mode(): - tensor = ops.convert_to_tensor(1.) + tensor = ops.convert_to_tensor_v2_with_dispatch(1.) def mock_numpy(): raise RuntimeError( @@ -2193,7 +2193,7 @@ class TestTensorBoardV2(keras_parameterized.TestCase): steps=100, verbose=0) - tensor = ops.convert_to_tensor(1.) + tensor = ops.convert_to_tensor_v2_with_dispatch(1.) def mock_numpy(): raise RuntimeError( diff --git a/tensorflow/python/keras/distribute/BUILD b/tensorflow/python/keras/distribute/BUILD index 748ab7ce0f4..82f4dcb819c 100644 --- a/tensorflow/python/keras/distribute/BUILD +++ b/tensorflow/python/keras/distribute/BUILD @@ -245,7 +245,6 @@ distribute_py_test( main = "custom_training_loop_models_test.py", tags = [ "multi_and_single_gpu", - "no_cuda11", ], tpu_tags = [ "no_oss", # b/153615544. diff --git a/tensorflow/python/keras/distribute/custom_training_loop_optimizer_test.py b/tensorflow/python/keras/distribute/custom_training_loop_optimizer_test.py index b9eee26220a..b014c887c03 100644 --- a/tensorflow/python/keras/distribute/custom_training_loop_optimizer_test.py +++ b/tensorflow/python/keras/distribute/custom_training_loop_optimizer_test.py @@ -56,8 +56,8 @@ class OptimizerTest(test.TestCase, parameterized.TestCase): @def_function.function def optimize(): grads = values.PerReplica([ - ops.convert_to_tensor([1., 1.]), - ops.convert_to_tensor([2., 2.]), + ops.convert_to_tensor_v2_with_dispatch([1., 1.]), + ops.convert_to_tensor_v2_with_dispatch([2., 2.]), ]) def step_fn(grads): @@ -85,7 +85,7 @@ class OptimizerTest(test.TestCase, parameterized.TestCase): @def_function.function def optimize(): - grads = ops.convert_to_tensor([1., 1.]) + grads = ops.convert_to_tensor_v2_with_dispatch([1., 1.]) def step_fn(grads): optimizer.apply_gradients( @@ -107,7 +107,7 @@ class OptimizerTest(test.TestCase, parameterized.TestCase): v = variables.Variable([0., 0.]) optimizer = gradient_descent.SGD(0.1) - grads = ops.convert_to_tensor([1., 1.]) + grads = ops.convert_to_tensor_v2_with_dispatch([1., 1.]) def step_fn(grads): with self.assertRaises(NotImplementedError): diff --git a/tensorflow/python/keras/distribute/distribute_strategy_test.py b/tensorflow/python/keras/distribute/distribute_strategy_test.py index 4ea53429195..2329f510cb3 100644 --- a/tensorflow/python/keras/distribute/distribute_strategy_test.py +++ b/tensorflow/python/keras/distribute/distribute_strategy_test.py @@ -1605,6 +1605,8 @@ class TestRegularizerLoss(test.TestCase, parameterized.TestCase): self.assertEqual(-1.0, v) +@testing_utils.run_all_without_tensor_float_32( + 'Uses Dense layers, which call matmul') class TestDistributionStrategyWithKerasModels(test.TestCase, parameterized.TestCase): @@ -1762,7 +1764,7 @@ class TestDistributionStrategyWithKerasModels(test.TestCase, outputs = keras.layers.Dense(1)(x) model = keras.Model(inputs, outputs) - model.compile('sgd', 'mse', experimental_steps_per_execution=10) + model.compile('sgd', 'mse', steps_per_execution=10) bc = BatchCountingCB() x, y = np.ones((100, 10, 10, 3)), np.ones((100, 1)) @@ -1786,7 +1788,7 @@ class TestDistributionStrategyWithKerasModels(test.TestCase, outputs = keras.layers.Dense(1)(inputs) model = keras.Model(inputs, outputs) - model.compile('sgd', 'mse', experimental_steps_per_execution=20) + model.compile('sgd', 'mse', steps_per_execution=20) bc = BatchCountingCB() x, y = np.ones((100, 10)), np.ones((100, 1)) @@ -1810,7 +1812,7 @@ class TestDistributionStrategyWithKerasModels(test.TestCase, outputs = keras.layers.Dense(1)(inputs) model = keras.Model(inputs, outputs) - model.compile('sgd', 'mse', experimental_steps_per_execution=20) + model.compile('sgd', 'mse', steps_per_execution=20) x, y = np.ones((100, 10)), np.ones((100, 1)) ds = dataset_ops.DatasetV2.from_tensor_slices((x, y)).batch(2) @@ -1846,7 +1848,7 @@ class TestDistributionStrategyWithKerasModels(test.TestCase, outputs = keras.layers.Dense(1)(inputs) model = keras.Model(inputs, outputs) - model.compile('sgd', 'mse', experimental_steps_per_execution=500) + model.compile('sgd', 'mse', steps_per_execution=500) x, y = np.ones((100, 10)), np.ones((100, 1)) bc = BatchCountingCB() diff --git a/tensorflow/python/keras/distribute/keras_dnn_correctness_test.py b/tensorflow/python/keras/distribute/keras_dnn_correctness_test.py index 6ec7cc2bac5..e04b40e33be 100644 --- a/tensorflow/python/keras/distribute/keras_dnn_correctness_test.py +++ b/tensorflow/python/keras/distribute/keras_dnn_correctness_test.py @@ -24,6 +24,7 @@ from tensorflow.python.distribute import combinations from tensorflow.python.distribute import distribution_strategy_context from tensorflow.python.eager import context from tensorflow.python.keras import backend as K +from tensorflow.python.keras import testing_utils from tensorflow.python.keras.distribute import keras_correctness_test_base from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_keras from tensorflow.python.platform import test @@ -47,6 +48,8 @@ def is_default_strategy(strategy): return not distribution_strategy_context.has_strategy() +@testing_utils.run_all_without_tensor_float_32( + 'Uses Dense layers, which call matmul') class TestDistributionStrategyDnnCorrectness( keras_correctness_test_base.TestDistributionStrategyCorrectnessBase): @@ -240,6 +243,8 @@ class SubclassedModel(keras.Model): return self.dense4(x) +@testing_utils.run_all_without_tensor_float_32( + 'Uses Dense layers, which call matmul') class TestDistributionStrategyDnnCorrectnessWithSubclassedModel( TestDistributionStrategyDnnCorrectness): diff --git a/tensorflow/python/keras/distribute/keras_image_model_correctness_test.py b/tensorflow/python/keras/distribute/keras_image_model_correctness_test.py index 7e6ae3cc719..57b9b718491 100644 --- a/tensorflow/python/keras/distribute/keras_image_model_correctness_test.py +++ b/tensorflow/python/keras/distribute/keras_image_model_correctness_test.py @@ -21,11 +21,15 @@ import numpy as np from tensorflow.python import keras from tensorflow.python.distribute import combinations from tensorflow.python.eager import context +from tensorflow.python.keras import testing_utils from tensorflow.python.keras.distribute import keras_correctness_test_base from tensorflow.python.keras.optimizer_v2 import gradient_descent from tensorflow.python.platform import test +@testing_utils.run_all_without_tensor_float_32( + 'Uses Dense layers, which call matmul. Even if Dense layers run in ' + 'float64, the test sometimes fails with tf32 enabled for unknown reasons') class DistributionStrategyCnnCorrectnessTest( keras_correctness_test_base.TestDistributionStrategyCorrectnessBase): diff --git a/tensorflow/python/keras/distribute/keras_rnn_model_correctness_test.py b/tensorflow/python/keras/distribute/keras_rnn_model_correctness_test.py index aa7f0c20045..4e82b7db433 100644 --- a/tensorflow/python/keras/distribute/keras_rnn_model_correctness_test.py +++ b/tensorflow/python/keras/distribute/keras_rnn_model_correctness_test.py @@ -69,6 +69,8 @@ class _DistributionStrategyRnnModelCorrectnessTest( return model +@testing_utils.run_all_without_tensor_float_32( + 'Uses Dense layers, which call matmul') class DistributionStrategyGruModelCorrectnessTest( _DistributionStrategyRnnModelCorrectnessTest): @@ -88,6 +90,8 @@ class DistributionStrategyGruModelCorrectnessTest( self.run_correctness_test(distribution, use_numpy, use_validation_data) +@testing_utils.run_all_without_tensor_float_32( + 'Uses Dense layers, which call matmul') class DistributionStrategyLstmModelCorrectnessTest( _DistributionStrategyRnnModelCorrectnessTest): diff --git a/tensorflow/python/keras/distribute/keras_save_load_test.py b/tensorflow/python/keras/distribute/keras_save_load_test.py index 65877a0f869..fc2e2bd46ec 100644 --- a/tensorflow/python/keras/distribute/keras_save_load_test.py +++ b/tensorflow/python/keras/distribute/keras_save_load_test.py @@ -20,10 +20,13 @@ from __future__ import print_function from tensorflow.python.distribute import combinations from tensorflow.python.eager import test +from tensorflow.python.keras import testing_utils from tensorflow.python.keras.distribute import saved_model_test_base as test_base from tensorflow.python.keras.saving import save +@testing_utils.run_all_without_tensor_float_32( + 'Uses Dense layers, which call matmul') class KerasSaveLoadTest(test_base.TestSavedModelBase): def setUp(self): diff --git a/tensorflow/python/keras/distribute/saved_model_mixed_api_test.py b/tensorflow/python/keras/distribute/saved_model_mixed_api_test.py index d303a4228b5..7815d7403fd 100644 --- a/tensorflow/python/keras/distribute/saved_model_mixed_api_test.py +++ b/tensorflow/python/keras/distribute/saved_model_mixed_api_test.py @@ -26,12 +26,15 @@ from __future__ import print_function from tensorflow.python.distribute import combinations from tensorflow.python.eager import test +from tensorflow.python.keras import testing_utils from tensorflow.python.keras.distribute import saved_model_test_base as test_base from tensorflow.python.keras.saving import save _DEFAULT_FUNCTION_KEY = 'serving_default' +@testing_utils.run_all_without_tensor_float_32( + 'Uses Dense layers, which call matmul') class SavedModelSaveAndLoadTest(test_base.TestSavedModelBase): def setUp(self): diff --git a/tensorflow/python/keras/distribute/saved_model_save_load_test.py b/tensorflow/python/keras/distribute/saved_model_save_load_test.py index 39856af2a20..2174d39bae4 100644 --- a/tensorflow/python/keras/distribute/saved_model_save_load_test.py +++ b/tensorflow/python/keras/distribute/saved_model_save_load_test.py @@ -24,6 +24,7 @@ from tensorflow.python.distribute import combinations from tensorflow.python.distribute import strategy_combinations from tensorflow.python.eager import test from tensorflow.python.framework import tensor_spec +from tensorflow.python.keras import testing_utils from tensorflow.python.keras.distribute import model_combinations from tensorflow.python.keras.distribute import saved_model_test_base as test_base from tensorflow.python.ops import array_ops @@ -32,6 +33,8 @@ from tensorflow.python.saved_model import save_options as save_options_lib from tensorflow.python.saved_model import saved_model +@testing_utils.run_all_without_tensor_float_32( + 'Uses Dense layers, which call matmul') class SavedModelKerasModelTest(test_base.TestSavedModelBase): def setUp(self): diff --git a/tensorflow/python/keras/engine/base_layer.py b/tensorflow/python/keras/engine/base_layer.py index a9c863cbc9e..dbe83bedf0a 100644 --- a/tensorflow/python/keras/engine/base_layer.py +++ b/tensorflow/python/keras/engine/base_layer.py @@ -1006,10 +1006,10 @@ class Layer(module.Module, version_utils.LayerVersionSelector): np_arrays.ndarray, np.ndarray, float, int)) for x in input_list): def _convert_non_tensor(x): - # Don't call `ops.convert_to_tensor_v2` on all `inputs` because + # Don't call `ops.convert_to_tensor` on all `inputs` because # `SparseTensors` can't be converted to `Tensor`. if isinstance(x, (np_arrays.ndarray, np.ndarray, float, int)): - return ops.convert_to_tensor_v2(x) + return ops.convert_to_tensor_v2_with_dispatch(x) return x inputs = nest.map_structure(_convert_non_tensor, inputs) @@ -1518,7 +1518,8 @@ class Layer(module.Module, version_utils.LayerVersionSelector): if loss is None: return None # Will be filtered out when computing the .losses property if not tensor_util.is_tensor(loss): - loss = ops.convert_to_tensor_v2(loss, dtype=backend.floatx()) + loss = ops.convert_to_tensor_v2_with_dispatch( + loss, dtype=backend.floatx()) loss._unconditional_loss = True # pylint: disable=protected-access return loss @@ -1535,7 +1536,8 @@ class Layer(module.Module, version_utils.LayerVersionSelector): continue if not tensor_util.is_tensor(loss) and not isinstance( loss, keras_tensor.KerasTensor): - loss = ops.convert_to_tensor_v2(loss, dtype=backend.floatx()) + loss = ops.convert_to_tensor_v2_with_dispatch( + loss, dtype=backend.floatx()) # TF Functions should take the eager path. if ((tf_utils.is_symbolic_tensor(loss) or isinstance(loss, keras_tensor.KerasTensor)) and @@ -2586,10 +2588,10 @@ class Layer(module.Module, version_utils.LayerVersionSelector): # we copy them to avoid loss of KerasHistory metadata. flat_outputs = nest.flatten(outputs) flat_inputs = nest.flatten((args, kwargs)) - inputs_set = object_identity.ObjectIdentitySet(flat_inputs) + input_ids_set = {id(i) for i in flat_inputs} outputs_copy = [] for x in flat_outputs: - if x in inputs_set: + if id(x) in input_ids_set: with backend.name_scope(self.name): x = array_ops.identity(x) outputs_copy.append(x) @@ -2985,12 +2987,13 @@ class Layer(module.Module, version_utils.LayerVersionSelector): def _dedup_weights(self, weights): """Dedupe weights while maintaining order as much as possible.""" - output, seen_weights = [], object_identity.ObjectIdentitySet() + output, seen_ids = [], set() for w in weights: - if w not in seen_weights: + if id(w) not in seen_ids: output.append(w) # Track the Variable's identity to avoid __eq__ issues. - seen_weights.add(w) + seen_ids.add(id(w)) + return output def _split_out_first_arg(self, args, kwargs): @@ -3266,7 +3269,7 @@ def _in_functional_construction_mode(layer, inputs, args, kwargs, input_list): def _convert_numpy_or_python_types(x): if isinstance(x, (np_arrays.ndarray, np.ndarray, float, int)): - return ops.convert_to_tensor_v2(x) + return ops.convert_to_tensor_v2_with_dispatch(x) return x diff --git a/tensorflow/python/keras/engine/base_layer_test.py b/tensorflow/python/keras/engine/base_layer_test.py index 022718ea549..c377e3ec1bd 100644 --- a/tensorflow/python/keras/engine/base_layer_test.py +++ b/tensorflow/python/keras/engine/base_layer_test.py @@ -1135,7 +1135,7 @@ class NameScopingTest(keras_parameterized.TestCase): self.assertEqual(sublayer.active_name_scope, 'MyName2/Sublayer') def test_name_scope_tf_tensor(self): - x = ops.convert_to_tensor_v2(np.ones((10, 10))) + x = ops.convert_to_tensor_v2_with_dispatch(np.ones((10, 10))) layer = layers.Dense( 10, activation=layers.ReLU(name='MyAct'), name='MyName3') layer(x) diff --git a/tensorflow/python/keras/engine/base_layer_v1.py b/tensorflow/python/keras/engine/base_layer_v1.py index 536efb52ad1..a6141522531 100644 --- a/tensorflow/python/keras/engine/base_layer_v1.py +++ b/tensorflow/python/keras/engine/base_layer_v1.py @@ -690,10 +690,10 @@ class Layer(base_layer.Layer): # Accept NumPy and scalar inputs by converting to Tensors. if any(isinstance(x, (np.ndarray, float, int)) for x in input_list): def _convert_non_tensor(x): - # Don't call `ops.convert_to_tensor_v2` on all `inputs` because + # Don't call `ops.convert_to_tensor` on all `inputs` because # `SparseTensors` can't be converted to `Tensor`. if isinstance(x, (np.ndarray, float, int)): - return ops.convert_to_tensor_v2(x) + return ops.convert_to_tensor_v2_with_dispatch(x) return x inputs = nest.map_structure(_convert_non_tensor, inputs) input_list = nest.flatten(inputs) @@ -1053,7 +1053,8 @@ class Layer(base_layer.Layer): if loss is None: return None # Will be filtered out when computing the .losses property if not tensor_util.is_tensor(loss): - loss = ops.convert_to_tensor_v2(loss, dtype=backend.floatx()) + loss = ops.convert_to_tensor_v2_with_dispatch( + loss, dtype=backend.floatx()) loss._unconditional_loss = (inputs is None) # pylint: disable=protected-access return loss @@ -1068,7 +1069,8 @@ class Layer(base_layer.Layer): if loss is None: continue if not tensor_util.is_tensor(loss): - loss = ops.convert_to_tensor_v2(loss, dtype=backend.floatx()) + loss = ops.convert_to_tensor_v2_with_dispatch( + loss, dtype=backend.floatx()) # TF Functions should take the eager path. if (tf_utils.is_symbolic_tensor(loss) and not base_layer_utils.is_in_tf_function()): @@ -1229,7 +1231,7 @@ class Layer(base_layer.Layer): elif hasattr(x, 'op'): update = x.op else: - update = ops.convert_to_tensor_v2(x) + update = ops.convert_to_tensor_v2_with_dispatch(x) reachable = tf_utils.get_reachable_from_inputs(relevant_inputs, [update]) update._unconditional_update = update not in reachable @@ -2378,12 +2380,13 @@ class Layer(base_layer.Layer): def _dedup_weights(self, weights): """Dedupe weights while maintaining order as much as possible.""" - output, seen_weights = [], object_identity.ObjectIdentitySet() + output, seen_ids = [], set() for w in weights: - if w not in seen_weights: + if id(w) not in seen_ids: output.append(w) # Track the Variable's identity to avoid __eq__ issues. - seen_weights.add(w) + seen_ids.add(id(w)) + return output # SavedModel properties. Please see keras/saving/saved_model for details. diff --git a/tensorflow/python/keras/engine/base_preprocessing_layer.py b/tensorflow/python/keras/engine/base_preprocessing_layer.py index f5577bf058e..a02a329c0e8 100644 --- a/tensorflow/python/keras/engine/base_preprocessing_layer.py +++ b/tensorflow/python/keras/engine/base_preprocessing_layer.py @@ -149,7 +149,7 @@ class CombinerPreprocessingLayer(PreprocessingLayer): else: accumulator = self._combiner.restore(self._restore_updates()) if isinstance(data, (list, tuple)): - data = ops.convert_to_tensor_v2(data) + data = ops.convert_to_tensor_v2_with_dispatch(data) if not isinstance(data, (dataset_ops.DatasetV2, np.ndarray, diff --git a/tensorflow/python/keras/engine/compile_utils_test.py b/tensorflow/python/keras/engine/compile_utils_test.py index 39127270539..ae92b9aeb09 100644 --- a/tensorflow/python/keras/engine/compile_utils_test.py +++ b/tensorflow/python/keras/engine/compile_utils_test.py @@ -53,7 +53,7 @@ class LossesContainerTest(keras_parameterized.TestCase): y_t = [array_ops.ones((10, 1)), array_ops.zeros((10, 1))] y_p = [array_ops.ones((10, 1)), array_ops.ones((10, 1))] - sw = ops.convert_to_tensor_v2([0, 0, 0, 0, 0, 1, 1, 1, 1, 1]) + sw = ops.convert_to_tensor_v2_with_dispatch([0, 0, 0, 0, 0, 1, 1, 1, 1, 1]) total_loss = loss_container(y_t, y_p, sample_weight=sw) @@ -86,7 +86,7 @@ class LossesContainerTest(keras_parameterized.TestCase): y_t = {'out1': array_ops.ones((10, 1)), 'out2': array_ops.zeros((10, 1))} y_p = {'out1': array_ops.ones((10, 1)), 'out2': array_ops.ones((10, 1))} - sw = ops.convert_to_tensor_v2([0, 0, 0, 0, 0, 1, 1, 1, 1, 1]) + sw = ops.convert_to_tensor_v2_with_dispatch([0, 0, 0, 0, 0, 1, 1, 1, 1, 1]) total_loss = loss_container(y_t, y_p, sample_weight=sw) @@ -112,7 +112,7 @@ class LossesContainerTest(keras_parameterized.TestCase): y_t = [array_ops.ones((10, 1)), array_ops.zeros((10, 1))] y_p = [array_ops.ones((10, 1)), array_ops.ones((10, 1))] - sw = ops.convert_to_tensor_v2([0, 0, 0, 0, 0, 1, 1, 1, 1, 1]) + sw = ops.convert_to_tensor_v2_with_dispatch([0, 0, 0, 0, 0, 1, 1, 1, 1, 1]) total_loss = loss_container(y_t, y_p, sample_weight=sw) @@ -135,7 +135,7 @@ class LossesContainerTest(keras_parameterized.TestCase): y_t = {'out1': array_ops.ones((10, 1)), 'out2': array_ops.zeros((10, 1))} y_p = {'out1': array_ops.ones((10, 1)), 'out2': array_ops.ones((10, 1))} - sw = ops.convert_to_tensor_v2([0, 0, 0, 0, 0, 1, 1, 1, 1, 1]) + sw = ops.convert_to_tensor_v2_with_dispatch([0, 0, 0, 0, 0, 1, 1, 1, 1, 1]) total_loss = loss_container(y_t, y_p, sample_weight=sw) @@ -170,7 +170,7 @@ class LossesContainerTest(keras_parameterized.TestCase): array_ops.zeros((10, 1))], 'a': array_ops.ones((10, 1)) } - sw = ops.convert_to_tensor_v2([0, 0, 0, 0, 0, 1, 1, 1, 1, 1]) + sw = ops.convert_to_tensor_v2_with_dispatch([0, 0, 0, 0, 0, 1, 1, 1, 1, 1]) total_loss = loss_container(y_t, y_p, sample_weight=sw) self.assertEqual(total_loss.numpy(), 0.75) @@ -193,7 +193,7 @@ class LossesContainerTest(keras_parameterized.TestCase): y_t = [array_ops.ones((10, 1)), array_ops.zeros((10, 1))] y_p = [array_ops.ones((10, 1)), array_ops.ones((10, 1))] - sw = ops.convert_to_tensor_v2([0, 0, 0, 0, 0, 1, 1, 1, 1, 1]) + sw = ops.convert_to_tensor_v2_with_dispatch([0, 0, 0, 0, 0, 1, 1, 1, 1, 1]) total_loss = loss_container(y_t, y_p, sample_weight=sw) self.assertEqual(total_loss.numpy(), 0.5) @@ -220,13 +220,13 @@ class LossesContainerTest(keras_parameterized.TestCase): }) y_p = { - 'output1': ops.convert_to_tensor([[0], [1], [2]]), - 'output2': ops.convert_to_tensor([[3], [4], [5]]), - 'output3': ops.convert_to_tensor([[6], [7], [8]]) + 'output1': ops.convert_to_tensor_v2_with_dispatch([[0], [1], [2]]), + 'output2': ops.convert_to_tensor_v2_with_dispatch([[3], [4], [5]]), + 'output3': ops.convert_to_tensor_v2_with_dispatch([[6], [7], [8]]) } y_t = { - 'output1': ops.convert_to_tensor([[1], [2], [3]]), - 'output3': ops.convert_to_tensor([[4], [5], [6]]) + 'output1': ops.convert_to_tensor_v2_with_dispatch([[1], [2], [3]]), + 'output3': ops.convert_to_tensor_v2_with_dispatch([[4], [5], [6]]) } total_loss = loss_container(y_t, y_p) @@ -372,7 +372,7 @@ class MetricsContainerTest(keras_parameterized.TestCase): y_t = [array_ops.ones((10, 1)), array_ops.zeros((10, 1))] y_p = [array_ops.ones((10, 1)), 2 * array_ops.ones((10, 1))] - sw = ops.convert_to_tensor_v2([0, 0, 0, 0, 0, 1, 1, 1, 1, 1]) + sw = ops.convert_to_tensor_v2_with_dispatch([0, 0, 0, 0, 0, 1, 1, 1, 1, 1]) metric_container.update_state(y_t, y_p, sample_weight=sw) self.assertLen(metric_container.metrics, 6) @@ -415,7 +415,7 @@ class MetricsContainerTest(keras_parameterized.TestCase): y_t = {'out1': array_ops.ones((10, 1)), 'out2': array_ops.zeros((10, 1))} y_p = {'out1': array_ops.ones((10, 1)), 'out2': 2 * array_ops.ones((10, 1))} - sw = ops.convert_to_tensor_v2([0, 0, 0, 0, 0, 1, 1, 1, 1, 1]) + sw = ops.convert_to_tensor_v2_with_dispatch([0, 0, 0, 0, 0, 1, 1, 1, 1, 1]) metric_container.update_state(y_t, y_p, sample_weight=sw) mse_metric = metric_container.metrics[0] @@ -440,7 +440,7 @@ class MetricsContainerTest(keras_parameterized.TestCase): y_t = [array_ops.ones((10, 1)), array_ops.zeros((10, 1))] y_p = [array_ops.ones((10, 1)), array_ops.ones((10, 1))] - sw = ops.convert_to_tensor_v2([0, 0, 0, 0, 0, 1, 1, 1, 1, 1]) + sw = ops.convert_to_tensor_v2_with_dispatch([0, 0, 0, 0, 0, 1, 1, 1, 1, 1]) metric_container.update_state(y_t, y_p, sample_weight=sw) self.assertLen(metric_container.metrics, 1) @@ -457,7 +457,7 @@ class MetricsContainerTest(keras_parameterized.TestCase): y_t = {'out1': array_ops.ones((10, 1)), 'out2': array_ops.zeros((10, 1))} y_p = {'out1': array_ops.ones((10, 1)), 'out2': array_ops.ones((10, 1))} - sw = ops.convert_to_tensor_v2([0, 0, 0, 0, 0, 1, 1, 1, 1, 1]) + sw = ops.convert_to_tensor_v2_with_dispatch([0, 0, 0, 0, 0, 1, 1, 1, 1, 1]) metric_container.update_state(y_t, y_p, sample_weight=sw) self.assertLen(metric_container.metrics, 1) @@ -487,7 +487,7 @@ class MetricsContainerTest(keras_parameterized.TestCase): array_ops.zeros((10, 1))], 'a': array_ops.ones((10, 1)) } - sw = ops.convert_to_tensor_v2([0, 0, 0, 0, 0, 1, 1, 1, 1, 1]) + sw = ops.convert_to_tensor_v2_with_dispatch([0, 0, 0, 0, 0, 1, 1, 1, 1, 1]) metric_container.update_state(y_t, y_p, sample_weight=sw) self.assertLen(metric_container.metrics, 3) @@ -548,9 +548,9 @@ class MetricsContainerTest(keras_parameterized.TestCase): metric_container = compile_utils.MetricsContainer( metrics=['mae'], weighted_metrics=['mae']) - y_t = ops.convert_to_tensor_v2([[0], [3], [0]]) - y_p = ops.convert_to_tensor_v2([[0], [0], [0]]) - sw = ops.convert_to_tensor_v2([[1], [0], [1]]) + y_t = ops.convert_to_tensor_v2_with_dispatch([[0], [3], [0]]) + y_p = ops.convert_to_tensor_v2_with_dispatch([[0], [0], [0]]) + sw = ops.convert_to_tensor_v2_with_dispatch([[1], [0], [1]]) metric_container.update_state(y_t, y_p, sample_weight=sw) self.assertLen(metric_container.metrics, 2) @@ -566,8 +566,8 @@ class MetricsContainerTest(keras_parameterized.TestCase): def test_broadcast_metrics_to_dict(self): metric_container = compile_utils.MetricsContainer(metrics=['mae']) - y_p = {'output': ops.convert_to_tensor([[0], [1], [2]])} - y_t = {'output': ops.convert_to_tensor([[1], [2], [3]])} + y_p = {'output': ops.convert_to_tensor_v2_with_dispatch([[0], [1], [2]])} + y_t = {'output': ops.convert_to_tensor_v2_with_dispatch([[1], [2], [3]])} metric_container.update_state(y_t, y_p) mae_metric = metric_container.metrics[0] @@ -578,8 +578,8 @@ class MetricsContainerTest(keras_parameterized.TestCase): metric_container = compile_utils.MetricsContainer( metrics=['mae'], output_names=['output']) - y_p = ops.convert_to_tensor([[0], [1], [2]]) - y_t = {'output': ops.convert_to_tensor([[1], [2], [3]])} + y_p = ops.convert_to_tensor_v2_with_dispatch([[0], [1], [2]]) + y_t = {'output': ops.convert_to_tensor_v2_with_dispatch([[1], [2], [3]])} metric_container.update_state(y_t, y_p) mae_metric = metric_container.metrics[0] @@ -595,13 +595,13 @@ class MetricsContainerTest(keras_parameterized.TestCase): }) y_p = { - 'output1': ops.convert_to_tensor([[0], [1], [2]]), - 'output2': ops.convert_to_tensor([[3], [4], [5]]), - 'output3': ops.convert_to_tensor([[6], [7], [8]]) + 'output1': ops.convert_to_tensor_v2_with_dispatch([[0], [1], [2]]), + 'output2': ops.convert_to_tensor_v2_with_dispatch([[3], [4], [5]]), + 'output3': ops.convert_to_tensor_v2_with_dispatch([[6], [7], [8]]) } y_t = { - 'output1': ops.convert_to_tensor([[1], [2], [3]]), - 'output3': ops.convert_to_tensor([[4], [5], [6]]) + 'output1': ops.convert_to_tensor_v2_with_dispatch([[1], [2], [3]]), + 'output3': ops.convert_to_tensor_v2_with_dispatch([[4], [5], [6]]) } metric_container.update_state(y_t, y_p) diff --git a/tensorflow/python/keras/engine/data_adapter.py b/tensorflow/python/keras/engine/data_adapter.py index e9662da73e7..0df15f368fa 100644 --- a/tensorflow/python/keras/engine/data_adapter.py +++ b/tensorflow/python/keras/engine/data_adapter.py @@ -1006,7 +1006,7 @@ def _process_tensorlike(inputs): dtype = None if issubclass(x.dtype.type, np.floating): dtype = backend.floatx() - return ops.convert_to_tensor(x, dtype=dtype) + return ops.convert_to_tensor_v2_with_dispatch(x, dtype=dtype) elif scipy_sparse and scipy_sparse.issparse(x): return _scipy_sparse_to_sparse_tensor(x) return x @@ -1281,7 +1281,7 @@ def _make_class_weight_map_fn(class_weight): "than the number of classes, found {}").format(class_weight) raise ValueError(error_msg) - class_weight_tensor = ops.convert_to_tensor_v2( + class_weight_tensor = ops.convert_to_tensor_v2_with_dispatch( [class_weight[int(c)] for c in class_ids]) def _class_weights_map_fn(*data): diff --git a/tensorflow/python/keras/engine/data_adapter_test.py b/tensorflow/python/keras/engine/data_adapter_test.py index fad193009cf..b17410c6d25 100644 --- a/tensorflow/python/keras/engine/data_adapter_test.py +++ b/tensorflow/python/keras/engine/data_adapter_test.py @@ -446,7 +446,7 @@ class GenericArrayLikeDataAdapterTest(DataAdapterTestBase): def test_training(self): # First verify that DummyArrayLike can't be converted to a Tensor with self.assertRaises(TypeError): - ops.convert_to_tensor_v2(self.arraylike_input) + ops.convert_to_tensor_v2_with_dispatch(self.arraylike_input) # Then train on the array like. # It should not be converted to a tensor directly (which would force it into @@ -914,7 +914,7 @@ class DataHandlerTest(keras_parameterized.TestCase): def generator(): for _ in range(2): for step in range(3): - yield (ops.convert_to_tensor_v2([step]),) + yield (ops.convert_to_tensor_v2_with_dispatch([step]),) data_handler = data_adapter.DataHandler( generator(), epochs=2, steps_per_epoch=3) @@ -1007,20 +1007,20 @@ class TestValidationSplit(keras_parameterized.TestCase): y = np.array([0, 2, 4, 6, 8]) sw = np.array([0, 4, 8, 12, 16]) else: - x = ops.convert_to_tensor_v2([0, 1, 2, 3, 4]) - y = ops.convert_to_tensor_v2([0, 2, 4, 6, 8]) - sw = ops.convert_to_tensor_v2([0, 4, 8, 12, 16]) + x = ops.convert_to_tensor_v2_with_dispatch([0, 1, 2, 3, 4]) + y = ops.convert_to_tensor_v2_with_dispatch([0, 2, 4, 6, 8]) + sw = ops.convert_to_tensor_v2_with_dispatch([0, 4, 8, 12, 16]) (train_x, train_y, train_sw), (val_x, val_y, val_sw) = ( data_adapter.train_validation_split((x, y, sw), validation_split=0.2)) if use_numpy: - train_x = ops.convert_to_tensor_v2(train_x) - train_y = ops.convert_to_tensor_v2(train_y) - train_sw = ops.convert_to_tensor_v2(train_sw) - val_x = ops.convert_to_tensor_v2(val_x) - val_y = ops.convert_to_tensor_v2(val_y) - val_sw = ops.convert_to_tensor_v2(val_sw) + train_x = ops.convert_to_tensor_v2_with_dispatch(train_x) + train_y = ops.convert_to_tensor_v2_with_dispatch(train_y) + train_sw = ops.convert_to_tensor_v2_with_dispatch(train_sw) + val_x = ops.convert_to_tensor_v2_with_dispatch(val_x) + val_y = ops.convert_to_tensor_v2_with_dispatch(val_y) + val_sw = ops.convert_to_tensor_v2_with_dispatch(val_sw) self.assertEqual(train_x.numpy().tolist(), [0, 1, 2, 3]) self.assertEqual(train_y.numpy().tolist(), [0, 2, 4, 6]) diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py index 9cb35ff1e88..52d73ada157 100644 --- a/tensorflow/python/keras/engine/training.py +++ b/tensorflow/python/keras/engine/training.py @@ -445,6 +445,7 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector): loss_weights=None, weighted_metrics=None, run_eagerly=None, + steps_per_execution=None, **kwargs): """Configures the model for training. @@ -496,17 +497,18 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector): logic will not be wrapped in a `tf.function`. Recommended to leave this as `None` unless your `Model` cannot be run inside a `tf.function`. - **kwargs: Any additional arguments. Supported arguments: - - `experimental_steps_per_execution`: Int. The number of batches to - run during each `tf.function` call. Running multiple batches - inside a single `tf.function` call can greatly improve performance - on TPUs or small models with a large Python overhead. Note that if - this value is set to `N`, `Callback.on_batch` methods will only be - called every `N` batches. This currently defaults to `1`. At most, - one full epoch will be run each execution. If a number larger than - the size of the epoch is passed, the execution will be truncated - to the size of the epoch. - - `sample_weight_mode` for backward compatibility. + steps_per_execution: Int. Defaults to 1. The number of batches to + run during each `tf.function` call. Running multiple batches + inside a single `tf.function` call can greatly improve performance + on TPUs or small models with a large Python overhead. + At most, one full epoch will be run each + execution. If a number larger than the size of the epoch is passed, + the execution will be truncated to the size of the epoch. + Note that if `steps_per_execution` is set to `N`, + `Callback.on_batch_begin` and `Callback.on_batch_end` methods + will only be called every `N` batches + (i.e. before/after each `tf.function` execution). + **kwargs: Arguments supported for backwards compatibility only. Raises: ValueError: In case of invalid arguments for @@ -514,6 +516,13 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector): """ base_layer.keras_api_gauge.get_cell('compile').set(True) with self.distribute_strategy.scope(): + if 'experimental_steps_per_execution' in kwargs: + logging.warn('The argument `steps_per_execution` is no longer ' + 'experimental. Pass `steps_per_execution` instead of ' + '`experimental_steps_per_execution`.') + if not steps_per_execution: + steps_per_execution = kwargs.pop('experimental_steps_per_execution') + self._validate_compile(optimizer, metrics, **kwargs) self._run_eagerly = run_eagerly @@ -523,9 +532,7 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector): self.compiled_metrics = compile_utils.MetricsContainer( metrics, weighted_metrics, output_names=self.output_names) - experimental_steps_per_execution = kwargs.pop( - 'experimental_steps_per_execution', 1) - self._configure_steps_per_execution(experimental_steps_per_execution) + self._configure_steps_per_execution(steps_per_execution or 1) # Initializes attrs that are reset each time `compile` is called. self._reset_compile_cache() @@ -2460,9 +2467,7 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector): if kwargs.pop('target_tensors', None) is not None: raise ValueError( 'target_tensors argument is not supported when executing eagerly.') - invalid_kwargs = set(kwargs) - { - 'experimental_steps_per_execution', 'sample_weight_mode' - } + invalid_kwargs = set(kwargs) - {'sample_weight_mode'} if invalid_kwargs: raise TypeError('Invalid keyword argument(s) in `compile`: %s' % (invalid_kwargs,)) diff --git a/tensorflow/python/keras/engine/training_eager.py b/tensorflow/python/keras/engine/training_eager.py index b3ce3d13ed7..09e6f0d1edd 100644 --- a/tensorflow/python/keras/engine/training_eager.py +++ b/tensorflow/python/keras/engine/training_eager.py @@ -121,7 +121,7 @@ def _model_loss(model, if any( isinstance(input_t, (np.ndarray, float, int)) for input_t in nest.flatten(inputs)): - inputs = nest.map_structure(ops.convert_to_tensor_v2, inputs) + inputs = nest.map_structure(ops.convert_to_tensor_v2_with_dispatch, inputs) outs = model(inputs, **kwargs) outs = nest.flatten(outs) @@ -131,7 +131,8 @@ def _model_loss(model, # TODO(sallymatson/psv): check if we should do same mismatch fix for weights if sample_weights: sample_weights = [ - training_utils.cast_if_floating_dtype(ops.convert_to_tensor_v2(val)) + training_utils.cast_if_floating_dtype( + ops.convert_to_tensor_v2_with_dispatch(val)) if val is not None else None for val in sample_weights ] diff --git a/tensorflow/python/keras/engine/training_utils.py b/tensorflow/python/keras/engine/training_utils.py index 84bcd99922f..7a8c1c16eaa 100644 --- a/tensorflow/python/keras/engine/training_utils.py +++ b/tensorflow/python/keras/engine/training_utils.py @@ -1009,7 +1009,7 @@ def standardize_weights(y, class_sample_weight = math_ops.cast(class_sample_weight, K.floatx()) if sample_weight is not None: sample_weight = math_ops.cast( - ops.convert_to_tensor_v2(sample_weight), K.floatx()) + ops.convert_to_tensor_v2_with_dispatch(sample_weight), K.floatx()) else: y_classes = y if len(y.shape) == 2: @@ -1365,7 +1365,7 @@ def check_steps_argument(input_data, steps, steps_name): def cast_single_tensor(x, dtype=None): if isinstance(x, np.ndarray): - x = ops.convert_to_tensor_v2(x) + x = ops.convert_to_tensor_v2_with_dispatch(x) dtype = dtype or K.floatx() if x.dtype.is_floating: return math_ops.cast(x, dtype=dtype) @@ -1391,7 +1391,7 @@ def cast_if_floating_dtype_and_mismatch(targets, outputs): new_targets = [] for target, out in zip(targets, outputs): if isinstance(target, np.ndarray): - target = ops.convert_to_tensor_v2(target) + target = ops.convert_to_tensor_v2_with_dispatch(target) if target.dtype != out.dtype: new_targets.append(cast_single_tensor(target, dtype=out.dtype)) else: diff --git a/tensorflow/python/keras/layers/BUILD b/tensorflow/python/keras/layers/BUILD index 6458d097f62..b5b876b79bf 100644 --- a/tensorflow/python/keras/layers/BUILD +++ b/tensorflow/python/keras/layers/BUILD @@ -814,7 +814,6 @@ cuda_py_test( python_version = "PY3", shard_count = 12, tags = [ - "no_cuda11", "no_oss", ], xla_enable_strict_auto_jit = False, diff --git a/tensorflow/python/keras/layers/convolutional_transpose_test.py b/tensorflow/python/keras/layers/convolutional_transpose_test.py index dd73d22d51b..4326044458e 100644 --- a/tensorflow/python/keras/layers/convolutional_transpose_test.py +++ b/tensorflow/python/keras/layers/convolutional_transpose_test.py @@ -207,3 +207,6 @@ class Conv3DTransposeTest(keras_parameterized.TestCase): }, input_shape=(None, 3, None, None, None), input_data=input_data) + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/python/keras/layers/core.py b/tensorflow/python/keras/layers/core.py index 36ac087ef64..1ceedad4791 100644 --- a/tensorflow/python/keras/layers/core.py +++ b/tensorflow/python/keras/layers/core.py @@ -201,7 +201,7 @@ class Dropout(Layer): noise_shape = [] for i, value in enumerate(self.noise_shape): noise_shape.append(concrete_inputs_shape[i] if value is None else value) - return ops.convert_to_tensor_v2(noise_shape) + return ops.convert_to_tensor_v2_with_dispatch(noise_shape) def call(self, inputs, training=None): if training is None: diff --git a/tensorflow/python/keras/layers/core_test.py b/tensorflow/python/keras/layers/core_test.py index f6509814249..b7a11d32c71 100644 --- a/tensorflow/python/keras/layers/core_test.py +++ b/tensorflow/python/keras/layers/core_test.py @@ -504,14 +504,14 @@ class CoreLayersTest(keras_parameterized.TestCase): keras.layers.Dense, kwargs={'units': 3}, input_shape=(3, 4, 5, 2)) def test_dense_dtype(self): - inputs = ops.convert_to_tensor_v2( + inputs = ops.convert_to_tensor_v2_with_dispatch( np.random.randint(low=0, high=7, size=(2, 2))) layer = keras.layers.Dense(5, dtype='float32') outputs = layer(inputs) self.assertEqual(outputs.dtype, 'float32') def test_dense_with_policy(self): - inputs = ops.convert_to_tensor_v2( + inputs = ops.convert_to_tensor_v2_with_dispatch( np.random.randint(low=0, high=7, size=(2, 2))) layer = keras.layers.Dense(5, dtype=policy.Policy('mixed_float16')) outputs = layer(inputs) diff --git a/tensorflow/python/keras/layers/dense_attention.py b/tensorflow/python/keras/layers/dense_attention.py index cd277a1a6a9..ab2912505ef 100644 --- a/tensorflow/python/keras/layers/dense_attention.py +++ b/tensorflow/python/keras/layers/dense_attention.py @@ -180,7 +180,7 @@ class BaseDenseAttention(Layer): q_mask = mask[0] if q_mask is None: return None - return ops.convert_to_tensor_v2(q_mask) + return ops.convert_to_tensor_v2_with_dispatch(q_mask) return None def _validate_call_args(self, inputs, mask): diff --git a/tensorflow/python/keras/layers/kernelized.py b/tensorflow/python/keras/layers/kernelized.py index eac985e63bf..c8a6a65d68c 100644 --- a/tensorflow/python/keras/layers/kernelized.py +++ b/tensorflow/python/keras/layers/kernelized.py @@ -218,7 +218,7 @@ class RandomFourierFeatures(base_layer.Layer): super(RandomFourierFeatures, self).build(input_shape) def call(self, inputs): - inputs = ops.convert_to_tensor_v2(inputs, dtype=self.dtype) + inputs = ops.convert_to_tensor_v2_with_dispatch(inputs, dtype=self.dtype) inputs = gen_math_ops.cast(inputs, dtypes.float32) kernel = (1.0 / self.kernel_scale) * self.unscaled_kernel outputs = gen_math_ops.mat_mul(inputs, kernel) diff --git a/tensorflow/python/keras/layers/legacy_rnn/rnn_cell_impl.py b/tensorflow/python/keras/layers/legacy_rnn/rnn_cell_impl.py index 1e33edd497c..416a26dbb59 100644 --- a/tensorflow/python/keras/layers/legacy_rnn/rnn_cell_impl.py +++ b/tensorflow/python/keras/layers/legacy_rnn/rnn_cell_impl.py @@ -282,7 +282,7 @@ class RNNCell(base_layer.Layer): def get_initial_state(self, inputs=None, batch_size=None, dtype=None): if inputs is not None: # Validate the given batch_size and dtype against inputs if provided. - inputs = ops.convert_to_tensor(inputs, name="inputs") + inputs = ops.convert_to_tensor_v2_with_dispatch(inputs, name="inputs") if batch_size is not None: if tensor_util.is_tensor(batch_size): static_batch_size = tensor_util.constant_value( diff --git a/tensorflow/python/keras/layers/legacy_rnn/rnn_cell_wrapper_impl.py b/tensorflow/python/keras/layers/legacy_rnn/rnn_cell_wrapper_impl.py index 2e3923918a0..9618bc75545 100644 --- a/tensorflow/python/keras/layers/legacy_rnn/rnn_cell_wrapper_impl.py +++ b/tensorflow/python/keras/layers/legacy_rnn/rnn_cell_wrapper_impl.py @@ -116,7 +116,7 @@ class DropoutWrapperBase(object): with ops.name_scope_v2("DropoutWrapperInit"): def tensor_and_const_value(v): - tensor_value = ops.convert_to_tensor(v) + tensor_value = ops.convert_to_tensor_v2_with_dispatch(v) const_value = tensor_util.constant_value(tensor_value) return (tensor_value, const_value) diff --git a/tensorflow/python/keras/layers/normalization.py b/tensorflow/python/keras/layers/normalization.py index d9bac2c2e92..92178bc8fb1 100644 --- a/tensorflow/python/keras/layers/normalization.py +++ b/tensorflow/python/keras/layers/normalization.py @@ -41,26 +41,43 @@ from tensorflow.python.util.tf_export import keras_export class BatchNormalizationBase(Layer): - r"""Normalize and scale inputs or activations. + r"""Layer that normalizes its inputs. - Normalize the activations of the previous layer at each batch, - i.e. applies a transformation that maintains the mean activation - close to 0 and the activation standard deviation close to 1. + Batch normalization applies a transformation that maintains the mean output + close to 0 and the output standard deviation close to 1. - Batch normalization differs from other layers in several key aspects: + Importantly, batch normalization works differently during training and + during inference. - 1) Adding BatchNormalization with `training=True` to a model causes the - result of one example to depend on the contents of all other examples in a - minibatch. Be careful when padding batches or masking examples, as these can - change the minibatch statistics and affect other examples. + **During training** (i.e. when using `fit()` or when calling the layer/model + with the argument `training=True`), the layer normalizes its output using + the mean and standard deviation of the current batch of inputs. That is to + say, for each channel being normalized, the layer returns + `(batch - mean(batch)) / (var(batch) + epsilon) * gamma + beta`, where: - 2) Updates to the weights (moving statistics) are based on the forward pass - of a model rather than the result of gradient computations. + - `epsilon` is small constant (configurable as part of the constructor + arguments) + - `gamma` is a learned scaling factor (initialized as 1), which + can be disabled by passing `scale=False` to the constructor. + - `beta` is a learned offset factor (initialized as 0), which + can be disabled by passing `center=False` to the constructor. - 3) When performing inference using a model containing batch normalization, it - is generally (though not always) desirable to use accumulated statistics - rather than mini-batch statistics. This is accomplished by passing - `training=False` when calling the model, or using `model.predict`. + **During inference** (i.e. when using `evaluate()` or `predict()` or when + calling the layer/model with the argument `training=False` (which is the + default), the layer normalizes its output using a moving average of the + mean and standard deviation of the batches it has seen during training. That + is to say, it returns + `(batch - self.moving_mean) / (self.moving_var + epsilon) * gamma + beta`. + + `self.moving_mean` and `self.moving_var` are non-trainable variables that + are updated each time the layer in called in training mode, as such: + + - `moving_mean = moving_mean * momentum + mean(batch) * (1 - momentum)` + - `moving_var = moving_var * momentum + var(batch) * (1 - momentum)` + + As such, the layer will only normalize its inputs during inference + *after having been trained on data that has similar statistics as the + inference data*. Arguments: axis: Integer, the axis that should be normalized (typically the features @@ -117,6 +134,7 @@ class BatchNormalizationBase(Layer): across all examples), and finally apply gamma and/or beta. If `None`, no adjustment is applied. Cannot be specified if virtual_batch_size is specified. + Call arguments: inputs: Input tensor (of any rank). training: Python boolean indicating whether the layer should behave in @@ -125,21 +143,13 @@ class BatchNormalizationBase(Layer): variance of the current batch of inputs. - `training=False`: The layer will normalize its inputs using the mean and variance of its moving statistics, learned during training. + Input shape: Arbitrary. Use the keyword argument `input_shape` (tuple of integers, does not include the samples axis) when using this layer as the first layer in a model. + Output shape: Same shape as input. {{TRAINABLE_ATTRIBUTE_NOTE}} - Normalization equations: Consider the intermediate activations \(x\) of a - mini-batch of size - \\(m\\): We can compute the mean and variance of the batch \\({\mu_B} = - \frac{1}{m} \sum_{i=1}^{m} {x_i}\\) \\({\sigma_B^2} = \frac{1}{m} - \sum_{i=1}^{m} ({x_i} - {\mu_B})^2\\) and then compute a normalized - \\(x\\), including a small factor \\({\epsilon}\\) for numerical - stability. \\(\hat{x_i} = \frac{x_i - \mu_B}{\sqrt{\sigma_B^2 + - \epsilon}}\\) And finally \\(\hat{x}\) is linearly transformed by - \({\gamma}\\) - and \\({\beta}\\), which are learned parameters: \\({y_i} = {\gamma * - \hat{x_i} + \beta}\\) + Reference: - [Ioffe and Szegedy, 2015](https://arxiv.org/abs/1502.03167). """ @@ -480,7 +490,8 @@ class BatchNormalizationBase(Layer): def _assign_moving_average(self, variable, value, momentum, inputs_size): with K.name_scope('AssignMovingAvg') as scope: with ops.colocate_with(variable): - decay = ops.convert_to_tensor_v2(1.0 - momentum, name='decay') + decay = ops.convert_to_tensor_v2_with_dispatch( + 1.0 - momentum, name='decay') if decay.dtype != variable.dtype.base_dtype: decay = math_ops.cast(decay, variable.dtype.base_dtype) update_delta = (variable - math_ops.cast(value, variable.dtype)) * decay @@ -585,7 +596,7 @@ class BatchNormalizationBase(Layer): lambda: self.momentum, lambda: 1.0) else: - momentum = ops.convert_to_tensor_v2(self.momentum) + momentum = ops.convert_to_tensor_v2_with_dispatch(self.momentum) def mean_update(): """Update self.moving_mean with the most recent data point.""" @@ -787,10 +798,11 @@ class BatchNormalizationBase(Layer): moving_variance = self.moving_variance mean = control_flow_util.smart_cond( - training, lambda: mean, lambda: ops.convert_to_tensor_v2(moving_mean)) + training, lambda: mean, + lambda: ops.convert_to_tensor_v2_with_dispatch(moving_mean)) variance = control_flow_util.smart_cond( training, lambda: variance, - lambda: ops.convert_to_tensor_v2(moving_variance)) + lambda: ops.convert_to_tensor_v2_with_dispatch(moving_variance)) if self.virtual_batch_size is not None: # This isn't strictly correct since in ghost batch norm, you are diff --git a/tensorflow/python/keras/layers/preprocessing/category_crossing.py b/tensorflow/python/keras/layers/preprocessing/category_crossing.py index bdb29d21c4e..747a105afdd 100644 --- a/tensorflow/python/keras/layers/preprocessing/category_crossing.py +++ b/tensorflow/python/keras/layers/preprocessing/category_crossing.py @@ -143,7 +143,7 @@ class CategoryCrossing(base_preprocessing_layer.PreprocessingLayer): def _preprocess_input(self, inp): if isinstance(inp, (list, tuple, np.ndarray)): - inp = ops.convert_to_tensor(inp) + inp = ops.convert_to_tensor_v2_with_dispatch(inp) if inp.shape.rank == 1: inp = array_ops.expand_dims(inp, axis=-1) return inp diff --git a/tensorflow/python/keras/layers/preprocessing/category_encoding.py b/tensorflow/python/keras/layers/preprocessing/category_encoding.py index 95540176e04..87112fa3d04 100644 --- a/tensorflow/python/keras/layers/preprocessing/category_encoding.py +++ b/tensorflow/python/keras/layers/preprocessing/category_encoding.py @@ -269,7 +269,7 @@ class CategoryEncoding(base_preprocessing_layer.CombinerPreprocessingLayer): def call(self, inputs, count_weights=None): if isinstance(inputs, (list, np.ndarray)): - inputs = ops.convert_to_tensor_v2(inputs) + inputs = ops.convert_to_tensor_v2_with_dispatch(inputs) if inputs.shape.rank == 1: inputs = array_ops.expand_dims(inputs, 1) diff --git a/tensorflow/python/keras/layers/preprocessing/hashing.py b/tensorflow/python/keras/layers/preprocessing/hashing.py index a6de075535c..ea8d6f0fd95 100644 --- a/tensorflow/python/keras/layers/preprocessing/hashing.py +++ b/tensorflow/python/keras/layers/preprocessing/hashing.py @@ -154,7 +154,7 @@ class Hashing(base_preprocessing_layer.PreprocessingLayer): def _preprocess_single_input(self, inp): if isinstance(inp, (list, tuple, np.ndarray)): - inp = ops.convert_to_tensor(inp) + inp = ops.convert_to_tensor_v2_with_dispatch(inp) return inp def _preprocess_inputs(self, inputs): diff --git a/tensorflow/python/keras/layers/preprocessing/image_preprocessing.py b/tensorflow/python/keras/layers/preprocessing/image_preprocessing.py index 87a18db31f3..6e0098cbbe5 100644 --- a/tensorflow/python/keras/layers/preprocessing/image_preprocessing.py +++ b/tensorflow/python/keras/layers/preprocessing/image_preprocessing.py @@ -681,7 +681,7 @@ def transform(images, if output_shape_value is not None: output_shape = output_shape_value - output_shape = ops.convert_to_tensor_v2( + output_shape = ops.convert_to_tensor_v2_with_dispatch( output_shape, dtypes.int32, name='output_shape') if not output_shape.get_shape().is_compatible_with([2]): diff --git a/tensorflow/python/keras/layers/preprocessing/normalization.py b/tensorflow/python/keras/layers/preprocessing/normalization.py index 4b75def0247..b8cf233d780 100644 --- a/tensorflow/python/keras/layers/preprocessing/normalization.py +++ b/tensorflow/python/keras/layers/preprocessing/normalization.py @@ -145,7 +145,7 @@ class Normalization(base_preprocessing_layer.CombinerPreprocessingLayer): super(Normalization, self).build(input_shape) def call(self, inputs): - inputs = ops.convert_to_tensor_v2(inputs) + inputs = ops.convert_to_tensor_v2_with_dispatch(inputs) if inputs.shape.rank == 1: inputs = array_ops.expand_dims(inputs, 1) # If the inputs are not floats, cast them to floats. This avoids issues diff --git a/tensorflow/python/keras/layers/preprocessing/table_utils.py b/tensorflow/python/keras/layers/preprocessing/table_utils.py index 3329f32b4fe..c72b8252480 100644 --- a/tensorflow/python/keras/layers/preprocessing/table_utils.py +++ b/tensorflow/python/keras/layers/preprocessing/table_utils.py @@ -62,8 +62,10 @@ class TableHandler(object): raise RuntimeError("Size mismatch between values and key arrays. " "Keys had size %s, values had size %s." % (len(keys), len(values))) - keys = ops.convert_to_tensor(keys, dtype=self.table._key_dtype) # pylint: disable=protected-access - values = ops.convert_to_tensor(values, dtype=self.table._value_dtype) # pylint: disable=protected-access + keys = ops.convert_to_tensor_v2_with_dispatch( + keys, dtype=self.table._key_dtype) # pylint: disable=protected-access + values = ops.convert_to_tensor_v2_with_dispatch( + values, dtype=self.table._value_dtype) # pylint: disable=protected-access if values.shape.ndims != 1: raise ValueError("`values` must be 1-dimensional, got an input with " " %s dimensions." % values.shape.ndims) diff --git a/tensorflow/python/keras/layers/preprocessing/text_vectorization.py b/tensorflow/python/keras/layers/preprocessing/text_vectorization.py index 2cc8bc2b340..36e326bdc5c 100644 --- a/tensorflow/python/keras/layers/preprocessing/text_vectorization.py +++ b/tensorflow/python/keras/layers/preprocessing/text_vectorization.py @@ -367,7 +367,7 @@ class TextVectorization(base_preprocessing_layer.CombinerPreprocessingLayer): # on an implicit call to `build` in the base layer's `adapt`, since # preprocessing changes the input shape. if isinstance(data, (list, tuple, np.ndarray)): - data = ops.convert_to_tensor(data) + data = ops.convert_to_tensor_v2_with_dispatch(data) if isinstance(data, ops.Tensor): if data.shape.rank == 1: @@ -566,7 +566,7 @@ class TextVectorization(base_preprocessing_layer.CombinerPreprocessingLayer): def call(self, inputs): if isinstance(inputs, (list, tuple, np.ndarray)): - inputs = ops.convert_to_tensor(inputs) + inputs = ops.convert_to_tensor_v2_with_dispatch(inputs) self._called = True inputs = self._preprocess(inputs) diff --git a/tensorflow/python/keras/layers/recurrent_v2.py b/tensorflow/python/keras/layers/recurrent_v2.py index 9794189cf09..c7a9a87699c 100644 --- a/tensorflow/python/keras/layers/recurrent_v2.py +++ b/tensorflow/python/keras/layers/recurrent_v2.py @@ -387,9 +387,9 @@ class GRU(recurrent.DropoutRNNCellMixin, recurrent.GRU): else: logging.warn(_CUDNN_NOT_AVAILABLE_MSG % self.name) - # TODO(b/162616551): Remove all compat statements after 08/20/2020. + # TODO(b/162616551): Remove all compat statements after 9/2/2020. # This follows b/161915509 and is mainly to test the stateless Case op. - if compat.forward_compatible(2020, 8, 27): + if compat.forward_compatible(2020, 9, 2): # The first two attributes are added to support TFLite use case. supportive_attributes = { 'time_major': time_major, @@ -483,7 +483,7 @@ class GRU(recurrent.DropoutRNNCellMixin, recurrent.GRU): if dropout_mask is not None: inputs = inputs * dropout_mask[0] - if compat.forward_compatible(2020, 8, 27): + if compat.forward_compatible(2020, 9, 2): gru_kwargs = { 'inputs': inputs, 'init_h': _read_variable_value(initial_state[0]), @@ -797,7 +797,7 @@ def gru_with_backend_selection(inputs, init_h, kernel, recurrent_kernel, bias, true_fn=cudnn_gru_fn, false_fn=standard_gru_fn) - if compat.forward_compatible(2020, 8, 27): + if compat.forward_compatible(2020, 9, 2): # Chooses the implementation dynamicly based on the running device. (last_output, outputs, new_h, runtime) = control_flow_ops.execute_fn_for_device( @@ -1141,7 +1141,7 @@ class LSTM(recurrent.DropoutRNNCellMixin, recurrent.LSTM): else: logging.warn(_CUDNN_NOT_AVAILABLE_MSG % self.name) - if compat.forward_compatible(2020, 8, 27): + if compat.forward_compatible(2020, 9, 2): # The first two attributes are added to support TFLite use case. supportive_attributes = { 'time_major': time_major, @@ -1202,7 +1202,7 @@ class LSTM(recurrent.DropoutRNNCellMixin, recurrent.LSTM): dropout_mask = self.get_dropout_mask_for_cell(inputs, training, count=4) if dropout_mask is not None: inputs = inputs * dropout_mask[0] - if compat.forward_compatible(2020, 8, 27): + if compat.forward_compatible(2020, 9, 2): lstm_kwargs = { 'inputs': inputs, @@ -1633,7 +1633,7 @@ def lstm_with_backend_selection(inputs, init_h, init_c, kernel, true_fn=cudnn_lstm_fn, false_fn=stardard_lstm_fn) - if compat.forward_compatible(2020, 8, 27): + if compat.forward_compatible(2020, 9, 2): # Chooses the implementation dynamicly based on the running device. (last_output, outputs, new_h, new_c, runtime) = control_flow_ops.execute_fn_for_device( diff --git a/tensorflow/python/keras/layers/rnn_cell_wrapper_v2_test.py b/tensorflow/python/keras/layers/rnn_cell_wrapper_v2_test.py index b0fd5189b17..19ea3dcce90 100644 --- a/tensorflow/python/keras/layers/rnn_cell_wrapper_v2_test.py +++ b/tensorflow/python/keras/layers/rnn_cell_wrapper_v2_test.py @@ -40,8 +40,10 @@ class RNNCellWrapperTest(test.TestCase, parameterized.TestCase): def testResidualWrapper(self): wrapper_type = rnn_cell_wrapper_v2.ResidualWrapper - x = ops.convert_to_tensor_v2(np.array([[1., 1., 1.]]), dtype="float32") - m = ops.convert_to_tensor_v2(np.array([[0.1, 0.1, 0.1]]), dtype="float32") + x = ops.convert_to_tensor_v2_with_dispatch( + np.array([[1., 1., 1.]]), dtype="float32") + m = ops.convert_to_tensor_v2_with_dispatch( + np.array([[0.1, 0.1, 0.1]]), dtype="float32") base_cell = rnn_cell_impl.GRUCell( 3, kernel_initializer=init_ops.constant_initializer(0.5), bias_initializer=init_ops.constant_initializer(0.5)) @@ -62,9 +64,10 @@ class RNNCellWrapperTest(test.TestCase, parameterized.TestCase): def testResidualWrapperWithSlice(self): wrapper_type = rnn_cell_wrapper_v2.ResidualWrapper - x = ops.convert_to_tensor_v2( + x = ops.convert_to_tensor_v2_with_dispatch( np.array([[1., 1., 1., 1., 1.]]), dtype="float32") - m = ops.convert_to_tensor_v2(np.array([[0.1, 0.1, 0.1]]), dtype="float32") + m = ops.convert_to_tensor_v2_with_dispatch( + np.array([[0.1, 0.1, 0.1]]), dtype="float32") base_cell = rnn_cell_impl.GRUCell( 3, kernel_initializer=init_ops.constant_initializer(0.5), bias_initializer=init_ops.constant_initializer(0.5)) @@ -116,7 +119,8 @@ class RNNCellWrapperTest(test.TestCase, parameterized.TestCase): base_cell = layers.SimpleRNNCell(1, name="basic_rnn_cell") rnn_cell = wrapper(base_cell) rnn_layer = layers.RNN(rnn_cell) - inputs = ops.convert_to_tensor_v2([[[1]]], dtype=dtypes.float32) + inputs = ops.convert_to_tensor_v2_with_dispatch([[[1]]], + dtype=dtypes.float32) rnn_layer(inputs) wrapper_name = generic_utils.to_snake_case(wrapper.__name__) @@ -140,8 +144,8 @@ class RNNCellWrapperTest(test.TestCase, parameterized.TestCase): base_cell = rnn_cell_impl.MultiRNNCell( [rnn_cell_impl.BasicRNNCell(1) for _ in range(2)]) rnn_cell = wrapper(base_cell) - inputs = ops.convert_to_tensor_v2([[1]], dtype=dtypes.float32) - state = ops.convert_to_tensor_v2([[1]], dtype=dtypes.float32) + inputs = ops.convert_to_tensor_v2_with_dispatch([[1]], dtype=dtypes.float32) + state = ops.convert_to_tensor_v2_with_dispatch([[1]], dtype=dtypes.float32) _ = rnn_cell(inputs, [state, state]) weights = base_cell._cells[0].weights self.assertLen(weights, expected_len=2) diff --git a/tensorflow/python/keras/layers/subclassed_layers_test.py b/tensorflow/python/keras/layers/subclassed_layers_test.py index 6adeb0934ed..572ce859702 100644 --- a/tensorflow/python/keras/layers/subclassed_layers_test.py +++ b/tensorflow/python/keras/layers/subclassed_layers_test.py @@ -37,7 +37,7 @@ class SubclassedLayersTest(keras_parameterized.TestCase): class BuildConstantLayer(keras.layers.Layer): def build(self, input_shape): - self.b = ops.convert_to_tensor_v2(2.0) + self.b = ops.convert_to_tensor_v2_with_dispatch(2.0) def call(self, inputs): return self.b * inputs @@ -46,7 +46,7 @@ class SubclassedLayersTest(keras_parameterized.TestCase): model = testing_utils.get_model_from_layers( [layer, keras.layers.Dense(1)], input_shape=(1,)) - x = ops.convert_to_tensor_v2([[3.0]]) + x = ops.convert_to_tensor_v2_with_dispatch([[3.0]]) self.assertEqual( tf_utils.is_symbolic_tensor(model(x)), not context.executing_eagerly()) self.assertEqual( @@ -58,10 +58,10 @@ class SubclassedLayersTest(keras_parameterized.TestCase): class BuildDerivedConstantLayer(keras.layers.Layer): def build(self, input_shape): - a = ops.convert_to_tensor_v2(1.0) + a = ops.convert_to_tensor_v2_with_dispatch(1.0) b = 2.0 * a self.variable = variables.Variable(b) - self.constant = ops.convert_to_tensor_v2(self.variable) + self.constant = ops.convert_to_tensor_v2_with_dispatch(self.variable) def call(self, inputs): return self.variable * self.constant * inputs @@ -70,7 +70,7 @@ class SubclassedLayersTest(keras_parameterized.TestCase): model = testing_utils.get_model_from_layers( [layer, keras.layers.Dense(1)], input_shape=(1,)) - x = ops.convert_to_tensor_v2([[3.0]]) + x = ops.convert_to_tensor_v2_with_dispatch([[3.0]]) self.assertEqual( tf_utils.is_symbolic_tensor(model(x)), not context.executing_eagerly()) self.assertEqual( diff --git a/tensorflow/python/keras/layers/tensorflow_op_layer_test.py b/tensorflow/python/keras/layers/tensorflow_op_layer_test.py index e128323a1a6..02932337ed1 100644 --- a/tensorflow/python/keras/layers/tensorflow_op_layer_test.py +++ b/tensorflow/python/keras/layers/tensorflow_op_layer_test.py @@ -637,7 +637,7 @@ class AutoLambdaTest(keras_parameterized.TestCase): self.assertAllEqual(model(ones), 3.0 * ones) def test_numerical_correctness_simple(self): - x = ops.convert_to_tensor_v2([[-1., 0., -2., 1.]]) + x = ops.convert_to_tensor_v2_with_dispatch([[-1., 0., -2., 1.]]) inputs = keras.Input(shape=(4,)) outputs = gen_nn_ops.relu(inputs) model = keras.Model(inputs, outputs) @@ -645,7 +645,7 @@ class AutoLambdaTest(keras_parameterized.TestCase): self.assertAllClose(y, [[0., 0., 0., 1.]]) def test_numerical_correctness_with_attrs(self): - x = ops.convert_to_tensor_v2([[1.5, 1.5], [2.5, 3.5]]) + x = ops.convert_to_tensor_v2_with_dispatch([[1.5, 1.5], [2.5, 3.5]]) inputs = keras.Input(shape=(2,)) outputs = math_ops.reduce_mean(inputs, axis=1) model = keras.Model(inputs, outputs) @@ -653,7 +653,7 @@ class AutoLambdaTest(keras_parameterized.TestCase): self.assertAllClose(y, [1.5, 3.]) def test_numerical_correctness_serialization(self): - x = ops.convert_to_tensor_v2([[-1., 0., -2., 1.]]) + x = ops.convert_to_tensor_v2_with_dispatch([[-1., 0., -2., 1.]]) inputs = keras.Input(shape=(4,)) outputs = gen_nn_ops.relu(inputs) model1 = keras.Model(inputs, outputs) diff --git a/tensorflow/python/keras/layers/wrappers_test.py b/tensorflow/python/keras/layers/wrappers_test.py index 671fe65d520..4f0eee81f1f 100644 --- a/tensorflow/python/keras/layers/wrappers_test.py +++ b/tensorflow/python/keras/layers/wrappers_test.py @@ -44,7 +44,6 @@ from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.platform import test from tensorflow.python.training.tracking import util as trackable_util from tensorflow.python.util import nest -from tensorflow.python.util import object_identity class _RNNCellWithConstants(keras.layers.Layer): @@ -130,10 +129,11 @@ class TimeDistributedTest(keras_parameterized.TestCase): # check whether the model variables are present in the # trackable list of objects - checkpointed_objects = object_identity.ObjectIdentitySet( - trackable_util.list_objects(model)) + checkpointed_object_ids = { + id(o) for o in trackable_util.list_objects(model) + } for v in model.variables: - self.assertIn(v, checkpointed_objects) + self.assertIn(id(v), checkpointed_object_ids) def test_timedistributed_static_batch_size(self): model = keras.models.Sequential() @@ -492,10 +492,11 @@ class BidirectionalTest(test.TestCase, parameterized.TestCase): # check whether the model variables are present in the # trackable list of objects - checkpointed_objects = object_identity.ObjectIdentitySet( - trackable_util.list_objects(model)) + checkpointed_object_ids = { + id(o) for o in trackable_util.list_objects(model) + } for v in model.variables: - self.assertIn(v, checkpointed_objects) + self.assertIn(id(v), checkpointed_object_ids) # test compute output shape ref_shape = model.layers[-1].output.shape @@ -1030,10 +1031,11 @@ class BidirectionalTest(test.TestCase, parameterized.TestCase): # check whether the model variables are present in the # trackable list of objects - checkpointed_objects = object_identity.ObjectIdentitySet( - trackable_util.list_objects(model)) + checkpointed_object_ids = { + id(o) for o in trackable_util.list_objects(model) + } for v in model.variables: - self.assertIn(v, checkpointed_objects) + self.assertIn(id(v), checkpointed_object_ids) # test compute output shape ref_shape = model.layers[-1].output.shape diff --git a/tensorflow/python/keras/losses.py b/tensorflow/python/keras/losses.py index bda32897fc5..6b74121cf80 100644 --- a/tensorflow/python/keras/losses.py +++ b/tensorflow/python/keras/losses.py @@ -1189,7 +1189,7 @@ def mean_squared_error(y_true, y_pred): Returns: Mean squared error values. shape = `[batch_size, d0, .. dN-1]`. """ - y_pred = ops.convert_to_tensor_v2(y_pred) + y_pred = ops.convert_to_tensor_v2_with_dispatch(y_pred) y_true = math_ops.cast(y_true, y_pred.dtype) return K.mean(math_ops.squared_difference(y_pred, y_true), axis=-1) @@ -1222,7 +1222,7 @@ def mean_absolute_error(y_true, y_pred): Returns: Mean absolute error values. shape = `[batch_size, d0, .. dN-1]`. """ - y_pred = ops.convert_to_tensor_v2(y_pred) + y_pred = ops.convert_to_tensor_v2_with_dispatch(y_pred) y_true = math_ops.cast(y_true, y_pred.dtype) return K.mean(math_ops.abs(y_pred - y_true), axis=-1) @@ -1257,7 +1257,7 @@ def mean_absolute_percentage_error(y_true, y_pred): Returns: Mean absolute percentage error values. shape = `[batch_size, d0, .. dN-1]`. """ - y_pred = ops.convert_to_tensor_v2(y_pred) + y_pred = ops.convert_to_tensor_v2_with_dispatch(y_pred) y_true = math_ops.cast(y_true, y_pred.dtype) diff = math_ops.abs( (y_true - y_pred) / K.maximum(math_ops.abs(y_true), K.epsilon())) @@ -1296,7 +1296,7 @@ def mean_squared_logarithmic_error(y_true, y_pred): Returns: Mean squared logarithmic error values. shape = `[batch_size, d0, .. dN-1]`. """ - y_pred = ops.convert_to_tensor_v2(y_pred) + y_pred = ops.convert_to_tensor_v2_with_dispatch(y_pred) y_true = math_ops.cast(y_true, y_pred.dtype) first_log = math_ops.log(K.maximum(y_pred, K.epsilon()) + 1.) second_log = math_ops.log(K.maximum(y_true, K.epsilon()) + 1.) @@ -1344,7 +1344,7 @@ def squared_hinge(y_true, y_pred): Returns: Squared hinge loss values. shape = `[batch_size, d0, .. dN-1]`. """ - y_pred = ops.convert_to_tensor_v2(y_pred) + y_pred = ops.convert_to_tensor_v2_with_dispatch(y_pred) y_true = math_ops.cast(y_true, y_pred.dtype) y_true = _maybe_convert_labels(y_true) return K.mean( @@ -1377,7 +1377,7 @@ def hinge(y_true, y_pred): Returns: Hinge loss values. shape = `[batch_size, d0, .. dN-1]`. """ - y_pred = ops.convert_to_tensor_v2(y_pred) + y_pred = ops.convert_to_tensor_v2_with_dispatch(y_pred) y_true = math_ops.cast(y_true, y_pred.dtype) y_true = _maybe_convert_labels(y_true) return K.mean(math_ops.maximum(1. - y_true * y_pred, 0.), axis=-1) @@ -1409,7 +1409,7 @@ def categorical_hinge(y_true, y_pred): Returns: Categorical hinge loss values. """ - y_pred = ops.convert_to_tensor_v2(y_pred) + y_pred = ops.convert_to_tensor_v2_with_dispatch(y_pred) y_true = math_ops.cast(y_true, y_pred.dtype) pos = math_ops.reduce_sum(y_true * y_pred, axis=-1) neg = math_ops.reduce_max((1. - y_true) * y_pred, axis=-1) @@ -1444,7 +1444,7 @@ def huber(y_true, y_pred, delta=1.0): delta = math_ops.cast(delta, dtype=K.floatx()) error = math_ops.subtract(y_pred, y_true) abs_error = math_ops.abs(error) - half = ops.convert_to_tensor_v2(0.5, dtype=abs_error.dtype) + half = ops.convert_to_tensor_v2_with_dispatch(0.5, dtype=abs_error.dtype) return K.mean( array_ops.where_v2( abs_error <= delta, half * math_ops.pow(error, 2), @@ -1481,7 +1481,7 @@ def log_cosh(y_true, y_pred): Returns: Logcosh error values. shape = `[batch_size, d0, .. dN-1]`. """ - y_pred = ops.convert_to_tensor_v2(y_pred) + y_pred = ops.convert_to_tensor_v2_with_dispatch(y_pred) y_true = math_ops.cast(y_true, y_pred.dtype) def _logcosh(x): @@ -1518,9 +1518,10 @@ def categorical_crossentropy(y_true, Returns: Categorical crossentropy loss value. """ - y_pred = ops.convert_to_tensor_v2(y_pred) + y_pred = ops.convert_to_tensor_v2_with_dispatch(y_pred) y_true = math_ops.cast(y_true, y_pred.dtype) - label_smoothing = ops.convert_to_tensor_v2(label_smoothing, dtype=K.floatx()) + label_smoothing = ops.convert_to_tensor_v2_with_dispatch( + label_smoothing, dtype=K.floatx()) def _smooth_labels(): num_classes = math_ops.cast(array_ops.shape(y_true)[-1], y_pred.dtype) @@ -1557,7 +1558,7 @@ def sparse_categorical_crossentropy(y_true, y_pred, from_logits=False, axis=-1): Returns: Sparse categorical crossentropy loss value. """ - y_pred = ops.convert_to_tensor_v2(y_pred) + y_pred = ops.convert_to_tensor_v2_with_dispatch(y_pred) y_true = math_ops.cast(y_true, y_pred.dtype) return K.sparse_categorical_crossentropy( y_true, y_pred, from_logits=from_logits, axis=axis) @@ -1588,9 +1589,10 @@ def binary_crossentropy(y_true, y_pred, from_logits=False, label_smoothing=0): Returns: Binary crossentropy loss value. shape = `[batch_size, d0, .. dN-1]`. """ - y_pred = ops.convert_to_tensor_v2(y_pred) + y_pred = ops.convert_to_tensor_v2_with_dispatch(y_pred) y_true = math_ops.cast(y_true, y_pred.dtype) - label_smoothing = ops.convert_to_tensor_v2(label_smoothing, dtype=K.floatx()) + label_smoothing = ops.convert_to_tensor_v2_with_dispatch( + label_smoothing, dtype=K.floatx()) def _smooth_labels(): return y_true * (1.0 - label_smoothing) + 0.5 * label_smoothing @@ -1638,7 +1640,7 @@ def kl_divergence(y_true, y_pred): Raises: TypeError: If `y_true` cannot be cast to the `y_pred.dtype`. """ - y_pred = ops.convert_to_tensor_v2(y_pred) + y_pred = ops.convert_to_tensor_v2_with_dispatch(y_pred) y_true = math_ops.cast(y_true, y_pred.dtype) y_true = K.clip(y_true, K.epsilon(), 1) y_pred = K.clip(y_pred, K.epsilon(), 1) @@ -1674,7 +1676,7 @@ def poisson(y_true, y_pred): Raises: InvalidArgumentError: If `y_true` and `y_pred` have incompatible shapes. """ - y_pred = ops.convert_to_tensor_v2(y_pred) + y_pred = ops.convert_to_tensor_v2_with_dispatch(y_pred) y_true = math_ops.cast(y_true, y_pred.dtype) return K.mean(y_pred - y_true * math_ops.log(y_pred + K.epsilon()), axis=-1) diff --git a/tensorflow/python/keras/losses_test.py b/tensorflow/python/keras/losses_test.py index 34213c8308a..4de49e69829 100644 --- a/tensorflow/python/keras/losses_test.py +++ b/tensorflow/python/keras/losses_test.py @@ -95,16 +95,19 @@ class KerasLossesTest(test.TestCase, parameterized.TestCase): p = backend.placeholder() o = losses.categorical_crossentropy(t, p) - t_val = ops.convert_to_tensor_v2([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]]) - p_val = ops.convert_to_tensor_v2([[.9, .05, .05], [.05, .89, .06], - [.05, .01, .94]]) + t_val = ops.convert_to_tensor_v2_with_dispatch([[1., 0., 0.], [0., 1., 0.], + [0., 0., 1.]]) + p_val = ops.convert_to_tensor_v2_with_dispatch([[.9, .05, .05], + [.05, .89, .06], + [.05, .01, .94]]) f = backend.function([t, p], o) result = f([t_val, p_val]) self.assertArrayNear(result, [.105, .116, .062], 1e-3) # from logits - p_val = ops.convert_to_tensor_v2([[8., 1., 1.], [0., 9., 1.], [2., 3., 5.]]) + p_val = ops.convert_to_tensor_v2_with_dispatch([[8., 1., 1.], [0., 9., 1.], + [2., 3., 5.]]) o = losses.categorical_crossentropy(t, p, from_logits=True) f = backend.function([t, p], o) @@ -133,16 +136,18 @@ class KerasLossesTest(test.TestCase, parameterized.TestCase): p = backend.placeholder() o = losses.sparse_categorical_crossentropy(t, p) - t_val = ops.convert_to_tensor_v2([0, 1, 2]) - p_val = ops.convert_to_tensor_v2([[.9, .05, .05], [.05, .89, .06], - [.05, .01, .94]]) + t_val = ops.convert_to_tensor_v2_with_dispatch([0, 1, 2]) + p_val = ops.convert_to_tensor_v2_with_dispatch([[.9, .05, .05], + [.05, .89, .06], + [.05, .01, .94]]) f = backend.function([t, p], o) result = f([t_val, p_val]) self.assertArrayNear(result, [.105, .116, .062], 1e-3) # from logits - p_val = ops.convert_to_tensor_v2([[8., 1., 1.], [0., 9., 1.], [2., 3., 5.]]) + p_val = ops.convert_to_tensor_v2_with_dispatch([[8., 1., 1.], [0., 9., 1.], + [2., 3., 5.]]) o = losses.sparse_categorical_crossentropy(t, p, from_logits=True) f = backend.function([t, p], o) diff --git a/tensorflow/python/keras/metrics.py b/tensorflow/python/keras/metrics.py index b3f391c7897..eea1881ba4f 100644 --- a/tensorflow/python/keras/metrics.py +++ b/tensorflow/python/keras/metrics.py @@ -963,7 +963,7 @@ class _ConfusionMatrixConditionCount(Metric): result = self.accumulator[0] else: result = self.accumulator - return ops.convert_to_tensor_v2(result) + return ops.convert_to_tensor_v2_with_dispatch(result) def reset_states(self): num_thresholds = len(to_list(self.thresholds)) @@ -3239,7 +3239,7 @@ def binary_accuracy(y_true, y_pred, threshold=0.5): Returns: Binary accuracy values. shape = `[batch_size, d0, .. dN-1]` """ - y_pred = ops.convert_to_tensor_v2(y_pred) + y_pred = ops.convert_to_tensor_v2_with_dispatch(y_pred) threshold = math_ops.cast(threshold, y_pred.dtype) y_pred = math_ops.cast(y_pred > threshold, y_pred.dtype) return K.mean(math_ops.equal(y_true, y_pred), axis=-1) @@ -3297,8 +3297,8 @@ def sparse_categorical_accuracy(y_true, y_pred): Returns: Sparse categorical accuracy values. """ - y_pred = ops.convert_to_tensor_v2(y_pred) - y_true = ops.convert_to_tensor_v2(y_true) + y_pred = ops.convert_to_tensor_v2_with_dispatch(y_pred) + y_true = ops.convert_to_tensor_v2_with_dispatch(y_true) y_pred_rank = y_pred.shape.ndims y_true_rank = y_true.shape.ndims # If the shape of y_true is (num_samples, 1), squeeze to (num_samples,) @@ -3364,8 +3364,8 @@ def sparse_top_k_categorical_accuracy(y_true, y_pred, k=5): Returns: Sparse top K categorical accuracy value. """ - y_pred_rank = ops.convert_to_tensor_v2(y_pred).shape.ndims - y_true_rank = ops.convert_to_tensor_v2(y_true).shape.ndims + y_pred_rank = ops.convert_to_tensor_v2_with_dispatch(y_pred).shape.ndims + y_true_rank = ops.convert_to_tensor_v2_with_dispatch(y_true).shape.ndims # Flatten y_pred to (batch_size, num_samples) and y_true to (num_samples,) if (y_true_rank is not None) and (y_pred_rank is not None): if y_pred_rank > 2: diff --git a/tensorflow/python/keras/metrics_test.py b/tensorflow/python/keras/metrics_test.py index 7b339fc5a47..554609d4086 100644 --- a/tensorflow/python/keras/metrics_test.py +++ b/tensorflow/python/keras/metrics_test.py @@ -71,7 +71,7 @@ class KerasSumTest(test.TestCase, parameterized.TestCase): self.assertEqual(self.evaluate(m.total), 100) # check update_state() and result() + state accumulation + tensor input - update_op = m.update_state(ops.convert_to_tensor_v2([1, 5])) + update_op = m.update_state(ops.convert_to_tensor_v2_with_dispatch([1, 5])) self.evaluate(update_op) self.assertAlmostEqual(self.evaluate(m.result()), 106) self.assertEqual(self.evaluate(m.total), 106) # 100 + 1 + 5 diff --git a/tensorflow/python/keras/mixed_precision/experimental/autocast_variable.py b/tensorflow/python/keras/mixed_precision/experimental/autocast_variable.py index 20770061639..d0d6442639f 100644 --- a/tensorflow/python/keras/mixed_precision/experimental/autocast_variable.py +++ b/tensorflow/python/keras/mixed_precision/experimental/autocast_variable.py @@ -126,7 +126,7 @@ class AutoCastVariable(variables.Variable, core.Tensor): raise ValueError( 'Incompatible type conversion requested to type {!r} for variable ' 'of type {!r}'.format(dtype.name, self.dtype.name)) - val = ops.convert_to_tensor_v2( + val = ops.convert_to_tensor_v2_with_dispatch( self._variable, dtype=self._variable.dtype, name=name) return math_ops.cast(val, self.dtype) diff --git a/tensorflow/python/keras/mixed_precision/experimental/loss_scale_optimizer_test.py b/tensorflow/python/keras/mixed_precision/experimental/loss_scale_optimizer_test.py index 9a9d174a64f..75fe3d92565 100644 --- a/tensorflow/python/keras/mixed_precision/experimental/loss_scale_optimizer_test.py +++ b/tensorflow/python/keras/mixed_precision/experimental/loss_scale_optimizer_test.py @@ -124,10 +124,10 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase): def testGetScaledLoss(self): opt = gradient_descent.SGD(2.0) opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale=2.) - loss = ops.convert_to_tensor_v2(5.) + loss = ops.convert_to_tensor_v2_with_dispatch(5.) self.assertEqual(10., self.evaluate(opt.get_scaled_loss(loss))) self.assertEqual(10., self.evaluate(opt.get_scaled_loss(lambda: loss)())) - loss = ops.convert_to_tensor_v2(5., dtype='float16') + loss = ops.convert_to_tensor_v2_with_dispatch(5., dtype='float16') self.assertEqual(10., self.evaluate(opt.get_scaled_loss(loss))) self.assertEqual(10., self.evaluate(opt.get_scaled_loss(lambda: loss)())) @@ -135,8 +135,8 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase): opt = gradient_descent.SGD(2.0) opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale=2) scaled_grads = [ - ops.convert_to_tensor_v2(3.), None, - ops.convert_to_tensor_v2(-4., dtype='float16') + ops.convert_to_tensor_v2_with_dispatch(3.), None, + ops.convert_to_tensor_v2_with_dispatch(-4., dtype='float16') ] grads = opt.get_unscaled_gradients(scaled_grads) grads = [self.evaluate(g) if g is not None else g for g in grads] @@ -146,9 +146,10 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase): opt = gradient_descent.SGD(2.0) opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale=2) sparse_scaled_grad = ops.IndexedSlices( - ops.convert_to_tensor_v2([[4., 2.], [8., 5.]]), - ops.convert_to_tensor_v2([1, 3], dtype='int32'), - dense_shape=ops.convert_to_tensor_v2([5, 2], dtype='int32')) + ops.convert_to_tensor_v2_with_dispatch([[4., 2.], [8., 5.]]), + ops.convert_to_tensor_v2_with_dispatch([1, 3], dtype='int32'), + dense_shape=ops.convert_to_tensor_v2_with_dispatch([5, 2], + dtype='int32')) sparse_grad = opt.get_unscaled_gradients([sparse_scaled_grad])[0] self.assertIsInstance(sparse_grad, ops.IndexedSlices) self.assertAllEqual([[2., 1.], [4., 2.5]], diff --git a/tensorflow/python/keras/mixed_precision/experimental/policy.py b/tensorflow/python/keras/mixed_precision/experimental/policy.py index 592057f0b56..c8acd86c5f7 100644 --- a/tensorflow/python/keras/mixed_precision/experimental/policy.py +++ b/tensorflow/python/keras/mixed_precision/experimental/policy.py @@ -543,7 +543,12 @@ def set_policy(policy): passed to the layer constructor. If no global policy is set, layers will instead default to a Policy constructed from `tf.keras.backend.floatx()`. - See `keras.mixed_precision.experimental.Policy` for more information. + Only floating point policies can be set as the global policy, such as + `'float32'` and `'mixed_float16'`. Non-floating point policies such as + `'int32'` and `'complex64'` cannot be set as the global policy because most + layers do not support such policies. + + See `tf.keras.mixed_precision.experimental.Policy` for more information. Args: policy: A Policy, or a string that will be converted to a Policy.. @@ -559,6 +564,12 @@ def set_policy(policy): is_mixed_policy = policy is not None and policy.should_cast_variables if is_mixed_policy: _check_if_mixed_precision_graph_rewrite_is_enabled(policy) + if (policy is not None and policy.compute_dtype is not None and + not dtypes.as_dtype(policy.compute_dtype).is_floating): + raise ValueError('set_policy can only be used to set the global policy to ' + 'floating-point policies, such as "float32" and ' + '"mixed_float16", but got policy: %s' + % (policy.name,)) _global_policy = policy mixed_precision_global_state.using_mixed_precision_policy = is_mixed_policy diff --git a/tensorflow/python/keras/mixed_precision/experimental/policy_test.py b/tensorflow/python/keras/mixed_precision/experimental/policy_test.py index 94880a9b239..060f80f255b 100644 --- a/tensorflow/python/keras/mixed_precision/experimental/policy_test.py +++ b/tensorflow/python/keras/mixed_precision/experimental/policy_test.py @@ -155,6 +155,21 @@ class PolicyTest(test.TestCase, parameterized.TestCase): finally: mp_policy.set_policy(None) + @testing_utils.enable_v2_dtype_behavior + def test_global_policy_dtype_error(self): + with self.assertRaisesRegex( + ValueError, + 'set_policy can only be used to set the global policy to ' + 'floating-point policies, such as "float32" and "mixed_float16", but ' + 'got policy: int32'): + mp_policy.set_policy('int32') + with self.assertRaisesRegex( + ValueError, + 'set_policy can only be used to set the global policy to ' + 'floating-point policies, such as "float32" and "mixed_float16", but ' + 'got policy: complex64'): + mp_policy.set_policy(mp_policy.Policy('complex64')) + @testing_utils.enable_v2_dtype_behavior def test_loss_scale_warning(self): with test.mock.patch.object(tf_logging, 'warn') as mock_warn: diff --git a/tensorflow/python/keras/mixed_precision/experimental/test_util.py b/tensorflow/python/keras/mixed_precision/experimental/test_util.py index 937b378115d..c0d9cbf98d6 100644 --- a/tensorflow/python/keras/mixed_precision/experimental/test_util.py +++ b/tensorflow/python/keras/mixed_precision/experimental/test_util.py @@ -55,7 +55,7 @@ def create_identity_with_grad_check_fn(expected_gradient, expected_dtype=None): if expected_dtype: assert dx.dtype == expected_dtype, ( 'dx.dtype should be %s but is: %s' % (expected_dtype, dx.dtype)) - expected_tensor = ops.convert_to_tensor_v2( + expected_tensor = ops.convert_to_tensor_v2_with_dispatch( expected_gradient, dtype=dx.dtype, name='expected_gradient') # Control dependency is to ensure input is available. It's possible the # dataset will throw a StopIteration to indicate there is no more data, in diff --git a/tensorflow/python/keras/optimizer_v2/adadelta.py b/tensorflow/python/keras/optimizer_v2/adadelta.py index 8c895ae07f4..404b3f81e3f 100644 --- a/tensorflow/python/keras/optimizer_v2/adadelta.py +++ b/tensorflow/python/keras/optimizer_v2/adadelta.py @@ -101,7 +101,8 @@ class Adadelta(optimizer_v2.OptimizerV2): super(Adadelta, self)._prepare_local(var_device, var_dtype, apply_state) apply_state[(var_device, var_dtype)].update( dict( - epsilon=ops.convert_to_tensor_v2(self.epsilon, var_dtype), + epsilon=ops.convert_to_tensor_v2_with_dispatch( + self.epsilon, var_dtype), rho=array_ops.identity(self._get_hyper('rho', var_dtype)))) def set_weights(self, weights): diff --git a/tensorflow/python/keras/optimizer_v2/adagrad.py b/tensorflow/python/keras/optimizer_v2/adagrad.py index ba76b837942..4d3294ab9f8 100644 --- a/tensorflow/python/keras/optimizer_v2/adagrad.py +++ b/tensorflow/python/keras/optimizer_v2/adagrad.py @@ -87,7 +87,8 @@ class Adagrad(optimizer_v2.OptimizerV2): super(Adagrad, self)._prepare_local(var_device, var_dtype, apply_state) apply_state[(var_device, var_dtype)].update( dict( - epsilon=ops.convert_to_tensor_v2(self.epsilon, var_dtype), + epsilon=ops.convert_to_tensor_v2_with_dispatch( + self.epsilon, var_dtype), neg_lr_t=-apply_state[(var_device, var_dtype)]['lr_t'], zero=array_ops.zeros((), dtype=dtypes.int64))) diff --git a/tensorflow/python/keras/optimizer_v2/adam.py b/tensorflow/python/keras/optimizer_v2/adam.py index 1fccd116012..e4896fd167e 100644 --- a/tensorflow/python/keras/optimizer_v2/adam.py +++ b/tensorflow/python/keras/optimizer_v2/adam.py @@ -144,7 +144,8 @@ class Adam(optimizer_v2.OptimizerV2): apply_state[(var_device, var_dtype)].update( dict( lr=lr, - epsilon=ops.convert_to_tensor_v2(self.epsilon, var_dtype), + epsilon=ops.convert_to_tensor_v2_with_dispatch( + self.epsilon, var_dtype), beta_1_t=beta_1_t, beta_1_power=beta_1_power, one_minus_beta_1_t=1 - beta_1_t, @@ -396,7 +397,8 @@ class NonFusedAdam(optimizer_v2.OptimizerV2): apply_state[(var_device, var_dtype)].update( dict( lr=lr, - epsilon=ops.convert_to_tensor_v2(self.epsilon, var_dtype), + epsilon=ops.convert_to_tensor_v2_with_dispatch( + self.epsilon, var_dtype), beta_1_t=beta_1_t, beta_1_power=beta_1_power, one_minus_beta_1_t=1 - beta_1_t, diff --git a/tensorflow/python/keras/optimizer_v2/adamax.py b/tensorflow/python/keras/optimizer_v2/adamax.py index 3f4312c731e..26cc59b1f98 100644 --- a/tensorflow/python/keras/optimizer_v2/adamax.py +++ b/tensorflow/python/keras/optimizer_v2/adamax.py @@ -122,7 +122,8 @@ class Adamax(optimizer_v2.OptimizerV2): apply_state[(var_device, var_dtype)].update( dict( neg_scaled_lr=-lr_t / (1 - beta_1_power), - epsilon=ops.convert_to_tensor_v2(self.epsilon, var_dtype), + epsilon=ops.convert_to_tensor_v2_with_dispatch( + self.epsilon, var_dtype), beta_1_t=beta_1_t, beta_1_power=beta_1_power, one_minus_beta_1_t=1 - beta_1_t, diff --git a/tensorflow/python/keras/optimizer_v2/learning_rate_schedule.py b/tensorflow/python/keras/optimizer_v2/learning_rate_schedule.py index 4dcff3d6c44..30b4f2145bb 100644 --- a/tensorflow/python/keras/optimizer_v2/learning_rate_schedule.py +++ b/tensorflow/python/keras/optimizer_v2/learning_rate_schedule.py @@ -143,7 +143,7 @@ class ExponentialDecay(LearningRateSchedule): def __call__(self, step): with ops.name_scope_v2(self.name or "ExponentialDecay") as name: - initial_learning_rate = ops.convert_to_tensor_v2( + initial_learning_rate = ops.convert_to_tensor_v2_with_dispatch( self.initial_learning_rate, name="initial_learning_rate") dtype = initial_learning_rate.dtype decay_steps = math_ops.cast(self.decay_steps, dtype) @@ -237,11 +237,11 @@ class PiecewiseConstantDecay(LearningRateSchedule): def __call__(self, step): with ops.name_scope_v2(self.name or "PiecewiseConstant"): - boundaries = nest.map_structure(ops.convert_to_tensor_v2, + boundaries = nest.map_structure(ops.convert_to_tensor_v2_with_dispatch, nest.flatten(self.boundaries)) - values = nest.map_structure(ops.convert_to_tensor_v2, + values = nest.map_structure(ops.convert_to_tensor_v2_with_dispatch, nest.flatten(self.values)) - x_recomp = ops.convert_to_tensor_v2(step) + x_recomp = ops.convert_to_tensor_v2_with_dispatch(step) for i, b in enumerate(boundaries): if b.dtype.base_dtype != x_recomp.dtype.base_dtype: # We cast the boundaries to have the same type as the step @@ -374,7 +374,7 @@ class PolynomialDecay(LearningRateSchedule): def __call__(self, step): with ops.name_scope_v2(self.name or "PolynomialDecay") as name: - initial_learning_rate = ops.convert_to_tensor_v2( + initial_learning_rate = ops.convert_to_tensor_v2_with_dispatch( self.initial_learning_rate, name="initial_learning_rate") dtype = initial_learning_rate.dtype end_learning_rate = math_ops.cast(self.end_learning_rate, dtype) @@ -494,7 +494,7 @@ class InverseTimeDecay(LearningRateSchedule): def __call__(self, step): with ops.name_scope_v2(self.name or "InverseTimeDecay") as name: - initial_learning_rate = ops.convert_to_tensor_v2( + initial_learning_rate = ops.convert_to_tensor_v2_with_dispatch( self.initial_learning_rate, name="initial_learning_rate") dtype = initial_learning_rate.dtype decay_steps = math_ops.cast(self.decay_steps, dtype) @@ -588,7 +588,7 @@ class CosineDecay(LearningRateSchedule): def __call__(self, step): with ops.name_scope_v2(self.name or "CosineDecay"): - initial_learning_rate = ops.convert_to_tensor_v2( + initial_learning_rate = ops.convert_to_tensor_v2_with_dispatch( self.initial_learning_rate, name="initial_learning_rate") dtype = initial_learning_rate.dtype decay_steps = math_ops.cast(self.decay_steps, dtype) @@ -687,7 +687,7 @@ class CosineDecayRestarts(LearningRateSchedule): def __call__(self, step): with ops.name_scope_v2(self.name or "SGDRDecay") as name: - initial_learning_rate = ops.convert_to_tensor_v2( + initial_learning_rate = ops.convert_to_tensor_v2_with_dispatch( self.initial_learning_rate, name="initial_learning_rate") dtype = initial_learning_rate.dtype first_decay_steps = math_ops.cast(self.first_decay_steps, dtype) @@ -824,7 +824,7 @@ class LinearCosineDecay(LearningRateSchedule): def __call__(self, step): with ops.name_scope_v2(self.name or "LinearCosineDecay") as name: - initial_learning_rate = ops.convert_to_tensor_v2( + initial_learning_rate = ops.convert_to_tensor_v2_with_dispatch( self.initial_learning_rate, name="initial_learning_rate") dtype = initial_learning_rate.dtype decay_steps = math_ops.cast(self.decay_steps, dtype) @@ -950,7 +950,7 @@ class NoisyLinearCosineDecay(LearningRateSchedule): def __call__(self, step): with ops.name_scope_v2(self.name or "NoisyLinearCosineDecay") as name: - initial_learning_rate = ops.convert_to_tensor_v2( + initial_learning_rate = ops.convert_to_tensor_v2_with_dispatch( self.initial_learning_rate, name="initial_learning_rate") dtype = initial_learning_rate.dtype decay_steps = math_ops.cast(self.decay_steps, dtype) diff --git a/tensorflow/python/keras/optimizer_v2/legacy_learning_rate_decay.py b/tensorflow/python/keras/optimizer_v2/legacy_learning_rate_decay.py index ad280568fc7..ab8e4f55b52 100644 --- a/tensorflow/python/keras/optimizer_v2/legacy_learning_rate_decay.py +++ b/tensorflow/python/keras/optimizer_v2/legacy_learning_rate_decay.py @@ -148,10 +148,11 @@ def piecewise_constant(x, boundaries, values, name=None): the learning rate value across different invocations of optimizer functions. @end_compatibility """ - boundaries = nest.map_structure(ops.convert_to_tensor_v2, + boundaries = nest.map_structure(ops.convert_to_tensor_v2_with_dispatch, nest.flatten(boundaries)) - values = nest.map_structure(ops.convert_to_tensor_v2, nest.flatten(values)) - x_recomp = ops.convert_to_tensor(x) + values = nest.map_structure(ops.convert_to_tensor_v2_with_dispatch, + nest.flatten(values)) + x_recomp = ops.convert_to_tensor_v2_with_dispatch(x) # Avoid explicit conversion to x's dtype. This could result in faulty # comparisons, for example if floats are converted to integers. for i, b in enumerate(boundaries): diff --git a/tensorflow/python/keras/optimizer_v2/nadam.py b/tensorflow/python/keras/optimizer_v2/nadam.py index 090eabacf1e..550db0f6472 100644 --- a/tensorflow/python/keras/optimizer_v2/nadam.py +++ b/tensorflow/python/keras/optimizer_v2/nadam.py @@ -122,7 +122,7 @@ class Nadam(optimizer_v2.OptimizerV2): apply_state[(var_device, var_dtype)] = dict( lr_t=lr_t, neg_lr_t=-lr_t, - epsilon=ops.convert_to_tensor_v2(self.epsilon, var_dtype), + epsilon=ops.convert_to_tensor_v2_with_dispatch(self.epsilon, var_dtype), beta_1_t=beta_1_t, beta_2_t=beta_2_t, m_t=m_t, diff --git a/tensorflow/python/keras/optimizer_v2/optimizer_v2_test.py b/tensorflow/python/keras/optimizer_v2/optimizer_v2_test.py index 6a5b9865372..65539fa11aa 100644 --- a/tensorflow/python/keras/optimizer_v2/optimizer_v2_test.py +++ b/tensorflow/python/keras/optimizer_v2/optimizer_v2_test.py @@ -237,7 +237,7 @@ class OptimizerTest(test.TestCase, parameterized.TestCase): @combinations.generate(combinations.combine(mode=['graph', 'eager'])) def testComputeGradientsWithTensors(self): with testing_utils.use_gpu(): - x = ops.convert_to_tensor_v2(1.0) + x = ops.convert_to_tensor_v2_with_dispatch(1.0) def f(): return x * x diff --git a/tensorflow/python/keras/optimizer_v2/rmsprop.py b/tensorflow/python/keras/optimizer_v2/rmsprop.py index 1fa2577e72f..407dbf33206 100644 --- a/tensorflow/python/keras/optimizer_v2/rmsprop.py +++ b/tensorflow/python/keras/optimizer_v2/rmsprop.py @@ -167,7 +167,8 @@ class RMSprop(optimizer_v2.OptimizerV2): apply_state[(var_device, var_dtype)].update( dict( neg_lr_t=-apply_state[(var_device, var_dtype)]["lr_t"], - epsilon=ops.convert_to_tensor_v2(self.epsilon, var_dtype), + epsilon=ops.convert_to_tensor_v2_with_dispatch( + self.epsilon, var_dtype), rho=rho, momentum=array_ops.identity(self._get_hyper("momentum", var_dtype)), one_minus_rho=1. - rho)) diff --git a/tensorflow/python/keras/preprocessing/image.py b/tensorflow/python/keras/preprocessing/image.py index f943967d65b..40ec3d1a030 100644 --- a/tensorflow/python/keras/preprocessing/image.py +++ b/tensorflow/python/keras/preprocessing/image.py @@ -111,7 +111,7 @@ def smart_resize(x, size, interpolation='bilinear'): if len(size) != 2: raise ValueError('Expected `size` to be a tuple of 2 integers, ' 'but got: %s' % (size,)) - img = ops.convert_to_tensor(x) + img = ops.convert_to_tensor_v2_with_dispatch(x) if img.shape.rank is not None: if img.shape.rank != 3: raise ValueError( diff --git a/tensorflow/python/keras/saving/saved_model/load.py b/tensorflow/python/keras/saving/saved_model/load.py index 4889ee97211..556675d4bb5 100644 --- a/tensorflow/python/keras/saving/saved_model/load.py +++ b/tensorflow/python/keras/saving/saved_model/load.py @@ -43,7 +43,6 @@ from tensorflow.python.training.tracking import base as trackable from tensorflow.python.training.tracking.tracking import delete_tracking from tensorflow.python.util import compat from tensorflow.python.util import nest -from tensorflow.python.util import object_identity # To avoid circular dependencies between keras/engine and keras/saving, # code in keras/saving must delay imports. @@ -179,8 +178,6 @@ class KerasObjectLoader(tf_load.Loader): # records all nodes that were generated directly/indirectly from the config, # so that they do not get recreated multiple times. self._nodes_recreated_from_config = {} - self._all_nodes_recreated_from_config = ( - object_identity.ObjectIdentityWeakSet()) # Store all node ids that have already been traversed when tracking nodes # that were recreated from the config. self._traversed_nodes_from_config = [] @@ -293,7 +290,6 @@ class KerasObjectLoader(tf_load.Loader): 'Object: {}'.format(obj_child)) self._nodes_recreated_from_config[child_id] = ( obj_child, self._config_node_setter(setter)) - self._all_nodes_recreated_from_config.add(obj_child) self._add_children_recreated_from_config( obj_child, child_proto, child_id) @@ -363,7 +359,6 @@ class KerasObjectLoader(tf_load.Loader): setter = self._config_node_setter(_revive_setter) self._nodes_recreated_from_config[node_id] = obj, setter - self._all_nodes_recreated_from_config.add(obj) self._add_children_recreated_from_config( obj, self._proto.nodes[node_id], node_id) return obj, setter diff --git a/tensorflow/python/keras/saving/saved_model/saved_model_test.py b/tensorflow/python/keras/saving/saved_model/saved_model_test.py index 1dff9a2e8cf..69115e04cb0 100644 --- a/tensorflow/python/keras/saving/saved_model/saved_model_test.py +++ b/tensorflow/python/keras/saving/saved_model/saved_model_test.py @@ -507,7 +507,7 @@ class TestModelSavingAndLoadingV2(keras_parameterized.TestCase): self.assertAllClose( model.predict(input_arr), - loaded.signatures['predict'](ops.convert_to_tensor_v2( + loaded.signatures['predict'](ops.convert_to_tensor_v2_with_dispatch( input_arr.astype('float32')))['predictions']) feature = { @@ -517,7 +517,7 @@ class TestModelSavingAndLoadingV2(keras_parameterized.TestCase): example = example_pb2.Example( features=feature_pb2.Features(feature=feature)) outputs = loaded.signatures['parse_and_predict']( - ops.convert_to_tensor_v2([example.SerializeToString()])) + ops.convert_to_tensor_v2_with_dispatch([example.SerializeToString()])) self.assertAllClose(model.predict(input_arr), outputs['predictions']) self.assertAllClose(model.layers[0](input_arr), outputs['layer_1_outputs']) diff --git a/tensorflow/python/keras/testing_utils.py b/tensorflow/python/keras/testing_utils.py index 550ff664823..96868d009f4 100644 --- a/tensorflow/python/keras/testing_utils.py +++ b/tensorflow/python/keras/testing_utils.py @@ -26,6 +26,7 @@ import numpy as np from tensorflow.python import tf2 from tensorflow.python.eager import context +from tensorflow.python.framework import config from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape @@ -937,3 +938,69 @@ def use_gpu(): """Uses gpu when requested and available.""" with device(should_use_gpu=True): yield + + +def for_all_test_methods(decorator, *args, **kwargs): + """Generate class-level decorator from given method-level decorator. + + It is expected for the given decorator to take some arguments and return + a method that is then called on the test method to produce a decorated + method. + + Args: + decorator: The decorator to apply. + *args: Positional arguments + **kwargs: Keyword arguments + Returns: Function that will decorate a given classes test methods with the + decorator. + """ + + def all_test_methods_impl(cls): + """Apply decorator to all test methods in class.""" + for name in dir(cls): + value = getattr(cls, name) + if callable(value) and name.startswith('test') and (name != + 'test_session'): + setattr(cls, name, decorator(*args, **kwargs)(value)) + return cls + + return all_test_methods_impl + + +# The description is just for documentation purposes. +def run_without_tensor_float_32(description): # pylint: disable=unused-argument + """Execute test with TensorFloat-32 disabled. + + While almost every real-world deep learning model runs fine with + TensorFloat-32, many tests use assertAllClose or similar methods. + TensorFloat-32 matmuls typically will cause such methods to fail with the + default tolerances. + + Args: + description: A description used for documentation purposes, describing why + the test requires TensorFloat-32 to be disabled. + + Returns: + Decorator which runs a test with TensorFloat-32 disabled. + """ + + def decorator(f): + + @functools.wraps(f) + def decorated(self, *args, **kwargs): + allowed = config.tensor_float_32_execution_enabled() + try: + config.enable_tensor_float_32_execution(False) + f(self, *args, **kwargs) + finally: + config.enable_tensor_float_32_execution(allowed) + + return decorated + + return decorator + + +# The description is just for documentation purposes. +def run_all_without_tensor_float_32(description): # pylint: disable=unused-argument + """Execute all tests in a class with TensorFloat-32 disabled.""" + return for_all_test_methods(run_without_tensor_float_32, description) diff --git a/tensorflow/python/keras/tests/BUILD b/tensorflow/python/keras/tests/BUILD index 4db3327d1f6..9ce20c24284 100644 --- a/tensorflow/python/keras/tests/BUILD +++ b/tensorflow/python/keras/tests/BUILD @@ -274,7 +274,6 @@ cuda_py_test( name = "op_callbacks_test", srcs = ["op_callbacks_test.py"], python_version = "PY3", - tags = ["no_cuda11"], xla_enable_strict_auto_jit = False, deps = [ "//tensorflow/python:framework_ops", diff --git a/tensorflow/python/keras/tests/model_subclassing_test.py b/tensorflow/python/keras/tests/model_subclassing_test.py index 4b9b625cc0b..f47ee627a6f 100644 --- a/tensorflow/python/keras/tests/model_subclassing_test.py +++ b/tensorflow/python/keras/tests/model_subclassing_test.py @@ -428,7 +428,7 @@ class ModelSubclassingTest(keras_parameterized.TestCase): def call(self, inputs): return inputs + self.b + self.c - x = ops.convert_to_tensor_v2(np.ones((10, 10), 'float32')) + x = ops.convert_to_tensor_v2_with_dispatch(np.ones((10, 10), 'float32')) model = MyModel() model(x) self.assertEqual(1, len(model.trainable_weights)) @@ -444,7 +444,7 @@ class ModelSubclassingTest(keras_parameterized.TestCase): def call(self, inputs): return inputs + self.b + self.c - x = ops.convert_to_tensor_v2(np.ones((10, 10), 'float32')) + x = ops.convert_to_tensor_v2_with_dispatch(np.ones((10, 10), 'float32')) model = MyModelCustomBuild() model(x) self.assertEqual(1, len(model.trainable_weights)) @@ -467,7 +467,7 @@ class ModelSubclassingTest(keras_parameterized.TestCase): self.add_update(self.c.assign(inputs[1, :])) return inputs + self.b + self.c - x = ops.convert_to_tensor_v2(np.ones((10, 10), 'float32')) + x = ops.convert_to_tensor_v2_with_dispatch(np.ones((10, 10), 'float32')) model = MyModel() model(x) diff --git a/tensorflow/python/keras/utils/layer_utils.py b/tensorflow/python/keras/utils/layer_utils.py index 3195bb0eb13..7959b0263b1 100644 --- a/tensorflow/python/keras/utils/layer_utils.py +++ b/tensorflow/python/keras/utils/layer_utils.py @@ -29,7 +29,6 @@ from tensorflow.python.keras import backend as K from tensorflow.python.keras.utils.conv_utils import convert_kernel from tensorflow.python.util import deprecation from tensorflow.python.util import nest -from tensorflow.python.util import object_identity from tensorflow.python.util.tf_export import keras_export @@ -104,7 +103,7 @@ def count_params(weights): Returns: The total number of scalars composing the weights """ - unique_weights = object_identity.ObjectIdentitySet(weights) + unique_weights = {id(w): w for w in weights}.values() weight_shapes = [w.shape.as_list() for w in unique_weights] standardized_weight_shapes = [ [0 if w_i is None else w_i for w_i in w] for w in weight_shapes @@ -502,4 +501,3 @@ def cached_per_instance(f): wrapped.cache = cache return wrapped - diff --git a/tensorflow/python/keras/utils/losses_utils.py b/tensorflow/python/keras/utils/losses_utils.py index b8a063e3b42..08ef613c3e2 100644 --- a/tensorflow/python/keras/utils/losses_utils.py +++ b/tensorflow/python/keras/utils/losses_utils.py @@ -253,11 +253,11 @@ def compute_weighted_loss(losses, ops.get_default_graph()._last_loss_reduction = reduction # pylint: disable=protected-access if not isinstance(losses, keras_tensor.KerasTensor): - losses = ops.convert_to_tensor_v2(losses) + losses = ops.convert_to_tensor_v2_with_dispatch(losses) input_dtype = losses.dtype if not isinstance(sample_weight, keras_tensor.KerasTensor): - sample_weight = ops.convert_to_tensor_v2(sample_weight) + sample_weight = ops.convert_to_tensor_v2_with_dispatch(sample_weight) # TODO(psv): Handle casting here in a better way, eg. if losses is float64 # we do not want to lose precision. diff --git a/tensorflow/python/keras/utils/metrics_utils.py b/tensorflow/python/keras/utils/metrics_utils.py index 7d47850e8aa..5b3905a28da 100644 --- a/tensorflow/python/keras/utils/metrics_utils.py +++ b/tensorflow/python/keras/utils/metrics_utils.py @@ -311,7 +311,8 @@ def update_confusion_matrix_variables(variables_to_update, y_true = math_ops.cast(y_true, dtype=variable_dtype) y_pred = math_ops.cast(y_pred, dtype=variable_dtype) - thresholds = ops.convert_to_tensor_v2(thresholds, dtype=variable_dtype) + thresholds = ops.convert_to_tensor_v2_with_dispatch( + thresholds, dtype=variable_dtype) num_thresholds = thresholds.shape[0] if multi_label: one_thresh = math_ops.equal( diff --git a/tensorflow/python/keras/utils/tf_utils_test.py b/tensorflow/python/keras/utils/tf_utils_test.py index 9a3939e0c39..73d8671e388 100644 --- a/tensorflow/python/keras/utils/tf_utils_test.py +++ b/tensorflow/python/keras/utils/tf_utils_test.py @@ -44,14 +44,17 @@ class TestIsSymbolicTensor(test.TestCase, parameterized.TestCase): self.assertFalse(tf_utils.is_symbolic_tensor( variables.Variable(name='blah', initial_value=0.))) self.assertFalse( - tf_utils.is_symbolic_tensor(ops.convert_to_tensor_v2(0.))) + tf_utils.is_symbolic_tensor( + ops.convert_to_tensor_v2_with_dispatch(0.))) self.assertFalse(tf_utils.is_symbolic_tensor( sparse_tensor.SparseTensor( indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4]))) else: self.assertTrue(tf_utils.is_symbolic_tensor( variables.Variable(name='blah', initial_value=0.))) - self.assertTrue(tf_utils.is_symbolic_tensor(ops.convert_to_tensor_v2(0.))) + self.assertTrue( + tf_utils.is_symbolic_tensor( + ops.convert_to_tensor_v2_with_dispatch(0.))) self.assertTrue(tf_utils.is_symbolic_tensor( sparse_tensor.SparseTensor( indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4]))) @@ -61,7 +64,7 @@ class TestIsSymbolicTensor(test.TestCase, parameterized.TestCase): class CustomClass(object): def value(self): - return ops.convert_to_tensor_v2(42.) + return ops.convert_to_tensor_v2_with_dispatch(42.) ops.register_tensor_conversion_function( CustomClass, lambda value, **_: value.value()) @@ -72,7 +75,8 @@ class TestIsSymbolicTensor(test.TestCase, parameterized.TestCase): self.assertFalse(tf_utils.is_symbolic_tensor( variables.Variable(name='blah', initial_value=0.))) self.assertFalse( - tf_utils.is_symbolic_tensor(ops.convert_to_tensor_v2(0.))) + tf_utils.is_symbolic_tensor( + ops.convert_to_tensor_v2_with_dispatch(0.))) self.assertFalse(tf_utils.is_symbolic_tensor( sparse_tensor.SparseTensor( indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4]))) @@ -80,7 +84,9 @@ class TestIsSymbolicTensor(test.TestCase, parameterized.TestCase): else: self.assertTrue(tf_utils.is_symbolic_tensor( variables.Variable(name='blah', initial_value=0.))) - self.assertTrue(tf_utils.is_symbolic_tensor(ops.convert_to_tensor_v2(0.))) + self.assertTrue( + tf_utils.is_symbolic_tensor( + ops.convert_to_tensor_v2_with_dispatch(0.))) self.assertTrue(tf_utils.is_symbolic_tensor( sparse_tensor.SparseTensor( indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4]))) @@ -95,7 +101,7 @@ class TestIsSymbolicTensor(test.TestCase, parameterized.TestCase): def __init__(self, input_): self._input = input_ - self.value = ops.convert_to_tensor_v2([[42.]]) + self.value = ops.convert_to_tensor_v2_with_dispatch([[42.]]) @property def dtype(self): @@ -110,7 +116,7 @@ class TestIsSymbolicTensor(test.TestCase, parameterized.TestCase): def __init__(self, fn, **kwargs): def _fn(*fargs, **fkwargs): d = fn(*fargs, **fkwargs) - x = ops.convert_to_tensor_v2(d) + x = ops.convert_to_tensor_v2_with_dispatch(d) d.shape = x.shape d.get_shape = x.get_shape return d, x @@ -138,7 +144,7 @@ class TestIsSymbolicTensor(test.TestCase, parameterized.TestCase): model = keras.Model(model.inputs, model(model.outputs)) # Now we instantiate the model and verify we have a `Foo` object, not a # `Tensor`. - y = model(ops.convert_to_tensor_v2([[7.]])) + y = model(ops.convert_to_tensor_v2_with_dispatch([[7.]])) self.assertIsInstance(y, Foo) # Confirm that (custom) loss sees `Foo` instance, not Tensor. obtained_prediction_box = [None] diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD index 0d6b6ac36a3..2df4860d266 100644 --- a/tensorflow/python/kernel_tests/BUILD +++ b/tensorflow/python/kernel_tests/BUILD @@ -2048,6 +2048,9 @@ cuda_py_test( name = "dynamic_partition_op_test", size = "medium", srcs = ["dynamic_partition_op_test.py"], + tags = [ + "multi_and_single_gpu", + ], tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", diff --git a/tensorflow/python/kernel_tests/batch_matmul_op_test.py b/tensorflow/python/kernel_tests/batch_matmul_op_test.py index 30b61027813..ac82a320bb6 100644 --- a/tensorflow/python/kernel_tests/batch_matmul_op_test.py +++ b/tensorflow/python/kernel_tests/batch_matmul_op_test.py @@ -130,6 +130,7 @@ class BatchMatmulOpTest(test.TestCase): def _GetBatchMatmulOpTest(dtype, adjoint_a, adjoint_b, use_static_shape): + @test_util.run_without_tensor_float_32("Tests batch matmul") def Test(self): np.random.seed(42) self._testNonEmpty(dtype, adjoint_a, adjoint_b, use_static_shape) @@ -141,6 +142,7 @@ def _GetBatchMatmulOpTest(dtype, adjoint_a, adjoint_b, use_static_shape): def _GetBatchMatmulOpBroadcastingTest(dtype, adjoint_a, adjoint_b, use_static_shape): + @test_util.run_without_tensor_float_32("Tests batch matmul") def Test(self): np.random.seed(42) self._testBroadcasting(dtype, adjoint_a, adjoint_b, use_static_shape) diff --git a/tensorflow/python/kernel_tests/cholesky_op_test.py b/tensorflow/python/kernel_tests/cholesky_op_test.py index a9afca8bfe7..0697f7def1b 100644 --- a/tensorflow/python/kernel_tests/cholesky_op_test.py +++ b/tensorflow/python/kernel_tests/cholesky_op_test.py @@ -106,7 +106,7 @@ class CholeskyOpTest(test.TestCase): def _verifyCholesky(self, x): # Verify that LL^T == x. chol = linalg_ops.cholesky(x) - verification = math_ops.matmul(chol, chol, adjoint_b=True) + verification = test_util.matmul_without_tf32(chol, chol, adjoint_b=True) self._verifyCholeskyBase(x, chol, verification) @test_util.run_in_graph_and_eager_modes(use_gpu=True) @@ -271,8 +271,8 @@ class CholeskyGradTest(test.TestCase): def Compute(x): # Turn the random matrix x into a Hermitian matrix by # computing the quadratic form x * x^H. - a = math_ops.matmul(x, math_ops.conj( - array_ops.matrix_transpose(x))) / shape[0] + a = test_util.matmul_without_tf32( + x, math_ops.conj(array_ops.matrix_transpose(x))) / shape[0] if batch: a = array_ops.tile(array_ops.expand_dims(a, 0), [2, 1, 1]) # Finally take the cholesky decomposition of the Hermitian matrix. diff --git a/tensorflow/python/kernel_tests/conv_ops_3d_test.py b/tensorflow/python/kernel_tests/conv_ops_3d_test.py index 9bd962e75f3..3acc1fe03be 100644 --- a/tensorflow/python/kernel_tests/conv_ops_3d_test.py +++ b/tensorflow/python/kernel_tests/conv_ops_3d_test.py @@ -48,6 +48,10 @@ def GetTestConfigs(): return test_configs +@test_util.run_all_without_tensor_float_32( + "Tests Conv3d, which in some cases is implemented with a matmul. With " + "tf32, tests fail in some of those cases (and as of August 13 2020, only " + "those cases)") class Conv3DTest(test.TestCase): def _DtypesToTest(self, use_gpu): diff --git a/tensorflow/python/kernel_tests/conv_ops_test.py b/tensorflow/python/kernel_tests/conv_ops_test.py index f480f4319da..f7234398cdc 100644 --- a/tensorflow/python/kernel_tests/conv_ops_test.py +++ b/tensorflow/python/kernel_tests/conv_ops_test.py @@ -310,14 +310,14 @@ class Conv2DTest(test.TestCase): data_format, use_gpu) expected_results.append(expected) computed_results.append(computed) - tolerance = 1e-2 if use_gpu else 1e-5 - expected_values = self.evaluate(expected_results) - computed_values = self.evaluate(computed_results) - for e_value, c_value in zip(expected_values, computed_values): - tf_logging.debug("expected = %s", e_value) - tf_logging.debug("actual = %s", c_value) - self.assertAllClose( - e_value.flatten(), c_value.flatten(), atol=tolerance, rtol=rtol) + tolerance = 1e-2 if use_gpu else 1e-5 + expected_values = self.evaluate(expected_results) + computed_values = self.evaluate(computed_results) + for e_value, c_value in zip(expected_values, computed_values): + tf_logging.debug("expected = %s", e_value) + tf_logging.debug("actual = %s", c_value) + self.assertAllClose( + e_value.flatten(), c_value.flatten(), atol=tolerance, rtol=rtol) def _VerifyValues(self, tensor_in_sizes, diff --git a/tensorflow/python/kernel_tests/distributions/dirichlet_multinomial_test.py b/tensorflow/python/kernel_tests/distributions/dirichlet_multinomial_test.py index a4c07daa940..7c8f389f178 100644 --- a/tensorflow/python/kernel_tests/distributions/dirichlet_multinomial_test.py +++ b/tensorflow/python/kernel_tests/distributions/dirichlet_multinomial_test.py @@ -268,6 +268,8 @@ class DirichletMultinomialTest(test.TestCase): self.assertAllClose(sample_var_, analytic_var, atol=0.05, rtol=0.) self.assertAllClose(sample_stddev_, analytic_stddev, atol=0.02, rtol=0.) + @test_util.run_without_tensor_float_32( + "Tests DirichletMultinomial.covariance, which calls matmul") def testCovariance(self): # Shape [2] alpha = [1., 2] diff --git a/tensorflow/python/kernel_tests/distributions/dirichlet_test.py b/tensorflow/python/kernel_tests/distributions/dirichlet_test.py index 0f963824531..a0d8bef327d 100644 --- a/tensorflow/python/kernel_tests/distributions/dirichlet_test.py +++ b/tensorflow/python/kernel_tests/distributions/dirichlet_test.py @@ -200,6 +200,8 @@ class DirichletTest(test.TestCase): self.assertAllClose(sample_var_, analytic_var, atol=0.03, rtol=0.) self.assertAllClose(sample_stddev_, analytic_stddev, atol=0.02, rtol=0.) + @test_util.run_without_tensor_float_32( + "Calls Dirichlet.covariance, which calls matmul") def testVariance(self): alpha = [1., 2, 3] denominator = np.sum(alpha)**2 * (np.sum(alpha) + 1) diff --git a/tensorflow/python/kernel_tests/dynamic_partition_op_test.py b/tensorflow/python/kernel_tests/dynamic_partition_op_test.py index 8c448194076..0fd9790c794 100644 --- a/tensorflow/python/kernel_tests/dynamic_partition_op_test.py +++ b/tensorflow/python/kernel_tests/dynamic_partition_op_test.py @@ -23,8 +23,10 @@ import unittest import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin +from tensorflow.python.framework import config from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import data_flow_ops @@ -346,6 +348,19 @@ class DynamicPartitionTest(test.TestCase): res = self.evaluate(partitioned) self.assertEqual(res[-1].shape[0], 192) + # see https://github.com/tensorflow/tensorflow/issues/42500 + def testMultiGPU(self): + device_list = config.list_logical_devices("GPU") + results = [] + for device in device_list: + with ops.device(device.name): + data = constant_op.constant(np.zeros((1000,))) + partitions = constant_op.constant(np.arange(1000, dtype=np.int32) % 10) + result = data_flow_ops.dynamic_partition(data, partitions, 10) + results.append(self.evaluate(result)) + if device_list: + self.assertAllEqual(results, np.zeros((len(device_list), 10, 100))) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/kernel_tests/einsum_op_test.py b/tensorflow/python/kernel_tests/einsum_op_test.py index aa9e356bea5..4236eb93278 100644 --- a/tensorflow/python/kernel_tests/einsum_op_test.py +++ b/tensorflow/python/kernel_tests/einsum_op_test.py @@ -35,6 +35,8 @@ from tensorflow.python.platform import benchmark from tensorflow.python.platform import test +@test_util.run_all_without_tensor_float_32( + 'Tests einsum, which sometimes does a matmul with cuBLAS') class EinsumOpTest(test.TestCase): def _check(self, s, *input_shapes, **kwargs): @@ -285,6 +287,8 @@ class EinsumOpTest(test.TestCase): @test_util.run_all_in_graph_and_eager_modes +@test_util.run_all_without_tensor_float_32( + "Tests einsum's gradient, which sometimes does a matmul with cuBLAS") class EinsumGradTest(test.TestCase): def _check_gradient(self, s, *input_shapes): diff --git a/tensorflow/python/kernel_tests/init_ops_test.py b/tensorflow/python/kernel_tests/init_ops_test.py index e3268fad2d8..f2348c6c7ac 100644 --- a/tensorflow/python/kernel_tests/init_ops_test.py +++ b/tensorflow/python/kernel_tests/init_ops_test.py @@ -945,6 +945,8 @@ class ConvolutionDeltaOrthogonalInitializerTest(test.TestCase): self.assertAllClose(abs_value, count, rtol=tol, atol=tol) +@test_util.run_all_without_tensor_float_32( + "Tests convolutional_orthogonal_1d, which calls matmul") class ConvolutionOrthogonal1dInitializerTest(test.TestCase): @test_util.run_deprecated_v1 @@ -1174,6 +1176,8 @@ class ConvolutionOrthogonal2dInitializerTest(test.TestCase): self.assertAllClose(self.evaluate(ratio), gain, rtol=tol, atol=tol) +@test_util.run_all_without_tensor_float_32( + "Tests convolutional_orthogonal_3d, which calls matmul") class ConvolutionOrthogonal3dInitializerTest(test.TestCase): @test_util.run_deprecated_v1 diff --git a/tensorflow/python/kernel_tests/linalg/sparse/csr_sparse_matrix_ops_test.py b/tensorflow/python/kernel_tests/linalg/sparse/csr_sparse_matrix_ops_test.py index ac82f190db0..f42600bd334 100644 --- a/tensorflow/python/kernel_tests/linalg/sparse/csr_sparse_matrix_ops_test.py +++ b/tensorflow/python/kernel_tests/linalg/sparse/csr_sparse_matrix_ops_test.py @@ -534,7 +534,7 @@ class CSRSparseMatrixOpsTest(test.TestCase): c_value = self.evaluate(c) expected_c_value = self.evaluate( - math_ops.conj(math_ops.matmul(a_dense, b))) + math_ops.conj(test_util.matmul_without_tf32(a_dense, b))) self.assertAllClose(expected_c_value, c_value) @test_util.run_in_graph_and_eager_modes @@ -576,7 +576,7 @@ class CSRSparseMatrixOpsTest(test.TestCase): transpose_b=transpose_b, adjoint_a=adjoint_a, adjoint_b=adjoint_b) - c_dense_t = math_ops.matmul( + c_dense_t = test_util.matmul_without_tf32( a_mats, b_mats, transpose_a=transpose_a, @@ -640,7 +640,7 @@ class CSRSparseMatrixOpsTest(test.TestCase): adjoint_b=adjoint_b) # Example: t(adj(a) . b) = t(b) . conj(a) - c_dense_t = math_ops.matmul( + c_dense_t = test_util.matmul_without_tf32( math_ops.conj(b_mats) if adjoint_b else b_mats, math_ops.conj(a_mats) if adjoint_a else a_mats, transpose_a=not (transpose_b or adjoint_b), @@ -670,7 +670,7 @@ class CSRSparseMatrixOpsTest(test.TestCase): c_t = sparse_csr_matrix_ops.sparse_matrix_mat_mul( a_sm, b_mats, conjugate_output=True) - c_dense_t = math_ops.conj(math_ops.matmul(a_mats, b_mats)) + c_dense_t = math_ops.conj(test_util.matmul_without_tf32(a_mats, b_mats)) self.assertAllEqual(c_t.shape, c_dense_t.shape) c_t_value, c_dense_t_value = self.evaluate((c_t, c_dense_t)) @@ -772,7 +772,7 @@ class CSRSparseMatrixOpsTest(test.TestCase): adjoint_b=adjoint_b) c_sm_dense = sparse_csr_matrix_ops.csr_sparse_matrix_to_dense( c_sm, dtypes.float32) - c_dense_t = math_ops.matmul( + c_dense_t = test_util.matmul_without_tf32( a_mats, b_mats, transpose_a=transpose_a, @@ -1143,7 +1143,7 @@ class CSRSparseMatrixOpsTest(test.TestCase): dense_cholesky = sparse_csr_matrix_ops.csr_sparse_matrix_to_dense( cholesky_sparse_matrices, dtype) # Compute L * Lh where L is the Sparse Cholesky factor. - verification = math_ops.matmul( + verification = test_util.matmul_without_tf32( dense_cholesky, array_ops.transpose(dense_cholesky, conjugate=True)) verification = twist_matrix(verification, ordering_amd) # Assert that input matrix A satisfies A = L * Lh. @@ -1197,7 +1197,7 @@ class CSRSparseMatrixOpsTest(test.TestCase): cholesky_sparse_matrix, dtype) # Compute L * Lh. - verification = math_ops.matmul( + verification = test_util.matmul_without_tf32( dense_cholesky, array_ops.transpose(dense_cholesky, perm=[0, 2, 1], conjugate=True)) verification = twist_matrix(verification, ordering_amd) @@ -1238,7 +1238,7 @@ class CSRSparseMatrixOpsTest(test.TestCase): cholesky_sparse_matrix, dtypes.float32) # Compute L * Lh. - verification = math_ops.matmul( + verification = test_util.matmul_without_tf32( dense_cholesky, array_ops.transpose(dense_cholesky, perm=[0, 2, 1])) verification = twist_matrix(verification, ordering_amd) verification_values = self.evaluate(verification) diff --git a/tensorflow/python/kernel_tests/linalg/sparse/csr_sparse_matrix_test.py b/tensorflow/python/kernel_tests/linalg/sparse/csr_sparse_matrix_test.py index 35c706cb36a..4aa3474ffbb 100644 --- a/tensorflow/python/kernel_tests/linalg/sparse/csr_sparse_matrix_test.py +++ b/tensorflow/python/kernel_tests/linalg/sparse/csr_sparse_matrix_test.py @@ -162,7 +162,7 @@ class SparseMatrixMatmulTest(test.TestCase): 1.j * np.random.randn(*dense_shape_b))).astype(dtype) a_sm = sparse_csr_matrix_ops.CSRSparseMatrix(a_mats) b_sm = sparse_csr_matrix_ops.CSRSparseMatrix(b_mats) - c_dense = math_ops.matmul( + c_dense = test_util.matmul_without_tf32( a_mats, b_mats, transpose_a=transpose_a, @@ -202,7 +202,7 @@ class SparseMatrixMatmulTest(test.TestCase): b_mats = (np.random.randn(*dense_shape_b) + 1.j * np.random.randn(*dense_shape_b)).astype(dtype) a_sm = sparse_csr_matrix_ops.CSRSparseMatrix(a_mats) - c_dense = math_ops.matmul( + c_dense = test_util.matmul_without_tf32( a_mats, b_mats, transpose_a=transpose_a, @@ -240,7 +240,7 @@ class SparseMatrixMatmulTest(test.TestCase): b_mats = sparsify((np.random.randn(*dense_shape_b) + 1.j * np.random.randn(*dense_shape_b))).astype(dtype) b_sm = sparse_csr_matrix_ops.CSRSparseMatrix(b_mats) - c_dense = math_ops.matmul( + c_dense = test_util.matmul_without_tf32( a_mats, b_mats, transpose_a=transpose_a, diff --git a/tensorflow/python/kernel_tests/linalg_grad_test.py b/tensorflow/python/kernel_tests/linalg_grad_test.py index f1d885fd231..273aba4d94f 100644 --- a/tensorflow/python/kernel_tests/linalg_grad_test.py +++ b/tensorflow/python/kernel_tests/linalg_grad_test.py @@ -63,6 +63,9 @@ def _GetMatrixUnaryFunctorGradientTest(functor_, dtype_, shape_, **kwargs_): @test_util.enable_control_flow_v2 @test_util.run_in_graph_and_eager_modes(use_gpu=True) + @test_util.run_without_tensor_float_32( + 'Tests `tf.linalg.expm`, which call matmul. Additionally, calls ops ' + 'which do matmul in their gradient, such as MatrixSolve.') def Test(self): def RandomInput(): @@ -102,6 +105,16 @@ def _GetMatrixBinaryFunctorGradientTest(functor_, **kwargs_): @test_util.run_in_graph_and_eager_modes(use_gpu=True) + @test_util.run_without_tensor_float_32( + 'Tests `tf.linalg.lstsq`, which call matmul. Additionally, calls ops ' + 'which do matmul in their gradient, such as MatrixSolveLs.') + # TODO(b/164254522): With tf32, some tests fails with extremely high absolute + # and relative differences when calling assertAllClose. For example, the test + # test_MatrixSolveLsGradient_float32_10_10_1e-06 of class + # MatrixBinaryFunctorGradientTest fails with a max absolute difference of + # 0.883 and a max relative difference of 736892. We should consider disabling + # tf32 within `tf.linalg.lstsq and perhaps other linear algebra functions, + # even if tf32 is allowed globally. def Test(self): def RandomInput(): diff --git a/tensorflow/python/kernel_tests/lu_op_test.py b/tensorflow/python/kernel_tests/lu_op_test.py index fee6aecb3b0..8d522e80a08 100644 --- a/tensorflow/python/kernel_tests/lu_op_test.py +++ b/tensorflow/python/kernel_tests/lu_op_test.py @@ -91,7 +91,7 @@ class LuOpTest(test.TestCase): # Prepare the upper factor. upper = array_ops.matrix_band_part(lu, 0, -1) - verification = math_ops.matmul(lower, upper) + verification = test_util.matmul_without_tf32(lower, upper) # Permute the rows of product of the Cholesky factors. if num_rows > 0: diff --git a/tensorflow/python/kernel_tests/matmul_op_test.py b/tensorflow/python/kernel_tests/matmul_op_test.py index 712d7336b94..737ca777804 100644 --- a/tensorflow/python/kernel_tests/matmul_op_test.py +++ b/tensorflow/python/kernel_tests/matmul_op_test.py @@ -70,6 +70,7 @@ class MatMulTest(test_lib.TestCase): def _GetMatMulTest(a_np_, b_np_, use_static_shape_, **kwargs_): + @test_util.run_without_tensor_float_32("Tests matmul") def Test(self): np_val = np.matrix(a_np_) * np.matrix(b_np_) diff --git a/tensorflow/python/kernel_tests/matrix_inverse_op_test.py b/tensorflow/python/kernel_tests/matrix_inverse_op_test.py index ffe0f595618..9a5a467a5a1 100644 --- a/tensorflow/python/kernel_tests/matrix_inverse_op_test.py +++ b/tensorflow/python/kernel_tests/matrix_inverse_op_test.py @@ -26,7 +26,6 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import linalg_ops -from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import variables from tensorflow.python.platform import benchmark @@ -41,7 +40,7 @@ class InverseOpTest(test.TestCase): with self.cached_session(use_gpu=True): # Verify that x^{-1} * x == Identity matrix. inv = linalg_ops.matrix_inverse(y, adjoint=adjoint) - tf_ans = math_ops.matmul(inv, y, adjoint_b=adjoint) + tf_ans = test_util.matmul_without_tf32(inv, y, adjoint_b=adjoint) np_ans = np.identity(y.shape[-1]) if x.ndim > 2: tiling = list(y.shape) diff --git a/tensorflow/python/kernel_tests/matrix_square_root_op_test.py b/tensorflow/python/kernel_tests/matrix_square_root_op_test.py index 6cf330ed981..98796f256ab 100644 --- a/tensorflow/python/kernel_tests/matrix_square_root_op_test.py +++ b/tensorflow/python/kernel_tests/matrix_square_root_op_test.py @@ -29,6 +29,7 @@ from tensorflow.python.ops import stateless_random_ops from tensorflow.python.platform import test +@test_util.run_all_without_tensor_float_32 class SquareRootOpTest(test.TestCase): def _verifySquareRoot(self, matrix, np_type): @@ -36,7 +37,7 @@ class SquareRootOpTest(test.TestCase): # Verify that matmul(sqrtm(A), sqrtm(A)) = A sqrt = gen_linalg_ops.matrix_square_root(matrix) - square = math_ops.matmul(sqrt, sqrt) + square = test_util.matmul_without_tf32(sqrt, sqrt) self.assertShapeEqual(matrix, square) self.assertAllClose(matrix, square, rtol=1e-4, atol=1e-3) diff --git a/tensorflow/python/kernel_tests/parse_single_example_op_test.py b/tensorflow/python/kernel_tests/parse_single_example_op_test.py index ab270bf0d59..498b6a8fd65 100644 --- a/tensorflow/python/kernel_tests/parse_single_example_op_test.py +++ b/tensorflow/python/kernel_tests/parse_single_example_op_test.py @@ -856,6 +856,7 @@ class ParseSingleExampleTest(test.TestCase): expected_err[1]): out = parsing_ops.parse_single_example(**kwargs) sess.run(flatten_values_tensors_or_sparse(out.values())) + return else: # Returns dict w/ Tensors and SparseTensors. out = parsing_ops.parse_single_example(**kwargs) @@ -939,6 +940,20 @@ class ParseSingleExampleTest(test.TestCase): }, expected_output) + def testExampleLongerThanSpec(self): + serialized = example( + features=features({ + "a": bytes_feature([b"a", b"b"]), + })).SerializeToString() + self._test( + { + "serialized": ops.convert_to_tensor(serialized), + "features": { + "a": parsing_ops.FixedLenFeature(1, dtypes.string) + } + }, + expected_err=(errors_impl.OpError, "Can't parse serialized Example")) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/kernel_tests/qr_op_test.py b/tensorflow/python/kernel_tests/qr_op_test.py index b895fe4ea99..0a618b7f555 100644 --- a/tensorflow/python/kernel_tests/qr_op_test.py +++ b/tensorflow/python/kernel_tests/qr_op_test.py @@ -200,6 +200,8 @@ def _GetQrGradOpTest(dtype_, shape_, full_matrices_): return a @test_util.run_in_graph_and_eager_modes(use_gpu=True) + @test_util.run_without_tensor_float_32("Tests Qr gradient, which calls matmul" + ) def Test(self): np.random.seed(42) # Optimal stepsize for central difference is O(epsilon^{1/3}). diff --git a/tensorflow/python/kernel_tests/resource_variable_ops_test.py b/tensorflow/python/kernel_tests/resource_variable_ops_test.py index 9a927b86d0b..83fdfc7a33b 100644 --- a/tensorflow/python/kernel_tests/resource_variable_ops_test.py +++ b/tensorflow/python/kernel_tests/resource_variable_ops_test.py @@ -169,10 +169,12 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase, @test_util.run_in_graph_and_eager_modes def testVariableShape(self): v = resource_variable_ops.ResourceVariable([1., 1.]) + vshape = resource_variable_ops.variable_shape(v.handle) self.assertAllEqual( - tensor_util.constant_value( - resource_variable_ops.variable_shape(v.handle)), + tensor_util.constant_value(vshape), [2]) + if not context.executing_eagerly(): + self.assertEqual("Const", vshape.op.type) @test_util.run_deprecated_v1 def testDifferentAssignGraph(self): diff --git a/tensorflow/python/kernel_tests/rnn_cell_test.py b/tensorflow/python/kernel_tests/rnn_cell_test.py index 01b324f29fb..7fa31d14777 100644 --- a/tensorflow/python/kernel_tests/rnn_cell_test.py +++ b/tensorflow/python/kernel_tests/rnn_cell_test.py @@ -3062,6 +3062,8 @@ class RNNCellTest(test.TestCase, parameterized.TestCase): @test_util.run_all_in_graph_and_eager_modes +@test_util.run_all_without_tensor_float_32( + "Uses an LSTMCell, which calls matmul") class DropoutWrapperTest(test.TestCase, parameterized.TestCase): def _testDropoutWrapper(self, diff --git a/tensorflow/python/kernel_tests/self_adjoint_eig_op_test.py b/tensorflow/python/kernel_tests/self_adjoint_eig_op_test.py index 5be7cb4dd3a..40f8b31b7c2 100644 --- a/tensorflow/python/kernel_tests/self_adjoint_eig_op_test.py +++ b/tensorflow/python/kernel_tests/self_adjoint_eig_op_test.py @@ -38,6 +38,7 @@ def _AddTest(test_class, op_name, testcase_name, fn): setattr(test_class, test_name, fn) +@test_util.run_all_without_tensor_float_32 class SelfAdjointEigTest(test.TestCase): @test_util.run_deprecated_v1 @@ -160,8 +161,8 @@ def _GetSelfAdjointEigTest(dtype_, shape_, compute_v_): tf_e, tf_v = linalg_ops.self_adjoint_eig(constant_op.constant(a)) # Check that V*diag(E)*V^T is close to A. - a_ev = math_ops.matmul( - math_ops.matmul(tf_v, array_ops.matrix_diag(tf_e)), + a_ev = test_util.matmul_without_tf32( + test_util.matmul_without_tf32(tf_v, array_ops.matrix_diag(tf_e)), tf_v, adjoint_b=True) self.assertAllClose(self.evaluate(a_ev), a, atol=atol) diff --git a/tensorflow/python/kernel_tests/tensordot_op_test.py b/tensorflow/python/kernel_tests/tensordot_op_test.py index a031f9bca07..368a7f18f8b 100644 --- a/tensorflow/python/kernel_tests/tensordot_op_test.py +++ b/tensorflow/python/kernel_tests/tensordot_op_test.py @@ -165,6 +165,7 @@ def _get_tensordot_tests(dtype_, rank_a_, rank_b_, num_dims_, dynamic_shape_): return a, b, a_dims, b_dims @test_util.run_in_graph_and_eager_modes(use_gpu=True) + @test_util.run_without_tensor_float_32("Tests tensordot, which calls matmul") def test_tensordot(self): if dynamic_shape_ and context.executing_eagerly(): self.skipTest("Placeholders not support in eager mode") @@ -196,6 +197,7 @@ def _get_tensordot_tests(dtype_, rank_a_, rank_b_, num_dims_, dynamic_shape_): self.assertAllEqual(tf_ans.shape, np_ans.shape) @test_util.run_in_graph_and_eager_modes(use_gpu=True) + @test_util.run_without_tensor_float_32("Tests tensordot, which calls matmul") def test_tensordot_scalar_axes(self): if dynamic_shape_ and context.executing_eagerly(): self.skipTest("Placeholders not support in eager mode") diff --git a/tensorflow/python/lib/core/safe_ptr.cc b/tensorflow/python/lib/core/safe_ptr.cc index 2194f2499fd..ce852a4f009 100644 --- a/tensorflow/python/lib/core/safe_ptr.cc +++ b/tensorflow/python/lib/core/safe_ptr.cc @@ -17,10 +17,6 @@ limitations under the License. namespace tensorflow { -Safe_PyObjectPtr make_safe(PyObject* object) { - return Safe_PyObjectPtr(object); -} - Safe_TF_TensorPtr make_safe(TF_Tensor* tensor) { return Safe_TF_TensorPtr(tensor); } diff --git a/tensorflow/python/lib/core/safe_ptr.h b/tensorflow/python/lib/core/safe_ptr.h index 44d14e9bea4..00f47d7bbe6 100644 --- a/tensorflow/python/lib/core/safe_ptr.h +++ b/tensorflow/python/lib/core/safe_ptr.h @@ -16,20 +16,17 @@ limitations under the License. #ifndef TENSORFLOW_PYTHON_LIB_CORE_SAFE_PTR_H_ #define TENSORFLOW_PYTHON_LIB_CORE_SAFE_PTR_H_ -#include - #include +#include + #include "tensorflow/c/c_api.h" #include "tensorflow/c/eager/c_api.h" +#include "tensorflow/python/lib/core/safe_pyobject_ptr.h" namespace tensorflow { namespace detail { -struct PyDecrefDeleter { - void operator()(PyObject* p) const { Py_DECREF(p); } -}; - struct TFTensorDeleter { void operator()(TF_Tensor* p) const { TF_DeleteTensor(p); } }; @@ -48,11 +45,6 @@ struct TFBufferDeleter { } // namespace detail -// Safe container for an owned PyObject. On destruction, the reference count of -// the contained object will be decremented. -using Safe_PyObjectPtr = std::unique_ptr; -Safe_PyObjectPtr make_safe(PyObject* o); - // Safe containers for an owned TF_Tensor. On destruction, the tensor will be // deleted by TF_DeleteTensor. using Safe_TF_TensorPtr = std::unique_ptr; diff --git a/tensorflow/python/lib/core/safe_pyobject_ptr.cc b/tensorflow/python/lib/core/safe_pyobject_ptr.cc new file mode 100644 index 00000000000..966d3ec5ab5 --- /dev/null +++ b/tensorflow/python/lib/core/safe_pyobject_ptr.cc @@ -0,0 +1,24 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/python/lib/core/safe_pyobject_ptr.h" + +namespace tensorflow { + +Safe_PyObjectPtr make_safe(PyObject* object) { + return Safe_PyObjectPtr(object); +} + +} // namespace tensorflow diff --git a/tensorflow/python/lib/core/safe_pyobject_ptr.h b/tensorflow/python/lib/core/safe_pyobject_ptr.h new file mode 100644 index 00000000000..496bfed6c62 --- /dev/null +++ b/tensorflow/python/lib/core/safe_pyobject_ptr.h @@ -0,0 +1,39 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_PYTHON_LIB_CORE_SAFE_PYOBJECT_PTR_H_ +#define TENSORFLOW_PYTHON_LIB_CORE_SAFE_PYOBJECT_PTR_H_ + +#include + +#include + +namespace tensorflow { +namespace detail { + +struct PyDecrefDeleter { + void operator()(PyObject* p) const { Py_DECREF(p); } +}; + +} // namespace detail + +// Safe container for an owned PyObject. On destruction, the reference count of +// the contained object will be decremented. +using Safe_PyObjectPtr = std::unique_ptr; +Safe_PyObjectPtr make_safe(PyObject* o); + +} // namespace tensorflow + +#endif // TENSORFLOW_PYTHON_LIB_CORE_SAFE_PYOBJECT_PTR_H_ diff --git a/tensorflow/python/ops/collective_ops_multi_worker_test.py b/tensorflow/python/ops/collective_ops_multi_worker_test.py new file mode 100644 index 00000000000..4385a20cd20 --- /dev/null +++ b/tensorflow/python/ops/collective_ops_multi_worker_test.py @@ -0,0 +1,139 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for multi worker Collective Operations.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import copy +import os +import time + +from tensorflow.core.protobuf import tensorflow_server_pb2 +from tensorflow.python.distribute import cluster_resolver as cluster_resolver_lib +from tensorflow.python.distribute import multi_process_runner +from tensorflow.python.distribute import multi_worker_test_base +from tensorflow.python.eager import context +from tensorflow.python.eager import test +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import errors +from tensorflow.python.ops import collective_ops + + +def enable_collective_ops(cluster_resolver): + context.context().configure_collective_ops( + collective_leader="/job:worker/replica:0/task:0") + config_proto = copy.deepcopy(context.context().config) + server_def = tensorflow_server_pb2.ServerDef( + cluster=cluster_resolver.cluster_spec().as_cluster_def(), + default_session_config=config_proto, + job_name=cluster_resolver.task_type, + task_index=cluster_resolver.task_id, + protocol=cluster_resolver.rpc_layer or "grpc") + context.context().enable_collective_ops(server_def) + + +class CollectiveOpTest(test.TestCase): + + def testCheckHealth(self): + + def worker_fn(): + enable_collective_ops(cluster_resolver_lib.TFConfigClusterResolver()) + # There may be some delays before the server startup. Check health should + # eventually be OK. + while True: + try: + for task in [ + "/job:worker/replica:0/task:0", + "/job:worker/replica:0/task:1", + ]: + context.context().check_collective_ops_peer_health(task) + except errors.UnavailableError: + continue + break + multi_process_runner.barrier().wait() + + cluster_spec = multi_worker_test_base.create_cluster_spec(num_workers=2) + mpr = multi_process_runner.MultiProcessRunner(worker_fn, cluster_spec) + mpr.start() + mpr.join() + + def testCheckHealthPeerDown(self): + + def worker_fn(): + enable_collective_ops(cluster_resolver_lib.TFConfigClusterResolver()) + context.context().check_collective_ops_peer_health( + "/job:worker/replica:0/task:1",) + + cluster_spec = multi_worker_test_base.create_cluster_spec(num_workers=2) + mpr = multi_process_runner.MultiProcessRunner(worker_fn, cluster_spec) + mpr.start_single_process("worker", 0) + with self.assertRaises(errors.UnavailableError): + mpr.join() + + def testCheckHealthPeerRestart(self): + + def worker_fn(): + cluster_resolver = cluster_resolver_lib.TFConfigClusterResolver() + enable_collective_ops(cluster_resolver) + + collective_ops.all_reduce( + constant_op.constant(1.), + group_size=2, + group_key=100, + instance_key=100, + merge_op="Add", + final_op="Id", + communication_hint="ring") + + if cluster_resolver.task_type == "worker": + # MultiProcessRunner will auto restart worker-0. + os._exit(1) # pylint: disable=protected-access + else: + # chief should eventually gets FailedPreconditionError after worker-0 + # has restarted. + while True: + time.sleep(1) + try: + context.context().check_collective_ops_peer_health( + "/job:worker/replica:0/task:0",) + except errors.UnavailableError: + pass + except errors.FailedPreconditionError: + break + + cluster_spec = multi_worker_test_base.create_cluster_spec( + has_chief=True, num_workers=1) + mpr = multi_process_runner.MultiProcessRunner( + worker_fn, cluster_spec, auto_restart=True) + mpr.start() + mpr.join() + + def testCheckHealthInvalidPeer(self): + + def worker_fn(): + enable_collective_ops(cluster_resolver_lib.TFConfigClusterResolver()) + context.context().check_collective_ops_peer_health("localhost:12345",) + + cluster_spec = multi_worker_test_base.create_cluster_spec(num_workers=2) + mpr = multi_process_runner.MultiProcessRunner(worker_fn, cluster_spec) + mpr.start_single_process("worker", 0) + with self.assertRaises(errors.InvalidArgumentError): + mpr.join() + + +if __name__ == "__main__": + multi_process_runner.test_main() diff --git a/tensorflow/python/ops/image_ops_test.py b/tensorflow/python/ops/image_ops_test.py index 751a8a00758..1c8d8d69b38 100644 --- a/tensorflow/python/ops/image_ops_test.py +++ b/tensorflow/python/ops/image_ops_test.py @@ -97,6 +97,8 @@ class RGBToHSVTest(test_util.TensorFlowTestCase): class RGBToYIQTest(test_util.TensorFlowTestCase): + @test_util.run_without_tensor_float_32( + "Calls rgb_to_yiq and yiq_to_rgb, which use matmul") def testBatch(self): # Build an arbitrary RGB image np.random.seed(7) @@ -127,6 +129,8 @@ class RGBToYIQTest(test_util.TensorFlowTestCase): class RGBToYUVTest(test_util.TensorFlowTestCase): + @test_util.run_without_tensor_float_32( + "Calls rgb_to_yuv and yuv_to_rgb, which use matmul") def testBatch(self): # Build an arbitrary RGB image np.random.seed(7) diff --git a/tensorflow/python/ops/nn_test.py b/tensorflow/python/ops/nn_test.py index 9b864be39a2..7f3d9f6e286 100644 --- a/tensorflow/python/ops/nn_test.py +++ b/tensorflow/python/ops/nn_test.py @@ -541,6 +541,8 @@ class DropoutTest(test_lib.TestCase): _ = nn_ops.dropout(x, 0.5) +@test_util.run_all_without_tensor_float_32( + "Tests _compute_sampled_logits and related functions, which call matmul") class ComputeSampledLogitsTest(test_lib.TestCase): def setUp(self): diff --git a/tensorflow/python/ops/parallel_for/math_test.py b/tensorflow/python/ops/parallel_for/math_test.py index 85b58055d8f..30e724413f4 100644 --- a/tensorflow/python/ops/parallel_for/math_test.py +++ b/tensorflow/python/ops/parallel_for/math_test.py @@ -261,6 +261,9 @@ class MathTest(PForTestCase, parameterized.TestCase): self._test_loop_fn(loop_fn, 4) + @test_util.run_without_tensor_float_32( + "Calls matmul in parallel for-loop and compares result to calling matmul " + "in sequential for-loop") def test_matmul(self): for tr_a in (True, False): for tr_b in (True, False): @@ -745,6 +748,9 @@ class LinalgTest(PForTestCase): self._test_loop_fn(loop_fn, 2) + @test_util.run_without_tensor_float_32( + "Calls einsum in parallel for-loop and compares result to calling einsum " + "in sequential for-loop") def test_einsum(self): b = 10 x_series = random_ops.random_uniform([b, 9, 9]) diff --git a/tensorflow/python/ops/special_math_ops_test.py b/tensorflow/python/ops/special_math_ops_test.py index 623f5063c7d..ba184b222ca 100644 --- a/tensorflow/python/ops/special_math_ops_test.py +++ b/tensorflow/python/ops/special_math_ops_test.py @@ -635,6 +635,8 @@ class BesselTest(test.TestCase, parameterized.TestCase): @test_util.run_all_in_graph_and_eager_modes +@test_util.run_all_without_tensor_float_32( + 'Tests einsum, which sometimes does a matmul with cuBLAS') class EinsumTest(test.TestCase): def _check(self, s, *input_shapes, **kwargs): diff --git a/tensorflow/python/profiler/internal/BUILD b/tensorflow/python/profiler/internal/BUILD index 98e93367c8f..5eeaf96448d 100644 --- a/tensorflow/python/profiler/internal/BUILD +++ b/tensorflow/python/profiler/internal/BUILD @@ -142,7 +142,10 @@ cc_library( visibility = ["//visibility:public"], deps = [ "//tensorflow/core:lib", - "//tensorflow/core/profiler/lib:traceme", + "//tensorflow/core/profiler/protobuf:xplane_proto_cc", + "//tensorflow/core/profiler/utils:xplane_builder", + "//tensorflow/core/profiler/utils:xplane_schema", + "//tensorflow/core/profiler/utils:xplane_utils", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", "@pybind11", diff --git a/tensorflow/python/profiler/internal/profiler_wrapper.cc b/tensorflow/python/profiler/internal/profiler_wrapper.cc index 401201018d9..7ddce914010 100644 --- a/tensorflow/python/profiler/internal/profiler_wrapper.cc +++ b/tensorflow/python/profiler/internal/profiler_wrapper.cc @@ -16,6 +16,7 @@ limitations under the License. #include #include "absl/memory/memory.h" +#include "absl/strings/match.h" #include "absl/strings/numbers.h" #include "pybind11/pybind11.h" #include "pybind11/pytypes.h" @@ -50,7 +51,7 @@ tensorflow::Status ValidateHostPortPair(const std::string& host_port) { // Must be host:port, port must be a number, host must not contain a '/', // host also must not be empty. if (parts.size() != 2 || !absl::SimpleAtoi(parts[1], &port) || - parts[0].find("/") != std::string::npos || parts[0].empty()) { + absl::StrContains(parts[0], "/") || parts[0].empty()) { return tensorflow::errors::InvalidArgument( "Could not interpret \"", host_port, "\" as a host-port pair."); } @@ -123,7 +124,8 @@ PYBIND11_MODULE(_pywrap_profiler, m) { .def("export_to_tb", &ProfilerSessionWrapper::ExportToTensorBoard); m.def("start_server", [](int port) { - auto profiler_server = absl::make_unique(); + auto profiler_server = + absl::make_unique(); profiler_server->StartProfilerServer(port); // Intentionally release profiler server. Should transfer ownership to // caller instead. diff --git a/tensorflow/python/profiler/internal/python_hooks.cc b/tensorflow/python/profiler/internal/python_hooks.cc index ee2ad1e254b..aa59305df6e 100644 --- a/tensorflow/python/profiler/internal/python_hooks.cc +++ b/tensorflow/python/profiler/internal/python_hooks.cc @@ -16,13 +16,19 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/strings/strip.h" +#include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/path.h" +#include "tensorflow/core/profiler/utils/xplane_builder.h" +#include "tensorflow/core/profiler/utils/xplane_schema.h" +#include "tensorflow/core/profiler/utils/xplane_utils.h" namespace tensorflow { namespace profiler { namespace py = ::pybind11; +namespace { + template int ProfileFunction(PyObject* obj, PyFrameObject* frame, int what, PyObject* arg) { @@ -40,40 +46,153 @@ void ThreadingSetProfile(const py::object& callback) { setprofile(callback); } +std::string GetEventName(PyCodeObject* py_code) { + string filename(py::reinterpret_borrow(py_code->co_filename)); + string function; + if (py_code->co_name == nullptr) { + function = ""; + } else { + function = py::reinterpret_borrow(py_code->co_name); + } + + return absl::StrCat("$", io::Basename(filename), ":", py_code->co_firstlineno, + " ", function); +} + +string GetEventName(PyCFunctionObject* py_cfunc) { + PyObject* module = py_cfunc->m_module; + string filename; + bool filename_ok; +#if PY_MAJOR_VERSION < 3 + filename_ok = (module != nullptr && PyString_Check(module)); +#else + filename_ok = (module != nullptr && PyUnicode_Check(module)); +#endif + if (filename_ok) { + filename = py::reinterpret_borrow(module); + } else { + filename = ""; + } + + return absl::StrCat("$", filename, " ", py_cfunc->m_ml->ml_name); +} + +void AddEventToXLine(const PythonTraceEntry& event, XLineBuilder* line, + XPlaneBuilder* plane) { + // TODO(jiesun): maybe add full filename as event stats. + auto xevent = line->AddEvent(*plane->GetOrCreateEventMetadata(event.Name())); + xevent.SetTimestampNs(event.start_time_ns); + xevent.SetEndTimestampNs(event.end_time_ns); +} + +} // namespace + +std::string PythonTraceEntry::Name() const { + std::string event_name; + if (code_object) { + return GetEventName(code_object); + } else if (function_object) { + return GetEventName(function_object); + } + return ""; +} + PythonHooks* PythonHooks::GetSingleton() { static PythonHooks* singleton = new PythonHooks; return singleton; } -void PythonHooks::Start(const PythonHooksOptions& option) { +void PythonHooks::Start(const PythonHooksOptions& options) { if (!Py_IsInitialized()) return; - if (option.enable_python_traceme || option.enable_trace_python_function) { + options_ = options; + start_timestamp_ns_ = EnvTime::NowNanos(); + if (options_.enable_python_traceme || options_.enable_trace_python_function) { PyGILState_STATE gil_state = PyGILState_Ensure(); - if (option.enable_trace_python_function) { + if (options_.enable_trace_python_function) { SetProfilerInAllThreads(); } - if (option.enable_python_traceme) { + if (options_.enable_python_traceme) { EnableTraceMe(true); } + if (options_.end_to_end_mode) { + // When end to end mode is used, Stop() and Finalize() i.e. symbolization + // and data collection happens during C's atexit(), when Py_FinalizeEx() + // already called. + try { + auto atexit = py::module::import("atexit"); + atexit.attr("register")(py::cpp_function([]() { + PythonHooks* singleton = PythonHooks::GetSingleton(); + singleton->Stop(); + singleton->CollectData(&(singleton->end_to_end_xplane_.emplace())); + })); + } catch (const py::error_already_set& e) { + LOG(ERROR) << "Can't install atexit handler for e2e mode." << e.what(); + } + } PyGILState_Release(gil_state); + active_session_ = true; } } -void PythonHooks::Stop(const PythonHooksOptions& option) { +void PythonHooks::Stop() { if (!Py_IsInitialized()) return; - if (option.enable_python_traceme || option.enable_trace_python_function) { + if (!active_session_) return; // Makes sure Stop() can be reentrant. + if (options_.enable_python_traceme || options_.enable_trace_python_function) { PyGILState_STATE gil_state = PyGILState_Ensure(); - if (option.enable_trace_python_function) { + if (options_.enable_trace_python_function) { ClearProfilerInAllThreads(); } - if (option.enable_python_traceme) { + if (options_.enable_python_traceme) { EnableTraceMe(false); } PyGILState_Release(gil_state); + active_session_ = false; } } -void PythonHooks::Finalize() { tracemes_.clear(); } +void PythonHooks::CollectData(XPlane* raw_plane) { + DCHECK(raw_plane); + XPlaneBuilder plane(raw_plane); + for (auto& it : entries_) { + uint64 thread_id = it.first; + auto& thread_events = it.second; + VLOG(1) << "Collecting " << thread_events.completed.size() << ":" + << thread_events.active.size() << " events on thread " << thread_id; + auto line = plane.GetOrCreateLine(thread_id); + line.SetTimestampNs(start_timestamp_ns_); + for (const auto& event : thread_events.completed) { + AddEventToXLine(event, &line, &plane); + } + if (options_.include_incomplete_events) { + uint64 now = EnvTime::NowNanos(); + while (!thread_events.active.empty()) { + auto& event = thread_events.active.top(); + event.end_time_ns = now; + AddEventToXLine(event, &line, &plane); + thread_events.active.pop(); + } + } + } + entries_.clear(); +} + +void PythonHooks::Finalize(XSpace* space) { + if (space) { + XPlane* plane = + FindOrAddMutablePlaneWithName(space, kPythonTracerPlaneName); + if (options_.end_to_end_mode) { + if (end_to_end_xplane_) { + end_to_end_xplane_->set_name(plane->name()); + plane->Swap(&*end_to_end_xplane_); + end_to_end_xplane_.reset(); + } + } else { + PyGILState_STATE gil_state = PyGILState_Ensure(); + CollectData(plane); + PyGILState_Release(gil_state); + } + } +} void PythonHooks::ProfileSlow(const py::object& frame, const string& event, const py::object& arg) { @@ -106,52 +225,58 @@ void PythonHooks::ProfileSlow(const py::object& frame, const string& event, } void PythonHooks::ProfileFast(PyFrameObject* frame, int what, PyObject* arg) { - const int64 thread_id = PyThread_get_thread_ident(); + const int64 thread_id = Env::Default()->GetCurrentThreadId(); + uint64 now = EnvTime::NowNanos(); + auto& thread_traces = entries_[thread_id]; - if (what == PyTrace_CALL) { - PyCodeObject* f_code = frame->f_code; - string filename(py::reinterpret_borrow(f_code->co_filename)); - int line_no = frame->f_lineno; - - string function; - if (f_code->co_name == nullptr) { - function = ""; - } else { - function = py::reinterpret_borrow(f_code->co_name); + switch (what) { + case PyTrace_CALL: { + PyCodeObject* f_code = frame->f_code; + thread_traces.active.emplace(now, 0, f_code, nullptr); + break; } - - tracemes_[thread_id].push_back( - absl::make_unique([&filename, line_no, &function] { - return absl::StrCat("$", io::Basename(filename), ":", line_no, " ", - function); - })); - } else if (what == PyTrace_C_CALL && PyCFunction_Check(arg)) { - // Python stack does not have a filename/line_no for native calls. - auto* func = reinterpret_cast(arg); - PyObject* module = func->m_module; - string filename; - bool filename_ok; -#if PY_MAJOR_VERSION < 3 - filename_ok = (module != nullptr && PyString_Check(module)); -#else - filename_ok = (module != nullptr && PyUnicode_Check(module)); -#endif - if (filename_ok) { - filename = py::reinterpret_borrow(module); - } else { - filename = ""; + case PyTrace_RETURN: + case PyTrace_EXCEPTION: { + if (!thread_traces.active.empty()) { + auto& entry = thread_traces.active.top(); + entry.end_time_ns = now; + thread_traces.completed.emplace_back(std::move(entry)); + thread_traces.active.pop(); + } else if (options_.include_incomplete_events) { + PyCodeObject* f_code = frame->f_code; + thread_traces.completed.emplace_back(start_timestamp_ns_, now, f_code, + nullptr); + } + break; } - - tracemes_[thread_id].push_back( - absl::make_unique([&filename, func] { - return absl::StrCat(filename, " ", func->m_ml->ml_name); - })); - } else if (what == PyTrace_RETURN || what == PyTrace_C_RETURN || - what == PyTrace_EXCEPTION || what == PyTrace_C_EXCEPTION) { - auto& thread_tracemes = tracemes_[thread_id]; - if (!thread_tracemes.empty()) { - thread_tracemes.pop_back(); + case PyTrace_C_CALL: { + if (PyCFunction_Check(arg)) { + // Python stack does not have a filename/line_no for native calls. + auto* func = reinterpret_cast(arg); + entries_[thread_id].active.emplace(now, 0, nullptr, func); + } + break; } + case PyTrace_C_RETURN: + case PyTrace_C_EXCEPTION: { + if (!thread_traces.active.empty()) { + auto& entry = thread_traces.active.top(); + entry.end_time_ns = now; + thread_traces.completed.emplace_back(std::move(entry)); + thread_traces.active.pop(); + } else if (options_.include_incomplete_events) { + // Only the end of the events is recorded, use profiler start as start. + if (PyCFunction_Check(arg)) { + // Python stack does not have a filename/line_no for native calls. + auto* func = reinterpret_cast(arg); + entries_[thread_id].completed.emplace_back(start_timestamp_ns_, now, + nullptr, func); + } + } + break; + } + default: + break; } } diff --git a/tensorflow/python/profiler/internal/python_hooks.h b/tensorflow/python/profiler/internal/python_hooks.h index 582edf4a93b..b30fcc391f4 100644 --- a/tensorflow/python/profiler/internal/python_hooks.h +++ b/tensorflow/python/profiler/internal/python_hooks.h @@ -16,14 +16,16 @@ limitations under the License. #define TENSORFLOW_PYTHON_PROFILER_INTERNAL_PYTHON_HOOKS_H_ #include +#include #include #include "absl/container/flat_hash_map.h" #include "pybind11/cast.h" #include "pybind11/pybind11.h" #include "pybind11/pytypes.h" +#include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" -#include "tensorflow/core/profiler/lib/traceme.h" +#include "tensorflow/core/profiler/protobuf/xplane.pb.h" namespace tensorflow { namespace profiler { @@ -33,6 +35,52 @@ namespace py = ::pybind11; struct PythonHooksOptions { bool enable_trace_python_function = false; bool enable_python_traceme = true; + bool end_to_end_mode = false; + // Incomplete events are defined as those python calls which we only see + // either start or end, but not both. If we want to include them in the final + // result, profiler start, end time are used respectively to the absent + // timestamps. + bool include_incomplete_events = true; +}; + +struct PythonTraceEntry { + PythonTraceEntry(uint64 start, uint64 end, PyCodeObject* code, + PyCFunctionObject* func) + : start_time_ns(start), + end_time_ns(end), + code_object(code), + function_object(func) { + Py_XINCREF(code_object); + Py_XINCREF(function_object); + } + ~PythonTraceEntry() { + Py_XDECREF(code_object); + Py_XDECREF(function_object); + } + PythonTraceEntry(PythonTraceEntry&& other) { + start_time_ns = other.start_time_ns; + end_time_ns = other.end_time_ns; + code_object = other.code_object; + function_object = other.function_object; + other.code_object = nullptr; + other.function_object = nullptr; + } + + std::string Name() const; + + uint64 start_time_ns; + uint64 end_time_ns; + PyCodeObject* code_object; + PyCFunctionObject* function_object; + + PythonTraceEntry(const PythonTraceEntry& other) = delete; + void operator=(const PythonTraceEntry&) = delete; + void operator=(PythonTraceEntry&&) = delete; +}; + +struct PerThreadEvents { + std::deque completed; + std::stack active; }; // Singleton for tracing python function calls. @@ -41,19 +89,27 @@ class PythonHooks { static PythonHooks* GetSingleton(); void Start(const PythonHooksOptions& option); - void Stop(const PythonHooksOptions& option); - void Finalize(); + void Stop(); + void Finalize(XSpace* space); void ProfileSlow(const py::object& frame, const string& event, const py::object& arg); void ProfileFast(PyFrameObject* frame, int what, PyObject* arg); private: void EnableTraceMe(bool enable); + void CollectData(XPlane* raw_plane); void SetProfilerInAllThreads(); void ClearProfilerInAllThreads(); - absl::flat_hash_map>> tracemes_; + // entries_ are accessed when GIL is held, therefore no race conditions. + absl::flat_hash_map entries_; + uint64 start_timestamp_ns_; + bool active_session_ = false; + PythonHooksOptions options_; + // In end to end mode, Python get uninitialized before Stop()/Finalize(), we + // need to buffer the result. + absl::optional end_to_end_xplane_; }; } // namespace profiler diff --git a/tensorflow/python/tfe_wrapper.cc b/tensorflow/python/tfe_wrapper.cc index 6060029a32a..8a843ebe07a 100644 --- a/tensorflow/python/tfe_wrapper.cc +++ b/tensorflow/python/tfe_wrapper.cc @@ -893,6 +893,14 @@ PYBIND11_MODULE(_pywrap_tfe, m) { TF_SetStatus(status.get(), static_cast(code), message); TFE_AbortCollectiveOps(tensorflow::InputTFE_Context(ctx), status.get()); }); + m.def("TFE_CollectiveOpsCheckPeerHealth", + [](const py::handle& ctx, const char* task) { + tensorflow::Safe_TF_StatusPtr status = + tensorflow::make_safe(TF_NewStatus()); + TFE_CollectiveOpsCheckPeerHealth(tensorflow::InputTFE_Context(ctx), + task, status.get()); + tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); + }); m.def("TF_ListPhysicalDevices", &tensorflow::TF_ListPhysicalDevices); m.def("TF_GetDeviceDetails", &tensorflow::TF_GetDeviceDetails); m.def("TF_DeleteDeviceList", &TF_DeleteDeviceList, diff --git a/tensorflow/python/tpu/tpu_outside_compilation_test.py b/tensorflow/python/tpu/tpu_outside_compilation_test.py index 7e0278aa343..421be2596e5 100644 --- a/tensorflow/python/tpu/tpu_outside_compilation_test.py +++ b/tensorflow/python/tpu/tpu_outside_compilation_test.py @@ -90,6 +90,10 @@ def _events_from_logdir(test_case, logdir): class TpuOutsideCompilationTest(test.TestCase, parameterized.TestCase): + def setUp(self): + super(TpuOutsideCompilationTest, self).setUp() + config.set_soft_device_placement(False) + def testResourceVariableAssignOnHost(self): strategy = get_tpu_strategy() with strategy.scope(): diff --git a/tensorflow/python/types/BUILD b/tensorflow/python/types/BUILD index 5f3f4fd0e31..d48f066d294 100644 --- a/tensorflow/python/types/BUILD +++ b/tensorflow/python/types/BUILD @@ -35,6 +35,7 @@ py_strict_library( ":doc_typealias", "//tensorflow/python:tf_export", "//third_party/py/numpy", + "@typing_extensions_archive//:typing_extensions", ], ) diff --git a/tensorflow/python/types/core.py b/tensorflow/python/types/core.py index bec5aecaba0..b4506594a82 100644 --- a/tensorflow/python/types/core.py +++ b/tensorflow/python/types/core.py @@ -18,14 +18,21 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import sys import textwrap from typing import Union + import numpy as np from tensorflow.python.types import doc_typealias from tensorflow.python.util.tf_export import tf_export +if sys.version_info >= (3, 8): + from typing import Protocol # pylint:disable=g-import-not-at-top +else: + from typing_extensions import Protocol # pylint:disable=g-import-not-at-top + # TODO(mdan): Consider adding ABC once the dependence on isinstance is reduced. # TODO(mdan): Add type annotations. @@ -67,9 +74,24 @@ class Value(Tensor): pass +class TensorProtocol(Protocol): + """Protocol type for objects that can be converted to Tensor.""" + + def __tf_tensor__(self, dtype=None, name=None): + """Converts this object to a Tensor. + + Args: + dtype: data type for the returned Tensor + name: a name for the operations which create the Tensor + Returns: + A Tensor. + """ + pass + + # TODO(rahulkamat): Add missing types that are convertible to Tensor. -TensorLike = Union[Tensor, int, float, bool, str, complex, tuple, list, - np.ndarray] +TensorLike = Union[Tensor, TensorProtocol, int, float, bool, str, complex, + tuple, list, np.ndarray] doc_typealias.document( obj=TensorLike, doc=textwrap.dedent("""\ diff --git a/tensorflow/python/util/util.cc b/tensorflow/python/util/util.cc index 41b02a3dd4e..a34141716c6 100644 --- a/tensorflow/python/util/util.cc +++ b/tensorflow/python/util/util.cc @@ -24,7 +24,7 @@ limitations under the License. #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/mutex.h" -#include "tensorflow/python/lib/core/safe_ptr.h" +#include "tensorflow/python/lib/core/safe_pyobject_ptr.h" namespace tensorflow { namespace swig { diff --git a/tensorflow/stream_executor/tpu/tpu_executable_interface.cc b/tensorflow/stream_executor/tpu/tpu_executable_interface.cc index f260cc1631f..90ea2dc5914 100644 --- a/tensorflow/stream_executor/tpu/tpu_executable_interface.cc +++ b/tensorflow/stream_executor/tpu/tpu_executable_interface.cc @@ -194,8 +194,9 @@ StatusOr TpuExecutableInterface::ExecuteAsyncOnStream( // Address of the buffer in TPU memory that is being speculated. absl::optional cross_program_prefetch_addr; if (hlo_module_) { - for (const auto& [parameter, index] : - hlo_module_->CrossProgramPrefetches()) { + for (const auto& prefetch : hlo_module_->CrossProgramPrefetches()) { + const auto& parameter = prefetch.first; + const auto& index = prefetch.second; CHECK_LT(parameter, arguments.size()); // Ensure the cross program prefetched buffer doesn't alias with any // program outputs. If the input and output aliased, the buffer could be diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl index b6b966258fb..b0456888a33 100644 --- a/tensorflow/tensorflow.bzl +++ b/tensorflow/tensorflow.bzl @@ -352,12 +352,6 @@ def tf_copts( def tf_openmp_copts(): return (if_mkl_lnx_x64(["-fopenmp"]) + if_mkldnn_threadpool(["-fno-openmp"])) -def tfe_xla_copts(): - return select({ - "//tensorflow:with_xla_support": ["-DTENSORFLOW_EAGER_USE_XLA"], - "//conditions:default": [], - }) - def tf_opts_nortti(): return [ "-fno-rtti", diff --git a/tensorflow/tools/api/golden/v1/tensorflow.config.experimental.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.config.experimental.pbtxt index 7397719e656..7a25de0b5a6 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.config.experimental.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.config.experimental.pbtxt @@ -24,6 +24,10 @@ tf_module { name: "enable_mlir_graph_optimization" argspec: "args=[], varargs=None, keywords=None, defaults=None" } + member_method { + name: "enable_tensor_float_32_execution" + argspec: "args=[\'enabled\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "get_device_details" argspec: "args=[\'device\'], varargs=None, keywords=None, defaults=None" @@ -76,4 +80,8 @@ tf_module { name: "set_visible_devices" argspec: "args=[\'devices\', \'device_type\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "tensor_float_32_execution_enabled" + argspec: "args=[], varargs=None, keywords=None, defaults=None" + } } diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.-model.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.-model.pbtxt index a9f6f069560..da08722a7a3 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.-model.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.-model.pbtxt @@ -174,7 +174,7 @@ tf_class { } member_method { name: "compile" - argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'weighted_metrics\', \'run_eagerly\'], varargs=None, keywords=kwargs, defaults=[\'rmsprop\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'weighted_metrics\', \'run_eagerly\', \'steps_per_execution\'], varargs=None, keywords=kwargs, defaults=[\'rmsprop\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " } member_method { name: "compute_mask" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.-sequential.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.-sequential.pbtxt index 168539be647..1719c8bd9c7 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.-sequential.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.-sequential.pbtxt @@ -180,7 +180,7 @@ tf_class { } member_method { name: "compile" - argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'weighted_metrics\', \'run_eagerly\'], varargs=None, keywords=kwargs, defaults=[\'rmsprop\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'weighted_metrics\', \'run_eagerly\', \'steps_per_execution\'], varargs=None, keywords=kwargs, defaults=[\'rmsprop\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " } member_method { name: "compute_mask" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-linear-model.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-linear-model.pbtxt index 2aff054a51d..d93c018b073 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-linear-model.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-linear-model.pbtxt @@ -175,7 +175,7 @@ tf_class { } member_method { name: "compile" - argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'weighted_metrics\', \'run_eagerly\'], varargs=None, keywords=kwargs, defaults=[\'rmsprop\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'weighted_metrics\', \'run_eagerly\', \'steps_per_execution\'], varargs=None, keywords=kwargs, defaults=[\'rmsprop\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " } member_method { name: "compute_mask" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-wide-deep-model.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-wide-deep-model.pbtxt index ed49246e458..9fba915d01a 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-wide-deep-model.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-wide-deep-model.pbtxt @@ -175,7 +175,7 @@ tf_class { } member_method { name: "compile" - argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'weighted_metrics\', \'run_eagerly\'], varargs=None, keywords=kwargs, defaults=[\'rmsprop\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'weighted_metrics\', \'run_eagerly\', \'steps_per_execution\'], varargs=None, keywords=kwargs, defaults=[\'rmsprop\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " } member_method { name: "compute_mask" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-model.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-model.pbtxt index 4368742d7bb..15c0ab5abbb 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-model.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-model.pbtxt @@ -174,7 +174,7 @@ tf_class { } member_method { name: "compile" - argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'weighted_metrics\', \'run_eagerly\'], varargs=None, keywords=kwargs, defaults=[\'rmsprop\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'weighted_metrics\', \'run_eagerly\', \'steps_per_execution\'], varargs=None, keywords=kwargs, defaults=[\'rmsprop\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " } member_method { name: "compute_mask" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-sequential.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-sequential.pbtxt index 8e9409f27a9..729fdd660ca 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-sequential.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-sequential.pbtxt @@ -180,7 +180,7 @@ tf_class { } member_method { name: "compile" - argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'weighted_metrics\', \'run_eagerly\'], varargs=None, keywords=kwargs, defaults=[\'rmsprop\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'weighted_metrics\', \'run_eagerly\', \'steps_per_execution\'], varargs=None, keywords=kwargs, defaults=[\'rmsprop\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " } member_method { name: "compute_mask" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.config.experimental.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.config.experimental.pbtxt index 7397719e656..7a25de0b5a6 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.config.experimental.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.config.experimental.pbtxt @@ -24,6 +24,10 @@ tf_module { name: "enable_mlir_graph_optimization" argspec: "args=[], varargs=None, keywords=None, defaults=None" } + member_method { + name: "enable_tensor_float_32_execution" + argspec: "args=[\'enabled\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "get_device_details" argspec: "args=[\'device\'], varargs=None, keywords=None, defaults=None" @@ -76,4 +80,8 @@ tf_module { name: "set_visible_devices" argspec: "args=[\'devices\', \'device_type\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "tensor_float_32_execution_enabled" + argspec: "args=[], varargs=None, keywords=None, defaults=None" + } } diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.-model.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.-model.pbtxt index a9f6f069560..da08722a7a3 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.-model.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.-model.pbtxt @@ -174,7 +174,7 @@ tf_class { } member_method { name: "compile" - argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'weighted_metrics\', \'run_eagerly\'], varargs=None, keywords=kwargs, defaults=[\'rmsprop\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'weighted_metrics\', \'run_eagerly\', \'steps_per_execution\'], varargs=None, keywords=kwargs, defaults=[\'rmsprop\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " } member_method { name: "compute_mask" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.-sequential.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.-sequential.pbtxt index 168539be647..1719c8bd9c7 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.-sequential.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.-sequential.pbtxt @@ -180,7 +180,7 @@ tf_class { } member_method { name: "compile" - argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'weighted_metrics\', \'run_eagerly\'], varargs=None, keywords=kwargs, defaults=[\'rmsprop\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'weighted_metrics\', \'run_eagerly\', \'steps_per_execution\'], varargs=None, keywords=kwargs, defaults=[\'rmsprop\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " } member_method { name: "compute_mask" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-linear-model.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-linear-model.pbtxt index 2aff054a51d..d93c018b073 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-linear-model.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-linear-model.pbtxt @@ -175,7 +175,7 @@ tf_class { } member_method { name: "compile" - argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'weighted_metrics\', \'run_eagerly\'], varargs=None, keywords=kwargs, defaults=[\'rmsprop\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'weighted_metrics\', \'run_eagerly\', \'steps_per_execution\'], varargs=None, keywords=kwargs, defaults=[\'rmsprop\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " } member_method { name: "compute_mask" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-wide-deep-model.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-wide-deep-model.pbtxt index ed49246e458..9fba915d01a 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-wide-deep-model.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-wide-deep-model.pbtxt @@ -175,7 +175,7 @@ tf_class { } member_method { name: "compile" - argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'weighted_metrics\', \'run_eagerly\'], varargs=None, keywords=kwargs, defaults=[\'rmsprop\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'weighted_metrics\', \'run_eagerly\', \'steps_per_execution\'], varargs=None, keywords=kwargs, defaults=[\'rmsprop\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " } member_method { name: "compute_mask" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-model.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-model.pbtxt index 4368742d7bb..15c0ab5abbb 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-model.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-model.pbtxt @@ -174,7 +174,7 @@ tf_class { } member_method { name: "compile" - argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'weighted_metrics\', \'run_eagerly\'], varargs=None, keywords=kwargs, defaults=[\'rmsprop\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'weighted_metrics\', \'run_eagerly\', \'steps_per_execution\'], varargs=None, keywords=kwargs, defaults=[\'rmsprop\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " } member_method { name: "compute_mask" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-sequential.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-sequential.pbtxt index 8e9409f27a9..729fdd660ca 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-sequential.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-sequential.pbtxt @@ -180,7 +180,7 @@ tf_class { } member_method { name: "compile" - argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'weighted_metrics\', \'run_eagerly\'], varargs=None, keywords=kwargs, defaults=[\'rmsprop\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'weighted_metrics\', \'run_eagerly\', \'steps_per_execution\'], varargs=None, keywords=kwargs, defaults=[\'rmsprop\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " } member_method { name: "compute_mask" diff --git a/tensorflow/tools/ci_build/builds/docker_test.sh b/tensorflow/tools/ci_build/builds/docker_test.sh index b2d1dbae433..eee0a9103ff 100755 --- a/tensorflow/tools/ci_build/builds/docker_test.sh +++ b/tensorflow/tools/ci_build/builds/docker_test.sh @@ -122,7 +122,7 @@ ${GPU_EXTRA_PARAMS} ${ROCM_EXTRA_PARAMS} \ "${DOCKER_IMG_TAG}" \ /bin/bash -c "tensorflow/tools/ci_build/builds/run_pip_tests.sh && "\ "tensorflow/tools/ci_build/builds/test_tutorials.sh && "\ -"tensorflow/tools/ci_bukld/builds/integration_tests.sh" +"tensorflow/tools/ci_build/builds/integration_tests.sh" RESULT=$? diff --git a/tensorflow/tools/ci_build/release/common.sh b/tensorflow/tools/ci_build/release/common.sh index a22556a7d86..2c7a425e049 100644 --- a/tensorflow/tools/ci_build/release/common.sh +++ b/tensorflow/tools/ci_build/release/common.sh @@ -222,6 +222,7 @@ function install_macos_pip_deps { ${SUDO_CMD} ${PIP_CMD} install numpy==1.16.0 ${SUDO_CMD} ${PIP_CMD} install gast==0.3.3 ${SUDO_CMD} ${PIP_CMD} install h5py==2.10.0 + ${SUDO_CMD} ${PIP_CMD} install typing_extensions ${SUDO_CMD} ${PIP_CMD} install --upgrade grpcio ${SUDO_CMD} ${PIP_CMD} install --upgrade tb-nightly ${PIP_CMD} install --user --upgrade flatbuffers diff --git a/tensorflow/tools/ci_build/release/common_win.bat b/tensorflow/tools/ci_build/release/common_win.bat index 23dc09a8d59..267ea67e177 100644 --- a/tensorflow/tools/ci_build/release/common_win.bat +++ b/tensorflow/tools/ci_build/release/common_win.bat @@ -57,6 +57,7 @@ IF "%PYTHON_DIRECTORY%"=="Python37" ( @REM handle this case. %PIP_EXE% install gast==0.3.3 %PIP_EXE% install astunparse==1.6.3 +%PIP_EXE% install typing_extensions :: Set cuda related environment variables. If we are not using CUDA, these are not used. IF NOT DEFINED TF_CUDA_VERSION ( diff --git a/tensorflow/tools/def_file_filter/symbols_pybind.txt b/tensorflow/tools/def_file_filter/symbols_pybind.txt index 93305a0707c..ed4dadad5e1 100644 --- a/tensorflow/tools/def_file_filter/symbols_pybind.txt +++ b/tensorflow/tools/def_file_filter/symbols_pybind.txt @@ -337,8 +337,8 @@ tensorflow::ProfilerSession::Status tensorflow::ProfilerSession::~ProfilerSession [profiler_server_impl] # profiler -tensorflow::ProfilerServer::StartProfilerServer -tensorflow::ProfilerServer::~ProfilerServer +tensorflow::profiler::ProfilerServer::StartProfilerServer +tensorflow::profiler::ProfilerServer::~ProfilerServer [profiler_client_impl] # profiler tensorflow::profiler::ProfileGrpc diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD index 74585cbb11d..71b17271226 100644 --- a/tensorflow/tools/pip_package/BUILD +++ b/tensorflow/tools/pip_package/BUILD @@ -218,6 +218,7 @@ filegroup( "@sobol_data//:LICENSE", "@tblib_archive//:LICENSE", "@termcolor_archive//:COPYING.txt", + "@typing_extensions_archive//:LICENSE", "@zlib//:zlib.h", "@clog//:LICENSE", "@cpuinfo//:LICENSE", diff --git a/tensorflow/tools/pip_package/pip_smoke_test.py b/tensorflow/tools/pip_package/pip_smoke_test.py index d2002b58598..9916b150c85 100644 --- a/tensorflow/tools/pip_package/pip_smoke_test.py +++ b/tensorflow/tools/pip_package/pip_smoke_test.py @@ -81,7 +81,7 @@ DEPENDENCY_DENYLIST = [ "//tensorflow/python:tf_optimizer", "//tensorflow/python:compare_test_proto_py", "//tensorflow/core:image_testdata", - "//tensorflow/core:lmdb_testdata", + "//tensorflow/core/lib/lmdb:lmdb_testdata", "//tensorflow/core/kernels/cloud:bigquery_reader_ops", "//tensorflow/python/debug:grpc_tensorflow_server.par", "//tensorflow/python/feature_column:vocabulary_testdata", diff --git a/tensorflow/tools/pip_package/setup.py b/tensorflow/tools/pip_package/setup.py index 5917b0fca7f..4d72ecaae04 100644 --- a/tensorflow/tools/pip_package/setup.py +++ b/tensorflow/tools/pip_package/setup.py @@ -260,6 +260,7 @@ setup( version=_VERSION.replace('-', ''), description=DOCLINES[0], long_description='\n'.join(DOCLINES[2:]), + long_description_content_type="text/markdown", url='https://www.tensorflow.org/', download_url='https://github.com/tensorflow/tensorflow/tags', author='Google Inc.', diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index 2bf8a08d727..e95866e9f09 100755 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -518,14 +518,25 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""): name = "typing_extensions_archive", build_file = clean_dep("//third_party:typing_extensions.BUILD"), sha256 = "79ee589a3caca649a9bfd2a8de4709837400dfa00b6cc81962a1e6a1815969ae", - strip_prefix = "typing_extensions-3.7.4.2", - system_build_file = clean_dep("//third_party/systemlibs:six.BUILD"), + strip_prefix = "typing_extensions-3.7.4.2/src_py3", + system_build_file = clean_dep("//third_party/systemlibs:typing_extensions.BUILD"), urls = [ "http://mirror.tensorflow.org/files.pythonhosted.org/packages/6a/28/d32852f2af6b5ead85d396249d5bdf450833f3a69896d76eb480d9c5e406/typing_extensions-3.7.4.2.tar.gz", "https://files.pythonhosted.org/packages/6a/28/d32852f2af6b5ead85d396249d5bdf450833f3a69896d76eb480d9c5e406/typing_extensions-3.7.4.2.tar.gz", ], ) + filegroup_external( + name = "typing_extensions_license", + licenses = ["notice"], # PSFL + sha256_urls = { + "ff17ce94e102024deb68773eb1cc74ca76da4e658f373531f0ac22d68a6bb1ad": [ + "http://mirror.tensorflow.org/raw.githubusercontent.com/python/typing/master/typing_extensions/LICENSE", + "https://raw.githubusercontent.com/python/typing/master/typing_extensions/LICENSE", + ], + }, + ) + tf_http_archive( name = "opt_einsum_archive", build_file = clean_dep("//third_party:opt_einsum.BUILD"), @@ -711,8 +722,8 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""): ) # Check out LLVM and MLIR from llvm-project. - LLVM_COMMIT = "9500a7209163d6bfce7d554cb71a7c189ca5de58" - LLVM_SHA256 = "6f78428f16e540ea1a88cc04ccdffdb33613c1672337f4cc7154e67ea5bc61dc" + LLVM_COMMIT = "90166c25631053eb4eaf5084358563ea268bb482" + LLVM_SHA256 = "1d37ee43092e2f1eea6cad9871a5ccab1ff0092a6080412a5a23e21618a533e8" LLVM_URLS = [ "https://storage.googleapis.com/mirror.tensorflow.org/github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT), "https://github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT), diff --git a/third_party/gpus/find_cuda_config.py b/third_party/gpus/find_cuda_config.py index 091cd32d5fe..80f343023cd 100644 --- a/third_party/gpus/find_cuda_config.py +++ b/third_party/gpus/find_cuda_config.py @@ -176,6 +176,7 @@ def _header_paths(): "include/*-linux-gnu", "extras/CUPTI/include", "include/cuda/CUPTI", + "local/cuda/extras/CUPTI/include", ] @@ -188,6 +189,8 @@ def _library_paths(): "lib/*-linux-gnu", "lib/x64", "extras/CUPTI/*", + "local/cuda/lib64", + "local/cuda/extras/CUPTI/lib64", ] @@ -268,12 +271,14 @@ def _find_cuda_config(base_paths, required_version): nvcc_path, nvcc_version = _find_versioned_file(base_paths, [ "", "bin", + "local/cuda/bin", ], nvcc_name, cuda_version, get_nvcc_version) nvvm_path = _find_file(base_paths, [ "nvvm/libdevice", "share/cuda", "lib/nvidia-cuda-toolkit/libdevice", + "local/cuda/nvvm/libdevice", ], "libdevice*.10.bc") cupti_header_path = _find_file(base_paths, _header_paths(), "cupti.h") diff --git a/third_party/llvm/llvm.autogenerated.BUILD b/third_party/llvm/llvm.autogenerated.BUILD index 92d1535b5ee..a63c38c5821 100644 --- a/third_party/llvm/llvm.autogenerated.BUILD +++ b/third_party/llvm/llvm.autogenerated.BUILD @@ -3324,6 +3324,7 @@ cc_library( ":IPO", ":InstCombine", ":Instrumentation", + ":ObjCARC", ":Scalar", ":Support", ":Target", diff --git a/third_party/llvm/llvm.bzl b/third_party/llvm/llvm.bzl index 48b986bbdc1..dcbaab9edd4 100644 --- a/third_party/llvm/llvm.bzl +++ b/third_party/llvm/llvm.bzl @@ -190,6 +190,7 @@ posix_cmake_vars = { "HAVE_PTHREAD_H": 1, "HAVE_SIGNAL_H": 1, "HAVE_STDINT_H": 1, + "HAVE_SYSEXITS_H": 1, "HAVE_SYS_IOCTL_H": 1, "HAVE_SYS_MMAN_H": 1, "HAVE_SYS_PARAM_H": 1, diff --git a/third_party/mkl_dnn/mkldnn_v1.BUILD b/third_party/mkl_dnn/mkldnn_v1.BUILD index 445b5474065..8c1a5eb942b 100644 --- a/third_party/mkl_dnn/mkldnn_v1.BUILD +++ b/third_party/mkl_dnn/mkldnn_v1.BUILD @@ -75,8 +75,8 @@ cc_library( "src/cpu/**/*.cpp", "src/cpu/**/*.hpp", "src/cpu/xbyak/*.h", - "src/cpu/jit_utils/jitprofiling/*.c", - "src/cpu/jit_utils/jitprofiling/*.h", + "src/cpu/x64/jit_utils/jitprofiling/*.c", + "src/cpu/x64/jit_utils/jitprofiling/*.h", ]) + [ ":dnnl_config_h", ":dnnl_version_h", diff --git a/third_party/mlir/BUILD b/third_party/mlir/BUILD index 94129a29b84..7d1495b5cfb 100644 --- a/third_party/mlir/BUILD +++ b/third_party/mlir/BUILD @@ -595,6 +595,7 @@ cc_library( ":LinalgToLLVM", ":LinalgToSPIRV", ":LinalgToStandard", + ":OpenMPToLLVM", ":SCFToGPUPass", ":SCFToStandard", ":SPIRVToLLVM", @@ -1820,6 +1821,61 @@ gentbl( ], ) +cc_library( + name = "PDLInterpDialect", + srcs = glob([ + "lib/Dialect/PDLInterp/IR/*.cpp", + "lib/Dialect/PDLInterp/IR/*.h", + ]), + hdrs = glob([ + "include/mlir/Dialect/PDLInterp/IR/*.h", + ]), + includes = ["include"], + deps = [ + ":IR", + ":InferTypeOpInterface", + ":PDLDialect", + ":PDLInterpOpsIncGen", + ":SideEffects", + ":Support", + "@llvm-project//llvm:Support", + ], +) + +filegroup( + name = "PDLInterpOpsTdFiles", + srcs = [ + "include/mlir/Dialect/PDL/IR/PDLBase.td", + "include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td", + "include/mlir/Interfaces/SideEffectInterfaces.td", + ":OpBaseTdFiles", + ], +) + +gentbl( + name = "PDLInterpOpsIncGen", + strip_include_prefix = "include", + tbl_outs = [ + ( + "-gen-op-decls", + "include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.h.inc", + ), + ( + "-gen-op-defs", + "include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.cpp.inc", + ), + ( + "-gen-dialect-decls -dialect=pdl_interp", + "include/mlir/Dialect/PDLInterp/IR/PDLInterpOpsDialect.h.inc", + ), + ], + tblgen = ":mlir-tblgen", + td_file = "include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td", + td_srcs = [ + ":PDLInterpOpsTdFiles", + ], +) + # TODO(gcmn): Update SPIRV dependencies so that they map better to cmake files. filegroup( name = "SPIRVOpsTdFiles", @@ -2934,7 +2990,9 @@ cc_library( ":NVVMDialect", ":OpenACCDialect", ":OpenMPDialect", + ":OpenMPToLLVM", ":PDLDialect", + ":PDLInterpDialect", ":QuantOps", ":QuantPassIncGen", ":ROCDLDialect", @@ -3359,6 +3417,30 @@ cc_library( ], ) +cc_library( + name = "OpenMPToLLVM", + srcs = glob([ + "lib/Conversion/OpenMPToLLVM/*.cpp", + "lib/Conversion/OpenMPToLLVM/*.h", + ]) + ["lib/Conversion/PassDetail.h"], + hdrs = glob([ + "include/mlir/Conversion/OpenMPToLLVM/*.h", + ]), + includes = ["include"], + deps = [ + ":ConversionPassIncGen", + ":IR", + ":LLVMDialect", + ":OpenMPDialect", + ":Pass", + ":StandardOps", + ":StandardToLLVM", + ":Transforms", + "@llvm-project//llvm:Core", + "@llvm-project//llvm:Support", + ], +) + ## QuantOps dialect filegroup( name = "QuantizationOpsTdFiles", diff --git a/third_party/typing_extensions.BUILD b/third_party/typing_extensions.BUILD index efd526cd491..f3b6c26e295 100644 --- a/third_party/typing_extensions.BUILD +++ b/third_party/typing_extensions.BUILD @@ -4,11 +4,17 @@ licenses(["notice"]) # PSF -exports_files(["LICENSE"]) - py_library( name = "typing_extensions", - srcs = ["src_py3/typing_extensions.py"], + srcs = ["typing_extensions.py"], srcs_version = "PY2AND3", visibility = ["//visibility:public"], ) + +genrule( + name = "license", + srcs = ["@astunparse_license"], + outs = ["LICENSE"], + cmd = "cp $< $@", + visibility = ["//visibility:public"], +)