diff --git a/README.md b/README.md index 05ddb90fabc..27032043e07 100644 --- a/README.md +++ b/README.md @@ -2,6 +2,10 @@ +[![Python](https://img.shields.io/pypi/pyversions/tensorflow.svg?style=plastic)](https://badge.fury.io/py/tensorflow) +[![PyPI](https://badge.fury.io/py/tensorflow.svg)](https://badge.fury.io/py/tensorflow) + + **`Documentation`** | ------------------- | [![Documentation](https://img.shields.io/badge/api-reference-blue.svg)](https://www.tensorflow.org/api_docs/) | diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index 4326b723f74..d49f679083e 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -223,6 +223,7 @@ tf_cuda_cc_test( ":c_api_test_util", "//tensorflow/c:c_test_util", "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", @@ -371,6 +372,22 @@ tf_cuda_cc_test( ], ) +cc_library( + name = "custom_device_testutil", + testonly = True, + srcs = ["custom_device_testutil.cc"], + hdrs = ["custom_device_testutil.h"], + visibility = ["//tensorflow:internal"], + deps = [ + ":c_api", + ":c_api_experimental", + ":c_api_test_util", + "//tensorflow/c:c_api", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + tf_cc_test( name = "custom_device_test", size = "small", @@ -381,6 +398,7 @@ tf_cc_test( ":c_api", ":c_api_experimental", ":c_api_test_util", + ":custom_device_testutil", "//tensorflow/c:c_api", "//tensorflow/c:c_test_util", "//tensorflow/cc/profiler", @@ -448,6 +466,7 @@ filegroup( ], exclude = [ "c_api_experimental.cc", + "*c_api_tfrt*", "*test*", "*dlpack*", ], diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 51ee82e55aa..b34d1026e08 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -21,6 +21,10 @@ limitations under the License. #include #include +// clang-format off +#include "tensorflow/core/platform/platform.h" +// clang-format on + #include "absl/algorithm/container.h" #include "absl/container/fixed_array.h" #include "absl/memory/memory.h" @@ -31,12 +35,14 @@ limitations under the License. #include "tensorflow/c/eager/operation_interface.h" #include "tensorflow/c/eager/tensor_handle_interface.h" #include "tensorflow/c/tf_tensor_internal.h" +#ifdef PLATFORM_GOOGLE +#include "tensorflow/c/eager/c_api_tfrt.h" +#endif #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/eager/context.h" #include "tensorflow/core/framework/device_attributes.pb.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/platform/errors.h" -#include "tensorflow/core/platform/platform.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/protobuf/device_filters.pb.h" #include "tensorflow/core/protobuf/error_codes.pb.h" @@ -676,6 +682,15 @@ void TFE_ContextOptionsSetDevicePlacementPolicy( void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; } TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) { + if (opts->use_tfrt) { +#ifdef PLATFORM_GOOGLE + status->status = tensorflow::Status::OK(); + return new TFE_Context{new tfrt::ContextInterface()}; +#else + status->status = tensorflow::errors::Unimplemented("TFRT is not supported"); + return nullptr; +#endif + } std::vector> devices; status->status = tensorflow::DeviceFactory::AddDevices( opts->session_options.options, "/job:localhost/replica:0/task:0", @@ -1669,6 +1684,8 @@ class CustomDeviceAPI : public tensorflow::CustomDevice { }; } // namespace +extern "C" { + void TFE_RegisterCustomDevice(TFE_Context* ctx, TFE_CustomDevice device, const char* device_name, void* device_info, TF_Status* status) { @@ -1679,3 +1696,5 @@ void TFE_RegisterCustomDevice(TFE_Context* ctx, TFE_CustomDevice device, status->status = context->RegisterCustomDevice(device_name, std::move(custom_device)); } + +} // extern "C" diff --git a/tensorflow/c/eager/c_api_experimental.h b/tensorflow/c/eager/c_api_experimental.h index 0037f2e81c8..dc1f9eaade3 100644 --- a/tensorflow/c/eager/c_api_experimental.h +++ b/tensorflow/c/eager/c_api_experimental.h @@ -515,9 +515,11 @@ typedef struct TFE_CustomDevice { // This API is highly experimental, and in particular is expected to change when // it starts supporting operations with attributes and when tf.function support // is added. -void TFE_RegisterCustomDevice(TFE_Context* ctx, TFE_CustomDevice device, - const char* device_name, void* device_info, - TF_Status* status); +TF_CAPI_EXPORT extern void TFE_RegisterCustomDevice(TFE_Context* ctx, + TFE_CustomDevice device, + const char* device_name, + void* device_info, + TF_Status* status); TF_CAPI_EXPORT extern void TFE_ContextGetFunctionDef(TFE_Context* ctx, const char* function_name, diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc index 55f1941ce89..e61cf7ef040 100644 --- a/tensorflow/c/eager/c_api_test.cc +++ b/tensorflow/c/eager/c_api_test.cc @@ -19,6 +19,10 @@ limitations under the License. #include +// clang-format off +#include "tensorflow/core/platform/platform.h" +// clang-format on + #include "absl/strings/match.h" #include "tensorflow/c/eager/c_api_experimental.h" #include "tensorflow/c/eager/c_api_internal.h" @@ -584,9 +588,10 @@ TEST(CAPI, TensorHandleDevices) { TFE_DeleteContext(ctx); } -void ExecuteAdd(bool async, bool forward_input) { +void ExecuteAdd(bool async, bool forward_input, bool tfrt) { TF_Status* status = TF_NewStatus(); TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetTfrt(opts, tfrt); TFE_ContextOptionsSetAsync(opts, static_cast(async)); TFE_Context* ctx = TFE_NewContext(opts, status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); @@ -651,7 +656,6 @@ void ExecuteAdd(bool async, bool forward_input) { ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_DeleteTensorHandle(m); TFE_DeleteTensorHandle(retval); - TFE_DeleteContext(ctx); ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); float result[100 * 100] = {0}; @@ -661,12 +665,42 @@ void ExecuteAdd(bool async, bool forward_input) { for (int i = 0; i < 100 * 100; ++i) { EXPECT_EQ(2.0f, result[i]); } + TFE_DeleteContext(ctx); TF_DeleteStatus(status); } -TEST(CAPI, ExecuteAdd) { ExecuteAdd(false, false); } -TEST(CAPI, ExecuteAddAsync) { ExecuteAdd(true, false); } -TEST(CAPI, ExecuteAddForward) { ExecuteAdd(false, true); } -TEST(CAPI, ExecuteAddForwardAsync) { ExecuteAdd(true, true); } +TEST(CAPI, ExecuteAdd) { + ExecuteAdd( + /*async=*/false, + /*forward_input*/ false, + /*tfrt*/ false); +} +TEST(CAPI, ExecuteAddAsync) { + ExecuteAdd( + /*async=*/true, + /*forward_input*/ false, + /*tfrt*/ false); +} +TEST(CAPI, ExecuteAddForward) { + ExecuteAdd( + /*async=*/false, + /*forward_input*/ true, + /*tfrt*/ false); +} +TEST(CAPI, ExecuteAddForwardAsync) { + ExecuteAdd( + /*async=*/true, + /*forward_input*/ true, + /*tfrt*/ false); +} +#ifdef PLATFORM_GOOGLE +// TODO(b/153349425): Add add forwarding tests for TFRT +TEST(CAPI, ExecuteAddTfrt) { + ExecuteAdd( + /*async=*/false, + /*forward_input*/ false, + /*tfrt*/ true); +} +#endif void Execute_MatMul_CPU(bool async) { TF_Status* status = TF_NewStatus(); diff --git a/tensorflow/c/eager/custom_device_test.cc b/tensorflow/c/eager/custom_device_test.cc index 8f13c1e5151..1ec9e9bd99a 100644 --- a/tensorflow/c/eager/custom_device_test.cc +++ b/tensorflow/c/eager/custom_device_test.cc @@ -20,129 +20,11 @@ limitations under the License. #include "tensorflow/c/eager/c_api.h" #include "tensorflow/c/eager/c_api_experimental.h" #include "tensorflow/c/eager/c_api_test_util.h" +#include "tensorflow/c/eager/custom_device_testutil.h" #include "tensorflow/c/tf_status.h" #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/platform/test.h" -namespace { - -struct LoggingDevice { - tensorflow::string device_name; - tensorflow::string underlying_device; - // Set to true whenever a TensorHandle is copied onto the device - bool* arrived_flag; - // Set to true whenever an operation is executed - bool* executed_flag; -}; - -struct LoggedTensor { - TFE_TensorHandle* tensor; - LoggedTensor() = delete; - explicit LoggedTensor(TFE_TensorHandle* tensor) : tensor(tensor) {} - ~LoggedTensor() { TFE_DeleteTensorHandle(tensor); } -}; - -void LoggedTensorDeallocator(void* data, size_t len, void* arg) { - delete reinterpret_cast(data); -} - -TFE_TensorHandle* MakeLoggedTensorHandle( - TFE_Context* context, const tensorflow::string& logging_device_name, - std::unique_ptr t, TF_Status* status) { - std::vector shape(TFE_TensorHandleNumDims(t->tensor, status)); - if (TF_GetCode(status) != TF_OK) return nullptr; - for (int i = 0; i < shape.size(); ++i) { - shape[i] = TFE_TensorHandleDim(t->tensor, i, status); - if (TF_GetCode(status) != TF_OK) return nullptr; - } - auto dtype = TFE_TensorHandleDataType(t->tensor); - return TFE_NewTensorHandleFromDeviceMemory( - context, logging_device_name.c_str(), dtype, shape.data(), shape.size(), - t.release(), 1, &LoggedTensorDeallocator, nullptr, status); -} - -TFE_TensorHandle* CopyToLoggingDevice(TFE_Context* context, - TFE_TensorHandle* tensor, - TF_Status* status, void* device_info) { - LoggingDevice* dev = reinterpret_cast(device_info); - TFE_TensorHandle* t = TFE_TensorHandleCopyToDevice( - tensor, context, dev->underlying_device.c_str(), status); - if (TF_GetCode(status) != TF_OK) return nullptr; - auto dst = std::make_unique(t); - *(dev->arrived_flag) = true; - return MakeLoggedTensorHandle(context, dev->device_name, std::move(dst), - status); -} - -TFE_TensorHandle* CopyTensorFromLoggingDevice(TFE_Context* context, - TFE_TensorHandle* tensor, - const char* target_device_name, - TF_Status* status, - void* device_info) { - TF_SetStatus(status, TF_INTERNAL, - "Trying to copy a tensor out of a logging device."); - return nullptr; -} - -void LoggingDeviceExecute(TFE_Context* context, int num_inputs, - TFE_TensorHandle** inputs, const char* operation_name, - const TFE_OpAttrs* attributes, int* num_outputs, - TFE_TensorHandle** outputs, TF_Status* s, - void* device_info) { - LoggingDevice* dev = reinterpret_cast(device_info); - TFE_Op* op(TFE_NewOp(context, operation_name, s)); - if (TF_GetCode(s) != TF_OK) return; - TFE_OpAddAttrs(op, attributes); - TFE_OpSetDevice(op, dev->underlying_device.c_str(), s); - for (int j = 0; j < num_inputs; ++j) { - TFE_TensorHandle* input = inputs[j]; - const char* input_device = TFE_TensorHandleDeviceName(input, s); - if (TF_GetCode(s) != TF_OK) return; - if (dev->device_name == input_device) { - LoggedTensor* t = reinterpret_cast( - TFE_TensorHandleDevicePointer(input, s)); - if (TF_GetCode(s) != TF_OK) return; - TFE_OpAddInput(op, t->tensor, s); - } else { - TFE_OpAddInput(op, input, s); - } - if (TF_GetCode(s) != TF_OK) return; - } - std::vector op_outputs(*num_outputs); - TFE_Execute(op, op_outputs.data(), num_outputs, s); - TFE_DeleteOp(op); - if (TF_GetCode(s) != TF_OK) return; - std::vector unwrapped_outputs; - for (auto* handle : op_outputs) { - unwrapped_outputs.push_back(handle); - } - for (int i = 0; i < *num_outputs; ++i) { - auto logged_tensor = std::make_unique(unwrapped_outputs[i]); - outputs[i] = MakeLoggedTensorHandle(context, dev->device_name, - std::move(logged_tensor), s); - } - *(dev->executed_flag) = true; -} - -void DeleteLoggingDevice(void* device_info) { - delete reinterpret_cast(device_info); -} - -void RegisterLoggingDevice(TFE_Context* context, const char* name, - bool* arrived_flag, bool* executed_flag, - TF_Status* status) { - TFE_CustomDevice custom_device; - custom_device.copy_tensor_to_device = &CopyToLoggingDevice; - custom_device.copy_tensor_from_device = &CopyTensorFromLoggingDevice; - custom_device.delete_device = &DeleteLoggingDevice; - custom_device.execute = &LoggingDeviceExecute; - LoggingDevice* device = new LoggingDevice; - device->arrived_flag = arrived_flag; - device->executed_flag = executed_flag; - device->device_name = name; - device->underlying_device = "/job:localhost/replica:0/task:0/device:CPU:0"; - TFE_RegisterCustomDevice(context, custom_device, name, device, status); -} TEST(CUSTOM_DEVICE, RegisterSimpleDevice) { std::unique_ptr status( @@ -276,9 +158,7 @@ TEST(CUSTOM_DEVICE, MakeVariable) { tensorflow::string( TFE_TensorHandleBackingDeviceName(var_value, status.get()))); TFE_TensorHandle* var_value_unpacked = - reinterpret_cast( - TFE_TensorHandleDevicePointer(var_value, status.get())) - ->tensor; + UnpackTensorHandle(var_value, status.get()); ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); std::unique_ptr resolved_value( TFE_TensorHandleResolve(var_value_unpacked, status.get()), @@ -394,5 +274,3 @@ TEST(CUSTOM_DEVICE, InvalidRegistrationError) { ASSERT_TRUE(TF_GetCode(status.get()) == TF_ALREADY_EXISTS) << TF_Message(status.get()); } - -} // namespace diff --git a/tensorflow/c/eager/custom_device_testutil.cc b/tensorflow/c/eager/custom_device_testutil.cc new file mode 100644 index 00000000000..28de3665653 --- /dev/null +++ b/tensorflow/c/eager/custom_device_testutil.cc @@ -0,0 +1,172 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// A simple logging device to test custom device registration. +#include + +#include "tensorflow/c/c_api.h" +#include "tensorflow/c/eager/c_api.h" +#include "tensorflow/c/eager/c_api_experimental.h" +#include "tensorflow/c/eager/c_api_test_util.h" +#include "tensorflow/c/tf_status.h" +#include "tensorflow/core/lib/gtl/cleanup.h" +#include "tensorflow/core/platform/test.h" + +namespace { + +struct LoggingDevice { + tensorflow::string device_name; + tensorflow::string underlying_device; + // Set to true whenever a TensorHandle is copied onto the device + bool* arrived_flag; + // Set to true whenever an operation is executed + bool* executed_flag; +}; + +struct LoggedTensor { + TFE_TensorHandle* tensor; + LoggedTensor() = delete; + explicit LoggedTensor(TFE_TensorHandle* tensor) : tensor(tensor) {} + ~LoggedTensor() { TFE_DeleteTensorHandle(tensor); } +}; + +void LoggedTensorDeallocator(void* data, size_t len, void* arg) { + delete reinterpret_cast(data); +} + +TFE_TensorHandle* MakeLoggedTensorHandle( + TFE_Context* context, const tensorflow::string& logging_device_name, + std::unique_ptr t, TF_Status* status) { + std::vector shape(TFE_TensorHandleNumDims(t->tensor, status)); + if (TF_GetCode(status) != TF_OK) return nullptr; + for (int i = 0; i < shape.size(); ++i) { + shape[i] = TFE_TensorHandleDim(t->tensor, i, status); + if (TF_GetCode(status) != TF_OK) return nullptr; + } + auto dtype = TFE_TensorHandleDataType(t->tensor); + return TFE_NewTensorHandleFromDeviceMemory( + context, logging_device_name.c_str(), dtype, shape.data(), shape.size(), + t.release(), 1, &LoggedTensorDeallocator, nullptr, status); +} + +TFE_TensorHandle* CopyToLoggingDevice(TFE_Context* context, + TFE_TensorHandle* tensor, + TF_Status* status, void* device_info) { + LoggingDevice* dev = reinterpret_cast(device_info); + TFE_TensorHandle* t = TFE_TensorHandleCopyToDevice( + tensor, context, dev->underlying_device.c_str(), status); + if (TF_GetCode(status) != TF_OK) return nullptr; + auto dst = std::make_unique(t); + *(dev->arrived_flag) = true; + return MakeLoggedTensorHandle(context, dev->device_name, std::move(dst), + status); +} + +TFE_TensorHandle* CopyTensorFromLoggingDevice(TFE_Context* context, + TFE_TensorHandle* tensor, + const char* target_device_name, + TF_Status* status, + void* device_info) { + TF_SetStatus(status, TF_INTERNAL, + "Trying to copy a tensor out of a logging device."); + return nullptr; +} + +void LoggingDeviceExecute(TFE_Context* context, int num_inputs, + TFE_TensorHandle** inputs, const char* operation_name, + const TFE_OpAttrs* attributes, int* num_outputs, + TFE_TensorHandle** outputs, TF_Status* s, + void* device_info) { + LoggingDevice* dev = reinterpret_cast(device_info); + TFE_Op* op(TFE_NewOp(context, operation_name, s)); + if (TF_GetCode(s) != TF_OK) return; + TFE_OpAddAttrs(op, attributes); + TFE_OpSetDevice(op, dev->underlying_device.c_str(), s); + for (int j = 0; j < num_inputs; ++j) { + TFE_TensorHandle* input = inputs[j]; + const char* input_device = TFE_TensorHandleDeviceName(input, s); + if (TF_GetCode(s) != TF_OK) return; + if (dev->device_name == input_device) { + LoggedTensor* t = reinterpret_cast( + TFE_TensorHandleDevicePointer(input, s)); + if (TF_GetCode(s) != TF_OK) return; + TFE_OpAddInput(op, t->tensor, s); + } else { + TFE_OpAddInput(op, input, s); + } + if (TF_GetCode(s) != TF_OK) return; + } + std::vector op_outputs(*num_outputs); + TFE_Execute(op, op_outputs.data(), num_outputs, s); + TFE_DeleteOp(op); + if (TF_GetCode(s) != TF_OK) return; + std::vector unwrapped_outputs; + for (auto* handle : op_outputs) { + unwrapped_outputs.push_back(handle); + } + for (int i = 0; i < *num_outputs; ++i) { + auto logged_tensor = std::make_unique(unwrapped_outputs[i]); + outputs[i] = MakeLoggedTensorHandle(context, dev->device_name, + std::move(logged_tensor), s); + } + *(dev->executed_flag) = true; +} + +void DeleteLoggingDevice(void* device_info) { + delete reinterpret_cast(device_info); +} + +} // namespace + +void RegisterLoggingDevice(TFE_Context* context, const char* name, + bool* arrived_flag, bool* executed_flag, + TF_Status* status) { + TFE_CustomDevice custom_device; + custom_device.copy_tensor_to_device = &CopyToLoggingDevice; + custom_device.copy_tensor_from_device = &CopyTensorFromLoggingDevice; + custom_device.delete_device = &DeleteLoggingDevice; + custom_device.execute = &LoggingDeviceExecute; + LoggingDevice* device = new LoggingDevice; + device->arrived_flag = arrived_flag; + device->executed_flag = executed_flag; + device->device_name = name; + device->underlying_device = "/job:localhost/replica:0/task:0/device:CPU:0"; + TFE_RegisterCustomDevice(context, custom_device, name, device, status); +} + +TFE_TensorHandle* UnpackTensorHandle(TFE_TensorHandle* logged_tensor_handle, + TF_Status* status) { + return reinterpret_cast( + TFE_TensorHandleDevicePointer(logged_tensor_handle, status)) + ->tensor; +} + +void AllocateLoggingDevice(const char* name, bool* arrived_flag, + bool* executed_flag, TFE_CustomDevice** device, + void** device_info) { + TFE_CustomDevice* custom_device = new TFE_CustomDevice; + custom_device->copy_tensor_to_device = &CopyToLoggingDevice; + custom_device->copy_tensor_from_device = &CopyTensorFromLoggingDevice; + custom_device->delete_device = &DeleteLoggingDevice; + custom_device->execute = &LoggingDeviceExecute; + *device = custom_device; + LoggingDevice* logging_device = new LoggingDevice; + logging_device->arrived_flag = arrived_flag; + logging_device->executed_flag = executed_flag; + logging_device->device_name = name; + logging_device->underlying_device = + "/job:localhost/replica:0/task:0/device:CPU:0"; + *device_info = reinterpret_cast(logging_device); +} diff --git a/tensorflow/c/eager/custom_device_testutil.h b/tensorflow/c/eager/custom_device_testutil.h new file mode 100644 index 00000000000..509df7d3e3e --- /dev/null +++ b/tensorflow/c/eager/custom_device_testutil.h @@ -0,0 +1,36 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EAGER_CUSTOM_DEVICE_TESTUTIL_H_ +#define TENSORFLOW_C_EAGER_CUSTOM_DEVICE_TESTUTIL_H_ + +// A simple logging device to test custom device registration. +#include + +#include "tensorflow/c/c_api.h" +#include "tensorflow/c/eager/c_api.h" +#include "tensorflow/c/eager/c_api_experimental.h" +#include "tensorflow/c/tf_status.h" + +void RegisterLoggingDevice(TFE_Context* context, const char* name, + bool* arrived_flag, bool* executed_flag, + TF_Status* status); +void AllocateLoggingDevice(const char* name, bool* arrived_flag, + bool* executed_flag, TFE_CustomDevice** device, + void** device_info); +TFE_TensorHandle* UnpackTensorHandle(TFE_TensorHandle* logged_tensor_handle, + TF_Status* status); + +#endif // TENSORFLOW_C_EAGER_CUSTOM_DEVICE_TESTUTIL_H_ diff --git a/tensorflow/c/eager/parallel_device/BUILD b/tensorflow/c/eager/parallel_device/BUILD new file mode 100644 index 00000000000..9d787d26433 --- /dev/null +++ b/tensorflow/c/eager/parallel_device/BUILD @@ -0,0 +1,38 @@ +load( + "//tensorflow:tensorflow.bzl", + "tf_cc_test", +) + +package( + licenses = ["notice"], # Apache 2.0 +) + +cc_library( + name = "parallel_device", + srcs = ["parallel_device.cc"], + hdrs = ["parallel_device.h"], + deps = [ + "//tensorflow/c:c_api", + "//tensorflow/c/eager:c_api", + "//tensorflow/c/eager:c_api_experimental", + "//tensorflow/core:lib", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:variant", + ], +) + +tf_cc_test( + name = "parallel_device_test", + srcs = ["parallel_device_test.cc"], + deps = [ + ":parallel_device", + "//tensorflow/c:c_api", + "//tensorflow/c:c_api_experimental", + "//tensorflow/c/eager:c_api", + "//tensorflow/c/eager:c_api_experimental", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) diff --git a/tensorflow/c/eager/parallel_device/parallel_device.cc b/tensorflow/c/eager/parallel_device/parallel_device.cc new file mode 100644 index 00000000000..bd5d8e777f2 --- /dev/null +++ b/tensorflow/c/eager/parallel_device/parallel_device.cc @@ -0,0 +1,597 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/eager/parallel_device/parallel_device.h" + +#include + +#include "absl/strings/str_cat.h" +#include "absl/types/optional.h" +#include "absl/types/variant.h" +#include "tensorflow/c/c_api.h" +#include "tensorflow/c/eager/c_api.h" +#include "tensorflow/c/eager/c_api_experimental.h" +#include "tensorflow/c/tf_status.h" +#include "tensorflow/core/lib/gtl/cleanup.h" + +namespace tensorflow { +namespace eager { +namespace { + +// Functor for making unique_ptrs slightly more ergonomic. Using +// decltype(delete_fn) in the unique_ptr's second template argument requires +// passing a function pointer to delete_fn when constructing the unique_ptr. +class TensorHandleDeleter { + public: + void operator()(TFE_TensorHandle* to_delete) const { + TFE_DeleteTensorHandle(to_delete); + } +}; + +using TensorHandlePtr = std::unique_ptr; + +class OpDeleter { + public: + void operator()(TFE_Op* to_delete) const { TFE_DeleteOp(to_delete); } +}; + +using OpPtr = std::unique_ptr; + +class ExecutorDeleter { + public: + void operator()(TFE_Executor* to_delete) const { + TFE_DeleteExecutor(to_delete); + } +}; + +using ExecutorPtr = std::unique_ptr; + +class ParallelTensor; + +using MaybeParallelTensorOwned = + absl::variant, TensorHandlePtr>; +using MaybeParallelTensorUnowned = + absl::variant; + +// Creates a vector of `count` new executors (threads). +std::vector MakeExecutors(size_t count) { + std::vector executors; + executors.reserve(count); + for (int i = 0; i < count; ++i) { + executors.emplace_back(TFE_NewExecutor(true /* is_async */)); + } + return executors; +} + +// A representation of the custom device passed in and out of the TFE custom +// device APIs, providing context about the parallel device to +// ParallelDeviceExecute. +class ParallelDevice { + public: + ParallelDevice(const std::string& name, + const std::vector& devices); + + // Helper to copy a tensor handle from another device once for each component + // of the ParallelDevice. + // + // Sets a bad status and returns a nullptr if `tensor` is already on the + // ParallelDevice, or if the individual copies fail. + std::unique_ptr CopyToParallelDevice(TFE_Context* context, + TFE_TensorHandle* tensor, + TF_Status* status) const; + + // Takes a description of a single operation being executed on the + // ParallelDevice, and in turn runs one operation per component device with + // its corresponding inputs from the input ParallelTensors (or + // implicitly-mirrored tensors on other devices). Wraps the resulting + // per-device and per-output TFE_TensorHandles into one ParallelTensor per + // output of the original operation. + // + // `inputs` are either ParallelTensors, i.e. already on the ParallelDevice, or + // un-replicated TFE_TensorHandles on other devices. TPUReplicatedInput + // requires non-parallel tensors, and TPUReplicatedOutput requires a parallel + // tensor, but other operations will implicitly broadcast non-parallel input + // tensors across the ParallelDevice's component devices. + // + // Two special-cased operations, TPUReplicatedInput and TPUReplicatedOutput, + // pack and un-pack parallel tensors respectively. Only TPUReplicatedOutput + // causes `Execute` to return non-parallel tensors. + // + // Attributes are forwarded to executed operations unmodified. + // + // The returned optional has a value if and only if `status` evaluates to + // TF_OK. + absl::optional> Execute( + TFE_Context* context, std::vector inputs, + const char* operation_name, const TFE_OpAttrs* attributes, + int expected_max_outputs, TF_Status* status) const; + + // Implements the parallel case for `Execute`, where all of the outputs of the + // operation are ParallelTensors, and all inputs are either ParallelTensors or + // should be implicitly broadcast. This means the operation is not + // TPUReplicatedInput or TPUReplicatedOutput. + // + // The returned optional has a value if and only if `status` evaluates to + // TF_OK. Bad statuses are forwarded from underlying `TFE_Execute` calls, or + // if sanity checks on dtypes/metadata fail. + absl::optional>> + ExecuteParallelOperation(TFE_Context* context, + std::vector inputs, + const char* operation_name, + const TFE_OpAttrs* attributes, + int expected_max_outputs, TF_Status* status) const; + + const std::string& device_name() const { return device_name_; } + + private: + // The name of the parallel device + // (e.g. "/job:localhost/replica:0/task:0/device:CUSTOM:0") + const std::string device_name_; + // A sequence of device names, indicating which devices replicated operations + // are forwarded to. + const std::vector underlying_devices_; + // A sequence of TFE_Executors, one per device, for executing operations in + // parallel. + const std::vector executors_; +}; + +// The internal representation of a TFE_TensorHandle placed on a +// ParallelDevice. Contains a tuple of tensors, one on each of the +// `underlying_devices_` of the ParallelDevice. +class ParallelTensor { + public: + // Construct a ParallelTensor from TensorHandles placed on the component + // devices of a ParallelDevice. + static std::unique_ptr FromTensorHandles( + const ParallelDevice& parallel_device, + std::vector components, TF_Status* status); + + // Helper to wrap a ParallelTensor into a TFE_TensorHandle which contains it. + static TensorHandlePtr AsTensorHandle(TFE_Context* context, + std::unique_ptr t, + TF_Status* status); + + size_t num_tensors() const { return tensors_.size(); } + TFE_TensorHandle* tensor(size_t index) const { return tensors_[index].get(); } + + private: + ParallelTensor(const ParallelDevice& device, + std::vector tensors, + std::vector shape, const TF_DataType dtype) + : device_(device), + tensors_(std::move(tensors)), + shape_(std::move(shape)), + dtype_(dtype) {} + + const ParallelDevice& device_; + const std::vector tensors_; + const std::vector shape_; + const TF_DataType dtype_; +}; + +ParallelDevice::ParallelDevice(const std::string& name, + const std::vector& devices) + : device_name_(name), + underlying_devices_(devices), + executors_(MakeExecutors(underlying_devices_.size())) {} + +std::unique_ptr ParallelDevice::CopyToParallelDevice( + TFE_Context* context, TFE_TensorHandle* tensor, TF_Status* status) const { + const char* current_device = TFE_TensorHandleDeviceName(tensor, status); + if (device_name_ == current_device) { + std::string message(absl::StrCat( + "Tried to copy a TensorHandle to its existing device: ", device_name_)); + TF_SetStatus(status, TF_INTERNAL, message.c_str()); + return nullptr; + } + std::vector components; + components.reserve(underlying_devices_.size()); + for (const std::string& underlying_device_name : underlying_devices_) { + TFE_TensorHandle* t = TFE_TensorHandleCopyToDevice( + tensor, context, underlying_device_name.c_str(), status); + if (TF_GetCode(status) != TF_OK) return nullptr; + components.emplace_back(t); + } + return ParallelTensor::FromTensorHandles(*this, std::move(components), + status); +} + +absl::optional> ParallelDevice::Execute( + TFE_Context* context, std::vector inputs, + const char* operation_name, const TFE_OpAttrs* attributes, + int expected_max_outputs, TF_Status* status) const { + absl::optional> result; + // TODO(allenl): We should remove "TPU" from these op names at the very least, + // or consider other ways of packing/unpacking parallel tensors. + if (operation_name == std::string("TPUReplicatedInput")) { + // Special-cased operation for packing per-device tensors into one parallel + // tensor. + if (inputs.size() != underlying_devices_.size()) { + std::string message(absl::StrCat( + "The parallel device ", device_name_, " expected ", + underlying_devices_.size(), " inputs to TPUReplicatedInput, but got ", + inputs.size())); + TF_SetStatus(status, TF_INVALID_ARGUMENT, message.c_str()); + return result; + } + std::vector components; + components.reserve(inputs.size()); + for (int i = 0; i < inputs.size(); ++i) { + if (absl::holds_alternative(inputs[i])) { + std::string message(absl::StrCat( + "Expected all inputs to TPUReplicatedInput to be non-parallel " + "TensorHandles. The input ", + i, + " was a parallel tensor (already " + "placed on the parallel device).")); + TF_SetStatus(status, TF_INVALID_ARGUMENT, message.c_str()); + return result; + } + components.emplace_back(TFE_TensorHandleCopySharingTensor( + absl::get(inputs[i]), status)); + } + std::vector result_content; + result_content.reserve(1); + result_content.push_back(ParallelTensor::FromTensorHandles( + *this, std::move(components), status)); + if (TF_GetCode(status) != TF_OK) return result; + result.emplace(std::move(result_content)); + return result; + } else if (operation_name == std::string("TPUReplicatedOutput")) { + // Special-cased operation for un-packing one parallel tensor into + // per-device tensors. + OpPtr op(TFE_NewOp(context, operation_name, status)); + TFE_OpAddAttrs(op.get(), attributes); + int expected_outputs = TFE_OpGetOutputLength(op.get(), "outputs", status); + if (TF_GetCode(status) != TF_OK) return result; + if (expected_outputs != underlying_devices_.size()) { + std::string message(absl::StrCat( + "The parallel device ", device_name_, " expected ", + underlying_devices_.size(), + " outputs for TPUReplicatedOutput, but got ", expected_outputs)); + TF_SetStatus(status, TF_INVALID_ARGUMENT, message.c_str()); + return result; + } + if (absl::holds_alternative(inputs[0])) { + TF_SetStatus(status, TF_INVALID_ARGUMENT, + "Expected the input to " + "TPUReplicatedOutput to be a parallel tensor (placed on the " + "parallel device)."); + return result; + } + ParallelTensor* t = absl::get(inputs[0]); + std::vector outputs; + outputs.reserve(t->num_tensors()); + for (int i = 0; i < t->num_tensors(); ++i) { + TensorHandlePtr this_output( + TFE_TensorHandleCopySharingTensor(t->tensor(i), status)); + outputs.emplace_back(std::move(this_output)); + if (TF_GetCode(status) != TF_OK) return result; + } + result.emplace(std::move(outputs)); + return result; + } + absl::optional>> + maybe_parallel_results( + ExecuteParallelOperation(context, std::move(inputs), operation_name, + attributes, expected_max_outputs, status)); + if (!maybe_parallel_results.has_value()) return result; + std::vector> parallel_results( + std::move(maybe_parallel_results.value())); + std::vector result_content; + result_content.reserve(parallel_results.size()); + for (std::unique_ptr& parallel_result : parallel_results) { + result_content.push_back( + MaybeParallelTensorOwned(std::move(parallel_result))); + } + result.emplace(std::move(result_content)); + return result; +} + +absl::optional>> +ParallelDevice::ExecuteParallelOperation( + TFE_Context* context, std::vector inputs, + const char* operation_name, const TFE_OpAttrs* attributes, + int expected_max_outputs, TF_Status* status) const { + absl::optional>> result; + // Compute per-device per-output tensors + std::vector> per_device_output_tensors; + per_device_output_tensors.reserve(underlying_devices_.size()); + // TODO(allenl): Add a TFE_ExecuteWithExecutor API so we don't have to keep + // setting the thread-local executor like this. + TFE_Executor* previous_executor(TFE_ContextGetExecutorForThread(context)); + auto reset_executor = gtl::MakeCleanup([context, previous_executor]() { + TFE_ContextSetExecutorForThread(context, previous_executor); + TFE_DeleteExecutor(previous_executor); + }); + int first_op_output_count; + for (int device_index = 0; device_index < underlying_devices_.size(); + ++device_index) { + TFE_Executor* executor = executors_[device_index].get(); + // Note that the `reset_executor` cleanup sets the thread's executor back to + // the value before this function ran. + TFE_ContextSetExecutorForThread(context, executor); + OpPtr op(TFE_NewOp(context, operation_name, status)); + if (TF_GetCode(status) != TF_OK) return result; + TFE_OpSetDevice(op.get(), underlying_devices_[device_index].c_str(), + status); + TFE_OpAddAttrs(op.get(), attributes); + for (int input_index = 0; input_index < inputs.size(); ++input_index) { + if (absl::holds_alternative(inputs[input_index])) { + // Non-parallel tensors are implicitly broadcast, i.e. set as the input + // to each parallel operation. + // + // TODO(allenl): There may be smarter ways to do this copy in some + // cases, i.e. with a collective broadcast. We'll need to be careful + // about things that are taken as inputs on the host or on their + // existing device (for multi-device functions). + TFE_OpAddInput(op.get(), + absl::get(inputs[input_index]), + status); + if (TF_GetCode(status) != TF_OK) return result; + } else { + // Parallel tensors are divided between operations by device. + TFE_OpAddInput(op.get(), + absl::get(inputs[input_index]) + ->tensor(device_index), + status); + if (TF_GetCode(status) != TF_OK) return result; + } + } + std::vector op_outputs(expected_max_outputs); + int real_num_outputs = expected_max_outputs; + // For nested devices, the inner device sees the async executor we've + // set. Inner parallel devices will just overwrite this with their own and + // then set it back to ours before returning. This means parallel devices + // which consist of several aliased parallel devices would hypothetically + // deadlock if the outer parallel device ran one collective with a group + // size equal to the total number of aliased physical devices. Currently + // physical devices cannot participate in a single collective reduction + // multiple times, so this would fail earlier. + // + // TODO(allenl): Keep a map from outer executor to list of inner executors + // rather than a single list of executors so aliased nested parallel devices + // don't re-use an executor. + TFE_Execute(op.get(), op_outputs.data(), &real_num_outputs, status); + if (device_index == 0) { + first_op_output_count = real_num_outputs; + } else { + if (real_num_outputs != first_op_output_count) { + TF_SetStatus(status, TF_INTERNAL, + "Parallel ops produced different numbers of tensors."); + return result; + } + } + if (TF_GetCode(status) != TF_OK) return result; + std::vector this_outputs; + this_outputs.reserve(real_num_outputs); + for (int output_num = 0; output_num < real_num_outputs; ++output_num) { + this_outputs.emplace_back(op_outputs[output_num]); + } + per_device_output_tensors.push_back(std::move(this_outputs)); + } + // For each output of the original operation, pack the per-device + // TensorHandles we've computed into a single parallel TensorHandle. + std::vector> per_device_outputs; + per_device_outputs.reserve(first_op_output_count); + for (int i = 0; i < first_op_output_count; ++i) { + std::vector components; + components.reserve(underlying_devices_.size()); + for (int j = 0; j < underlying_devices_.size(); ++j) { + components.push_back(std::move(per_device_output_tensors[j][i])); + } + per_device_outputs.push_back(ParallelTensor::FromTensorHandles( + *this, std::move(components), status)); + if (TF_GetCode(status) != TF_OK) return result; + } + result.emplace(std::move(per_device_outputs)); + return result; +} + +std::unique_ptr ParallelTensor::FromTensorHandles( + const ParallelDevice& parallel_device, + std::vector components, TF_Status* status) { + TF_DataType dtype = TFE_TensorHandleDataType(components[0].get()); + std::vector shape( + TFE_TensorHandleNumDims(components[0].get(), status)); + if (TF_GetCode(status) != TF_OK) return nullptr; + for (int i = 0; i < shape.size(); ++i) { + shape[i] = TFE_TensorHandleDim(components[0].get(), i, status); + if (TF_GetCode(status) != TF_OK) return nullptr; + } + + // Verify that the TensorHandle's shape and dtype match all of the component + // shapes and dtypes. + for (TensorHandlePtr& component : components) { + for (int i = 0; i < shape.size(); ++i) { + int64_t tensor_dim = TFE_TensorHandleDim(component.get(), i, status); + if (TF_GetCode(status) != TF_OK) return nullptr; + if (tensor_dim != shape[i]) { + // TODO(allenl): Allow shapes to differ. + TF_SetStatus(status, TF_UNIMPLEMENTED, + "Components of a ParallelTensor must currently all have " + "the same shape"); + return nullptr; + } + if (TFE_TensorHandleDataType(component.get()) != dtype) { + TF_SetStatus(status, TF_INTERNAL, + "Components of a ParallelTensor must all have " + "the same dtype"); + return nullptr; + } + } + } + + return std::unique_ptr(new ParallelTensor( + parallel_device, std::move(components), std::move(shape), dtype)); +} + +// Used as an argument to TFE_NewTensorHandleFromDeviceMemory, indicating how +// ParallelTensors wrapped in TFE_TensorHandles should be cleaned up once their +// reference counts drop to zero. +void ParallelTensorDeallocator(void* data, size_t len, void* arg) { + delete reinterpret_cast(data); +} + +TensorHandlePtr ParallelTensor::AsTensorHandle( + TFE_Context* context, std::unique_ptr t, + TF_Status* status) { + // The resulting TensorHandle owns an opaque pointer to "device memory", which + // for a ParallelDevice is really a ParallelTensor. When the TensorHandle is + // deleted, it will call ParallelTensorDeallocator to free the struct. + ParallelTensor* t_released = t.release(); + return TensorHandlePtr(TFE_NewTensorHandleFromDeviceMemory( + context, t_released->device_.device_name().c_str(), t_released->dtype_, + t_released->shape_.data(), t_released->shape_.size(), t_released, 1, + &ParallelTensorDeallocator, nullptr, status)); +} + +// For TFE_CustomDevice::copy_tensor_to_device in the parallel device +// registration. +// +// Replicates a single TFE_TensorHandle, producing a TFE_TensorHandle containing +// a ParallelTensor with one copy of `tensor` for each device in the +// ParallelDevice. +// +// Since this function is used to satisfy the TFE_CustomDevice C API, +// device_info is passed in using a C-style generic. It must always be a +// ParallelDevice. +TFE_TensorHandle* CopyToParallelDevice(TFE_Context* context, + TFE_TensorHandle* tensor, + TF_Status* status, void* device_info) { + ParallelDevice* dev = reinterpret_cast(device_info); + std::unique_ptr parallel_tensor( + dev->CopyToParallelDevice(context, tensor, status)); + if (TF_GetCode(status) != TF_OK) return nullptr; + return ParallelTensor::AsTensorHandle(context, std::move(parallel_tensor), + status) + .release(); +} + +// For TFE_CustomDevice::copy_tensor_from_device in the parallel device +// registration. +// +// Currently this is an error, and un-packing ParallelTensors must be performed +// explicitly by running a TPUReplicatedOutput operation on the parallel device. +// +// TODO(allenl): There are some use-cases that are only supported by copying to +// host at the moment (e.g. debug print on a tensor, .numpy(), etc.). We either +// need to return something here or address these use-cases one by one. +TFE_TensorHandle* CopyTensorFromParallelDevice(TFE_Context* context, + TFE_TensorHandle* tensor, + const char* target_device_name, + TF_Status* status, + void* device_info) { + TF_SetStatus(status, TF_INTERNAL, + "Trying to copy a tensor out of a parallel device."); + return nullptr; +} + +// For TFE_CustomDevice::execute in the parallel device registration. +// +// Since this function is used to satisfy the TFE_CustomDevice C API, +// device_info is passed in using a C-style generic. It must always be a +// ParallelDevice. +void ParallelDeviceExecute(TFE_Context* context, int num_inputs, + TFE_TensorHandle** inputs, + const char* operation_name, + const TFE_OpAttrs* attributes, int* num_outputs, + TFE_TensorHandle** outputs, TF_Status* status, + void* device_info) { + ParallelDevice* dev = reinterpret_cast(device_info); + std::vector typed_inputs; + typed_inputs.reserve(num_inputs); + for (int i = 0; i < num_inputs; ++i) { + const char* tensor_handle_device = + TFE_TensorHandleDeviceName(inputs[i], status); + if (TF_GetCode(status) != TF_OK) return; + if (dev->device_name() == tensor_handle_device) { + // We assume that any tensors already placed on this device are + // ParallelTensors. + typed_inputs.emplace_back(reinterpret_cast( + TFE_TensorHandleDevicePointer(inputs[i], status))); + if (TF_GetCode(status) != TF_OK) return; + } else { + typed_inputs.emplace_back(inputs[i]); + } + } + + absl::optional> maybe_typed_outputs( + dev->Execute(context, std::move(typed_inputs), operation_name, attributes, + *num_outputs, status)); + if (TF_GetCode(status) != TF_OK) return; + if (!maybe_typed_outputs.has_value()) { + TF_SetStatus(status, TF_INTERNAL, "OK status but no value was returned."); + return; + } + + std::vector typed_outputs( + std::move(maybe_typed_outputs.value())); + + if (typed_outputs.size() > *num_outputs) { + TF_SetStatus(status, TF_INTERNAL, + "The allocated output buffer was too small."); + return; + } + + for (int i = 0; i < typed_outputs.size(); ++i) { + MaybeParallelTensorOwned typed_output(std::move(typed_outputs[i])); + if (absl::holds_alternative(typed_output)) { + outputs[i] = absl::get(typed_output).release(); + } else { + outputs[i] = ParallelTensor::AsTensorHandle( + context, + std::move(absl::get>( + typed_output)), + status) + .release(); + if (TF_GetCode(status) != TF_OK) return; + } + } + *num_outputs = typed_outputs.size(); +} + +// For TFE_CustomDevice::delete_device in the parallel device registration. +// +// Since this function is used to satisfy the TFE_CustomDevice C API, +// device_info is passed in using a C-style generic. It must always be a +// ParallelDevice. +void DeleteParallelDevice(void* device_info) { + delete reinterpret_cast(device_info); +} + +} // namespace + +void RegisterParallelDevice(TFE_Context* context, const char* device_name, + const char** underlying_devices, + int num_underlying_devices, TF_Status* status) { + TFE_CustomDevice custom_device; + custom_device.copy_tensor_to_device = &CopyToParallelDevice; + custom_device.copy_tensor_from_device = &CopyTensorFromParallelDevice; + custom_device.delete_device = &DeleteParallelDevice; + custom_device.execute = &ParallelDeviceExecute; + std::vector underlying_devices_vector; + underlying_devices_vector.reserve(num_underlying_devices); + for (int device_index = 0; device_index < num_underlying_devices; + ++device_index) { + underlying_devices_vector.push_back(underlying_devices[device_index]); + } + ParallelDevice* d = + new ParallelDevice(device_name, underlying_devices_vector); + TFE_RegisterCustomDevice(context, custom_device, device_name, d, status); +} + +} // namespace eager +} // namespace tensorflow diff --git a/tensorflow/c/eager/parallel_device/parallel_device.h b/tensorflow/c/eager/parallel_device/parallel_device.h new file mode 100644 index 00000000000..b106524401f --- /dev/null +++ b/tensorflow/c/eager/parallel_device/parallel_device.h @@ -0,0 +1,62 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_H_ +#define TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_H_ + +#include "tensorflow/c/eager/c_api.h" + +namespace tensorflow { +namespace eager { + +// Register a parallel device named `device_name` which forwards operations to +// `underlying_devices`, maintaining "parallel tensors" with components placed +// on each underlying device. +// +// For example if `device_name` is +// "/job:localhost/replica:0/task:0/device:CUSTOM:0" +// and `underlying_devices` is +// {"/job:localhost/replica:0/task:0/device:GPU:0", +// "/job:localhost/replica:0/task:0/device:GPU:1"} +// Then executing an operation on CUSTOM:0 will execute it on GPU:0 and GPU:1. +// +// Implicit copies onto `device_name` are allowed, replicating the value once +// per device in `underlying_devices`. Implicit copies off of the device throw +// an error. +// +// All component tensors must have the same dtype. Currently they must also have +// the same shape, although this requirement may be relaxed in the future. +// +// `device_name` must not name an existing physical or custom device (see +// the documentation for TFE_RegisterCustomDevice for more information). +// +// Tensors may be copied on or off the device explicitly using +// TPUReplicatedInput and TPUReplicatedOutput respectively. For example, with +// two component devices, running `x = TPUReplicatedInput(inputs=[a, b])` on the +// parallel device creates a parallel tensor `x` with `a` on the first of +// `underlying_devices` and `b` on the second. Running `a_unpacked, b_unpacked = +// TPUReplicatedOutput(input=x, num_replicas=2)` un-packs the parallel tensor +// into its components. +// +// `context` owns the parallel device. `underlying_devices` must stay valid +// while the parallel device is in use. +void RegisterParallelDevice(TFE_Context* context, const char* device_name, + const char** underlying_devices, + int num_underlying_devices, TF_Status* status); + +} // namespace eager +} // namespace tensorflow + +#endif // TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_H_ diff --git a/tensorflow/c/eager/parallel_device/parallel_device_test.cc b/tensorflow/c/eager/parallel_device/parallel_device_test.cc new file mode 100644 index 00000000000..41c7d64e231 --- /dev/null +++ b/tensorflow/c/eager/parallel_device/parallel_device_test.cc @@ -0,0 +1,917 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/eager/parallel_device/parallel_device.h" + +#include + +#include "tensorflow/c/c_api.h" +#include "tensorflow/c/c_api_experimental.h" +#include "tensorflow/c/eager/c_api.h" +#include "tensorflow/c/eager/c_api_experimental.h" +#include "tensorflow/core/platform/test.h" + +// NOTE(allenl): These tests currently go through TFE_Execute and so are +// integration testing rather than purely testing the parallel device. They +// correspond fairly well to the implementation, but testing the C++ directly is +// another option. + +// Functor for making unique_ptr to TFE_TensorHandle slightly more +// ergonomic. Using decltype(TFE_DeleteTensorHandle) in the unique_ptr's second +// template argument requires passing a function pointer to +// TFE_DeleteTensorHandle when constructing the unique_ptr. +class TensorHandleDeleter { + public: + void operator()(TFE_TensorHandle* to_delete) { + TFE_DeleteTensorHandle(to_delete); + } +}; + +using TensorHandlePtr = std::unique_ptr; + +// A helper for performing common operations on variables. A much more +// restricted stand-in for tf.Variable in Python. +class Variable { + public: + // Construct a Variable from a resource-dtype TFE_TensorHandle and an + // indication of the dtype of the variable's value. + // + // Note that creating this resource-dtype handle can fail, so `Create` is a + // separate static method which returns a status. + Variable(TFE_TensorHandle* handle, TF_DataType type) + : handle_(handle), type_(type) {} + + // Helper for constructing a resource handle and wrapping it in a `Variable` + // object. + static Variable* Create(TFE_Context* context, TF_DataType type, + const int64_t* dims, const int num_dims, + const char* device, TF_Status* status); + // Dereferences the backing buffer for the variable. Note that since this can + // fail (it runs operations), it must be called explicitly and the resulting + // `status` checked. + void Destroy(TFE_Context* context, TF_Status* status); + + // Reads from the variable. + TensorHandlePtr Read(TFE_Context* context, TF_Status* status); + // Assigns a new value to the variable. + void Assign(TFE_Context* context, TFE_TensorHandle* value, TF_Status* status); + // Adds `value` to the existing value of the variable. + void AssignAdd(TFE_Context* context, TFE_TensorHandle* value, + TF_Status* status); + + private: + // Helper for running any single-argument assignment ops (Assign, AssignAdd, + // AssignSub, ...). + void GeneralAssignment(const char* op_name, TFE_Context* context, + TFE_TensorHandle* value, TF_Status* status); + + // The a handle for the resource-dtype tensor pointing to the variable's + // buffer. + TFE_TensorHandle* handle_; + // The dtype of the variable's buffer (input dtype for assignments, output + // dtype of read operations). + TF_DataType type_; +}; + +Variable* Variable::Create(TFE_Context* context, TF_DataType type, + const int64_t* dims, const int num_dims, + const char* device, TF_Status* status) { + std::unique_ptr op( + TFE_NewOp(context, "VarHandleOp", status), TFE_DeleteOp); + if (TF_GetCode(status) != TF_OK) return nullptr; + TFE_OpSetAttrType(op.get(), "dtype", type); + TFE_OpSetAttrShape(op.get(), "shape", dims, num_dims, status); + TFE_OpSetAttrString(op.get(), "container", "", 0); + // Use the special GUID for no buffer sharing + // + // TODO(allenl): Should we provide a better API for this? AFAIK this is the + // only reasonable way to make variables with no aliasing using the eager C + // API. + std::string no_sharing = "cd2c89b7-88b7-44c8-ad83-06c2a9158347"; + TFE_OpSetAttrString(op.get(), "shared_name", no_sharing.c_str(), + no_sharing.length()); + TFE_OpSetDevice(op.get(), device, status); + if (TF_GetCode(status) != TF_OK) return nullptr; + TFE_TensorHandle* var_handle = nullptr; + int num_retvals = 1; + TFE_Execute(op.get(), &var_handle, &num_retvals, status); + if (TF_GetCode(status) != TF_OK) return nullptr; + return new Variable(var_handle, type); +} + +void Variable::Destroy(TFE_Context* context, TF_Status* status) { + // Free the backing buffer for the variable. + std::unique_ptr op( + TFE_NewOp(context, "DestroyResourceOp", status), &TFE_DeleteOp); + if (TF_GetCode(status) != TF_OK) return; + TFE_OpAddInput(op.get(), handle_, status); + if (TF_GetCode(status) != TF_OK) return; + const char* device = TFE_TensorHandleDeviceName(handle_, status); + if (TF_GetCode(status) != TF_OK) return; + TFE_OpSetDevice(op.get(), device, status); + if (TF_GetCode(status) != TF_OK) return; + int num_retvals = 0; + TFE_Execute(op.get(), nullptr, &num_retvals, status); + if (TF_GetCode(status) != TF_OK) return; + // Delete the variable handle itself. + TFE_DeleteTensorHandle(handle_); +} + +TensorHandlePtr Variable::Read(TFE_Context* context, TF_Status* status) { + std::unique_ptr op( + TFE_NewOp(context, "ReadVariableOp", status), &TFE_DeleteOp); + if (TF_GetCode(status) != TF_OK) return nullptr; + TFE_OpAddInput(op.get(), handle_, status); + if (TF_GetCode(status) != TF_OK) return nullptr; + const char* device = TFE_TensorHandleDeviceName(handle_, status); + if (TF_GetCode(status) != TF_OK) return nullptr; + TFE_OpSetDevice(op.get(), device, status); + if (TF_GetCode(status) != TF_OK) return nullptr; + TFE_OpSetAttrType(op.get(), "dtype", type_); + int num_retvals = 1; + TFE_TensorHandle* var_value = nullptr; + TFE_Execute(op.get(), &var_value, &num_retvals, status); + if (TF_GetCode(status) != TF_OK) return nullptr; + return TensorHandlePtr(var_value); +} + +void Variable::GeneralAssignment(const char* op_name, TFE_Context* context, + TFE_TensorHandle* value, TF_Status* status) { + std::unique_ptr op( + TFE_NewOp(context, op_name, status), &TFE_DeleteOp); + if (TF_GetCode(status) != TF_OK) return; + TFE_OpSetAttrType(op.get(), "dtype", type_); + TFE_OpAddInput(op.get(), handle_, status); + if (TF_GetCode(status) != TF_OK) return; + TFE_OpAddInput(op.get(), value, status); + if (TF_GetCode(status) != TF_OK) return; + const char* device = TFE_TensorHandleDeviceName(handle_, status); + if (TF_GetCode(status) != TF_OK) return; + TFE_OpSetDevice(op.get(), device, status); + + int num_retvals = 0; + TFE_Execute(op.get(), nullptr, &num_retvals, status); + if (TF_GetCode(status) != TF_OK) return; +} + +void Variable::AssignAdd(TFE_Context* context, TFE_TensorHandle* value, + TF_Status* status) { + GeneralAssignment("AssignAddVariableOp", context, value, status); +} + +void Variable::Assign(TFE_Context* context, TFE_TensorHandle* value, + TF_Status* status) { + GeneralAssignment("AssignVariableOp", context, value, status); +} + +// Passed to `TF_NewTensor` to indicate how an array of floats should be +// deleted. +static void FloatDeallocator(void* data, size_t, void* arg) { + delete[] static_cast(data); +} + +// Creates a TFE_TensorHandle with value `v`. +TensorHandlePtr FloatTensorHandle(float v, TF_Status* status) { + const int num_bytes = sizeof(float); + float* values = new float[1]; + values[0] = v; + std::unique_ptr tensor( + TF_NewTensor(TF_FLOAT, nullptr, 0, values, num_bytes, &FloatDeallocator, + nullptr), + TF_DeleteTensor); + return TensorHandlePtr(TFE_NewTensorHandle(tensor.get(), status)); +} + +// Creates a rank-one TFE_TensorHandle with value `v`. +TensorHandlePtr VectorFloatTensorHandle(const std::vector& v, + TF_Status* status) { + const int num_bytes = v.size() * sizeof(float); + float* values = new float[v.size()]; + memcpy(values, v.data(), num_bytes); + int64_t dims = v.size(); + std::unique_ptr tensor( + TF_NewTensor(TF_FLOAT, &dims, 1 /* num_dims */, values, num_bytes, + &FloatDeallocator, nullptr), + TF_DeleteTensor); + return TensorHandlePtr(TFE_NewTensorHandle(tensor.get(), status)); +} + +// Helper to un-pack `num_replicas` TFE_TensorHandles from one parallel handle. +template +void ExtractPerDeviceValues( + TFE_Context* context, TFE_TensorHandle* input, + std::array* components, TF_Status* status) { + std::unique_ptr op( + TFE_NewOp(context, "TPUReplicatedOutput", status), TFE_DeleteOp); + if (TF_GetCode(status) != TF_OK) return; + TFE_OpSetAttrInt(op.get(), "num_replicas", num_replicas); + TFE_OpAddInput(op.get(), input, status); + if (TF_GetCode(status) != TF_OK) return; + const char* device = TFE_TensorHandleDeviceName(input, status); + if (TF_GetCode(status) != TF_OK) return; + TFE_OpSetDevice(op.get(), device, status); + if (TF_GetCode(status) != TF_OK) return; + + TFE_TensorHandle* result_handles[num_replicas]; + int num_retvals = num_replicas; + TFE_Execute(op.get(), result_handles, &num_retvals, status); + if (TF_GetCode(status) != TF_OK) return; + for (int i = 0; i < num_replicas; ++i) { + (*components)[i].reset(result_handles[i]); + } +} + +// Helper to pack `num_replicas` TFE_TensorHandles into one parallel handle. +template +TensorHandlePtr CreatePerDeviceValues( + TFE_Context* context, + const std::array& components, + const char* device, TF_Status* status) { + std::unique_ptr op( + TFE_NewOp(context, "TPUReplicatedInput", status), TFE_DeleteOp); + if (TF_GetCode(status) != TF_OK) return nullptr; + TFE_OpSetAttrInt(op.get(), "N", num_replicas); + for (int i = 0; i < num_replicas; ++i) { + TFE_OpAddInput(op.get(), components[i], status); + if (TF_GetCode(status) != TF_OK) return nullptr; + } + TFE_OpSetDevice(op.get(), device, status); + if (TF_GetCode(status) != TF_OK) return nullptr; + + TFE_TensorHandle* result_handle; + int num_retvals = 1; + TFE_Execute(op.get(), &result_handle, &num_retvals, status); + if (TF_GetCode(status) != TF_OK) return nullptr; + return TensorHandlePtr(result_handle); +} + +TensorHandlePtr Multiply(TFE_Context* context, TFE_TensorHandle* first, + TFE_TensorHandle* second, TF_Status* status) { + std::unique_ptr op( + TFE_NewOp(context, "Mul", status), TFE_DeleteOp); + if (TF_GetCode(status) != TF_OK) return nullptr; + TFE_OpAddInput(op.get(), first, status); + if (TF_GetCode(status) != TF_OK) return nullptr; + TFE_OpAddInput(op.get(), second, status); + if (TF_GetCode(status) != TF_OK) return nullptr; + const char* first_device = TFE_TensorHandleDeviceName(first, status); + if (TF_GetCode(status) != TF_OK) return nullptr; + TFE_OpSetDevice(op.get(), first_device, status); + + TFE_TensorHandle* result_handle; + int num_retvals = 1; + TFE_Execute(op.get(), &result_handle, &num_retvals, status); + if (TF_GetCode(status) != TF_OK) return nullptr; + return TensorHandlePtr(result_handle); +} + +// Assert that `handle` is equal to `expected_value`. +void AssertScalarFloatEq(TFE_TensorHandle* handle, float expected_value) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + std::unique_ptr value_zero( + TFE_TensorHandleResolve(handle, status.get()), TF_DeleteTensor); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + ASSERT_EQ(expected_value, + *static_cast(TF_TensorData(value_zero.get()))); +} + +// Create and modify a variable placed on a parallel device which composes +// `first_device` and `second_device`. +void BasicTestsForTwoDevices(TFE_Context* context, const char* first_device, + const char* second_device) { + // Register the custom device + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0"; + std::array underlying_devices{first_device, second_device}; + tensorflow::eager::RegisterParallelDevice( + context, device_name, underlying_devices.data(), + underlying_devices.size(), status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + + // Create a variable handle (uninitialized to start) placed on the parallel + // device. + std::function variable_deleter = [&](Variable* to_delete) { + to_delete->Destroy(context, status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + delete to_delete; + }; + std::unique_ptr variable( + Variable::Create(context, TF_FLOAT, /* Scalar */ {}, 0, device_name, + status.get()), + variable_deleter); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + + // Assign an initial value to the variable, implicitly mirroring it to each + // component device. + { + TensorHandlePtr initial_value = FloatTensorHandle(20., status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + + variable->Assign(context, initial_value.get(), status.get()); + } + + // Read from the variable and verify that we have a parallel tensor. + { + TensorHandlePtr read = variable->Read(context, status.get()); + std::array components; + ExtractPerDeviceValues(context, read.get(), &components, status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + + AssertScalarFloatEq(components[0].get(), 20.); + AssertScalarFloatEq(components[1].get(), 20.); + + std::string first_device = + TFE_TensorHandleBackingDeviceName(components[0].get(), status.get()); + ASSERT_EQ(underlying_devices[0], first_device); + std::string second_device = + TFE_TensorHandleBackingDeviceName(components[1].get(), status.get()); + ASSERT_EQ(underlying_devices[1], second_device); + } + + // Add a parallel tensor with different values on each device to the variable. + { + TensorHandlePtr value_one(FloatTensorHandle(3., status.get())); + TensorHandlePtr value_two(FloatTensorHandle(-2., status.get())); + std::array components{value_one.get(), + value_two.get()}; + TensorHandlePtr combined_value = + CreatePerDeviceValues(context, components, device_name, status.get()); + variable->AssignAdd(context, combined_value.get(), status.get()); + } + + // Read the variable and verify that each component has the right modified + // value. + { + TensorHandlePtr read = variable->Read(context, status.get()); + std::array components; + ExtractPerDeviceValues(context, read.get(), &components, status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + + AssertScalarFloatEq(components[0].get(), 23.); + AssertScalarFloatEq(components[1].get(), 18.); + + std::string first_device = + TFE_TensorHandleBackingDeviceName(components[0].get(), status.get()); + ASSERT_EQ(underlying_devices[0], first_device); + std::string second_device = + TFE_TensorHandleBackingDeviceName(components[1].get(), status.get()); + ASSERT_EQ(underlying_devices[1], second_device); + } +} + +TEST(PARALLEL_DEVICE, TestBasicCPU) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + std::unique_ptr opts( + TFE_NewContextOptions(), TFE_DeleteContextOptions); + std::unique_ptr config( + TF_CreateConfig( + /*xla*/ false, + /* gpu_memory_allow_growth */ true, /* num_cpu_devices */ + 2), + TF_DeleteBuffer); + TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length, + status.get()); + std::unique_ptr context( + TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + BasicTestsForTwoDevices(context.get(), + "/job:localhost/replica:0/task:0/device:CPU:0", + "/job:localhost/replica:0/task:0/device:CPU:1"); +} + +TEST(PARALLEL_DEVICE, TestBasicCPUAliased) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + std::unique_ptr opts( + TFE_NewContextOptions(), TFE_DeleteContextOptions); + std::unique_ptr context( + TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + BasicTestsForTwoDevices(context.get(), + "/job:localhost/replica:0/task:0/device:CPU:0", + "/job:localhost/replica:0/task:0/device:CPU:0"); +} + +TEST(PARALLEL_DEVICE, TestBasicTPUAliased) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + std::unique_ptr opts( + TFE_NewContextOptions(), TFE_DeleteContextOptions); + std::unique_ptr context( + TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + + // Skip the test if no TPU is available. + std::unique_ptr devices( + TFE_ContextListDevices(context.get(), status.get()), TF_DeleteDeviceList); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + bool has_tpu = false; + for (int device_index = 0; device_index < TF_DeviceListCount(devices.get()); + ++device_index) { + std::string device_type = + TF_DeviceListType(devices.get(), device_index, status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + if (device_type == "TPU") { + has_tpu = true; + break; + } + } + if (has_tpu) { + BasicTestsForTwoDevices(context.get(), + "/job:localhost/replica:0/task:0/device:TPU:0", + "/job:localhost/replica:0/task:0/device:TPU:0"); + } +} + +TEST(PARALLEL_DEVICE, TestExplicitCopies) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + std::unique_ptr opts( + TFE_NewContextOptions(), TFE_DeleteContextOptions); + std::unique_ptr config( + TF_CreateConfig( + /*xla*/ false, + /* gpu_memory_allow_growth */ true, /* num_cpu_devices */ + 2), + TF_DeleteBuffer); + TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length, + status.get()); + std::unique_ptr context( + TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + + const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0"; + std::vector underlying_devices; + const char* first_device_name = + "/job:localhost/replica:0/task:0/device:CPU:0"; + underlying_devices.push_back(first_device_name); + const char* second_device_name = + "/job:localhost/replica:0/task:0/device:CPU:1"; + underlying_devices.push_back(second_device_name); + tensorflow::eager::RegisterParallelDevice( + context.get(), device_name, underlying_devices.data(), + underlying_devices.size(), status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + + TensorHandlePtr cpu_value(FloatTensorHandle(3., status.get())); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + + // Copying on to a parallel device is OK. + TensorHandlePtr device_value(TFE_TensorHandleCopyToDevice( + cpu_value.get(), context.get(), device_name, status.get())); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + const char* backing_device = + TFE_TensorHandleBackingDeviceName(device_value.get(), status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + ASSERT_EQ(std::string(device_name), backing_device); + + // Un-pack the parallel tensor to verify that the copy was successful. + { + std::array components; + ExtractPerDeviceValues(context.get(), device_value.get(), &components, + status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + + // The value of the original tensor is replicated on each device. + AssertScalarFloatEq(components[0].get(), 3.); + AssertScalarFloatEq(components[1].get(), 3.); + + // Verify that the mirrors are placed on the component devices. + std::string first_device = + TFE_TensorHandleBackingDeviceName(components[0].get(), status.get()); + ASSERT_EQ(underlying_devices[0], first_device); + std::string second_device = + TFE_TensorHandleBackingDeviceName(components[1].get(), status.get()); + ASSERT_EQ(underlying_devices[1], second_device); + } + + // Copies off of parallel devices must be explicit. + TensorHandlePtr copy_back(TFE_TensorHandleCopyToDevice( + device_value.get(), context.get(), first_device_name, status.get())); + ASSERT_EQ(TF_GetCode(status.get()), TF_INTERNAL); +} + +TEST(PARALLEL_DEVICE, TestDifferentShapes) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + std::unique_ptr opts( + TFE_NewContextOptions(), TFE_DeleteContextOptions); + std::unique_ptr config( + TF_CreateConfig( + /*xla*/ false, + /* gpu_memory_allow_growth */ true, /* num_cpu_devices */ + 2), + TF_DeleteBuffer); + TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length, + status.get()); + std::unique_ptr context( + TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + + const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0"; + std::vector underlying_devices; + underlying_devices.push_back("/job:localhost/replica:0/task:0/device:CPU:0"); + underlying_devices.push_back("/job:localhost/replica:0/task:0/device:CPU:1"); + tensorflow::eager::RegisterParallelDevice( + context.get(), device_name, underlying_devices.data(), + underlying_devices.size(), status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + + // Create two vectors with different lengths + std::vector size_two_value{1., 2.}; + std::vector size_three_value{1., 2., 3.}; + TensorHandlePtr size_two( + VectorFloatTensorHandle(size_two_value, status.get())); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + TensorHandlePtr size_three( + VectorFloatTensorHandle(size_three_value, status.get())); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + + // Try to combine these values into a single parallel tensor. + std::array components{size_two.get(), size_three.get()}; + TensorHandlePtr combined_value = CreatePerDeviceValues( + context.get(), components, device_name, status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_UNIMPLEMENTED) + << TF_Message(status.get()); +} + +TEST(PARALLEL_DEVICE, TestNestedParallelDevices) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + std::unique_ptr opts( + TFE_NewContextOptions(), TFE_DeleteContextOptions); + std::unique_ptr config( + TF_CreateConfig( + /*xla*/ false, + /* gpu_memory_allow_growth */ true, /* num_cpu_devices */ + 3), + TF_DeleteBuffer); + TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length, + status.get()); + std::unique_ptr context( + TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + + // Create a parallel device with two CPUs + const char* first_device_name = + "/job:localhost/replica:0/task:0/device:CUSTOM:0"; + std::vector first_underlying_devices{ + "/job:localhost/replica:0/task:0/device:CPU:0", + "/job:localhost/replica:0/task:0/device:CPU:1"}; + tensorflow::eager::RegisterParallelDevice( + context.get(), first_device_name, first_underlying_devices.data(), + first_underlying_devices.size(), status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + + // Create a second parallel device with the first parallel device and one + // additional CPU. + const char* second_device_name = + "/job:localhost/replica:0/task:0/device:CUSTOM:1"; + std::vector second_underlying_devices{ + "/job:localhost/replica:0/task:0/device:CUSTOM:0", + "/job:localhost/replica:0/task:0/device:CPU:2"}; + tensorflow::eager::RegisterParallelDevice( + context.get(), second_device_name, second_underlying_devices.data(), + second_underlying_devices.size(), status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + + // Create a tensor on the first parallel device + TensorHandlePtr value_one(FloatTensorHandle(1., status.get())); + TensorHandlePtr value_two(FloatTensorHandle(2., status.get())); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + std::array components{value_one.get(), value_two.get()}; + TensorHandlePtr first_combined_value = CreatePerDeviceValues( + context.get(), components, first_device_name, status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + + // Nest the first parallel tensor into a second + TensorHandlePtr value_three(FloatTensorHandle(3., status.get())); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + components[0] = first_combined_value.get(); + components[1] = value_three.get(); + TensorHandlePtr second_combined_value = CreatePerDeviceValues( + context.get(), components, second_device_name, status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + + TensorHandlePtr negative_one(FloatTensorHandle(3., status.get())); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + TensorHandlePtr multiply_result(Multiply(context.get(), + second_combined_value.get(), + negative_one.get(), status.get())); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + + // Un-pack the parallel tensor to verify that the operation was + // successful. The resulting structure should be: + // second_device{first_device{1. * 3., 2. * 3.}, 3. * 3.}. + std::array second_components; + ExtractPerDeviceValues(context.get(), multiply_result.get(), + &second_components, status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + + AssertScalarFloatEq(second_components[1].get(), 9.); + + // Verify that the mirrors are placed on the component devices. + std::string first_device = TFE_TensorHandleBackingDeviceName( + second_components[0].get(), status.get()); + ASSERT_EQ(second_underlying_devices[0], first_device); + std::string second_device = TFE_TensorHandleBackingDeviceName( + second_components[1].get(), status.get()); + ASSERT_EQ(second_underlying_devices[1], second_device); + + // Un-pack the first parallel device's tensor too + std::array first_components; + ExtractPerDeviceValues(context.get(), second_components[0].get(), + &first_components, status.get()); + AssertScalarFloatEq(first_components[0].get(), 3.); + AssertScalarFloatEq(first_components[1].get(), 6.); + + first_device = TFE_TensorHandleBackingDeviceName(first_components[0].get(), + status.get()); + ASSERT_EQ(first_underlying_devices[0], first_device); + second_device = TFE_TensorHandleBackingDeviceName(first_components[1].get(), + status.get()); + ASSERT_EQ(first_underlying_devices[1], second_device); +} + +TEST(PARALLEL_DEVICE, TestInvalidPacking) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + std::unique_ptr opts( + TFE_NewContextOptions(), TFE_DeleteContextOptions); + std::unique_ptr context( + TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext); + const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0"; + std::vector underlying_devices; + underlying_devices.push_back("/job:localhost/replica:0/task:0/device:CPU:0"); + tensorflow::eager::RegisterParallelDevice( + context.get(), device_name, underlying_devices.data(), + underlying_devices.size(), status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + + TensorHandlePtr value_one(FloatTensorHandle(1., status.get())); + TensorHandlePtr value_two(FloatTensorHandle(2., status.get())); + { + // Try to pack two TensorHandles onto a parallel device with a single + // component. + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + std::array components{value_one.get(), + value_two.get()}; + TensorHandlePtr combined_value = CreatePerDeviceValues( + context.get(), components, device_name, status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_INVALID_ARGUMENT) + << TF_Message(status.get()); + } + + { + // Try to extract the wrong number of components from a parallel tensor + std::array correct_components{value_one.get()}; + TensorHandlePtr combined_value = CreatePerDeviceValues( + context.get(), correct_components, device_name, status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + + std::array incorrect_components; + ExtractPerDeviceValues(context.get(), combined_value.get(), + &incorrect_components, status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_INVALID_ARGUMENT) + << TF_Message(status.get()); + } + + { + // Try to pass a ParallelTensor to TPUReplicatedInput + std::array correct_components{value_one.get()}; + TensorHandlePtr combined_value = CreatePerDeviceValues( + context.get(), correct_components, device_name, status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + + std::array incorrect_components{combined_value.get()}; + TensorHandlePtr recombined_value = CreatePerDeviceValues( + context.get(), incorrect_components, device_name, status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_INVALID_ARGUMENT) + << TF_Message(status.get()); + } + + { + // Try to pass a non-parallel tensor to TPUReplicatedOutput + std::unique_ptr op( + TFE_NewOp(context.get(), "TPUReplicatedOutput", status.get()), + TFE_DeleteOp); + if (TF_GetCode(status.get()) != TF_OK) return; + TFE_OpSetAttrInt(op.get(), "num_replicas", 1); + TFE_OpAddInput(op.get(), value_one.get(), status.get()); + if (TF_GetCode(status.get()) != TF_OK) return; + TFE_OpSetDevice(op.get(), device_name, status.get()); + if (TF_GetCode(status.get()) != TF_OK) return; + + TFE_TensorHandle* result_handles; + int num_retvals = 1; + TFE_Execute(op.get(), &result_handles, &num_retvals, status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_INVALID_ARGUMENT) + << TF_Message(status.get()); + } +} + +TensorHandlePtr CollectiveSum(TFE_Context* context, TFE_TensorHandle* input, + int group_size, TF_Status* status) { + std::unique_ptr op( + TFE_NewOp(context, "CollectiveReduce", status), TFE_DeleteOp); + if (TF_GetCode(status) != TF_OK) return nullptr; + + const char* device = TFE_TensorHandleDeviceName(input, status); + if (TF_GetCode(status) != TF_OK) return nullptr; + TFE_OpSetDevice(op.get(), device, status); + if (TF_GetCode(status) != TF_OK) return nullptr; + TFE_OpSetAttrType(op.get(), "T", TFE_TensorHandleDataType(input)); + TFE_OpSetAttrInt(op.get(), "group_size", group_size); + TFE_OpSetAttrInt(op.get(), "group_key", 0); + TFE_OpSetAttrInt(op.get(), "instance_key", 0); + const std::string merge_op("Add"); + TFE_OpSetAttrString(op.get(), "merge_op", merge_op.c_str(), + merge_op.length()); + const std::string final_op("Id"); + TFE_OpSetAttrString(op.get(), "final_op", final_op.c_str(), + final_op.length()); + TFE_OpSetAttrIntList(op.get(), "subdiv_offsets", nullptr, 0); + + TFE_OpAddInput(op.get(), input, status); + if (TF_GetCode(status) != TF_OK) return nullptr; + + TFE_TensorHandle* result_handle; + int num_retvals = 1; + TFE_Execute(op.get(), &result_handle, &num_retvals, status); + if (TF_GetCode(status) != TF_OK) return nullptr; + return TensorHandlePtr(result_handle); +} + +TEST(PARALLEL_DEVICE, TestCollective) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + std::unique_ptr opts( + TFE_NewContextOptions(), TFE_DeleteContextOptions); + std::unique_ptr config( + TF_CreateConfig( + /*xla*/ false, + /* gpu_memory_allow_growth */ true, /* num_cpu_devices */ + 2), + TF_DeleteBuffer); + TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length, + status.get()); + std::unique_ptr context( + TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + + const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0"; + std::vector underlying_devices; + underlying_devices.push_back("/job:localhost/replica:0/task:0/device:CPU:0"); + underlying_devices.push_back("/job:localhost/replica:0/task:0/device:CPU:1"); + tensorflow::eager::RegisterParallelDevice( + context.get(), device_name, underlying_devices.data(), + underlying_devices.size(), status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + + // Create a tensor on the parallel device + TensorHandlePtr value_one(FloatTensorHandle(1., status.get())); + TensorHandlePtr value_two(FloatTensorHandle(2., status.get())); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + std::array components{value_one.get(), value_two.get()}; + TensorHandlePtr parallel_value = CreatePerDeviceValues( + context.get(), components, device_name, status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + + // Run a collective sum, so each component should now be the same. + TensorHandlePtr reduced( + CollectiveSum(context.get(), parallel_value.get(), 2, status.get())); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + + std::array result_components; + ExtractPerDeviceValues(context.get(), reduced.get(), &result_components, + status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + AssertScalarFloatEq(result_components[0].get(), 3.); + AssertScalarFloatEq(result_components[1].get(), 3.); +} + +void RegisterCollectiveMulFunction(TFE_Context* context, + const char* function_name, int group_size, + TF_Status* status) { + std::unique_ptr body(TF_NewGraph(), + TF_DeleteGraph); + TF_OperationDescription* placeholder_desc = + TF_NewOperation(body.get(), "Placeholder", "Placeholder"); + TF_SetAttrType(placeholder_desc, "dtype", TF_FLOAT); + TF_Operation* placeholder_op = TF_FinishOperation(placeholder_desc, status); + if (TF_GetCode(status) != TF_OK) return; + TF_Output x{placeholder_op, 0}; + + TF_OperationDescription* reduce_desc = + TF_NewOperation(body.get(), "CollectiveReduce", "CollectiveReduce"); + TF_SetAttrType(reduce_desc, "T", TF_FLOAT); + TF_SetAttrInt(reduce_desc, "group_size", group_size); + TF_SetAttrInt(reduce_desc, "group_key", 0); + TF_SetAttrInt(reduce_desc, "instance_key", 0); + + const std::string merge_op("Mul"); + TF_SetAttrString(reduce_desc, "merge_op", merge_op.c_str(), + merge_op.length()); + const std::string final_op("Id"); + TF_SetAttrString(reduce_desc, "final_op", final_op.c_str(), + final_op.length()); + TF_SetAttrIntList(reduce_desc, "subdiv_offsets", nullptr, 0); + TF_AddInput(reduce_desc, x); + TF_Operation* reduce_op = TF_FinishOperation(reduce_desc, status); + if (TF_GetCode(status) != TF_OK) return; + TF_Operation* operations[]{placeholder_op, reduce_op}; + TF_Output y{reduce_op, 0}; + const char* output_name = "y"; + std::unique_ptr function( + TF_GraphToFunction( + /* fn_body */ body.get(), /* fn_name */ function_name, + /* append_hash_to_fn_name */ 0, /* num_opers */ 2, + /* opers */ operations, /* ninputs */ 1, /* inputs */ &x, + /* noutputs */ 1, /* outputs */ &y, /* output_names */ &output_name, + /* opts */ nullptr, /* description */ "", /* status */ status), + TF_DeleteFunction); + if (TF_GetCode(status) != TF_OK) return; + TFE_ContextAddFunction(context, function.get(), status); +} + +TEST(PARALLEL_DEVICE, TestFunction) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + std::unique_ptr opts( + TFE_NewContextOptions(), TFE_DeleteContextOptions); + std::unique_ptr config( + TF_CreateConfig( + /*xla*/ false, + /* gpu_memory_allow_growth */ true, /* num_cpu_devices */ + 2), + TF_DeleteBuffer); + TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length, + status.get()); + std::unique_ptr context( + TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + + const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0"; + std::vector underlying_devices; + underlying_devices.push_back("/job:localhost/replica:0/task:0/device:CPU:0"); + underlying_devices.push_back("/job:localhost/replica:0/task:0/device:CPU:1"); + tensorflow::eager::RegisterParallelDevice( + context.get(), device_name, underlying_devices.data(), + underlying_devices.size(), status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + + const char* function_name = "test_reduce_mul"; + RegisterCollectiveMulFunction(context.get(), function_name, 2, status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + + TensorHandlePtr value_one(FloatTensorHandle(7., status.get())); + TensorHandlePtr value_two(FloatTensorHandle(9., status.get())); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + std::array components{value_one.get(), value_two.get()}; + TensorHandlePtr parallel_value = CreatePerDeviceValues( + context.get(), components, device_name, status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + + std::unique_ptr op( + TFE_NewOp(context.get(), function_name, status.get()), TFE_DeleteOp); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + TFE_OpSetDevice(op.get(), device_name, status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + TFE_OpAddInput(op.get(), parallel_value.get(), status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + + TFE_TensorHandle* raw_result_handle; + int num_retvals = 1; + TFE_Execute(op.get(), &raw_result_handle, &num_retvals, status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + TensorHandlePtr reduced(raw_result_handle); + + std::array result_components; + ExtractPerDeviceValues(context.get(), reduced.get(), &result_components, + status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + AssertScalarFloatEq(result_components[0].get(), 7. * 9.); + AssertScalarFloatEq(result_components[1].get(), 7. * 9.); + + std::string first_device = TFE_TensorHandleBackingDeviceName( + result_components[0].get(), status.get()); + ASSERT_EQ(underlying_devices[0], first_device); + std::string second_device = TFE_TensorHandleBackingDeviceName( + result_components[1].get(), status.get()); + ASSERT_EQ(underlying_devices[1], second_device); +} diff --git a/tensorflow/compiler/aot/BUILD b/tensorflow/compiler/aot/BUILD index 7f1590ff75d..fd4ae10595b 100644 --- a/tensorflow/compiler/aot/BUILD +++ b/tensorflow/compiler/aot/BUILD @@ -16,6 +16,12 @@ cc_library( deps = ["//tensorflow/core:test_main"], ) +filegroup( + name = "quantize_header", + srcs = ["quantize.h"], + visibility = ["//visibility:public"], +) + cc_library( name = "tfcompile_lib", srcs = [ @@ -27,6 +33,7 @@ cc_library( "codegen.h", "compile.h", "flags.h", + "quantize.h", ], defines = if_llvm_aarch64_available(["TF_LLVM_AARCH64_AVAILABLE=1"]), visibility = ["//tensorflow/python:__pkg__"], @@ -37,7 +44,6 @@ cc_library( "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", - "//tensorflow/compiler/mlir/lite/quantization/xla:quantize", "//tensorflow/compiler/tf2xla", "//tensorflow/compiler/tf2xla:mlir_tf2xla", "//tensorflow/compiler/tf2xla:tf2xla_proto_cc", diff --git a/tensorflow/compiler/aot/compile.cc b/tensorflow/compiler/aot/compile.cc index f83cd45f9f3..a2cba5cdf9e 100644 --- a/tensorflow/compiler/aot/compile.cc +++ b/tensorflow/compiler/aot/compile.cc @@ -24,7 +24,7 @@ limitations under the License. #include "llvm-c/Target.h" #include "tensorflow/compiler/aot/codegen.h" #include "tensorflow/compiler/aot/flags.h" -#include "tensorflow/compiler/mlir/lite/quantization/xla/quantize.h" +#include "tensorflow/compiler/aot/quantize.h" #include "tensorflow/compiler/tf2xla/tf2xla.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/xla/client/client_library.h" @@ -46,6 +46,14 @@ limitations under the License. namespace tensorflow { namespace tfcompile { +static llvm::ManagedStatic quantize_xla; + +bool RegisterQuantizeFn(const QuantizeXlaFn& fn) { + if (*quantize_xla) return false; + *quantize_xla = fn; + return true; +} + namespace { // Compiles the XLA computation into executable code. @@ -116,9 +124,11 @@ Status CompileGraph(GraphDef graph_def, const tf2xla::Config& config, } else { return errors::Unknown("Unknown mlir_components ", flags.mlir_components); } - if (flags.experimental_quantize) { - TF_RETURN_IF_ERROR(mlir::xla_hlo::XlaQuantize(config, &computation)); + + if (flags.experimental_quantize && *quantize_xla) { + TF_RETURN_IF_ERROR((*quantize_xla)(config, &computation)); } + if (!flags.out_session_module.empty()) { TF_ASSIGN_OR_RETURN(std::unique_ptr module, computation.Snapshot()); diff --git a/tensorflow/compiler/mlir/lite/quantization/xla/quantize.h b/tensorflow/compiler/aot/quantize.h similarity index 55% rename from tensorflow/compiler/mlir/lite/quantization/xla/quantize.h rename to tensorflow/compiler/aot/quantize.h index 2ec5dbb02ce..add05bd0422 100644 --- a/tensorflow/compiler/mlir/lite/quantization/xla/quantize.h +++ b/tensorflow/compiler/aot/quantize.h @@ -13,21 +13,29 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_XLA_QUANTIZE_H_ -#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_XLA_QUANTIZE_H_ +#ifndef TENSORFLOW_COMPILER_AOT_QUANTIZE_H_ +#define TENSORFLOW_COMPILER_AOT_QUANTIZE_H_ + +#include +#include +#include #include "tensorflow/compiler/tf2xla/tf2xla.pb.h" #include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/status.h" -namespace mlir { -namespace xla_hlo { +namespace tensorflow { +namespace tfcompile { -// Quantizes the model in the computation. -tensorflow::Status XlaQuantize(const tensorflow::tf2xla::Config& config, - xla::XlaComputation* computation); +using QuantizeXlaFn = std::function; -} // namespace xla_hlo -} // namespace mlir +// Set the static quantization function to the `fn` if it hasn't been set. +// Return false if the static function has been set. +bool RegisterQuantizeFn(const QuantizeXlaFn& fn); -#endif // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_XLA_QUANTIZE_H_ +} // namespace tfcompile +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_AOT_QUANTIZE_H_ diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc index 5081df28a08..b51749bc332 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.cc +++ b/tensorflow/compiler/jit/xla_compilation_cache.cc @@ -296,10 +296,10 @@ Status XlaCompilationCache::CompileSingleOp( arg_shapes.push_back(absl::get(arg.shape)); } GraphDebugInfo debug_info; - return CompileGraphToXlaHlo(*graph, {arg_shapes.data(), arg_shapes.size()}, - compile_options.use_tuple_arg, - *options.flib_def, debug_info, - options.shape_representation_fn, result); + return CompileGraphToXlaHlo( + *graph, {arg_shapes.data(), arg_shapes.size()}, + options.device_type.type_string(), compile_options.use_tuple_arg, + *options.flib_def, debug_info, options.shape_representation_fn, result); }; return CompileImpl(options, name, args, compile_op, /*compile_threshold=*/absl::nullopt, diff --git a/tensorflow/compiler/mlir/BUILD b/tensorflow/compiler/mlir/BUILD index 63546db1eb0..046ae607438 100644 --- a/tensorflow/compiler/mlir/BUILD +++ b/tensorflow/compiler/mlir/BUILD @@ -58,6 +58,11 @@ cc_library( "//tensorflow/python:__subpackages__", ], deps = [ + "@llvm-project//mlir:Affine", + "@llvm-project//mlir:QuantOps", + # Link jit lib to link JIT devices required to run + # xla-legalize-tf-with-tf2xla pass. + "//tensorflow/compiler/jit", "//tensorflow/compiler/mlir/lite:tensorflow_lite", "//tensorflow/compiler/mlir/lite:tensorflow_lite_dialect_registration", "//tensorflow/compiler/mlir/lite:tensorflow_lite_legalize_tf", @@ -65,7 +70,6 @@ cc_library( "//tensorflow/compiler/mlir/lite:tensorflow_lite_quantize", "//tensorflow/compiler/mlir/lite/quantization:quantization_passes", "//tensorflow/compiler/mlir/lite/quantization/tensorflow:tf_to_quant", - "//tensorflow/compiler/mlir/lite/quantization/xla:hlo_xla_quantization_passes", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration", "//tensorflow/compiler/mlir/tensorflow:tensorflow_passes", @@ -90,8 +94,6 @@ cc_library( "//tensorflow/compiler/mlir/xla:xla_lower", "//tensorflow/compiler/mlir/xla:xla_materialize_broadcasts", "//tensorflow/compiler/mlir/xla:xla_test_passes", - "@llvm-project//mlir:Affine", - "@llvm-project//mlir:QuantOps", ], ) diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD index 789d06b8ac9..c4042aad12e 100644 --- a/tensorflow/compiler/mlir/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/BUILD @@ -307,7 +307,7 @@ cc_library( "transforms/optimize_functional_ops.cc", "transforms/prepare_composite_functions_tf.cc", "transforms/prepare_tf.cc", - "transforms/runtime_type_verify.cc", + "transforms/runtime_verify.cc", "transforms/split_merged_operands.cc", "transforms/trim_functions_tf.cc", "transforms/while_loop_outline.cc", @@ -537,6 +537,15 @@ tf_native_cc_binary( ], ) +tf_native_cc_binary( + name = "json_to_flatbuffer", + srcs = ["json_to_flatbuffer.cc"], + deps = [ + "//tensorflow/lite/schema:schema_fbs", + "@flatbuffers", + ], +) + cc_library( name = "emit_error_reporter", srcs = [ diff --git a/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h b/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h index 82d058964cb..2ed63fcc794 100644 --- a/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h +++ b/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h @@ -36,7 +36,8 @@ struct PassConfig { form_clusters(false), unfold_batch_matmul(true), legalize_tf_while(true), - shape_inference(true) {} + shape_inference(true), + runtime_verification(true) {} // If `emit_builtin_tflite_ops` is true, TF Lite legalization passes will be // added, which produces TF Lite ops. @@ -65,6 +66,8 @@ struct PassConfig { bool legalize_tf_while; // Whether to do shape inference. bool shape_inference; + // Whether to do TFLite runtime verification. + bool runtime_verification; }; } // namespace TFL diff --git a/tensorflow/compiler/mlir/lite/converter_gen.cc b/tensorflow/compiler/mlir/lite/converter_gen.cc index bc894d36e75..f1ed97cb8e7 100644 --- a/tensorflow/compiler/mlir/lite/converter_gen.cc +++ b/tensorflow/compiler/mlir/lite/converter_gen.cc @@ -441,7 +441,7 @@ static bool RuntimeVerifierWriterMain(raw_ostream &os, RecordKeeper &records) { mlir::tblgen::FmtContext verify_ctx; os << "::mlir::LogicalResult " << op.getCppClassName() - << "::VerifyTflRuntimeTypes(::mlir::Operation *op, bool " + << "::VerifyTflRuntimeConstraints(::mlir::Operation *op, bool " "failure_on_operand_type_mismatch) {\n"; os << " auto top = cast<" << op.getCppClassName() << ">(op); (void)top;\n"; verify_ctx.withOp("top"); @@ -466,6 +466,25 @@ static bool RuntimeVerifierWriterMain(raw_ostream &os, RecordKeeper &records) { "operand"); GenOperandResultVerifier(os, def->getValueAsDag("results")->getArgs(), "result"); + + for (auto &trait : op.getTraits()) { + if (!trait.getDef().isSubClassOf("GenInternalOpTrait")) { + continue; + } + if (trait.getDef().getValueAsString("trait") != + "OpTrait::TFLRuntimeOpTrait") { + continue; + } + + auto *val = trait.getDef().getValue("tflRuntimePredicate"); + if (!val) continue; + + mlir::tblgen::Pred pred(dyn_cast(val->getValue())); + os << tgfmt( + " if (!($0)) {\n " + " return ::mlir::LogicalResult::Failure;\n }\n", + &verify_ctx, tgfmt(pred.getCondition(), &verify_ctx)); + } os << " return top.verify();\n}\n"; } diff --git a/tensorflow/compiler/mlir/lite/experimental/estimators/gpu_estimators.h b/tensorflow/compiler/mlir/lite/experimental/estimators/gpu_estimators.h index 4c9a51c0351..ecc90106ced 100644 --- a/tensorflow/compiler/mlir/lite/experimental/estimators/gpu_estimators.h +++ b/tensorflow/compiler/mlir/lite/experimental/estimators/gpu_estimators.h @@ -16,6 +16,19 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_GPU_ESTIMATORS_H_ #define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_GPU_ESTIMATORS_H_ +// tfl.abs +template <> +class TFLiteCostEstimator { + public: + static double GetCost(mlir::Operation* op) { + llvm::errs() << "No defined cost function for op: " + << op->getName().getStringRef().str(); + return 0.0; + } + + static bool IsSupported(mlir::Operation* op) { return true; } +}; + // tfl.add template <> class TFLiteCostEstimator { @@ -149,6 +162,19 @@ class TFLiteCostEstimator { static bool IsSupported(mlir::Operation* op) { return true; } }; +// tfl.log +template <> +class TFLiteCostEstimator { + public: + static double GetCost(mlir::Operation* op) { + llvm::errs() << "No defined cost function for op: " + << op->getName().getStringRef().str(); + return 0.0; + } + + static bool IsSupported(mlir::Operation* op) { return true; } +}; + // tfl.logistic template <> class TFLiteCostEstimator { @@ -240,6 +266,32 @@ class TFLiteCostEstimator { static bool IsSupported(mlir::Operation* op) { return true; } }; +// tfl.pow +template <> +class TFLiteCostEstimator { + public: + static double GetCost(mlir::Operation* op) { + llvm::errs() << "No defined cost function for op: " + << op->getName().getStringRef().str(); + return 0.0; + } + + static bool IsSupported(mlir::Operation* op) { return true; } +}; + +// tfl.prelu +template <> +class TFLiteCostEstimator { + public: + static double GetCost(mlir::Operation* op) { + llvm::errs() << "No defined cost function for op: " + << op->getName().getStringRef().str(); + return 0.0; + } + + static bool IsSupported(mlir::Operation* op) { return true; } +}; + // tfl.relu template <> class TFLiteCostEstimator { diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_op_interfaces.td b/tensorflow/compiler/mlir/lite/ir/tfl_op_interfaces.td index b20e81aefa9..ccad3cbb79e 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_op_interfaces.td +++ b/tensorflow/compiler/mlir/lite/ir/tfl_op_interfaces.td @@ -86,7 +86,7 @@ def TFL_RuntimeVerification : OpInterface<"TflRuntimeVerifyOpInterface"> { let methods = [ StaticInterfaceMethod< [{Returns whether the op's operands/results are supported by runtime.}], - "LogicalResult", "VerifyTflRuntimeTypes", + "LogicalResult", "VerifyTflRuntimeConstraints", (ins "Operation*":$op, "bool":$failure_on_operand_type_mismatch) >, ]; diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc index 45efe8f72f7..47a7b32d7e3 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc @@ -46,6 +46,30 @@ namespace mlir { #include "tensorflow/compiler/mlir/lite/ir/tfl_structs.cc.inc" namespace TFL { +// Returns true when the given two types have the same shape or broadcastable +// shape within the given rank. If any given shapes are non-static, this method +// returns true. +bool IsBinaryOperandsHaveSameShapesOrBroadcastableShape(Type lhs, Type rhs, + int max_bcast_rank) { + // Ignore shape checking on the non-static shapes for model compatibility. + auto lhs_shaped_type = lhs.dyn_cast(); + if (!lhs_shaped_type || !lhs_shaped_type.hasStaticShape()) return true; + auto rhs_shaped_type = rhs.dyn_cast(); + if (!rhs_shaped_type || !rhs_shaped_type.hasStaticShape()) return true; + + if (lhs_shaped_type.getShape().equals(rhs_shaped_type.getShape())) + return true; + + SmallVector result_shape; + if (!OpTrait::util::getBroadcastedShape(lhs_shaped_type.getShape(), + rhs_shaped_type.getShape(), + result_shape)) { + return false; + } + return lhs_shaped_type.getRank() <= max_bcast_rank && + rhs_shaped_type.getRank() <= max_bcast_rank; +} + //===----------------------------------------------------------------------===// // TensorFlowLiteDialect //===----------------------------------------------------------------------===// @@ -316,7 +340,7 @@ Attribute ConstFoldUnaryOp(Type result_type, Attribute operand, const int num_elements = result_shape_type.getNumElements(); new_values.reserve(num_elements); - for (APFloat old_value : dense_elements.getValues()) { + for (const APFloat &old_value : dense_elements.getValues()) { new_values.push_back(calculate(old_value)); } @@ -844,7 +868,7 @@ OpFoldResult ReshapeOp::fold(ArrayRef operands) { if (!shape_elements) return nullptr; SmallVector shape_data; - for (auto it : shape_elements.getValues()) { + for (const auto &it : shape_elements.getValues()) { shape_data.push_back(it.getSExtValue()); } result_type = @@ -1798,7 +1822,7 @@ static LogicalResult Verify(TransposeOp op) { int index = 0; llvm::SmallVector axes; - for (auto axis_int : perm.getValues()) { + for (const auto &axis_int : perm.getValues()) { const int64_t axis = axis_int.getSExtValue(); if (axis < 0 || (input_type.hasRank() && axis >= input_type.getRank())) { return op.emitOpError( diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td index 2bf6ca2ab89..519bd9dbfc0 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td @@ -106,6 +106,22 @@ class DerivedShapeAttr : DerivedAttr<"ArrayRef", body>; class DerivedTFLiteTypeAttr : DerivedAttr<"tflite::TensorType", body>; +// TFL Runtime op trait predicate. +class TFL_RuntimePredOpTrait : + GenInternalOpTrait<"TFLRuntimeOpTrait"> { + Pred tflRuntimePredicate = pred; + string tflRuntimeDescription = desc; +} + +class TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape< + int i, int j, int max_bcast_rank> : + TFL_RuntimePredOpTrait<"operand #" # i # " and operand #" # j # + " have the same shape or broadcastable shapes within the rank " # + max_bcast_rank, + CPred<"TFL::IsBinaryOperandsHaveSameShapesOrBroadcastableShape(" + "$_op.getOperand(" # i # ").getType(), $_op.getOperand(" # j # + ").getType(), " # max_bcast_rank # ")">>; + // These additional types/type constraints here are used to decouple the ops // from runtime support for the ops. Prefer to use these types when defining // new TF_Ops for uniformity. @@ -344,7 +360,10 @@ class TFL_ConvOp : // TFL op definitions. //===----------------------------------------------------------------------===// def TFL_AbsOp : TFL_Op<"abs", [ - NoSideEffect, SameOperandsAndResultType, NoQuantizableResult]> { + NoSideEffect, + SameOperandsAndResultType, + NoQuantizableResult, + TFL_GpuTargetOp]> { let summary = "Absolute value operator"; let description = [{ @@ -360,10 +379,9 @@ an output element, this operation computes \\(y = |x|\\). let hasFolder = 1; } -def TFL_AddOp : TFL_Op<"add", [ResultsBroadcastableShape, - NoSideEffect, - Commutative, - TFL_GpuTargetOp]> { +def TFL_AddOp : TFL_Op<"add", [ + TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<0, 1, 5>, + ResultsBroadcastableShape, NoSideEffect, Commutative, TFL_GpuTargetOp]> { let summary = "Addition operator"; let description = [{ @@ -371,11 +389,11 @@ def TFL_AddOp : TFL_Op<"add", [ResultsBroadcastableShape, }]; let arguments = ( - ins AnyTensor:$lhs, - AnyTensor:$rhs, + ins TFL_TensorOf<[F32, I32, QI8, QUI8, QI16]>:$lhs, + TFL_TensorOf<[F32, I32, QI8, QUI8, QI16]>:$rhs, TFL_AFAttr:$fused_activation_function); - let results = (outs AnyTensor:$output); + let results = (outs TFL_TensorOf<[F32, I32, QI8, QUI8, QI16]>:$output); let hasFolder = 1; @@ -1527,7 +1545,10 @@ def TFL_LogisticOp: TFL_Op<"logistic", [ } def TFL_LogOp: TFL_Op<"log", [ - NoSideEffect, SameOperandsAndResultType, NoQuantizableResult]> { + NoSideEffect, + SameOperandsAndResultType, + NoQuantizableResult, + TFL_GpuTargetOp]> { let summary = "Natural logarithm operator"; let description = [{ @@ -2072,7 +2093,10 @@ def TFL_PadV2Op : TFL_Op<"padv2", [ let hasOptions = 1; } -def TFL_PowOp : TFL_Op<"pow", [ResultsBroadcastableShape, NoSideEffect, NoQuantizableResult]> { +def TFL_PowOp : TFL_Op<"pow", [ResultsBroadcastableShape, + NoSideEffect, + NoQuantizableResult, + TFL_GpuTargetOp]> { let summary = "Power operator"; let description = [{ @@ -2092,7 +2116,7 @@ def TFL_PowOp : TFL_Op<"pow", [ResultsBroadcastableShape, NoSideEffect, NoQuanti let builders = [TFL_BroadcastableBinaryBuilder]; } -def TFL_PReluOp : TFL_Op<"prelu", [NoSideEffect]> { +def TFL_PReluOp : TFL_Op<"prelu", [NoSideEffect, TFL_GpuTargetOp]> { let summary = "Parameterized Relu operator"; let description = [{ diff --git a/tensorflow/compiler/mlir/lite/json_to_flatbuffer.cc b/tensorflow/compiler/mlir/lite/json_to_flatbuffer.cc new file mode 100644 index 00000000000..4a4e7a65cd6 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/json_to_flatbuffer.cc @@ -0,0 +1,63 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include + +#include +#include +#include +#include + +#include "flatbuffers/flatbuffers.h" // from @flatbuffers +#include "flatbuffers/idl.h" // from @flatbuffers +#include "flatbuffers/util.h" // from @flatbuffers +#include "tensorflow/lite/schema/schema_generated.h" + +int main(int argc, char** argv) { + // load FlatBuffer schema (.fbs) and JSON from disk + if (argc < 2) { + std::cerr << "Missing input argument. Usage:\n" + << argv[0] << " \n\n"; + return 1; + } + const char* schema_path = argv[1]; + const char* json_path = argv[2]; + std::string schema; + std::string json; + + const bool status = + flatbuffers::LoadFile(schema_path, /*binary=*/false, &schema) && + flatbuffers::LoadFile(json_path, /*binary=*/false, &json); + if (!status) { + std::cerr << "couldn't load files!\n"; + return 1; + } + + // parse schema first, so we can use it to parse the data after + flatbuffers::Parser parser; + const bool schema_parse_result = + parser.Parse(schema.c_str()) && parser.Parse(json.c_str()); + if (!schema_parse_result) { + std::cerr << "Parse error.\n"; + return 1; + } + const size_t length = parser.builder_.GetSize(); + const size_t n = + std::fwrite(parser.builder_.GetBufferPointer(), 1, length, stdout); + if (n != length) { + std::cerr << "print to stdout filed.\n"; + return 1; + } + return 0; +} diff --git a/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc index 1165561cb71..a07b7b8dd1d 100644 --- a/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc @@ -88,7 +88,6 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags, bool emit_builtin_tflite_ops = !toco_flags.force_select_tf_ops(); pass_config.emit_builtin_tflite_ops = emit_builtin_tflite_ops; pass_config.lower_tensor_list_ops = true; - pass_config.shape_inference = false; return internal::ConvertMLIRToTFLiteFlatBuffer(toco_flags, std::move(module), pass_config, result); diff --git a/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc index 681773a7e6b..c338b723a4a 100644 --- a/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc @@ -16,9 +16,12 @@ limitations under the License. #include +#include "llvm/ADT/StringSet.h" #include "llvm/Support/ToolOutputFile.h" #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Support/FileUtilities.h" // from @llvm-project #include "mlir/Transforms/ViewOpGraph.h" // from @llvm-project @@ -41,6 +44,77 @@ limitations under the License. namespace tensorflow { +Status HandleInputOutputArraysWithModule(const toco::ModelFlags& model_flags, + mlir::OwningModuleRef* module) { + mlir::FuncOp entry_function = nullptr; + for (auto func : module->get().getOps()) { + if (auto tf_attrs = + func.getAttrOfType("tf.entry_function")) { + // TODO(jaesung): There could be multiple entry functions. Let's handle + // such cases if there are any needs for that. + if (entry_function != nullptr) { + return errors::InvalidArgument( + "There should be only one tf.entry_function"); + } + entry_function = func; + } + } + if (entry_function == nullptr) { + return errors::InvalidArgument("no tf.entry_function found"); + } + + // Get the list of input Op names from the function attribute. + mlir::DictionaryAttr tf_attrs = + entry_function.getAttrOfType("tf.entry_function"); + llvm::SmallVector function_input_names; + function_input_names.reserve(model_flags.input_arrays().size()); + auto input_attr = tf_attrs.get("inputs"); + if (!input_attr) { + return errors::InvalidArgument("no inputs attribute found"); + } + auto input_names = input_attr.cast().getValue(); + input_names.split(function_input_names, ","); + if (function_input_names.size() != model_flags.input_arrays().size()) { + return errors::InvalidArgument( + "input array size mismatch: got ", function_input_names.size(), + ", expected: ", model_flags.input_arrays().size()); + } + llvm::StringSet<> function_input_names_set; + function_input_names_set.insert(function_input_names.begin(), + function_input_names.end()); + for (const auto& input_array : model_flags.input_arrays()) { + if (function_input_names_set.count(input_array.name()) == 0) { + return errors::InvalidArgument("input array name (", input_array.name(), + ") does not exist in the given graph"); + } + } + + // Get the list of output Op names from the function attribute. + llvm::SmallVector function_output_names; + function_output_names.reserve(model_flags.output_arrays().size()); + auto output_attr = tf_attrs.get("outputs"); + if (!output_attr) { + return errors::InvalidArgument("no outputs attribute found"); + } + auto output_names = output_attr.cast().getValue(); + output_names.split(function_output_names, ","); + if (function_output_names.size() != model_flags.output_arrays().size()) { + return errors::InvalidArgument( + "output array size mismatch: got ", function_output_names.size(), + ", expected: ", model_flags.output_arrays().size()); + } + llvm::StringSet<> function_output_names_set; + function_output_names_set.insert(function_output_names.begin(), + function_output_names.end()); + for (const auto& output_array : model_flags.output_arrays()) { + if (function_output_names_set.count(output_array) == 0) { + return errors::InvalidArgument("output array name (", output_array, + ") does not exist in the given graph"); + } + } + return Status::OK(); +} + Status ConvertSavedModelToTFLiteFlatBuffer( const toco::ModelFlags& model_flags, const toco::TocoFlags& toco_flags, string* result) { @@ -77,11 +151,15 @@ Status ConvertSavedModelToTFLiteFlatBuffer( model_flags.saved_model_version(), tags, exported_names, &context)); + if (!model_flags.input_arrays().empty() || + !model_flags.output_arrays().empty()) { + TF_RETURN_IF_ERROR(HandleInputOutputArraysWithModule(model_flags, &module)); + } + mlir::TFL::PassConfig pass_config(quant_specs); bool emit_builtin_tflite_ops = !toco_flags.force_select_tf_ops(); pass_config.emit_builtin_tflite_ops = emit_builtin_tflite_ops; pass_config.lower_tensor_list_ops = true; - pass_config.shape_inference = true; auto status = internal::ConvertMLIRToTFLiteFlatBuffer( toco_flags, std::move(module), pass_config, result); diff --git a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc index a17cdda2a39..6dd44e666fb 100644 --- a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc +++ b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc @@ -285,7 +285,7 @@ Status ConvertMLIRToTFLiteFlatBuffer(const toco::TocoFlags& toco_flags, if (pass_config.legalize_tf_while) { pm.addPass(mlir::TFL::CreateWhileOutlinePass()); } - pm.addPass(mlir::TFL::CreateRuntimeTypeVerifyPass()); + pm.addPass(mlir::TFL::CreateRuntimeVerifyPass()); auto status = ConvertTFExecutorToTFLOrFlatbuffer( module.get(), /*export_to_mlir=*/false, emit_builtin_tflite_ops, diff --git a/tensorflow/compiler/mlir/lite/quantization/BUILD b/tensorflow/compiler/mlir/lite/quantization/BUILD index a75135cf3b5..a63a1e4b1e5 100644 --- a/tensorflow/compiler/mlir/lite/quantization/BUILD +++ b/tensorflow/compiler/mlir/lite/quantization/BUILD @@ -14,7 +14,10 @@ package( package_group( name = "friends", includes = ["//third_party/mlir:subpackages"], - packages = ["//tensorflow/compiler/mlir/..."], + packages = [ + "//learning/brain/experimental/mlir/quantization/...", + "//tensorflow/compiler/mlir/...", + ], ) exports_files([ diff --git a/tensorflow/compiler/mlir/lite/quantization/import_quant_stats_pass.cc b/tensorflow/compiler/mlir/lite/quantization/import_quant_stats_pass.cc index 5a5012173e2..9d5aa167ff4 100644 --- a/tensorflow/compiler/mlir/lite/quantization/import_quant_stats_pass.cc +++ b/tensorflow/compiler/mlir/lite/quantization/import_quant_stats_pass.cc @@ -55,7 +55,8 @@ namespace quant { using QuantParamsEntry = QuantizationInfo::QuantParams; namespace { -class ImportQuantStatsPass : public FunctionPass { +class ImportQuantStatsPass + : public PassWrapper { public: explicit ImportQuantStatsPass(OperationToName op_to_name) : op_to_name_(op_to_name) {} @@ -193,7 +194,7 @@ void ImportQuantStatsPass::runOnFunction() { } // Creates an instance of the default quant parameters pass. -std::unique_ptr> CreateImportQuantStatsPass( +std::unique_ptr> CreateImportQuantStatsPass( OperationToName op_to_name, const std::string &stats_str) { auto pass = absl::make_unique(op_to_name); if (pass->ParseQuantStats(stats_str)) return nullptr; @@ -203,7 +204,7 @@ std::unique_ptr> CreateImportQuantStatsPass( // Creates an instance pass to import quantization stats to the operations in // the function. A custom method to get the name from the op is used because // different dialect ops might have different ways to assign the name. -std::unique_ptr> +std::unique_ptr> CreateImportQuantStatsPassForTFControlDialect(const std::string &stats_str) { auto get_name_func = [](Operation *op) { Location loc = op->getLoc(); diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_passes.h b/tensorflow/compiler/mlir/lite/quantization/quantization_passes.h index 2aa5f8e2d0d..bf034d49d4a 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_passes.h +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_passes.h @@ -27,13 +27,13 @@ using OperationToName = std::function; // Creates an instance pass to import quantization stats to the operations in // the function. A custom method to get the name from the op is used because // different dialect ops might have different ways to assign the name. -std::unique_ptr> CreateImportQuantStatsPass( +std::unique_ptr> CreateImportQuantStatsPass( OperationToName op_to_name, const std::string& stats_str); // Creates an instance pass to import quantization stats to the operations in // the function. A custom method to get the name from the op is used because // different dialect ops might have different ways to assign the name. -std::unique_ptr> +std::unique_ptr> CreateImportQuantStatsPassForTFControlDialect(const std::string& stats_str); } // namespace quant diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.cc b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.cc index 0bd914aa2e7..3d50f280d0f 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.cc +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.cc @@ -79,7 +79,7 @@ TypeAttr RescaleQuantizedType(Type input, Attribute factor) { SmallVector new_scales; new_scales.reserve(scales.size()); auto scales_iter = scales.begin(); - for (auto f : factor_values) { + for (const auto& f : factor_values) { new_scales.push_back(*(scales_iter++) * std::fabs(FloatAttr::getValueAsDouble(f))); } diff --git a/tensorflow/compiler/mlir/lite/quantization/tensorflow/passes.h b/tensorflow/compiler/mlir/lite/quantization/tensorflow/passes.h index 178daf1b1e0..1ee85a3f4eb 100644 --- a/tensorflow/compiler/mlir/lite/quantization/tensorflow/passes.h +++ b/tensorflow/compiler/mlir/lite/quantization/tensorflow/passes.h @@ -25,7 +25,7 @@ namespace mlir { namespace TF { // Legalize the tf ops to the quant ops, so the quantization passes can work. -std::unique_ptr> CreateLegalizeTFToQuantPass(); +std::unique_ptr> CreateLegalizeTFToQuantPass(); } // namespace TF } // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/quantization/tensorflow/tf_to_quant.cc b/tensorflow/compiler/mlir/lite/quantization/tensorflow/tf_to_quant.cc index 2c0b435cc04..e216cbb9306 100644 --- a/tensorflow/compiler/mlir/lite/quantization/tensorflow/tf_to_quant.cc +++ b/tensorflow/compiler/mlir/lite/quantization/tensorflow/tf_to_quant.cc @@ -27,7 +27,7 @@ namespace TF { namespace { // Legalize TF quantization emulation ops to that in Quant ops dialect. -struct LegalizeTFToQuant : public FunctionPass { +struct LegalizeTFToQuant : public PassWrapper { explicit LegalizeTFToQuant() = default; LegalizeTFToQuant(const LegalizeTFToQuant &) {} @@ -151,7 +151,7 @@ void LegalizeTFToQuant::runOnFunction() { } // namespace // Creates an instance of the TensorFlow dialect to QuantOps dialect pass. -std::unique_ptr> CreateLegalizeTFToQuantPass() { +std::unique_ptr> CreateLegalizeTFToQuantPass() { return std::make_unique(); } diff --git a/tensorflow/compiler/mlir/lite/quantization/xla/BUILD b/tensorflow/compiler/mlir/lite/quantization/xla/BUILD deleted file mode 100644 index 2bc1568eb17..00000000000 --- a/tensorflow/compiler/mlir/lite/quantization/xla/BUILD +++ /dev/null @@ -1,112 +0,0 @@ -load( - "//third_party/mlir:tblgen.bzl", - "gentbl", -) - -package( - default_visibility = [ - ":friends", - ], - licenses = ["notice"], # Apache 2.0 -) - -package_group( - name = "friends", - includes = ["//third_party/mlir:subpackages"], - packages = [ - "//tensorflow/compiler/aot/...", - "//tensorflow/compiler/mlir/...", - "//tensorflow/compiler/mlir/lite/...", - ], -) - -cc_library( - name = "hlo_xla_quantization_passes", - srcs = [ - "cpu_kernel_fusion.cc", - "generated_cpu_kernel_fusion.inc", - "materialize.cc", - "op_quant_spec.inc", - "propagate.cc", - ], - hdrs = [ - "passes.h", - ], - deps = [ - ":cpu_device_target", - "//tensorflow/compiler/mlir/lite/quantization:quantization_config", - "//tensorflow/compiler/mlir/lite/quantization:quantization_context", - "//tensorflow/compiler/mlir/lite/quantization:quantization_lib", - "//tensorflow/compiler/mlir/xla:hlo", - "//tensorflow/compiler/xla/client/lib:quantize", - "@com_google_absl//absl/memory", - "@llvm-project//llvm:support", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:QuantOps", - "@llvm-project//mlir:StandardOps", - "@llvm-project//mlir:Support", - "@llvm-project//mlir:TransformUtils", - ], - alwayslink = 1, -) - -cc_library( - name = "cpu_device_target", - srcs = [ - "cpu_device_target.cc", - ], - hdrs = [ - "cpu_device_target.h", - ], - deps = [ - "//tensorflow/compiler/mlir/lite/quantization:device_target", - "//tensorflow/compiler/mlir/lite/quantization:quantization_context", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:QuantOps", - "@llvm-project//mlir:Support", - ], -) - -cc_library( - name = "quantize", - srcs = [ - "quantize.cc", - ], - hdrs = [ - "quantize.h", - ], - deps = [ - "//tensorflow/compiler/mlir/xla:hlo", - "//tensorflow/compiler/mlir/xla:hlo_to_mlir_hlo", - "//tensorflow/compiler/tf2xla", - "//tensorflow/compiler/tf2xla:tf2xla_proto_cc", - "//tensorflow/compiler/tf2xla:tf2xla_util", - "//tensorflow/compiler/tf2xla:xla_compiler", - "//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops", - "//tensorflow/compiler/tf2xla/kernels:xla_ops", - "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/core/platform:status", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:StandardOps", - "@llvm-project//mlir:Transforms", - ], -) - -gentbl( - name = "cpu_kernel_fusion_inc_gen", - tbl_outs = [ - ( - "-gen-rewriters", - "generated_cpu_kernel_fusion.inc", - ), - ], - tblgen = "@llvm-project//mlir:mlir-tblgen", - td_file = "cpu_kernel_fusion.td", - td_srcs = [ - "@llvm-project//mlir:StdOpsTdFiles", - "//tensorflow/compiler/mlir/xla:hlo_ops_td_files", - "//tensorflow/compiler/mlir/lite/quantization:quantization_td_files", - ], -) diff --git a/tensorflow/compiler/mlir/lite/quantization/xla/cpu_device_target.cc b/tensorflow/compiler/mlir/lite/quantization/xla/cpu_device_target.cc deleted file mode 100644 index b456af27fa5..00000000000 --- a/tensorflow/compiler/mlir/lite/quantization/xla/cpu_device_target.cc +++ /dev/null @@ -1,67 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/mlir/lite/quantization/xla/cpu_device_target.h" - -#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project -#include "mlir/IR/MLIRContext.h" // from @llvm-project -#include "mlir/Support/LogicalResult.h" // from @llvm-project -#include "tensorflow/compiler/mlir/lite/quantization/quantization_context.h" - -namespace mlir { -namespace xla_hlo { - -namespace ph = std::placeholders; - -CpuDeviceTarget::CpuDeviceTarget(MLIRContext* ctx) : DeviceTarget(ctx) { - RegisterKernel("generic.concat", {qi8_, qi8_, qi8_}, - quant::ScaleConstraintType::OutputInputSameScale); - - // TODO(fengliuai): All the combinations are required to list. We need to - // improve this. - RegisterKernel("generic.reshape", {qi8_, any_}, - quant::ScaleConstraintType::OutputInputSameScale); - RegisterKernel("generic.reshape", {any_, qi8_}, - quant::ScaleConstraintType::OutputInputSameScale); - - RegisterKernel("generic.mul", {qi8_, qi8_, qi8_}, - quant::ScaleConstraintType::OutputInputFreeScale); - RegisterKernel("generic.mul_add", {qi8_, qi8n_, any_, qi8_}, - std::bind(&CpuDeviceTarget::HandleMultiplyAccumulateScale, - this, ph::_1, ph::_2, ph::_3, ph::_4)); - RegisterKernel("generic.matmul_add", {qi8_, qi8n_, any_, qi8_}, - std::bind(&CpuDeviceTarget::HandleMultiplyAccumulateScale, - this, ph::_1, ph::_2, ph::_3, ph::_4)); -} - -LogicalResult CpuDeviceTarget::HandleMultiplyAccumulateScale( - quant::QuantizeContext* ctx, Operation* op, - quant::AdjacentOperations* new_items, bool* changed) { - auto bias_params = ctx->GetOperandParams(op, 2); - if (!EmptyParams(bias_params)) { - return success(); - } - std::vector op_types{ctx->GetOperandParams(op, 0), - ctx->GetOperandParams(op, 1)}; - auto bias_scale = GetUniformQuantizedTypeForBias(op_types); - if (bias_scale && ctx->SetOperandParams(op, 2, bias_scale)) { - *changed = true; - new_items->push_back(op->getOperand(2).getDefiningOp()); - } - return success(); -} - -} // namespace xla_hlo -} // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/quantization/xla/cpu_device_target.h b/tensorflow/compiler/mlir/lite/quantization/xla/cpu_device_target.h deleted file mode 100644 index a2b05fb6a00..00000000000 --- a/tensorflow/compiler/mlir/lite/quantization/xla/cpu_device_target.h +++ /dev/null @@ -1,40 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_XLA_CPU_DEVICE_TARGET_H_ -#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_XLA_CPU_DEVICE_TARGET_H_ - -#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project -#include "mlir/Support/LogicalResult.h" // from @llvm-project -#include "tensorflow/compiler/mlir/lite/quantization/device_target.h" - -namespace mlir { -namespace xla_hlo { - -// Target specs for cpu kernels -class CpuDeviceTarget : public quant::DeviceTarget { - public: - explicit CpuDeviceTarget(MLIRContext* ctx); - - private: - LogicalResult HandleMultiplyAccumulateScale( - quant::QuantizeContext* ctx, Operation* op, - quant::AdjacentOperations* new_items, bool* changed); -}; - -} // namespace xla_hlo -} // namespace mlir - -#endif // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_XLA_CPU_DEVICE_TARGET_H_ diff --git a/tensorflow/compiler/mlir/lite/quantization/xla/cpu_kernel_fusion.cc b/tensorflow/compiler/mlir/lite/quantization/xla/cpu_kernel_fusion.cc deleted file mode 100644 index 47373e8bed9..00000000000 --- a/tensorflow/compiler/mlir/lite/quantization/xla/cpu_kernel_fusion.cc +++ /dev/null @@ -1,346 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include - -#include -#include -#include -#include -#include -#include - -#include "absl/memory/memory.h" -#include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/SmallPtrSet.h" -#include "llvm/ADT/SmallVector.h" -#include "llvm/ADT/StringRef.h" -#include "llvm/ADT/Twine.h" -#include "llvm/Support/Casting.h" -#include "llvm/Support/CommandLine.h" -#include "llvm/Support/Debug.h" -#include "llvm/Support/raw_ostream.h" -#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project -#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project -#include "mlir/IR/Attributes.h" // from @llvm-project -#include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project -#include "mlir/IR/Function.h" // from @llvm-project -#include "mlir/IR/MLIRContext.h" // from @llvm-project -#include "mlir/IR/Matchers.h" // from @llvm-project -#include "mlir/IR/PatternMatch.h" // from @llvm-project -#include "mlir/IR/StandardTypes.h" // from @llvm-project -#include "mlir/IR/Value.h" // from @llvm-project -#include "mlir/Pass/Pass.h" // from @llvm-project -#include "mlir/Support/LLVM.h" // from @llvm-project -#include "mlir/Support/LogicalResult.h" // from @llvm-project -#include "mlir/Transforms/DialectConversion.h" // from @llvm-project -#include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h" -#include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h" -#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" -#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" -#include "tensorflow/compiler/xla/client/lib/quantize.h" - -#define DEBUG_TYPE "quant-kernel-fusion" - -constexpr int kFakeQuantOperandsNum = 5; -constexpr int kFakeQuantPerChannelOperandsNum = 6; - -namespace mlir { -namespace xla_hlo { - -namespace { - -TypeAttr GetQuantSpec(Operation* op) { - auto fake_quant = llvm::dyn_cast_or_null(op); - if (!fake_quant || fake_quant.getNumOperands() < kFakeQuantOperandsNum || - fake_quant.getNumOperands() > kFakeQuantPerChannelOperandsNum || - fake_quant.call_target_name() != "fake_quant_with_min_max_vars") - return {}; - - DenseFPElementsAttr min, max; - DenseIntElementsAttr bit_width, narrow_range, quant_dim; - if (!matchPattern(fake_quant.getOperand(1), m_Constant(&min)) || - !matchPattern(fake_quant.getOperand(2), m_Constant(&max)) || - !matchPattern(fake_quant.getOperand(3), m_Constant(&bit_width)) || - !matchPattern(fake_quant.getOperand(4), m_Constant(&narrow_range))) - return {}; - - auto bit_width_val = (*bit_width.attr_value_begin()).cast(); - auto narrow_range_val = (*narrow_range.int_value_begin()).getSExtValue(); - int quant_dim_val = -1; - if (fake_quant.getNumOperands() == kFakeQuantPerChannelOperandsNum && - matchPattern(fake_quant.getOperand(kFakeQuantPerChannelOperandsNum - 1), - m_Constant(&quant_dim))) { - quant_dim_val = (*quant_dim.int_value_begin()).getSExtValue(); - } - - OpBuilder builder(op); - Type input_type = - fake_quant.getOperand(0).getType().cast().getElementType(); - return quant::GetQuantizedTypeAttr( - builder, input_type, min, max, quant_dim_val, bit_width_val, - builder.getBoolAttr(narrow_range_val), /*is_signed=*/true); -} - -// Collects input values from outside for 'ops'. -void CollectInputs(llvm::ArrayRef ops, - llvm::SmallVectorImpl* inputs, - llvm::SmallVectorImpl* input_specs) { - for (Operation* op : ops) { - for (Value operand : op->getOperands()) { - if (std::find(inputs->begin(), inputs->end(), operand) != inputs->end()) { - continue; - } - if (Operation* def_op = operand.getDefiningOp()) { - if (std::find(ops.begin(), ops.end(), def_op) == ops.end()) { - inputs->push_back(operand); - } - } else { // argument value - inputs->push_back(operand); - } - } - } - - for (Value input : *inputs) { - ShapedType input_type = input.getType().cast(); - if (TypeAttr spec = GetQuantSpec(input.getDefiningOp())) { - input_specs->push_back(spec); - } else { - input_specs->push_back(TypeAttr::get(input_type.getElementType())); - } - } -} - -// Collects values that are produced by 'ops' and have use outside of 'ops'. -// TODO(fengliuai): if it is a single user and QDQ, write that to the specs. -void CollectRets(llvm::ArrayRef ops, - llvm::SmallVectorImpl* rets, - llvm::SmallVectorImpl* ret_types, - llvm::SmallVectorImpl* ret_specs) { - for (Operation* op : ops) { - // The constant will not be shared outside the region. - if (llvm::isa(op)) continue; - - for (Value result : op->getResults()) { - for (Operation* user : result.getUsers()) { - // If there are any user outside of 'ops' - if (std::find(ops.begin(), ops.end(), user) == ops.end()) { - ShapedType ret_type = result.getType().cast(); - rets->push_back(result); - ret_types->push_back(ret_type); - if (TypeAttr spec = GetQuantSpec(user)) { - ret_specs->push_back(spec); - } else { - ret_specs->push_back(TypeAttr::get(ret_type.getElementType())); - } - break; - } - } - } - } -} - -enum FusedActivationFunc { NONE, RELU, RELU1, RELU6 }; - -#define FLOAT_EQ(value, x) fabs(value - x) <= 1e-6 - -// If the op is max(in, 0.0), we consider this is from Relu, so both this op -// and constant 0.0 will be fused. -// If the op is clamp(0.0, in, 1.0) or clamp(0.0, in, 6.0), we consider this is -// from Relu1 or Relu6, so all the constants and this op will be fused. -// Returns the activation function type. -FusedActivationFunc FuseReluX(Operation* op, - llvm::SmallVectorImpl* fused) { - if (auto max = llvm::dyn_cast(op)) { - Value min_val = max.rhs(); - llvm::SmallVector broadcast_ops; - if (auto broadcast = llvm::dyn_cast_or_null( - min_val.getDefiningOp())) { - min_val = broadcast.operand(); - broadcast_ops.push_back(broadcast); - } - DenseFPElementsAttr min; - if (!matchPattern(min_val, m_Constant(&min))) { - // In case the min value is lhs. - min_val = max.lhs(); - broadcast_ops.clear(); - if (auto broadcast = llvm::dyn_cast_or_null( - min_val.getDefiningOp())) { - min_val = broadcast.operand(); - broadcast_ops.push_back(broadcast); - } - if (!matchPattern(min_val, m_Constant(&min))) { - return NONE; - } - } - if (!min.isSplat() || - !(FLOAT_EQ(min.getSplatValue().cast().getValueAsDouble(), - 0.0))) { - return NONE; - } - - // Include the constant 0.0 as well, to avoid being quantized. - fused->push_back(min_val.getDefiningOp()); - fused->append(broadcast_ops.begin(), broadcast_ops.end()); - fused->push_back(max); - return RELU; - } - - if (auto clamp = llvm::dyn_cast(op)) { - DenseFPElementsAttr lower, upper; - if (!matchPattern(clamp.min(), m_Constant(&lower)) || - !matchPattern(clamp.max(), m_Constant(&upper)) || !lower.isSplat() || - !upper.isSplat() || - !(FLOAT_EQ(lower.getSplatValue().cast().getValueAsDouble(), - 0.0))) { - return NONE; - } - - double upper_value = - upper.getSplatValue().cast().getValueAsDouble(); - if (FLOAT_EQ(upper_value, 1.0) || FLOAT_EQ(upper_value, 6.0)) { - fused->push_back(clamp.min().getDefiningOp()); - fused->push_back(clamp.max().getDefiningOp()); - fused->push_back(op); - return (FLOAT_EQ(upper_value, 1.0) ? RELU1 : RELU6); - } - } - return NONE; -} - -llvm::SmallVector FuseOps(PatternRewriter* rewriter, - const std::initializer_list& results, - StringRef kernel) { - // Collect all the operations to be fused. - llvm::SmallVector fused; - llvm::SmallVector locs; - fused.reserve(results.size()); - locs.reserve(results.size()); - for (auto value : results) { - Operation* op = value.getDefiningOp(); - fused.push_back(op); - locs.push_back(op->getLoc()); - } - - Operation* root = fused.back(); - - FusedActivationFunc act_func = FusedActivationFunc::NONE; - // If there is Relu, Relu1 or Relu6, fuse it as well. - if (results.size() > 0 && std::rbegin(results)->hasOneUse()) { - act_func = FuseReluX(*std::rbegin(results)->user_begin(), &fused); - } - - // Collect inputs from outside to 'ops'. - llvm::SmallVector inputs; - llvm::SmallVector input_specs; - CollectInputs(fused, &inputs, &input_specs); - - // Collect outputs from 'ops' to outside. - llvm::SmallVector rets; - llvm::SmallVector ret_types; - llvm::SmallVector ret_specs; - CollectRets(fused, &rets, &ret_types, &ret_specs); - - // TODO(fengliuai): make activation function an attribute. - std::string kernel_name; - switch (act_func) { - case RELU: - kernel_name = llvm::Twine(kernel, "_relu").str(); - break; - case RELU1: - kernel_name = llvm::Twine(kernel, "_relu1").str(); - break; - case RELU6: - kernel_name = llvm::Twine(kernel, "_relu6").str(); - break; - default: - kernel_name = kernel.str(); - } - - // Create the region op with the return. - auto region = rewriter->create( - rewriter->getFusedLoc(locs), ret_types, inputs, - rewriter->getArrayAttr(input_specs), rewriter->getArrayAttr(ret_specs), - kernel_name); - auto* body = new Block(); - region.body().push_back(body); - - OpBuilder builder = OpBuilder::atBlockEnd(body); - BlockAndValueMapping mapping; - - // Make block arguments and add it to the block value mapping. - for (Value input : inputs) { - mapping.map(input, body->addArgument(input.getType())); - } - - // Clone the operations 'ops' to the region. - for (Operation* op : fused) { - builder.clone(*op, mapping); - } - - llvm::SmallVector new_rets; - new_rets.reserve(rets.size()); - for (auto ret : llvm::enumerate(rets)) { - Value new_ret = mapping.lookupOrNull(ret.value()); - assert(new_ret && "couldn't find return value."); - new_rets.push_back(new_ret); - ret.value().replaceAllUsesWith(region.getResult(ret.index())); - } - builder.create(builder.getUnknownLoc(), new_rets); - - LLVM_DEBUG({ - assert(region.verify().Success && "failed to create quant region."); - llvm::dbgs() << "\ncreated region: "; - region.print(llvm::dbgs()); - llvm::dbgs() << "\n\n\n"; - }); - - // All uses of the fused ops are replaced, so the values in this vector - // will not be used. - SmallVector new_values(root->getNumResults(), region.getResult(0)); - return new_values; -} - -struct CpuKernelFusionPass : public FunctionPass { - explicit CpuKernelFusionPass() = default; - CpuKernelFusionPass(const CpuKernelFusionPass&) {} - - void runOnFunction() override; -}; - -#include "tensorflow/compiler/mlir/lite/quantization/xla/generated_cpu_kernel_fusion.inc" - -void CpuKernelFusionPass::runOnFunction() { - Operation* op = getOperation(); - MLIRContext* ctx = op->getContext(); - OwningRewritePatternList patterns; - populateWithGenerated(ctx, &patterns); - applyPatternsGreedily(op->getRegions(), patterns); -} - -} // namespace - -// Creates an instance of the xla_hlo cpu kernel fusion pass. -std::unique_ptr> CreateCpuKernelFusionPass() { - return std::make_unique(); -} - -static PassRegistration pass( - "xla-hlo-cpu-fusion", "Fuse xla hlo ops into cpu kernels"); - -} // namespace xla_hlo -} // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/quantization/xla/cpu_kernel_fusion.td b/tensorflow/compiler/mlir/lite/quantization/xla/cpu_kernel_fusion.td deleted file mode 100644 index 69240015242..00000000000 --- a/tensorflow/compiler/mlir/lite/quantization/xla/cpu_kernel_fusion.td +++ /dev/null @@ -1,65 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -include "tensorflow/compiler/mlir/xla/ir/hlo_ops.td" -include "mlir/IR/OpBase.td" -include "mlir/Dialect/StandardOps/IR/Ops.td" - -class Fused1Ops : NativeCodeCall< - "FuseOps(&$_builder, {$0}, \"" # kernel # "\")">; -class Fused2Ops : NativeCodeCall< - "FuseOps(&$_builder, {$0, $1}, \"" # kernel # "\")">; -class Fused3Ops : NativeCodeCall< - "FuseOps(&$_builder, {$0, $1, $2}, \"" # kernel # "\")">; -class Fused4Ops : NativeCodeCall< - "FuseOps(&$_builder, {$0, $1, $2, $3}, \"" # kernel # "\")">; - -// We shouldn't revisit those ops which have been fused. This constraint is -// required because the greedy pattern rewriter will visit and match any new -// ops. So when the source pattern are matched and wrapped by the quant region -// op, these ops will be matched again. To prevent this, this constraint is -// added to bypass any ops inside a quant region. -def NeedsToBeFused : ConstraintgetParentOfType()">>; - -// dummy example -def : Pat<(HLO_AddOp:$add (HLO_MulOp:$mul $_, $_, $_), $_, $_), - (Fused2Ops<"generic.mul_add"> $mul, $add), - [(NeedsToBeFused $add)]>; - -// reduce_window: maxpool, avgpool -def : Pat<(HLO_ReduceWindowOp:$reduce $_, $_, $_, $_, $_, $_, $_), - (Fused1Ops<"generic.reduce_window"> $reduce), - [(NeedsToBeFused $reduce)]>; - -// reshape -def : Pat<(HLO_ReshapeOp:$reshape $_), (Fused1Ops<"generic.reshape"> $reshape), - [(NeedsToBeFused $reshape)]>; - -// broadcast -def : Pat<(HLO_BroadcastInDimOp:$broadcast $_, $_), - (Fused1Ops<"generic.broadcast"> $broadcast), - [(NeedsToBeFused $broadcast)]>; - -// dot -> add -def : Pat<(HLO_AddOp:$add (HLO_DotOp:$dot $_, $_, $_), $_, $_), - (Fused2Ops<"generic.biased_dot"> $dot, $add), - [(NeedsToBeFused $add)]>; - -// conv -> add -def : Pat<(HLO_AddOp:$add - (HLO_ConvOp:$conv $_, $_, $_, $_, $_, $_, $_, $_, $_, $_), $_, $_), - (Fused2Ops<"generic.biased_conv"> $conv, $add), - [(NeedsToBeFused $add)]>; diff --git a/tensorflow/compiler/mlir/lite/quantization/xla/materialize.cc b/tensorflow/compiler/mlir/lite/quantization/xla/materialize.cc deleted file mode 100644 index 25a5f38bf0a..00000000000 --- a/tensorflow/compiler/mlir/lite/quantization/xla/materialize.cc +++ /dev/null @@ -1,174 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// This transformation pass quantize the constant and rewrite the quantization -// ops by xla_hlo primitive ops. -#include -#include -#include -#include - -#include "absl/memory/memory.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/StringRef.h" -#include "llvm/Support/Casting.h" -#include "llvm/Support/CommandLine.h" -#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project -#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project -#include "mlir/IR/Attributes.h" // from @llvm-project -#include "mlir/IR/MLIRContext.h" // from @llvm-project -#include "mlir/IR/PatternMatch.h" // from @llvm-project -#include "mlir/IR/StandardTypes.h" // from @llvm-project -#include "mlir/IR/Value.h" // from @llvm-project -#include "mlir/Pass/Pass.h" // from @llvm-project -#include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h" -#include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h" -#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" -#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" -#include "tensorflow/compiler/xla/client/lib/quantize.h" - -//===----------------------------------------------------------------------===// -// The pass to materialize the quantization results by xla primitive ops. -// -namespace mlir { -namespace xla_hlo { - -namespace { - -// This pattern matches the "constant->qcast->dcast" pattern and replaces it by -// "quantized constant->xla_hlo.dequantize". If it only matches the -// "non-constant->qcast->dcast" pattern, it will remove both the "qcast->dcast". -// We chain the pattern as a whole to bypass the type checks of the normal -// xla_hlo ops. -// TODO(fengliuai): make this pass work for bf16 input. -class RewriteDequantize : public OpRewritePattern { - public: - explicit RewriteDequantize(int64_t size, MLIRContext *context) - : OpRewritePattern(context), size_(size) {} - - LogicalResult matchAndRewrite(quant::DequantizeCastOp op, - PatternRewriter &rewriter) const override { - // quant.dcast - // xla_hlo dequantize only takes min/max, so let's recover them from - // the quantization parameters. - Value dcast = op.arg(); - auto type = quant::QuantizedType::getQuantizedElementType(dcast.getType()); - if (!type || !type.isa()) { - return failure(); - } - auto qtype = type.cast(); - double scale = qtype.getScale(); - int64_t zero_point = qtype.getZeroPoint(); - float min = scale * (qtype.getStorageTypeMin() - zero_point); - float max = scale * (qtype.getStorageTypeMax() - zero_point); - - // quant.qcast - auto qcast = - llvm::dyn_cast_or_null(dcast.getDefiningOp()); - if (!qcast) return failure(); - - // constant - DenseFPElementsAttr attr; - // If it isn't a floating-point constant or the size is too small, let's - // remove the quantization. Also the last dimension size should be a - // multiplier of 4, so the shape isn't broken during packing and unpacking. - if (!matchPattern(qcast.arg(), m_Constant(&attr)) || - attr.getNumElements() <= size_ || - attr.getType().getDimSize(attr.getType().getRank() - 1) % 4 != 0) { - op.getResult().replaceAllUsesWith(qcast.arg()); - return success(); - } - // TODO(fengliuai): implement transpose if it has high dimension. - - // Create the quantized result - auto quantized_result = - quant::Quantize(attr, qtype).dyn_cast_or_null(); - if (!quantized_result) { - return failure(); - } - - // Pack the uint8 bits to uint32. The shape is changed from from - // [n0, n1, ..., nk] to [n0, n1, ..., nk / 4]. - std::vector raw_data; - for (auto d : quantized_result.getValues()) { - raw_data.push_back(d); - } - // The packing might increase the data size by paddings. - auto packed_data = xla::PackToUint32(raw_data); - auto packed_shape = attr.getType().getShape().vec(); - int lower_dims = std::accumulate( - packed_shape.begin(), - std::next(packed_shape.begin(), packed_shape.size() - 1), 1, - std::multiplies()); - packed_shape[packed_shape.size() - 1] = packed_data.size() / lower_dims; - auto packed_type = - RankedTensorType::get(packed_shape, rewriter.getIntegerType(32)); - - auto packed_quantized_result = - DenseElementsAttr::get(packed_type, packed_data); - auto quantized_constant = - rewriter.create(qcast.getLoc(), packed_quantized_result); - - // Create the xla dequantize op with bf16 output - auto dequantized_type = RankedTensorType::get(attr.getType().getShape(), - rewriter.getBF16Type()); - auto dequantize = rewriter.create( - qcast.getLoc(), dequantized_type, quantized_constant, - rewriter.getF32FloatAttr(min), rewriter.getF32FloatAttr(max), - rewriter.getStringAttr("MIN_COMBINED"), rewriter.getBoolAttr(false), - rewriter.getBoolAttr(false)); - - // Convert bf16 output back to f32 - rewriter.replaceOpWithNewOp(op, op.getResult().getType(), - dequantize); - return success(); - } - - private: - int64_t size_; -}; - -// Materialize the quantization results by hlo primitive ops. -struct MaterializeToXlaPass : public FunctionPass { - explicit MaterializeToXlaPass() = default; - MaterializeToXlaPass(const MaterializeToXlaPass &) {} - - void runOnFunction() override; -}; - -void MaterializeToXlaPass::runOnFunction() { - FuncOp func = getFunction(); - MLIRContext *ctx = &getContext(); - - OwningRewritePatternList patterns; - // TODO(fengliuai): make the size 6 configurable. - patterns.insert(6, ctx); - - applyPatternsGreedily(func, patterns); -} - -} // namespace - -// Creates an instance of the xla_hlo dialect quantization propagation pass. -std::unique_ptr> CreateMaterializeToXlaPass() { - return std::make_unique(); -} - -static PassRegistration pass( - "xla-hlo-materialize-quant", - "Materialize the quantization results by xla primitve ops"); - -} // namespace xla_hlo -} // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/quantization/xla/op_quant_spec.inc b/tensorflow/compiler/mlir/lite/quantization/xla/op_quant_spec.inc deleted file mode 100644 index fc469208467..00000000000 --- a/tensorflow/compiler/mlir/lite/quantization/xla/op_quant_spec.inc +++ /dev/null @@ -1,7 +0,0 @@ -// TODO(fengliuai): automatically generate this file -// TODO(fengliuai): add all the xla_hlo ops - -static std::unique_ptr GetOpQuantSpec(mlir::Operation *op) { - auto spec = absl::make_unique(); - return spec; -} diff --git a/tensorflow/compiler/mlir/lite/quantization/xla/passes.h b/tensorflow/compiler/mlir/lite/quantization/xla/passes.h deleted file mode 100644 index c4f9d63cf68..00000000000 --- a/tensorflow/compiler/mlir/lite/quantization/xla/passes.h +++ /dev/null @@ -1,37 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_XLA_PASSES_H_ -#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_XLA_PASSES_H_ - -#include - -#include "mlir/IR/Function.h" // from @llvm-project -#include "mlir/Pass/Pass.h" // from @llvm-project - -namespace mlir { -namespace xla_hlo { - -// Propagate the quantization information to all the tensors according to the -// op quant spec. -std::unique_ptr> CreatePropagateQuantPass(); - -// Rewrite the graph and quantize the constant. -std::unique_ptr> CreateMaterializeToXlaPass(); - -} // namespace xla_hlo -} // namespace mlir - -#endif // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_XLA_PASSES_H_ diff --git a/tensorflow/compiler/mlir/lite/quantization/xla/propagate.cc b/tensorflow/compiler/mlir/lite/quantization/xla/propagate.cc deleted file mode 100644 index 22dd4357416..00000000000 --- a/tensorflow/compiler/mlir/lite/quantization/xla/propagate.cc +++ /dev/null @@ -1,107 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// This transformation pass applies quantization propagation on xla_hlo dialect. -#include -#include - -#include "absl/memory/memory.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/StringRef.h" -#include "llvm/Support/CommandLine.h" -#include "mlir/IR/MLIRContext.h" // from @llvm-project -#include "mlir/IR/PatternMatch.h" // from @llvm-project -#include "mlir/IR/Value.h" // from @llvm-project -#include "mlir/Pass/Pass.h" // from @llvm-project -#include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h" -#include "tensorflow/compiler/mlir/lite/quantization/quantization_context.h" -#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" -#include "tensorflow/compiler/mlir/lite/quantization/xla/cpu_device_target.h" - -// NOLINTNEXTLINE -static llvm::cl::opt disable_per_channel( - "xla-disable-per-channel", llvm::cl::value_desc("bool"), - llvm::cl::desc("Whether disable per-channel quantized weights."), - llvm::cl::init(false)); - -//===----------------------------------------------------------------------===// -// The quantization propagation Pass. -// -namespace mlir { -namespace xla_hlo { - -namespace { - -// Applies the quantization propagation on the input function. During the -// propagation, two facts are respected: -// - The quantization type (params) of the ops in the function -// - The quantization spec for the ops -// The propagation results should assign quantization types to all the tensors -// and the two restrictions are respected. -struct PropagateQuantPass : public FunctionPass { - explicit PropagateQuantPass() = default; - PropagateQuantPass(const PropagateQuantPass &) {} - - void runOnFunction() override; -}; - -#include "tensorflow/compiler/mlir/lite/quantization/xla/op_quant_spec.inc" - -void PropagateQuantPass::runOnFunction() { - FuncOp func = getFunction(); - // TODO(fengliuai): deprecate this old code generation path. - // XLA only support uint8/uint16 quantization for now. - ApplyQuantizationParamsPropagation(func, /*is_signed*/ false, - disable_per_channel, GetOpQuantSpec); - - CpuDeviceTarget spec(&getContext()); - quant::QuantizeContext ctx(func, spec); - - std::vector work_list = ctx.GetAllOps(); - bool changed = false; - while (!work_list.empty()) { - quant::QuantizeRegionOp op = work_list.back(); - work_list.pop_back(); - - llvm::SmallVector new_items; - if (failed(ctx.Handle(op, &new_items, &changed))) { - // The IR is still valid, thus we shouldn't fail. - signalPassFailure(); - } - for (auto item : new_items) { - if (auto reg = llvm::dyn_cast_or_null(item)) - work_list.push_back(reg); - } - } - - if (!changed) return; - - if (failed(ctx.Finalize())) { - signalPassFailure(); - } -} - -} // namespace - -// Creates an instance of the xla_hlo dialect quantization propagation pass. -std::unique_ptr> CreatePropagateQuantPass() { - return std::make_unique(); -} - -static PassRegistration pass( - "xla-hlo-propagate-quant", "Propagate quantization information"); - -} // namespace xla_hlo -} // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/quantization/xla/quantize.cc b/tensorflow/compiler/mlir/lite/quantization/xla/quantize.cc deleted file mode 100644 index 9df41bb660a..00000000000 --- a/tensorflow/compiler/mlir/lite/quantization/xla/quantize.cc +++ /dev/null @@ -1,74 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/compiler/mlir/lite/quantization/xla/quantize.h" - -#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project -#include "mlir/IR/Builders.h" // from @llvm-project -#include "mlir/IR/Function.h" // from @llvm-project -#include "mlir/IR/MLIRContext.h" // from @llvm-project -#include "mlir/IR/Module.h" // from @llvm-project -#include "mlir/Pass/Pass.h" // from @llvm-project -#include "mlir/Pass/PassManager.h" // from @llvm-project -#include "mlir/Transforms/Passes.h" // from @llvm-project -#include "tensorflow/compiler/mlir/xla/hlo_to_mlir_hlo.h" -#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" -#include "tensorflow/compiler/tf2xla/tf2xla.h" -#include "tensorflow/compiler/tf2xla/tf2xla_util.h" - -namespace mlir { -namespace xla_hlo { - -static void RegisterDialects() { - static bool init_once = []() { - mlir::registerDialect(); - mlir::registerDialect(); - return true; - }(); - (void)init_once; -} - -// Quantizes the model in the computation. -tensorflow::Status XlaQuantize(const tensorflow::tf2xla::Config& config, - xla::XlaComputation* computation) { - TF_ASSIGN_OR_RETURN(std::unique_ptr snapshot, - computation->Snapshot()); - - RegisterDialects(); - MLIRContext context; - OwningModuleRef module = ModuleOp::create(UnknownLoc::get(&context)); - auto status = xla::ConvertHloToMlirHlo( - module.get(), snapshot->mutable_hlo()->mutable_hlo_module()); - if (!status.ok()) { - LOG(ERROR) << "Hlo module import failed: " << status; - return status; - } - - PassManager pm(&context); - pm.addPass(createCanonicalizerPass()); - pm.addPass(createInlinerPass()); - pm.addPass(createSymbolDCEPass()); - pm.addNestedPass(createCSEPass()); - - mlir::StatusScopedDiagnosticHandler diag_handler(&context); - LogicalResult result = pm.run(module.get()); - (void)result; - - module->dump(); - - return tensorflow::Status::OK(); -} - -} // namespace xla_hlo -} // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/quantization/xla/tests/BUILD b/tensorflow/compiler/mlir/lite/quantization/xla/tests/BUILD deleted file mode 100644 index 4b6b4212567..00000000000 --- a/tensorflow/compiler/mlir/lite/quantization/xla/tests/BUILD +++ /dev/null @@ -1,35 +0,0 @@ -load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") - -package(licenses = ["notice"]) - -glob_lit_tests( - data = [ - ":graph_config_files", - ":test_utilities", - ], - driver = "@llvm-project//mlir:run_lit.sh", - tags_override = { - "fadd_quant.mlir": ["no_oss"], # TODO(b/150957738): to be fixed on oss. - }, - test_file_exts = ["mlir"], -) - -# Bundle together all of the test utilities that are used by tests. -filegroup( - name = "test_utilities", - testonly = True, - data = [ - "//tensorflow/compiler/aot:tfcompile", - "//tensorflow/compiler/mlir:tf-opt", - "@llvm-project//llvm:FileCheck", - "@llvm-project//llvm:not", - ], -) - -# Bundle together all the graph files that are used by the tests. -filegroup( - name = "graph_config_files", - srcs = glob( - ["**/*.pbtxt"], - ), -) diff --git a/tensorflow/compiler/mlir/lite/quantization/xla/tests/cpu_kernel_fusion.mlir b/tensorflow/compiler/mlir/lite/quantization/xla/tests/cpu_kernel_fusion.mlir deleted file mode 100644 index 9920e1f214f..00000000000 --- a/tensorflow/compiler/mlir/lite/quantization/xla/tests/cpu_kernel_fusion.mlir +++ /dev/null @@ -1,199 +0,0 @@ -// RUN: tf-opt -xla-hlo-cpu-fusion %s | FileCheck %s - -// CHECK-LABEL: @mul_add_source -func @mul_add_source(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>) { - %0 = "xla_hlo.multiply"(%arg0, %arg1) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - %1 = "xla_hlo.add"(%0, %arg2) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - return %1 : tensor<4xf32> - -// CHECK: %[[region:.*]] = "quant.region"(%arg0, %arg1, %arg2) ( { -// CHECK: ^bb0(%arg3: tensor<4xf32>, %arg4: tensor<4xf32>, %arg5: tensor<4xf32>): // no predecessors -// CHECK: %[[mul:.*]] = xla_hlo.multiply %arg3, %arg4 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<4xf32> -// CHECK: %[[add:.*]] = xla_hlo.add %[[mul]], %arg5 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<4xf32> -// CHECK: "quant.return"(%[[add]]) : (tensor<4xf32>) -> () -// CHECK: }) {input_specs = [f32, f32, f32], logical_kernel = "generic.mul_add", output_specs = [f32]} : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> -// CHECK: return %[[region]] : tensor<4xf32> -} - -// CHECK-LABEL: @mul_add_annotated -func @mul_add_annotated(%arg0: tensor<2x4xf32>, %arg1: tensor<2x4xf32>, %arg2: tensor<2x4xf32>) -> (tensor<2x4xf32>) { - %cst = constant dense<0.0> : tensor - %cst_0 = constant dense<255.0> : tensor - %cst_1 = constant dense<8> : tensor - %cst_2 = constant dense : tensor - %qin = "xla_hlo.custom_call"(%arg0, %cst, %cst_0, %cst_1, %cst_2) {backend_config = "", call_target_name = "fake_quant_with_min_max_vars", - has_side_effect = false, name = "custom-call.1"} : (tensor<2x4xf32>, tensor, tensor, tensor, tensor) -> tensor<2x4xf32> - %qw = "xla_hlo.custom_call"(%arg1, %cst, %cst_0, %cst_1, %cst_2) {backend_config = "", call_target_name = "fake_quant_with_min_max_vars", - has_side_effect = false, name = "custom-call.2"} : (tensor<2x4xf32>, tensor, tensor, tensor, tensor) -> tensor<2x4xf32> - %0 = "xla_hlo.multiply"(%qin, %qw) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32> - %1 = "xla_hlo.add"(%0, %arg2) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32> - %r = "xla_hlo.custom_call"(%1, %cst, %cst_0, %cst_1, %cst_2) {backend_config = "", call_target_name = "fake_quant_with_min_max_vars", - has_side_effect = false, name = "custom-call.3"} : (tensor<2x4xf32>, tensor, tensor, tensor, tensor) -> tensor<2x4xf32> - return %r : tensor<2x4xf32> - -// CHECK: %[[region:.*]] = "quant.region" -// CHECK: ^bb0(%arg3: tensor<2x4xf32>, %arg4: tensor<2x4xf32>, %arg5: tensor<2x4xf32>): // no predecessors -// CHECK: %[[mul:.*]] = xla_hlo.multiply %arg3, %arg4 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<2x4xf32> -// CHECK: %[[add:.*]] = xla_hlo.add %[[mul]], %arg5 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<2x4xf32> -// CHECK: "quant.return"(%[[add]]) : (tensor<2x4xf32>) -> () -// CHECK: }) {input_specs = [!quant.uniform, !quant.uniform, f32], -// CHECK-SAME: logical_kernel = "generic.mul_add", output_specs = [!quant.uniform]} : -// CHECK-SAME: (tensor<2x4xf32>, tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32> -// CHECK: %[[r:.*]] = "xla_hlo.custom_call"(%[[region]] -// CHECK: return %[[r]] : tensor<2x4xf32> -} - -// CHECK-LABEL: @reduce_window -func @reduce_window(%arg0: tensor<1x28x28x32xf32>, %arg1: tensor) -> (tensor<1x14x14x32xf32>) { - %0 = "xla_hlo.reduce_window"(%arg0, %arg1) ({ - ^bb0(%arg2: tensor, %arg3: tensor): - %1 = xla_hlo.maximum %arg2, %arg3 : tensor - "xla_hlo.return"(%1) : (tensor) -> () - }) { - base_dilations = dense<1> : tensor<4xi64>, - padding = dense<0> : tensor<4x2xi64>, - window_dilations = dense<1> : tensor<4xi64>, - window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, - window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64> - } : (tensor<1x28x28x32xf32>, tensor) -> tensor<1x14x14x32xf32> - return %0 : tensor<1x14x14x32xf32> - -// CHECK: "quant.region"(%arg0, %arg1) ( { -// CHECK: ^bb0(%arg2: tensor<1x28x28x32xf32>, %arg3: tensor): // no predecessors -// CHECK: %[[rw:.*]] = "xla_hlo.reduce_window"(%arg2, %arg3) ( { -// CHECK: ^bb0(%arg4: tensor, %arg5: tensor): // no predecessors -// CHECK: %[[max:.*]] = xla_hlo.maximum %arg4, %arg5 : tensor -// CHECK: "xla_hlo.return"(%[[max]]) : (tensor) -> () -// CHECK: }) -// CHECK: "quant.return"(%[[rw]]) -// CHECK: }) {input_specs = [f32, f32], logical_kernel = "generic.reduce_window", output_specs = [f32]} -} - -// CHECK-LABEL: @reshape -func @reshape(%arg0: tensor<1x7x7x64xf32>) -> (tensor<1x3136xf32>) { - %0 = "xla_hlo.reshape"(%arg0) : (tensor<1x7x7x64xf32>) -> tensor<1x3136xf32> - return %0 : tensor<1x3136xf32> - -// CHECK: "quant.region"(%arg0) -// CHECK: logical_kernel = "generic.reshape" -} - -// CHECK-LABEL: @broadcast -func @broadcast(%arg0: tensor<64xf32>) -> (tensor<1x14x14x64xf32>) { - %0 = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<64xf32>) -> tensor<1x14x14x64xf32> - return %0 : tensor<1x14x14x64xf32> - -// CHECK: "quant.region"(%arg0) -// CHECK: logical_kernel = "generic.broadcast" -} - -// CHECK-LABEL: @biased_dot -func @biased_dot(%arg0: tensor<1x1024xf32>, %arg1: tensor<1024x10xf32>, %arg2: tensor<1x10xf32>) -> (tensor<1x10xf32>) { - %0 = "xla_hlo.dot"(%arg0, %arg1) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<1x1024xf32>, tensor<1024x10xf32>) -> tensor<1x10xf32> - %1 = xla_hlo.add %0, %arg2 : tensor<1x10xf32> - return %1 : tensor<1x10xf32> - -// CHECK: "quant.region"(%arg0, %arg1, %arg2) -// CHECK: xla_hlo.dot -// CHECK: xla_hlo.add -// CHECK: logical_kernel = "generic.biased_dot" -} - -// CHECK-LABEL: @biased_conv -func @biased_conv(%arg0: tensor<1x14x14x32xf32>, %arg1: tensor<5x5x32x64xf32>, %arg2: tensor<1x14x14x64xf32>) -> (tensor<1x14x14x64xf32>) { - %0 = "xla_hlo.conv"(%arg0, %arg1) {batch_group_count = 1 : i64, dimension_numbers = {input_batch_dimension = 0 : i64, input_feature_dimension = 3 : i64, - input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, kernel_input_feature_dimension = 2 : i64, kernel_output_feature_dimension = 3 : i64, - kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, output_batch_dimension = 0 : i64, output_feature_dimension = 3 : i64, - output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>}, feature_group_count = 1 : i64, lhs_dilations = dense<1> : tensor<2xi64>, - padding = dense<2> : tensor<2x2xi64>, precision_config = ["DEFAULT", "DEFAULT"], rhs_dilations = dense<1> : tensor<2xi64>, window_strides = dense<1> : tensor<2xi64> - } : (tensor<1x14x14x32xf32>, tensor<5x5x32x64xf32>) -> tensor<1x14x14x64xf32> - %1 = xla_hlo.add %0, %arg2 : tensor<1x14x14x64xf32> - return %1 : tensor<1x14x14x64xf32> - -// CHECK: "quant.region"(%arg0, %arg1, %arg2) -// CHECK: xla_hlo.conv -// CHECK: xla_hlo.add -// CHECK: logical_kernel = "generic.biased_conv" -} - -// CHECK-LABEL: @biased_dot_relu -func @biased_dot_relu(%arg0: tensor<1x1024xf32>, %arg1: tensor<1024x10xf32>, %arg2: tensor<1x10xf32>) -> (tensor<1x10xf32>) { - %cst = constant dense<0.0> : tensor<1x10xf32> - %0 = "xla_hlo.dot"(%arg0, %arg1) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<1x1024xf32>, tensor<1024x10xf32>) -> tensor<1x10xf32> - %1 = xla_hlo.add %0, %arg2 : tensor<1x10xf32> - %2 = xla_hlo.maximum %1, %cst : tensor<1x10xf32> - return %2 : tensor<1x10xf32> - -// CHECK: "quant.region"(%arg0, %arg1, %arg2) -// CHECK: constant -// CHECK: xla_hlo.dot -// CHECK: xla_hlo.add -// CHECK: xla_hlo.maximum -// CHECK: logical_kernel = "generic.biased_dot_relu" -} - -// CHECK-LABEL: @biased_conv_relu -func @biased_conv_relu(%arg0: tensor<1x14x14x32xf32>, %arg1: tensor<5x5x32x64xf32>, %arg2: tensor<1x14x14x64xf32>) -> (tensor<1x14x14x64xf32>) { - %cst = constant dense<0.0> : tensor<1x14x14x64xf32> - %0 = "xla_hlo.conv"(%arg0, %arg1) {batch_group_count = 1 : i64, dimension_numbers = {input_batch_dimension = 0 : i64, input_feature_dimension = 3 : i64, - input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, kernel_input_feature_dimension = 2 : i64, kernel_output_feature_dimension = 3 : i64, - kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, output_batch_dimension = 0 : i64, output_feature_dimension = 3 : i64, - output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>}, feature_group_count = 1 : i64, lhs_dilations = dense<1> : tensor<2xi64>, - padding = dense<2> : tensor<2x2xi64>, precision_config = ["DEFAULT", "DEFAULT"], rhs_dilations = dense<1> : tensor<2xi64>, window_strides = dense<1> : tensor<2xi64> - } : (tensor<1x14x14x32xf32>, tensor<5x5x32x64xf32>) -> tensor<1x14x14x64xf32> - %1 = xla_hlo.add %0, %arg2 : tensor<1x14x14x64xf32> - %2 = xla_hlo.maximum %1, %cst : tensor<1x14x14x64xf32> - return %2 : tensor<1x14x14x64xf32> - -// CHECK: "quant.region"(%arg0, %arg1, %arg2) -// CHECK: constant -// CHECK: xla_hlo.conv -// CHECK: xla_hlo.add -// CHECK: xla_hlo.maximum -// CHECK: logical_kernel = "generic.biased_conv_relu" -} - -// CHECK-LABEL: @biased_conv_relu_shared -func @biased_conv_relu_shared(%arg0: tensor<1x14x14x32xf32>, %arg1: tensor<5x5x32x64xf32>, %arg2: tensor<1x14x14x64xf32>) -> (tensor<1x14x14x64xf32>, tensor<1x14x14x64xf32>) { - %cst = constant dense<0.0> : tensor<1x14x14x64xf32> - %0 = "xla_hlo.conv"(%arg0, %arg1) {batch_group_count = 1 : i64, dimension_numbers = {input_batch_dimension = 0 : i64, input_feature_dimension = 3 : i64, - input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, kernel_input_feature_dimension = 2 : i64, kernel_output_feature_dimension = 3 : i64, - kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, output_batch_dimension = 0 : i64, output_feature_dimension = 3 : i64, - output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>}, feature_group_count = 1 : i64, lhs_dilations = dense<1> : tensor<2xi64>, - padding = dense<2> : tensor<2x2xi64>, precision_config = ["DEFAULT", "DEFAULT"], rhs_dilations = dense<1> : tensor<2xi64>, window_strides = dense<1> : tensor<2xi64> - } : (tensor<1x14x14x32xf32>, tensor<5x5x32x64xf32>) -> tensor<1x14x14x64xf32> - %1 = xla_hlo.add %0, %arg2 : tensor<1x14x14x64xf32> - %2 = xla_hlo.maximum %1, %cst : tensor<1x14x14x64xf32> - return %cst, %2 : tensor<1x14x14x64xf32>, tensor<1x14x14x64xf32> - -// CHECK: "quant.region"(%arg0, %arg1, %arg2) -// CHECK: constant -// CHECK: xla_hlo.conv -// CHECK: xla_hlo.add -// CHECK: %[[max:.*]] = xla_hlo.maximum -// CHECK: "quant.return"(%[[max]]) -// CHECK: logical_kernel = "generic.biased_conv_relu" -} - -// CHECK-LABEL: @biased_conv_relu6 -func @biased_conv_relu6(%arg0: tensor<1x14x14x32xf32>, %arg1: tensor<5x5x32x64xf32>, %arg2: tensor<1x14x14x64xf32>) -> (tensor<1x14x14x64xf32>) { - %min = constant dense<0.0> : tensor<1x14x14x64xf32> - %max = constant dense<6.0> : tensor<1x14x14x64xf32> - %0 = "xla_hlo.conv"(%arg0, %arg1) {batch_group_count = 1 : i64, dimension_numbers = {input_batch_dimension = 0 : i64, input_feature_dimension = 3 : i64, - input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, kernel_input_feature_dimension = 2 : i64, kernel_output_feature_dimension = 3 : i64, - kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, output_batch_dimension = 0 : i64, output_feature_dimension = 3 : i64, - output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>}, feature_group_count = 1 : i64, lhs_dilations = dense<1> : tensor<2xi64>, - padding = dense<2> : tensor<2x2xi64>, precision_config = ["DEFAULT", "DEFAULT"], rhs_dilations = dense<1> : tensor<2xi64>, window_strides = dense<1> : tensor<2xi64> - } : (tensor<1x14x14x32xf32>, tensor<5x5x32x64xf32>) -> tensor<1x14x14x64xf32> - %1 = xla_hlo.add %0, %arg2 : tensor<1x14x14x64xf32> - %2 = "xla_hlo.clamp"(%min, %1, %max) : (tensor<1x14x14x64xf32>, tensor<1x14x14x64xf32>, tensor<1x14x14x64xf32>) -> tensor<1x14x14x64xf32> - return %2 : tensor<1x14x14x64xf32> - -// CHECK: "quant.region"(%arg0, %arg1, %arg2) -// CHECK: constant -// CHECK: constant -// CHECK: xla_hlo.conv -// CHECK: xla_hlo.add -// CHECK: xla_hlo.clamp -// CHECK: logical_kernel = "generic.biased_conv_relu6" -} diff --git a/tensorflow/compiler/mlir/lite/quantization/xla/tests/fadd_quant.mlir b/tensorflow/compiler/mlir/lite/quantization/xla/tests/fadd_quant.mlir deleted file mode 100644 index d3e8b48daa8..00000000000 --- a/tensorflow/compiler/mlir/lite/quantization/xla/tests/fadd_quant.mlir +++ /dev/null @@ -1,15 +0,0 @@ -# RUN: not tfcompile --graph=%s.pbtxt --config=%s.config.pbtxt --experimental_quantize --cpp_class="::test::fadd_quant" 2>&1 | FileCheck %s -dump-input-on-failure - -# TODO(fengliuai): update this file with the progress of the implementation -// CHECK: func @main -// CHECK: %cst = constant dense<0.000000e+00> : tensor -// CHECK: %cst_0 = constant dense<1.270000e+02> : tensor -// CHECK: %cst_1 = constant dense<8> : tensor -// CHECK: %cst_2 = constant dense : tensor -// CHECK: %0 = "xla_hlo.custom_call"(%arg0, %cst, %cst_0, %cst_1, %cst_2) {backend_config = "", call_target_name = "fake_quant_with_min_max_vars", has_side_effect = false, name = "custom-call.9"} : (tensor<2x4xf32>, tensor, tensor, tensor, tensor) -> tensor<2x4xf32> -// CHECK: %1 = "xla_hlo.custom_call"(%arg1, %cst, %cst_0, %cst_1, %cst_2) {backend_config = "", call_target_name = "fake_quant_with_min_max_vars", has_side_effect = false, name = "custom-call.14"} : (tensor<2x4xf32>, tensor, tensor, tensor, tensor) -> tensor<2x4xf32> -// CHECK: %2 = xla_hlo.add %0, %1 {name = "add.15"} : tensor<2x4xf32> -// CHECK: %3 = "xla_hlo.custom_call"(%2, %cst, %cst_0, %cst_1, %cst_2) {backend_config = "", call_target_name = "fake_quant_with_min_max_vars", has_side_effect = false, name = "custom-call.20"} : (tensor<2x4xf32>, tensor, tensor, tensor, tensor) -> tensor<2x4xf32> -// CHECK: %4 = "xla_hlo.tuple"(%3) {name = "tuple.22"} : (tensor<2x4xf32>) -> tuple> -// CHECK: return %4 : tuple> -// CHECK: } diff --git a/tensorflow/compiler/mlir/lite/quantization/xla/tests/fadd_quant.mlir.config.pbtxt b/tensorflow/compiler/mlir/lite/quantization/xla/tests/fadd_quant.mlir.config.pbtxt deleted file mode 100644 index 1e97c1fa326..00000000000 --- a/tensorflow/compiler/mlir/lite/quantization/xla/tests/fadd_quant.mlir.config.pbtxt +++ /dev/null @@ -1,26 +0,0 @@ -feed { - id { node_name: "input0" } - shape { - dim { size: 2 } - dim { size: 4 } - } -} -feed { - id { node_name: "input1" } - shape { - dim { size: 2 } - dim { size: 4 } - } -} - -fetch { - id { node_name: "Add/FakeQuantWithMinMaxVars" } - shape { - dim { size: 2 } - dim { size: 4 } - } -} - -conversion_options { - custom_fake_quant_op_calls: true -} diff --git a/tensorflow/compiler/mlir/lite/quantization/xla/tests/fadd_quant.mlir.pbtxt b/tensorflow/compiler/mlir/lite/quantization/xla/tests/fadd_quant.mlir.pbtxt deleted file mode 100644 index 6995c861fd0..00000000000 --- a/tensorflow/compiler/mlir/lite/quantization/xla/tests/fadd_quant.mlir.pbtxt +++ /dev/null @@ -1,218 +0,0 @@ -node: { - name: "Add/FakeQuantWithMinMaxVars" - op: "FakeQuantWithMinMaxVars" - input: "Add" - input: "Add/FakeQuantWithMinMaxVars/min" - input: "Add/FakeQuantWithMinMaxVars/max" - attr: { - key: "num_bits" - value: { - i: 8 - } - } - attr: { - key: "narrow_range" - value: { - b: false - } - } -} -node: { - name: "Add/FakeQuantWithMinMaxVars/min" - op: "Const" - attr: { - key: "value" - value: { - tensor: { - dtype: DT_FLOAT - tensor_shape: { - } - float_val: 0.0 - } - } - } - attr: { - key: "dtype" - value: { - type: DT_FLOAT - } - } -} -node: { - name: "Add/FakeQuantWithMinMaxVars/max" - op: "Const" - attr: { - key: "value" - value: { - tensor: { - dtype: DT_FLOAT - tensor_shape: { - } - float_val: 127.0 - } - } - } - attr: { - key: "dtype" - value: { - type: DT_FLOAT - } - } -} -node { - name: "Add" - op: "Add" - input: "input0/FakeQuantWithMinMaxVars" - input: "input1/FakeQuantWithMinMaxVars" - attr { - key: "T" - value { - type: DT_FLOAT - } - } -} -node: { - name: "input0/FakeQuantWithMinMaxVars" - op: "FakeQuantWithMinMaxVars" - input: "input0" - input: "input0/FakeQuantWithMinMaxVars/min" - input: "input0/FakeQuantWithMinMaxVars/max" - attr: { - key: "num_bits" - value: { - i: 8 - } - } - attr: { - key: "narrow_range" - value: { - b: false - } - } -} -node: { - name: "input0/FakeQuantWithMinMaxVars/min" - op: "Const" - attr: { - key: "value" - value: { - tensor: { - dtype: DT_FLOAT - tensor_shape: { - } - float_val: 0.0 - } - } - } - attr: { - key: "dtype" - value: { - type: DT_FLOAT - } - } -} -node: { - name: "input0/FakeQuantWithMinMaxVars/max" - op: "Const" - attr: { - key: "value" - value: { - tensor: { - dtype: DT_FLOAT - tensor_shape: { - } - float_val: 127.0 - } - } - } - attr: { - key: "dtype" - value: { - type: DT_FLOAT - } - } -} -node { - name: "input0" - op: "Placeholder" - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } -} -node: { - name: "input1/FakeQuantWithMinMaxVars" - op: "FakeQuantWithMinMaxVars" - input: "input1" - input: "input1/FakeQuantWithMinMaxVars/min" - input: "input1/FakeQuantWithMinMaxVars/max" - attr: { - key: "num_bits" - value: { - i: 8 - } - } - attr: { - key: "narrow_range" - value: { - b: false - } - } -} -node: { - name: "input1/FakeQuantWithMinMaxVars/min" - op: "Const" - attr: { - key: "value" - value: { - tensor: { - dtype: DT_FLOAT - tensor_shape: { - } - float_val: 0.0 - } - } - } - attr: { - key: "dtype" - value: { - type: DT_FLOAT - } - } -} -node: { - name: "input1/FakeQuantWithMinMaxVars/max" - op: "Const" - attr: { - key: "value" - value: { - tensor: { - dtype: DT_FLOAT - tensor_shape: { - } - float_val: 127.0 - } - } - } - attr: { - key: "dtype" - value: { - type: DT_FLOAT - } - } -} -node { - name: "input1" - op: "Placeholder" - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } -} -versions { - producer: 27 -} diff --git a/tensorflow/compiler/mlir/lite/quantization/xla/tests/materialize.mlir b/tensorflow/compiler/mlir/lite/quantization/xla/tests/materialize.mlir deleted file mode 100644 index c731d72f752..00000000000 --- a/tensorflow/compiler/mlir/lite/quantization/xla/tests/materialize.mlir +++ /dev/null @@ -1,54 +0,0 @@ -// RUN: tf-opt -xla-hlo-materialize-quant %s | FileCheck %s - -// CHECK-LABEL: func @quantize_rewrite -func @quantize_rewrite(%arg0: tensor<2x4xf32>) -> tensor<2x4xf32> { -// CHECK: %[[qcst:.*]] = constant dense<{{\[\[}}21004416], [-1056997248]]> : tensor<2x1xi32> -// CHECK-NEXT: %[[dq:.*]] = "xla_hlo.dequantize"(%[[qcst]]) {is_16bits = false, max_range = 0.996078431 : f32, min_range = -1.00392163 : f32, -// CHECK-SAME: mode = "MIN_COMBINED", transpose_output = false} : (tensor<2x1xi32>) -> tensor<2x4xbf16> -// CHECK-NEXT: %[[cast:.*]] = "xla_hlo.convert"(%[[dq]]) : (tensor<2x4xbf16>) -> tensor<2x4xf32> -// CHECK-NEXT: %[[mul:.*]] = xla_hlo.multiply %arg0, %[[cast]] : tensor<2x4xf32> -// CHECK-NEXT: return %[[mul]] : tensor<2x4xf32> - - %w = constant dense<[[-1.0, -0.5, 0.0, 0.0], [0.5, 1.0, 0.0, 0.0]]> : tensor<2x4xf32> - %q = "quant.qcast"(%w) : (tensor<2x4xf32>) -> tensor<2x4x!quant.uniform> - %dq = "quant.dcast"(%q) : (tensor<2x4x!quant.uniform>) -> tensor<2x4xf32> - %mul = xla_hlo.multiply %arg0, %dq : tensor<2x4xf32> - return %mul: tensor<2x4xf32> -} - -// CHECK-LABEL: func @quantize_small -func @quantize_small(%arg0: tensor<1x4xf32>) -> tensor<1x4xf32> { -// CHECK: %[[w:.*]] = constant dense<1.000000e+00> : tensor<1x4xf32> -// CHECK-NEXT: %[[mul:.*]] = xla_hlo.multiply %arg0, %[[w]] : tensor<1x4xf32> -// CHECK-NEXT: return %[[mul]] : tensor<1x4xf32> - - %w = constant dense<1.0> : tensor<1x4xf32> - %q = "quant.qcast"(%w) : (tensor<1x4xf32>) -> tensor<1x4x!quant.uniform> - %dq = "quant.dcast"(%q) : (tensor<1x4x!quant.uniform>) -> tensor<1x4xf32> - %mul = xla_hlo.multiply %arg0, %dq : tensor<1x4xf32> - return %mul: tensor<1x4xf32> -} - -// CHECK-LABEL: func @quantize_non_cst -func @quantize_non_cst(%arg0: tensor<2x4xf32>) -> tensor<2x4xf32> { -// CHECK-NEXT: %[[mul:.*]] = xla_hlo.multiply %arg0, %arg0 : tensor<2x4xf32> -// CHECK-NEXT: return %[[mul]] : tensor<2x4xf32> - - %q = "quant.qcast"(%arg0) : (tensor<2x4xf32>) -> tensor<2x4x!quant.uniform> - %dq = "quant.dcast"(%q) : (tensor<2x4x!quant.uniform>) -> tensor<2x4xf32> - %mul = xla_hlo.multiply %arg0, %dq : tensor<2x4xf32> - return %mul: tensor<2x4xf32> -} - -// CHECK-LABEL: func @quantize_non_4x -func @quantize_non_4x(%arg0: tensor<2x5xf32>) -> tensor<2x5xf32> { -// CHECK: %[[w:.*]] = constant dense<1.000000e+00> : tensor<2x5xf32> -// CHECK-NEXT: %[[mul:.*]] = xla_hlo.multiply %arg0, %[[w]] : tensor<2x5xf32> -// CHECK-NEXT: return %[[mul]] : tensor<2x5xf32> - - %w = constant dense<1.0> : tensor<2x5xf32> - %q = "quant.qcast"(%w) : (tensor<2x5xf32>) -> tensor<2x5x!quant.uniform> - %dq = "quant.dcast"(%q) : (tensor<2x5x!quant.uniform>) -> tensor<2x5xf32> - %mul = xla_hlo.multiply %arg0, %dq : tensor<2x5xf32> - return %mul: tensor<2x5xf32> -} diff --git a/tensorflow/compiler/mlir/lite/quantization/xla/tests/region_propagation.mlir b/tensorflow/compiler/mlir/lite/quantization/xla/tests/region_propagation.mlir deleted file mode 100644 index a504be01827..00000000000 --- a/tensorflow/compiler/mlir/lite/quantization/xla/tests/region_propagation.mlir +++ /dev/null @@ -1,69 +0,0 @@ -// RUN: tf-opt -xla-hlo-propagate-quant %s | FileCheck %s --dump-input-on-failure - -// ----- - -// CHECK-LABEL: @mul_add_source_no_params -func @mul_add_source_no_params(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>) { - %region = "quant.region"(%arg0, %arg1, %arg2) ( { - ^bb0(%arg3: tensor<4xf32>, %arg4: tensor<4xf32>, %arg5: tensor<4xf32>): // no predecessors - %mul = xla_hlo.multiply %arg3, %arg4 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<4xf32> - %add = xla_hlo.add %mul, %arg5 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<4xf32> - "quant.return"(%add) : (tensor<4xf32>) -> () - }) {input_specs = [f32, f32, f32], logical_kernel = "generic.mul_add", output_specs = [f32]} : - (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - return %region : tensor<4xf32> - -// CHECK: input_specs = [f32, f32, f32] -// CHECK-SAME: output_specs = [f32] -} - -// ----- - -// CHECK-LABEL: @mul_add_annotated_no_narrow_range -func @mul_add_annotated_no_narrow_range(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>) { - %region = "quant.region"(%arg0, %arg1, %arg2) ( { - ^bb0(%arg3: tensor<4xf32>, %arg4: tensor<4xf32>, %arg5: tensor<4xf32>): // no predecessors - %mul = xla_hlo.multiply %arg3, %arg4 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<4xf32> - %add = xla_hlo.add %mul, %arg5 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<4xf32> - "quant.return"(%add) : (tensor<4xf32>) -> () - }) {input_specs = [!quant.uniform, !quant.uniform, f32], - logical_kernel = "generic.mul_add", output_specs = [!quant.uniform]} : - (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - return %region : tensor<4xf32> - -// CHECK: input_specs = [!quant.uniform, !quant.uniform, f32] -// CHECK-SAME: output_specs = [!quant.uniform] -} - -// ----- - -// CHECK-LABEL: @mul_add_annotated -func @mul_add_annotated(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>) { - %region = "quant.region"(%arg0, %arg1, %arg2) ( { - ^bb0(%arg3: tensor<4xf32>, %arg4: tensor<4xf32>, %arg5: tensor<4xf32>): // no predecessors - %mul = xla_hlo.multiply %arg3, %arg4 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<4xf32> - %add = xla_hlo.add %mul, %arg5 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<4xf32> - "quant.return"(%add) : (tensor<4xf32>) -> () - }) {input_specs = [!quant.uniform, !quant.uniform:f32, 1.0:-128>, f32], - logical_kernel = "generic.mul_add", output_specs = [!quant.uniform]} : - (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - return %region : tensor<4xf32> - -// CHECK: input_specs = [!quant.uniform, !quant.uniform:f32, 1.000000e+00:-128>, !quant.uniform] -// CHECK-SAME: output_specs = [!quant.uniform] -} - -// ----- - -// CHECK-LABEL: @same_scale_1_1 -func @same_scale_1_1(%arg0: tensor<1x7x7x64xf32>) -> (tensor<1x3136xf32>) { - %region = "quant.region"(%arg0) ( { - ^bb0(%arg1: tensor<1x7x7x64xf32>): // no predecessors - %r = "xla_hlo.reshape"(%arg1) : (tensor<1x7x7x64xf32>) -> (tensor<1x3136xf32>) - "quant.return"(%r) : (tensor<1x3136xf32>) -> () - }) {input_specs = [!quant.uniform], logical_kernel = "generic.reshape", output_specs = [f32]} : (tensor<1x7x7x64xf32>) -> tensor<1x3136xf32> - return %region : tensor<1x3136xf32> - -// CHECK: input_specs = [!quant.uniform] -// CHECK-SAME: output_specs = [!quant.uniform] -} diff --git a/tensorflow/compiler/mlir/lite/quantization/xla/tests/weight-only.mlir b/tensorflow/compiler/mlir/lite/quantization/xla/tests/weight-only.mlir deleted file mode 100644 index 8f0936c41af..00000000000 --- a/tensorflow/compiler/mlir/lite/quantization/xla/tests/weight-only.mlir +++ /dev/null @@ -1,25 +0,0 @@ -// RUN: tf-opt -xla-hlo-propagate-quant %s | FileCheck %s - -// CHECK-LABEL: func @mul -func @mul(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { -// CHECK: %[[w:.*]] = constant dense<{{\[\[}}-1.000000e+00, -5.000000e-01], [5.000000e-01, 1.000000e+00]]> : tensor<2x2xf32> -// CHECK-NEXT: %[[q:.*]] = "quant.qcast"(%[[w]]) : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform> -// CHECK-NEXT: %[[dq:.*]] = "quant.dcast"(%[[q]]) : (tensor<2x2x!quant.uniform>) -> tensor<2x2xf32> -// CHECK-NEXT: %[[mul:.*]] = xla_hlo.multiply %arg0, %[[dq]] : tensor<2x2xf32> -// CHECK-NEXT: return %[[mul]] : tensor<2x2xf32> - %w = constant dense<[[-1.0, -0.5], [0.5, 1.0]]> : tensor<2x2xf32> - %mul = xla_hlo.multiply %arg0, %w : tensor<2x2xf32> - return %mul: tensor<2x2xf32> -} - -// CHECK-LABEL: func @add -func @add(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { -// CHECK: %[[b:.*]] = constant dense<1.000000e+00> : tensor<2xf32> -// CHECK-NEXT: %[[q:.*]] = "quant.qcast"(%[[b]]) : (tensor<2xf32>) -> tensor<2x!quant.uniform> -// CHECK-NEXT: %[[dq:.*]] = "quant.dcast"(%[[q]]) : (tensor<2x!quant.uniform>) -> tensor<2xf32> -// CHECK-NEXT: %[[add:.*]] = "xla_hlo.add"(%arg0, %[[dq]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x2xf32>, tensor<2xf32>) -> tensor<2x2xf32> -// CHECK-NEXT: return %[[add]] : tensor<2x2xf32> - %b = constant dense<1.0> : tensor<2xf32> - %add = "xla_hlo.add"(%arg0, %b) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x2xf32>, tensor<2xf32>) -> tensor<2x2xf32> - return %add: tensor<2x2xf32> -} diff --git a/tensorflow/compiler/mlir/lite/tests/end2end/custom_opdef.pbtxt b/tensorflow/compiler/mlir/lite/tests/end2end/custom_opdef.pbtxt index f69b14d8073..345468e609e 100644 --- a/tensorflow/compiler/mlir/lite/tests/end2end/custom_opdef.pbtxt +++ b/tensorflow/compiler/mlir/lite/tests/end2end/custom_opdef.pbtxt @@ -39,7 +39,7 @@ versions { # CHECK-LABEL: func @main # CHECK-SAME: (%[[ARG_0:[a-z0-9]+]]: tensor<4xi32>, %[[ARG_1:[a-z0-9]+]]: tensor<4xi32>) -> tensor<*xi32> # CHECK-SAME: control_outputs = "" -# CHECK-SAME inputs = "input0,input1" +# CHECK-SAME: inputs = "input0,input1" # CHECK-SAME: outputs = "output" # CHECK-NEXT: %[[OP:[a-z0-9]+]] = "tf.BannaPotatoSaladWithColeslaw"(%[[ARG_0]], %[[ARG_1]]) {T = i32, device = ""} : (tensor<4xi32>, tensor<4xi32>) -> tensor<*xi32> # CHECK-NEXT: return %[[OP]] : tensor<*xi32> diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/BUILD b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/BUILD index b52b766a10d..da3fe02562b 100644 --- a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/BUILD +++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/BUILD @@ -12,6 +12,7 @@ glob_lit_tests( test_file_exts = [ "mlir", "cc", + "json", ], ) @@ -24,6 +25,8 @@ filegroup( ":importer_test_min_max", "//tensorflow/compiler/mlir/lite:flatbuffer_to_string", "//tensorflow/compiler/mlir/lite:flatbuffer_translate", + "//tensorflow/compiler/mlir/lite:json_to_flatbuffer", + "//tensorflow/lite/schema:schema.fbs", "@llvm-project//llvm:FileCheck", ], ) diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/import_json.json b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/import_json.json new file mode 100644 index 00000000000..d6d3b142931 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/import_json.json @@ -0,0 +1,83 @@ +// RUN: json_to_flatbuffer %p/../../../../../lite/schema/schema.fbs %s | flatbuffer_translate --tflite-flatbuffer-to-mlir -o - | FileCheck --dump-input-on-failure %s + +// CHECK: %cst = constant unit +// CHECK: %[[RES0:.*]] = "tfl.conv_2d"(%arg0, %arg1, %cst) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 0 : i32, stride_w = 0 : i32} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, none) -> tensor<256x32x32x16xf32> +// CHECK: return %[[RES0]] : tensor<256x32x32x16xf32> + +{ + version: 3, + operator_codes: [ + { + builtin_code: "CONV_2D", + } + ], + subgraphs: [ + { + tensors: [ + { + shape: [ + 256, + 32, + 32, + 3 + ], + name: "arg0", + quantization: { + } + }, + { + shape: [ + 16, + 3, + 3, + 3 + ], + name: "arg1", + quantization: { + } + }, + { + shape: [ + 0 + ], + name: "cst" + }, + { + shape: [ + 256, + 32, + 32, + 16 + ], + name: "output", + quantization: { + } + }, + ], + inputs: [ + 0, + 1 + ], + outputs: [ + 3 + ], + operators: [ + { + inputs: [ + 0, + 1, + -1 + ], + outputs: [ + 3 + ], + builtin_options_type: "Conv2DOptions", + builtin_options: { + } + } + ], + name: "main" + } + ], + description: "MLIR Converted." +} diff --git a/tensorflow/compiler/mlir/lite/tests/inlining.mlir b/tensorflow/compiler/mlir/lite/tests/inlining.mlir index ff7921b8aca..8f19e9983b5 100644 --- a/tensorflow/compiler/mlir/lite/tests/inlining.mlir +++ b/tensorflow/compiler/mlir/lite/tests/inlining.mlir @@ -1,4 +1,4 @@ -// RUN: tf-opt %s -inline -mlir-disable-inline-simplify | FileCheck %s --dump-input=fail +// RUN: tf-opt %s -inline="disable-simplify" | FileCheck %s --dump-input=fail // Inline a function that contains only tfl ops. func @func_with_tfl_ops(%arg0 : tensor<2xi32>) -> tensor<2xi32> { diff --git a/tensorflow/compiler/mlir/lite/tests/legalize-tf-while.mlir b/tensorflow/compiler/mlir/lite/tests/legalize-tf-while.mlir index 50ea5c1da41..c66e6d0145c 100644 --- a/tensorflow/compiler/mlir/lite/tests/legalize-tf-while.mlir +++ b/tensorflow/compiler/mlir/lite/tests/legalize-tf-while.mlir @@ -1,5 +1,5 @@ // RUN: tf-opt --tfl-legalize-tf-while %s -o - | FileCheck %s --dump-input-on-failure -// RUN: tf-opt --tfl-legalize-tf-while %s -o - --tfl-legalize-tf-while --inline --mlir-disable-inline-simplify | FileCheck %s --dump-input-on-failure --check-prefix=INLINE +// RUN: tf-opt --tfl-legalize-tf-while %s -o - --tfl-legalize-tf-while --inline="disable-simplify" | FileCheck %s --dump-input-on-failure --check-prefix=INLINE // RUN: tf-opt --tfl-legalize-tf-while %s -o - --tfl-legalize-tf-while --inline | FileCheck %s --dump-input-on-failure --check-prefix=CANON func @while_main(%arg0: tensor) -> (tensor, tensor<256x256xf32>, tensor) attributes {tf.entry_function = {inputs = "input", outputs = "Identity,Identity_1,Identity_2"}} { diff --git a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir index 7e9b1bdb711..8c39b8e9b06 100644 --- a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir @@ -9,6 +9,20 @@ func @add(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> { // CHECK: return } +// CHECK-LABEL: testAddHighDimsHaveSameShape +func @testAddHighDimsHaveSameShape(%arg0: tensor<1x2x3x4x5x6x7x8xi32>, %arg1: tensor<1x2x3x4x5x6x7x8xi32>) -> tensor<1x2x3x4x5x6x7x8xi32> { + // CHECK: tfl.add %arg0, %arg1 {fused_activation_function = "NONE"} + %0 = "tf.Add"(%arg0, %arg1) : (tensor<1x2x3x4x5x6x7x8xi32>, tensor<1x2x3x4x5x6x7x8xi32>) -> tensor<1x2x3x4x5x6x7x8xi32> + return %0 : tensor<1x2x3x4x5x6x7x8xi32> +} + +// CHECK-LABEL: testAddTooHighBroadcastableDims +func @testAddTooHighBroadcastableDims(%arg0: tensor<1x2x3x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> { + // expected-error @+1 {{'tfl.add' op failed to verify that operand #0 and operand #1 have the same shape or broadcastable shapes within the rank 4}} + %0 = "tf.Add"(%arg0, %arg1) : (tensor<1x2x3x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> + return %0 : tensor<1x2x3x4x5x6xi32> +} + func @LeakyRelu(%arg0: tensor<1xf32>) -> tensor<1xf32> { %2 = "tf.LeakyRelu"(%arg0) {alpha = 0.1 : f32} : (tensor<1xf32>) -> tensor<1xf32> return %2: tensor<1xf32> @@ -1448,7 +1462,7 @@ func @broadcast_to_f32(%arg0: tensor<3xf32>, %arg1: tensor<2xi32>) -> tensor<3x3 // CHECK: [[CST:%.*]] = constant dense<1.000000e+00> : tensor // CHECK: [[FILL:%.*]] = "tfl.fill"(%arg1, [[CST]]) : (tensor<2xi32>, tensor) -> tensor<3x3xf32> // CHECK: [[MUL:%.*]] = "tfl.mul"(%arg0, [[FILL]]) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32> -// CHECK return [[MUL]] : tensor<3x3xf32> +// CHECK: return [[MUL]] : tensor<3x3xf32> } func @broadcast_to_i32(%input: tensor<3xi32>, %shape: tensor<2xi32>) -> tensor<3x3xi32> { @@ -1459,5 +1473,5 @@ func @broadcast_to_i32(%input: tensor<3xi32>, %shape: tensor<2xi32>) -> tensor<3 // CHECK: [[CST:%.*]] = constant dense<1> : tensor // CHECK: [[FILL:%.*]] = "tfl.fill"(%arg1, [[CST]]) : (tensor<2xi32>, tensor) -> tensor<3x3xi32> // CHECK: [[MUL:%.*]] = "tfl.mul"(%arg0, [[FILL]]) {fused_activation_function = "NONE"} : (tensor<3xi32>, tensor<3x3xi32>) -> tensor<3x3xi32> -// CHECK return [[MUL]] : tensor<3x3xi32> +// CHECK: return [[MUL]] : tensor<3x3xi32> } diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc b/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc index 69a42d884cb..57f15719cfd 100644 --- a/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc +++ b/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc @@ -29,7 +29,7 @@ limitations under the License. namespace mlir { /// Create a pass to convert from the TFExecutor to the TF control dialect. -std::unique_ptr> +std::unique_ptr> CreateTFExecutorToControlDialectConversion(); } // namespace mlir @@ -134,6 +134,10 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config, pass_manager->addPass( mlir::tf_saved_model::CreateOptimizeGlobalTensorsPass()); + if (pass_config.shape_inference) { + // Add a shape inference pass to optimize away the unnecessary casts. + pass_manager->addPass(mlir::TF::CreateTFShapeInferencePass()); + } // Legalize while early to allow further constant folding. // TODO(jpienaar): This may not actually matter as we do canonicalization // after the legalize below, for now it needs to be below the above passes @@ -160,11 +164,6 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config, // constant ops. pass_manager->addPass(mlir::tf_saved_model::CreateFreezeGlobalTensorsPass()); - if (pass_config.shape_inference) { - // Add a shape inference pass to optimize away the unnecessary casts. - pass_manager->addPass(mlir::TF::CreateTFShapeInferencePass()); - } - // The below passes only make sense if Builtin TFLite ops are enabled // for emission. if (pass_config.emit_builtin_tflite_ops) { @@ -173,7 +172,8 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config, pass_manager->addPass( mlir::TFL::CreatePrepareTFPass(pass_config.unfold_batch_matmul)); pass_manager->addNestedPass(mlir::createCanonicalizerPass()); - pass_manager->addPass(mlir::TFL::CreateLegalizeTFPass()); + pass_manager->addPass( + mlir::TFL::CreateLegalizeTFPass(pass_config.runtime_verification)); pass_manager->addPass(mlir::TFL::CreateOptimizePass()); // This pass operates on TensorFlow ops but is triggered after legalization // so that it can target constants introduced once TensorFlow Identity ops @@ -255,7 +255,8 @@ void CreateTFLStandardPipeline(OpPassManager& pm, // TFLite dialect passes. pm.addPass(mlir::TFL::CreatePrepareTFPass(true)); pm.addNestedPass(mlir::createCanonicalizerPass()); - pm.addPass(mlir::TFL::CreateLegalizeTFPass()); + pm.addPass( + mlir::TFL::CreateLegalizeTFPass(/*run_tfl_runtime_verification=*/true)); pm.addPass(mlir::TFL::CreateOptimizePass()); pm.addPass(mlir::TFL::CreateOptimizeFunctionalOpsPass()); @@ -268,7 +269,7 @@ void CreateTFLStandardPipeline(OpPassManager& pm, pm.addPass(mlir::TFL::CreateWhileOutlinePass()); - pm.addPass(mlir::TFL::CreateRuntimeTypeVerifyPass()); + pm.addPass(mlir::TFL::CreateRuntimeVerifyPass()); } // Registers a pass pipeline for the standard TFL passes. diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc b/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc index 038adebabef..ade1e325617 100644 --- a/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc +++ b/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "absl/strings/str_split.h" #include "llvm/ADT/StringExtras.h" #include "llvm/Support/CommandLine.h" @@ -214,7 +216,7 @@ int main(int argc, char **argv) { if (pass_config.legalize_tf_while) { pm.addPass(mlir::TFL::CreateWhileOutlinePass()); } - pm.addPass(mlir::TFL::CreateRuntimeTypeVerifyPass()); + pm.addPass(mlir::TFL::CreateRuntimeVerifyPass()); std::string result; auto status = tensorflow::ConvertTFExecutorToTFLOrFlatbuffer( diff --git a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc index aacc1ad2fd6..0c82a71f952 100644 --- a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc @@ -40,7 +40,7 @@ limitations under the License. namespace mlir { /// Create a pass to convert from the TFExecutor to the TF control dialect. -std::unique_ptr> +std::unique_ptr> CreateTFExecutorToControlDialectConversion(); } // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/transforms/default_quant_params.cc b/tensorflow/compiler/mlir/lite/transforms/default_quant_params.cc index 0bbacd48ade..0319e8555fa 100644 --- a/tensorflow/compiler/mlir/lite/transforms/default_quant_params.cc +++ b/tensorflow/compiler/mlir/lite/transforms/default_quant_params.cc @@ -44,7 +44,8 @@ namespace TFL { #include "tensorflow/compiler/mlir/lite/utils/generated_op_quant_spec_getters.inc" namespace { -class DefaultQuantParamsPass : public FunctionPass { +class DefaultQuantParamsPass + : public PassWrapper { public: explicit DefaultQuantParamsPass(double default_min, double default_max) : default_min_(default_min), default_max_(default_max) {} @@ -220,7 +221,7 @@ quant::QuantParams DefaultQuantParamsPass::GetDefaultQuantParams( } // Creates an instance of the default quant parameters pass. -std::unique_ptr> CreateDefaultQuantParamsPass( +std::unique_ptr> CreateDefaultQuantParamsPass( double default_min, double default_max) { return absl::make_unique(default_min, default_max); } diff --git a/tensorflow/compiler/mlir/lite/transforms/dense_to_sparse.cc b/tensorflow/compiler/mlir/lite/transforms/dense_to_sparse.cc index 2341c0306f1..4c3a95dc2a4 100644 --- a/tensorflow/compiler/mlir/lite/transforms/dense_to_sparse.cc +++ b/tensorflow/compiler/mlir/lite/transforms/dense_to_sparse.cc @@ -29,7 +29,7 @@ namespace TFL { namespace { -struct DenseToSparse : public FunctionPass { +struct DenseToSparse : public PassWrapper { void runOnFunction() override; }; @@ -63,7 +63,7 @@ void DenseToSparse::runOnFunction() { } // namespace // Creates an instance of the TensorFlow Lite dialect DenseToSparse pass. -std::unique_ptr> CreateDenseToSparsePass() { +std::unique_ptr> CreateDenseToSparsePass() { return absl::make_unique(); } diff --git a/tensorflow/compiler/mlir/lite/transforms/dilated_conv.cc b/tensorflow/compiler/mlir/lite/transforms/dilated_conv.cc index 01430d99a65..23af1ffee64 100644 --- a/tensorflow/compiler/mlir/lite/transforms/dilated_conv.cc +++ b/tensorflow/compiler/mlir/lite/transforms/dilated_conv.cc @@ -18,7 +18,8 @@ namespace mlir { namespace TFL { namespace { -struct IdentifyDilatedConvPass : public FunctionPass { +struct IdentifyDilatedConvPass + : public PassWrapper { void runOnFunction() override; }; diff --git a/tensorflow/compiler/mlir/lite/transforms/extract_ophint.cc b/tensorflow/compiler/mlir/lite/transforms/extract_ophint.cc index 51b14d2013b..1d50c4dc29b 100644 --- a/tensorflow/compiler/mlir/lite/transforms/extract_ophint.cc +++ b/tensorflow/compiler/mlir/lite/transforms/extract_ophint.cc @@ -679,7 +679,8 @@ LogicalResult ConvertOphintToStub(StringRef stub_name, return success(); } -struct ExtractOphintPass : public OperationPass { +struct ExtractOphintPass + : public PassWrapper> { void runOnOperation() override; void Verify(); @@ -752,7 +753,7 @@ void ExtractOphintPass::Verify() { /// Creates an instance of the TensorFlow Lite dialect ExtractOphintPass /// pass. -std::unique_ptr> CreateExtractOphintPass() { +std::unique_ptr> CreateExtractOphintPass() { return std::make_unique(); } diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_ophint_func_op.cc b/tensorflow/compiler/mlir/lite/transforms/legalize_ophint_func_op.cc index 299a8774db6..652d10a53a8 100644 --- a/tensorflow/compiler/mlir/lite/transforms/legalize_ophint_func_op.cc +++ b/tensorflow/compiler/mlir/lite/transforms/legalize_ophint_func_op.cc @@ -69,7 +69,7 @@ constexpr char kUnidirectionalSequenceLstm[] = "UnidirectionalSequenceLstm"; // | // OutputOp1 struct LegalizeOphintFuncOpPass - : public OperationPass { + : public PassWrapper> { void runOnOperation() override; }; @@ -284,7 +284,7 @@ void LegalizeOphintFuncOpPass::runOnOperation() { /// Creates an instance of the TensorFlow Lite dialect LegalizeOphintFuncOpPass /// pass. -std::unique_ptr> CreateLegalizeOphintFuncOpPass() { +std::unique_ptr> CreateLegalizeOphintFuncOpPass() { return std::make_unique(); } diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc index 98501aaa803..d9b33f3fa72 100644 --- a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc @@ -70,8 +70,21 @@ constexpr char kUnidirectionalSequenceRnn[] = "tf.UnidirectionalSequenceRnn"; constexpr char kTfLiteInputIndices[] = "_tflite_input_indices"; // Legalize operations in functions. -struct LegalizeTF : public FunctionPass { +class LegalizeTF : public PassWrapper { + public: + LegalizeTF() = default; + LegalizeTF(const LegalizeTF&) {} + explicit LegalizeTF(bool run_tfl_runtime_verification) { + run_tfl_runtime_verification_ = run_tfl_runtime_verification; + } + + /// Performs the lowering to TFLite dialect. void runOnFunction() override; + + private: + Option run_tfl_runtime_verification_{ + *this, "run-tfl-runtime-verification", + llvm::cl::desc("Allow tfl runtime verification."), llvm::cl::init(true)}; }; // Returns true if all tensor value in `values` has static shape and same shape. @@ -314,7 +327,7 @@ Value PadStridedSliceAttributeArray(Operation* op, PatternRewriter& rewriter, // can't do any padding. Instead we just return it. return attribute; } - for (auto idx : dense_elem_attr.getIntValues()) { + for (const auto& idx : dense_elem_attr.getIntValues()) { padded_val.push_back(idx.getSExtValue()); } auto attr_dim_count = ranked_attr_type.getShape()[0]; @@ -440,7 +453,7 @@ bool ConvertTFMatrixDiagV2orV3(Operation* op, PatternRewriter* rewriter) { if (!matchPattern(tf_matrix_diag_v2_or_v3_op.padding_value(), m_Constant(&padding_value))) return false; - for (auto value : padding_value.getValues()) { + for (const auto& value : padding_value.getValues()) { if (value != 0) return false; } @@ -741,13 +754,19 @@ void LegalizeTF::runOnFunction() { // graph. target.addLegalOp(); target.addLegalOp(); - target.addDynamicallyLegalDialect( - Optional([](Operation* op) { - auto tfl_op = dyn_cast_or_null(op); - if (!tfl_op) return false; - return succeeded(tfl_op.VerifyTflRuntimeTypes( - tfl_op.getOperation(), /*failure_on_operand_type_mismatch=*/false)); - })); + if (run_tfl_runtime_verification_) { + target.addDynamicallyLegalDialect( + Optional( + [](Operation* op) { + auto tfl_op = dyn_cast_or_null(op); + if (!tfl_op) return false; + return succeeded(tfl_op.VerifyTflRuntimeConstraints( + tfl_op.getOperation(), + /*failure_on_operand_type_mismatch=*/false)); + })); + } else { + target.addLegalDialect(); + } // Keep trying to convert. // TODO(karimnosseir): This is similar to what apply greedy patterns does. // Look if there is a function that tries until it converge. @@ -763,8 +782,9 @@ void LegalizeTF::runOnFunction() { } // namespace // Creates an instance of the TensorFlow Lite dialect LegalizeTF pass. -std::unique_ptr> CreateLegalizeTFPass() { - return std::make_unique(); +std::unique_ptr> CreateLegalizeTFPass( + bool run_tfl_runtime_verification) { + return std::make_unique(run_tfl_runtime_verification); } static PassRegistration pass( diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_tf_while.cc b/tensorflow/compiler/mlir/lite/transforms/legalize_tf_while.cc index e85a85f26cb..31e3f6dd005 100644 --- a/tensorflow/compiler/mlir/lite/transforms/legalize_tf_while.cc +++ b/tensorflow/compiler/mlir/lite/transforms/legalize_tf_while.cc @@ -31,7 +31,8 @@ namespace { // Legalize TF While to TFL While with calls to the original functions from the // cond and body regions. -struct LegalizeWhile : public OperationPass { +struct LegalizeWhile + : public PassWrapper> { void RunOnFunction(FuncOp func); void runOnOperation() override { @@ -76,7 +77,7 @@ void LegalizeWhile::RunOnFunction(FuncOp func) { } // Creates an instance of the TensorFlow While to TFLite While pass. -std::unique_ptr> CreateLegalizeTFWhilePass() { +std::unique_ptr> CreateLegalizeTFWhilePass() { return std::make_unique(); } diff --git a/tensorflow/compiler/mlir/lite/transforms/load_quantization_recipe.cc b/tensorflow/compiler/mlir/lite/transforms/load_quantization_recipe.cc index 3d42f81a758..307a45639c5 100644 --- a/tensorflow/compiler/mlir/lite/transforms/load_quantization_recipe.cc +++ b/tensorflow/compiler/mlir/lite/transforms/load_quantization_recipe.cc @@ -42,7 +42,8 @@ namespace { // AnyQuantizedType, thus bitwidth, narrow_range, etc are included. The op also // defines the op quantization traits, which are used to propagate the // quantization parameters by the following passes. -struct LoadQuantizationRecipe : public FunctionPass { +struct LoadQuantizationRecipe + : public PassWrapper { void runOnFunction() override; private: @@ -215,7 +216,7 @@ void LoadQuantizationRecipe::runOnFunction() { // Creates an instance of the TensorFlow Lite dialect LoadQuantizationRecipe // pass. -std::unique_ptr> CreateLoadQuantizationRecipePass() { +std::unique_ptr> CreateLoadQuantizationRecipePass() { return absl::make_unique(); } diff --git a/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc b/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc index 17d0f6743a1..889f9dde00b 100644 --- a/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc +++ b/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc @@ -82,7 +82,7 @@ class TensorListPatternRewriter : public PatternRewriter { /// Lower TensorList ops in functions for subsequent legalization. struct LowerStaticTensorListPass - : public OperationPass { + : public PassWrapper> { void runOnOperation() override; // Apply type and op changes within a function. @@ -720,7 +720,7 @@ struct ConvertTensorListStack RankedTensorType::get({-1}, rewriter.getIntegerType(32)); auto new_shape = rewriter.create(loc, shape_type, input); SmallVector output_shape = {op.num_elements().getSExtValue()}; - for (auto dim : dense_elem_attr.getIntValues()) + for (const auto &dim : dense_elem_attr.getIntValues()) output_shape.push_back(dim.getSExtValue()); RankedTensorType result_type = RankedTensorType::get(output_shape, getElementTypeOrSelf(input)); @@ -906,7 +906,8 @@ void LowerStaticTensorListPass::runOnOperation() { /// Creates an instance of the TensorFlow Lite dialect LowerStaticTensorList /// pass. -std::unique_ptr> TFL::CreateLowerStaticTensorListPass() { +std::unique_ptr> +TFL::CreateLowerStaticTensorListPass() { return std::make_unique(); } diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize.cc b/tensorflow/compiler/mlir/lite/transforms/optimize.cc index e324f614ca4..ad1577674fd 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/optimize.cc @@ -74,7 +74,7 @@ bool L2NormalizeReduceAxis(Value sq_op, DenseElementsAttr axis) { using ::llvm::cast; // Optimize TFLite operations in functions. -struct Optimize : public FunctionPass { +struct Optimize : public PassWrapper { void runOnFunction() override; }; @@ -650,7 +650,7 @@ struct ConvertTrivialTransposeOpToReshapeOp auto input_shape = input_type.getShape(); SmallVector perm_values; - for (auto dim : perm_values_attr.getIntValues()) + for (const auto &dim : perm_values_attr.getIntValues()) perm_values.push_back(dim.getSExtValue()); // This should never happen unless the input graph is malformed. @@ -725,7 +725,7 @@ void Optimize::runOnFunction() { } // namespace // Creates an instance of the TensorFlow Lite dialect Optimize pass. -std::unique_ptr> CreateOptimizePass() { +std::unique_ptr> CreateOptimizePass() { return std::make_unique(); } diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_functional_ops.cc b/tensorflow/compiler/mlir/lite/transforms/optimize_functional_ops.cc index 302194e1293..062a78e14d4 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize_functional_ops.cc +++ b/tensorflow/compiler/mlir/lite/transforms/optimize_functional_ops.cc @@ -36,7 +36,7 @@ using FuncSet = llvm::SmallSet; // Module pass to optimize TensorFlow functional ops. struct OptimizeFunctionalOpsPass - : public OperationPass { + : public PassWrapper> { void runOnOperation() override; }; @@ -198,7 +198,7 @@ void OptimizeFunctionalOpsPass::runOnOperation() { } } // namespace -std::unique_ptr> CreateOptimizeFunctionalOpsPass() { +std::unique_ptr> CreateOptimizeFunctionalOpsPass() { return std::make_unique(); } diff --git a/tensorflow/compiler/mlir/lite/transforms/passes.h b/tensorflow/compiler/mlir/lite/transforms/passes.h index 1c92c806585..a744a570929 100644 --- a/tensorflow/compiler/mlir/lite/transforms/passes.h +++ b/tensorflow/compiler/mlir/lite/transforms/passes.h @@ -24,75 +24,79 @@ namespace mlir { class FuncOp; class ModuleOp; template -class OpPassBase; +class OperationPass; namespace TFL { class QuantizationSpecs; // Creates an instance of the TensorFlow Lite dialect LegalizeTF pass. -std::unique_ptr> CreateLegalizeTFPass(); +// When the given run_tfl_runtime_verification value is true, it will check each +// TFL builtin op towards the TFL runtime capability and the incompatible TF ops +// will be left in the graph without getting legalized. +std::unique_ptr> CreateLegalizeTFPass( + bool run_tfl_runtime_verification); // Creates an instance of the TensorFlow Lite dialect Optimize pass. -std::unique_ptr> CreateOptimizePass(); +std::unique_ptr> CreateOptimizePass(); // Creates an instance of the TensorFlow Lite dialect PrepareTF pass. -std::unique_ptr> CreatePrepareTFPass( +std::unique_ptr> CreatePrepareTFPass( bool unfold_batch_matmul); // Creates an instance of the TensorFlow Lite dialect LowerStaticTensorList // pass. -std::unique_ptr> CreateLowerStaticTensorListPass(); +std::unique_ptr> CreateLowerStaticTensorListPass(); // Creates an instance of the TensorFlow Lite dialect Quantize pass. -std::unique_ptr> CreateQuantizePass(); +std::unique_ptr> CreateQuantizePass(); // Creates an instance of the TensorFlow Lite dialect PrepareQuantize pass. -std::unique_ptr> CreatePrepareQuantizePass( +std::unique_ptr> CreatePrepareQuantizePass( const QuantizationSpecs& quant_specs); // Creates an instance of the TensorFlow Lite dialect PostQuantize pass. -std::unique_ptr> CreatePostQuantizePass( +std::unique_ptr> CreatePostQuantizePass( bool emit_quant_adaptor_ops); // Creates an instance of the TensorFlow Lite dialect TrimFunctions // pass. -std::unique_ptr> CreateTrimFunctionsPass( +std::unique_ptr> CreateTrimFunctionsPass( llvm::ArrayRef trim_funcs_whitelist); // Creates an instance of the TensorFlow Lite dialect PrepareCompositeFunctions // pass. -std::unique_ptr> CreatePrepareCompositeFunctionsPass(); +std::unique_ptr> CreatePrepareCompositeFunctionsPass(); // Creates an instance of the TensorFlow Lite dialect ExtractOphint pass. -std::unique_ptr> CreateExtractOphintPass(); +std::unique_ptr> CreateExtractOphintPass(); // Creates an instance of the TensorFlow Lite dialect LegalizeOphintFuncOpPass // pass. The composite op is created from the ophint extraction pass. -std::unique_ptr> CreateLegalizeOphintFuncOpPass(); +std::unique_ptr> CreateLegalizeOphintFuncOpPass(); // Creates an instance of the TensorFlow Lite dialect SplitMergedOperandsPass. -std::unique_ptr> CreateSplitMergedOperandsPass(); +std::unique_ptr> CreateSplitMergedOperandsPass(); // Creates an instance of the TensorFlow Lite dialect OptimizeFunctionalOpsPass. -std::unique_ptr> CreateOptimizeFunctionalOpsPass(); +std::unique_ptr> CreateOptimizeFunctionalOpsPass(); // Creates an instance of the TensorFlow Lite dialect pass to add default // quantization parameters. -std::unique_ptr> CreateDefaultQuantParamsPass( +std::unique_ptr> CreateDefaultQuantParamsPass( double default_min, double default_max); // Creates an instance of the TensorFlow Lite dialect pass to convert dense // tensor to sparse format. -std::unique_ptr> CreateDenseToSparsePass(); +std::unique_ptr> CreateDenseToSparsePass(); // Creates function pass to legalize TF While to TFL While. -std::unique_ptr> CreateLegalizeTFWhilePass(); +std::unique_ptr> CreateLegalizeTFWhilePass(); // Creates an instance of the TensorFlow Lite dialect WhileOp outline pass. -std::unique_ptr> CreateWhileOutlinePass(); +std::unique_ptr> CreateWhileOutlinePass(); -// Verifies runtime supports types used. -std::unique_ptr> CreateRuntimeTypeVerifyPass(); +// Verifies runtime constraints. +std::unique_ptr> CreateRuntimeVerifyPass(); } // namespace TFL diff --git a/tensorflow/compiler/mlir/lite/transforms/post_quantize.cc b/tensorflow/compiler/mlir/lite/transforms/post_quantize.cc index 86d23a2b0b2..e737e32044d 100644 --- a/tensorflow/compiler/mlir/lite/transforms/post_quantize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/post_quantize.cc @@ -30,7 +30,7 @@ namespace TFL { namespace { // Applies all the clean up steps after quantization. -class PostQuantizePass : public FunctionPass { +class PostQuantizePass : public PassWrapper { public: // Constructor used by the PassRegistration. This will remove the adaptor ops. explicit PostQuantizePass() : emit_quant_adaptor_ops_(false) {} @@ -135,7 +135,7 @@ void PostQuantizePass::runOnFunction() { } // namespace // Creates an instance of the TensorFlow Lite dialect PostQuantize pass. -std::unique_ptr> CreatePostQuantizePass( +std::unique_ptr> CreatePostQuantizePass( bool emit_quant_adaptor_ops) { return std::make_unique(emit_quant_adaptor_ops); } diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc index c29e85a0f4d..6179eb2ce64 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc @@ -94,7 +94,8 @@ class ConvertEmbeddedLookupFunc { // body with the corresponding fused TFLite op. The replacement need not always // be a fused op, though that is the primary use case. class PrepareCompositeFunctionsPass - : public OperationPass { + : public PassWrapper> { public: explicit PrepareCompositeFunctionsPass() {} @@ -211,7 +212,7 @@ void PrepareCompositeFunctionsPass::runOnOperation() { } } // namespace -std::unique_ptr> CreatePrepareCompositeFunctionsPass() { +std::unique_ptr> CreatePrepareCompositeFunctionsPass() { return std::make_unique(); } diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc index cdbf4c41539..3387015ed31 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc @@ -66,7 +66,8 @@ namespace { // across ops. This step is necessary for post-training quantization and also // making the quantization rule for some operations in the quantization-aware // training quantization simpler. -class PrepareQuantizePass : public FunctionPass { +class PrepareQuantizePass + : public PassWrapper { public: // Constructor used by the PassRegistration and enforce uint8 quantization. explicit PrepareQuantizePass() { @@ -281,7 +282,7 @@ void PrepareQuantizePass::runOnFunction() { } // namespace // Creates an instance of the TensorFlow Lite dialect PrepareQuantize pass. -std::unique_ptr> CreatePrepareQuantizePass( +std::unique_ptr> CreatePrepareQuantizePass( const QuantizationSpecs& quant_specs) { return std::make_unique(quant_specs); } diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc index f79543e6db6..012599e96c2 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc @@ -71,7 +71,7 @@ namespace TFL { namespace { // Prepare TF operations in functions for subsequent legalization. -class PrepareTFPass : public FunctionPass { +class PrepareTFPass : public PassWrapper { public: explicit PrepareTFPass() : unfold_batch_matmul_(true) {} explicit PrepareTFPass(bool unfold_batch_matmul) @@ -652,7 +652,7 @@ void PrepareTFPass::runOnFunction() { } // namespace // Creates an instance of the TensorFlow Lite dialect PrepareTF pass. -std::unique_ptr> CreatePrepareTFPass( +std::unique_ptr> CreatePrepareTFPass( bool unfold_batch_matmul) { return std::make_unique(unfold_batch_matmul); } diff --git a/tensorflow/compiler/mlir/lite/transforms/quantize.cc b/tensorflow/compiler/mlir/lite/transforms/quantize.cc index 3be335e8c7b..ed40fce3099 100644 --- a/tensorflow/compiler/mlir/lite/transforms/quantize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/quantize.cc @@ -75,7 +75,7 @@ struct TFLFullQuantization }; // Applies quantization on the model in TFL dialect. -struct QuantizePass : public FunctionPass { +struct QuantizePass : public PassWrapper { void runOnFunction() override; }; @@ -93,7 +93,7 @@ void QuantizePass::runOnFunction() { } // namespace // Creates an instance of the TensorFlow Lite dialect QuantizeTFL pass. -std::unique_ptr> CreateQuantizePass() { +std::unique_ptr> CreateQuantizePass() { return std::make_unique(); } diff --git a/tensorflow/compiler/mlir/lite/transforms/runtime_type_verify.cc b/tensorflow/compiler/mlir/lite/transforms/runtime_verify.cc similarity index 63% rename from tensorflow/compiler/mlir/lite/transforms/runtime_type_verify.cc rename to tensorflow/compiler/mlir/lite/transforms/runtime_verify.cc index d103209ffd9..3268329b1c1 100644 --- a/tensorflow/compiler/mlir/lite/transforms/runtime_type_verify.cc +++ b/tensorflow/compiler/mlir/lite/transforms/runtime_verify.cc @@ -22,33 +22,32 @@ namespace mlir { namespace TFL { namespace { -// This pass verifies that the operands and results types are supported by -// TFLite runtime. -class RuntimeTypeVerifyPass : public mlir::FunctionPass { +// This pass verifies that the TFL ops meet the TFL runtime constraints. +class RuntimeVerifyPass + : public mlir::PassWrapper { public: - explicit RuntimeTypeVerifyPass() {} + explicit RuntimeVerifyPass() {} private: void runOnFunction() override; }; -void RuntimeTypeVerifyPass::runOnFunction() { +void RuntimeVerifyPass::runOnFunction() { getFunction().walk([&](TflRuntimeVerifyOpInterface op) { - if (failed(op.VerifyTflRuntimeTypes( - op.getOperation(), - /*failure_on_operand_type_mismatch=*/true))) + if (failed(op.VerifyTflRuntimeConstraints( + op.getOperation(), /*failure_on_operand_type_mismatch=*/true))) signalPassFailure(); }); } } // namespace -// Verifies runtime supports types used. -std::unique_ptr> CreateRuntimeTypeVerifyPass() { - return std::make_unique(); +// Verifies TFL runtime constraints. +std::unique_ptr> CreateRuntimeVerifyPass() { + return std::make_unique(); } -static PassRegistration pass( - "tfl-runtime-verify", "TFLite runtime verification"); +static PassRegistration pass("tfl-runtime-verify", + "TFLite runtime verification"); } // namespace TFL } // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/transforms/split_merged_operands.cc b/tensorflow/compiler/mlir/lite/transforms/split_merged_operands.cc index 7f745727c49..5eb0dc1ab1a 100644 --- a/tensorflow/compiler/mlir/lite/transforms/split_merged_operands.cc +++ b/tensorflow/compiler/mlir/lite/transforms/split_merged_operands.cc @@ -66,7 +66,8 @@ namespace mlir { namespace TFL { namespace { -struct SplitMergedOperandsPass : public FunctionPass { +struct SplitMergedOperandsPass + : public PassWrapper { void runOnFunction() override; }; @@ -119,7 +120,7 @@ void SplitMergedOperandsPass::runOnFunction() { /// Creates an instance of the TensorFlow Lite dialect SplitMergedOperands /// pass. -std::unique_ptr> CreateSplitMergedOperandsPass() { +std::unique_ptr> CreateSplitMergedOperandsPass() { return std::make_unique(); } diff --git a/tensorflow/compiler/mlir/lite/transforms/trim_functions_tf.cc b/tensorflow/compiler/mlir/lite/transforms/trim_functions_tf.cc index 41adc21db35..013ffc26ea8 100644 --- a/tensorflow/compiler/mlir/lite/transforms/trim_functions_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/trim_functions_tf.cc @@ -45,7 +45,7 @@ namespace { // The pass to trim functions before we legalize to TFL // dialect using the specified whitelist. class TrimFunctionsPass - : public mlir::OperationPass { + : public mlir::PassWrapper> { public: explicit TrimFunctionsPass() : trim_funcs_whitelist_(trim_funcs_whitelist) {} explicit TrimFunctionsPass(llvm::ArrayRef trim_funcs_whitelist) @@ -120,7 +120,7 @@ void TrimFunctionsPass::Verify() { // Creates an instance of the TensorFlow Lite dialect TrimFunctions /// pass. -std::unique_ptr> CreateTrimFunctionsPass( +std::unique_ptr> CreateTrimFunctionsPass( llvm::ArrayRef trim_funcs_whitelist) { return std::make_unique(trim_funcs_whitelist); } diff --git a/tensorflow/compiler/mlir/lite/transforms/while_loop_outline.cc b/tensorflow/compiler/mlir/lite/transforms/while_loop_outline.cc index c2acb93fe78..a7f2a625e65 100644 --- a/tensorflow/compiler/mlir/lite/transforms/while_loop_outline.cc +++ b/tensorflow/compiler/mlir/lite/transforms/while_loop_outline.cc @@ -38,7 +38,7 @@ namespace { // This pass outlines the cond/body region of the TFL WhileOp into functions and // replaces the regions with calls to these outlined functions. class WhileOutlinePass - : public mlir::OperationPass { + : public mlir::PassWrapper> { public: explicit WhileOutlinePass() {} @@ -241,7 +241,7 @@ void WhileOutlinePass::runOnOperation() { } // Creates an instance of the TensorFlow Lite dialect WhileOp outline pass. -std::unique_ptr> CreateWhileOutlinePass() { +std::unique_ptr> CreateWhileOutlinePass() { return std::make_unique(); } diff --git a/tensorflow/compiler/mlir/runlit.cfg.py b/tensorflow/compiler/mlir/runlit.cfg.py index 67533197f3e..ab8c1107fc8 100644 --- a/tensorflow/compiler/mlir/runlit.cfg.py +++ b/tensorflow/compiler/mlir/runlit.cfg.py @@ -71,7 +71,7 @@ tool_dirs = config.mlir_tf_tools_dirs + [ tool_names = [ 'mlir-opt', 'mlir-translate', 'tf-opt', 'tf_tfl_translate', 'flatbuffer_to_string', 'flatbuffer_translate', 'tf-mlir-translate', - 'mlir-tflite-runner', 'tfcompile' + 'mlir-tflite-runner', 'tfcompile', 'json_to_flatbuffer' ] tools = [ToolSubst(s, unresolved='ignore') for s in tool_names] llvm_config.add_tool_substitutions(tools, tool_dirs) diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index c2120ccc4ab..7132ae6d7e5 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -1078,6 +1078,7 @@ COMPILE_MLIR_UTIL_DEPS = [ "//tensorflow/compiler/mlir/xla:mlir_hlo_to_hlo", "//tensorflow/compiler/mlir/xla:type_to_shape", "//tensorflow/compiler/mlir/xla:xla_legalize_tf", + "//tensorflow/compiler/mlir/xla:xla_legalize_tf_with_tf2xla", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/core:framework", @@ -1118,6 +1119,7 @@ tf_cc_test( srcs = ["utils/compile_mlir_util_test.cc"], deps = [ ":compile_mlir_util", + "//tensorflow/compiler/jit", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/xla:test", @@ -1329,6 +1331,7 @@ cc_library( deps = [ ":tensorflow", "//tensorflow/compiler/xla:xla_data_proto_cc", + "//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc", "@llvm-project//llvm:support", "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc index f757a1fe638..e8d32121d1b 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc @@ -108,8 +108,6 @@ TensorFlowDeviceDialect::TensorFlowDeviceDialect(MLIRContext* context) #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc.inc" >(); - addOperations(); - addInterfaces(); } @@ -161,22 +159,8 @@ LogicalResult Verify(ParallelExecuteOp op) { int output_index = 0; for (auto& region_and_index : llvm::enumerate(regions)) { auto& region = region_and_index.value(); - auto region_index = region_and_index.index(); - - // Each region must include a single block of ops and must not be empty. - if (region.empty()) { - return op.emitOpError() - << "regions must not be empty. " - << "Found an empty region (" << region_index << ")."; - } - - if (!has_single_element(region)) { - return op.emitOpError() - << "regions must be composed of a single block of operations." - << "Expected region (" << region_index << ") with 1 block."; - } - auto* region_terminator = region.front().getTerminator(); + // Check that output types of regions match return operand types. for (auto result_type : region_terminator->getOperandTypes()) { if (result_type != @@ -214,8 +198,6 @@ void ParallelExecuteOp::build(Builder* builder, OperationState& state, state.addTypes(output_types); } -LogicalResult ParallelExecuteOp::verify() { return Verify(*this); } - Block& ParallelExecuteOp::GetRegionBlockWithIndex(unsigned index) { return getOperation()->getRegion(index).front(); } diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_device.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_device.h index 6600edf35a7..4c20d1ccc4f 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_device.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_device.h @@ -43,47 +43,6 @@ class TensorFlowDeviceDialect : public Dialect { #define GET_OP_CLASSES #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h.inc" -// TODO(b/148642767): Use tablegen to define tf_device.parallel_execute op once -// variadic regions can be expressed in tablegen. -// -// ParallelExecute op concurrently executes variadic number of regions. Regions -// must represent separate sets of instructions to execute concurrently. In -// order to represent concurrently executed regions with dependencies, multiple -// ParallelExecute ops can be used instead. As so, regions within -// ParallelExecute op must not have control/data dependencies. While explicit -// dependencies between regions are disallowed, ParallelExecute op does not -// prevent implicit communication between regions (e.g. communication via -// send/recvs). In this case, users of ParallelExecute op must provide correct -// control dependencies between regions to guarantee correctness. Regions in -// ParallelExecute may include Resource ops. In the case where different regions -// include ops access the same resource, the users of the ParallelExecute op -// must provide mechanism (via send/recvs or via control dependencies) to -// guarantee correct ordering. Sequential ordering of ops within a region is -// guaranteed. Also, sequential ordering of ops before/after ParallelExecute ops -// are guaranteed. That is, execution of regions inside ParallelExecute op is -// blocked until all inputs to all regions are materialized and ops following -// ParallelExecute op are blocked until all regions are executed. -class ParallelExecuteOp - : public Op::Impl> { - public: - using Op::Op; - - static void build(Builder* builder, OperationState& state, int num_regions, - llvm::ArrayRef output_types); - - static StringRef getOperationName() { return "tf_device.parallel_execute"; } - - LogicalResult verify(); - Block& GetRegionBlockWithIndex(unsigned index); - Operation::result_range GetRegionOutputs(unsigned region_index); - - // Checks if a tf_device.parallel_execute index'th region block wraps a single - // operation and the single operation results are perfectly forwarded to the - // region block's return. - bool RegionWrapsSingleOp(unsigned index); -}; - } // namespace tf_device } // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_device_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_device_ops.td index 7c2bd9daa40..4673e86921a 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_device_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_device_ops.td @@ -125,6 +125,55 @@ def TfDevice_LaunchFuncOp : TfDevice_Op<"launch_func", []> { }]; } +def TfDevice_ParallelExecuteOp : TfDevice_Op<"parallel_execute", + [SingleBlockImplicitTerminator<"ReturnOp">]> { + let description = [{ + ParallelExecute op concurrently executes variadic number of regions. Regions + must represent separate sets of instructions to execute concurrently. In + order to represent concurrently executed regions with dependencies, multiple + ParallelExecute ops can be used instead. As so, regions within + ParallelExecute op must not have control/data dependencies. + + While explicit dependencies between regions are disallowed, ParallelExecute + op does not prevent implicit communication between regions (e.g. + communication via send/recvs). In this case, users of ParallelExecute op + must provide correct control dependencies between regions to guarantee + correctness. Regions in ParallelExecute may include Resource ops. + + In the case where different regions include ops access the same resource, + the users of the ParallelExecute op must provide mechanism (via send/recvs + or via control dependencies) to guarantee correct ordering. Sequential + ordering of ops within a region is guaranteed. Also, sequential ordering of + ops before/after ParallelExecute ops are guaranteed. That is, execution of + regions inside ParallelExecute op is blocked until all inputs to all regions + are materialized and ops following ParallelExecute op are blocked until all + regions are executed. + }]; + + let results = (outs + Variadic:$execute_outputs + ); + + let regions = (region VariadicRegion>:$regions); + + let extraClassDeclaration = [{ + Block& GetRegionBlockWithIndex(unsigned index); + Operation::result_range GetRegionOutputs(unsigned region_index); + + // Checks if a tf_device.parallel_execute index'th region block wraps a + // single operation and the single operation results are perfectly forwarded + // to the region block's return. + bool RegionWrapsSingleOp(unsigned index); + }]; + + let builders = [ + OpBuilder<"Builder* builder, OperationState& state, int num_regions," + "llvm::ArrayRef output_types">, + ]; + + let verifier = [{ return Verify(*this); }]; +} + def TfDevice_ReplicateOp : TfDevice_Op<"replicate", [SingleBlockImplicitTerminator<"ReturnOp">]> { let summary = "Wraps an N-way replicated computation."; diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc index 92590af2aea..1b13558b692 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc @@ -208,7 +208,7 @@ static Type InferReductionOpType(Value input, Value reduction_indices, int64_t num_reduce_dim = 0; llvm::SmallVector is_reduce_dim(rank, false); - for (APInt index : indices.getValues()) { + for (const APInt &index : indices.getValues()) { int64_t dim = GetDimForAxis(index.getSExtValue(), rank); // Invalid input. if (dim < 0 || dim >= rank) return UnrankedTensorType::get(element_ty); @@ -404,11 +404,11 @@ static bool AreCancellablePermutations(DenseIntElementsAttr perm0, if (perm0.getNumElements() != perm1.getNumElements()) return false; SmallVector perm0_values; - for (auto value : perm0.getIntValues()) + for (const auto &value : perm0.getIntValues()) perm0_values.push_back(value.getSExtValue()); SmallVector perm1_values; - for (auto value : perm1.getIntValues()) + for (const auto &value : perm1.getIntValues()) perm1_values.push_back(value.getSExtValue()); for (int i = 0; i < perm0_values.size(); ++i) { @@ -2548,12 +2548,15 @@ static LogicalResult Verify(SizeOp op) { // SliceOp //===----------------------------------------------------------------------===// -// Verifies that, +// Verifies that: // // - operands begin and size are 1D with the same number of elements. // - if the input is a ranked tensor, the rank of the input equals the number // of elements in operands begin and size. -// - if begin are constants, 0 <= begin[i] < input_ty.getShape()[i] +// - if begin are constants, that +// 0 <= begin[i] <= begin[i] + size[i] <= input_ty.getShape()[i] +// - if begins aren't constant but the input is a ranked tensor, that +// size[i] <= input_ty.getShape()[i] // static LogicalResult Verify(SliceOp op) { RankedTensorType begin_ty = GetRankedTensorTypeForOperand(op.begin()); @@ -2587,7 +2590,7 @@ static LogicalResult Verify(SliceOp op) { bool constant_slice_sizes = matchPattern(op.size(), m_Constant(&slice_sizes)); int dim = 0; - for (APInt raw_begin_index : begin_indices.getValues()) { + for (const APInt &raw_begin_index : begin_indices.getValues()) { int64_t begin_index = raw_begin_index.getSExtValue(); int64_t input_size = input_ty ? input_ty.getShape()[dim] : -1; int64_t slice_size = constant_slice_sizes @@ -2603,6 +2606,20 @@ static LogicalResult Verify(SliceOp op) { } ++dim; } + } else if (input_ty) { + // If the inputs are ranked, we can do a few more sanity checks. + DenseIntElementsAttr slice_sizes; + if (matchPattern(op.size(), m_Constant(&slice_sizes))) { + auto input_shape = input_ty.getShape(); + for (int64_t i = 0; i < input_ty.getRank(); ++i) { + int64_t slice_size = slice_sizes.getValue(i).getInt(); + int64_t input_size = input_shape[i]; + if (slice_size != -1 && input_size != -1 && slice_size > input_size) { + return op.emitOpError() << "requires size[i] <= Di, even if begin[i] " + "is unknown at compile time"; + } + } + } } return success(); @@ -3340,7 +3357,7 @@ void TransposeOp::build(Builder *builder, OperationState &result, Value x, x_type.getDimSize((*attr_shape.begin()).getSExtValue())); } else { const_shape.reserve(attr_shape.getNumElements()); - for (auto dim : attr_shape) + for (const auto &dim : attr_shape) const_shape.push_back(x_type.getDimSize(dim.getSExtValue())); } return TransposeOp::build( diff --git a/tensorflow/compiler/mlir/tensorflow/tests/cluster_outlining.mlir b/tensorflow/compiler/mlir/tensorflow/tests/cluster_outlining.mlir index 5f11ad25848..1866879c465 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/cluster_outlining.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/cluster_outlining.mlir @@ -27,6 +27,7 @@ module { // CHECK-LABEL: func @tpu0_func // CHECK-SAME: (%[[TPU0_FUNC_ARG_0:[a-z0-9]*]]: tensor) -> tensor +// CHECK-SAME: sym_visibility = "private" // CHECK: %[[TPU0_FUNC_B_OUTPUT:[0-9]*]] = "tf.B"(%[[TPU0_FUNC_ARG_0]]) // CHECK: return %[[TPU0_FUNC_B_OUTPUT]] } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-resource-args.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-resource-args.pbtxt index a3f78e282bc..df740bc6ccd 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-resource-args.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-resource-args.pbtxt @@ -93,7 +93,7 @@ library { # CHECK: return # CHECK: func @test_func_name0 # CHECK-SAME: tf.resource_arg_unique_id = 0 -# CHECK-SAME tf.resource_arg_unique_id = 0 +# CHECK-SAME: tf.resource_arg_unique_id = 0 # CHECK: tf_executor.graph # CHECK: tf_executor.fetch # CHECK: return diff --git a/tensorflow/compiler/mlir/tensorflow/tests/inlining.mlir b/tensorflow/compiler/mlir/tensorflow/tests/inlining.mlir index 4dce57fbd10..77e53ec041e 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/inlining.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/inlining.mlir @@ -1,4 +1,4 @@ -// RUN: tf-opt %s -inline -mlir-disable-inline-simplify | FileCheck %s --dump-input=fail +// RUN: tf-opt %s -inline="disable-simplify" | FileCheck %s --dump-input=fail // Test that simple TF operations can be inlined. diff --git a/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_to_nhwc.mlir b/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_to_nhwc.mlir index 85b4e3671ac..ca853505845 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_to_nhwc.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_to_nhwc.mlir @@ -48,7 +48,7 @@ func @transpose_resnet_layer(%arg0: tensor, // input } : (tensor, tensor<7x7x3x64xf32>) -> tensor // CHECK: %[[CONV0:[0-9]*]] = "tf.Conv2D" - // CHECK-SAME %[[PAD]] + // CHECK-SAME: %[[PAD]] // CHECK-SAME: data_format = "NHWC" // CHECK-SAME: strides = [1, 2, 2, 1] diff --git a/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir b/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir index 4b38465257d..0195b1b0d3e 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir @@ -163,12 +163,12 @@ func @bitwise_and_dynamic(%arg0: tensor, %arg1: tensor<1xi32>) -> tensor< } func @pow(%arg0: tensor<2xf32>) -> tensor<2xf32> { - %0 = xla_hlo.pow %arg0, %arg0 : tensor<2xf32> + %0 = xla_hlo.power %arg0, %arg0 : tensor<2xf32> return %0 : tensor<2xf32> } func @pow_dynamic(%arg0: tensor) -> tensor { - %0 = xla_hlo.pow %arg0, %arg0 : tensor + %0 = xla_hlo.power %arg0, %arg0 : tensor return %0 : tensor } @@ -184,7 +184,7 @@ func @floordiv_broadcast_i32(%arg0: tensor<2x3xi32>, %arg1: tensor<3xi32>) -> te %8 = xla_hlo.constant dense<1> : tensor<3xi32> %9 = xla_hlo.subtract %7, %8 : tensor<3xi32> %10 = "xla_hlo.add"(%6, %9) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32> - %11 = "xla_hlo.neg"(%10) : (tensor<2x3xi32>) -> tensor<2x3xi32> + %11 = "xla_hlo.negate"(%10) : (tensor<2x3xi32>) -> tensor<2x3xi32> %12 = "xla_hlo.abs"(%arg1) : (tensor<3xi32>) -> tensor<3xi32> %13 = "xla_hlo.divide"(%11, %12) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32> %14 = "xla_hlo.select"(%4, %5, %13) : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> @@ -203,7 +203,7 @@ func @floordiv_reverse_broadcast_i32(%arg0: tensor<3xi32>, %arg1: tensor<2x3xi32 %8 = xla_hlo.constant dense<1> : tensor<2x3xi32> %9 = xla_hlo.subtract %7, %8 : tensor<2x3xi32> %10 = "xla_hlo.add"(%6, %9) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> - %11 = "xla_hlo.neg"(%10) : (tensor<2x3xi32>) -> tensor<2x3xi32> + %11 = "xla_hlo.negate"(%10) : (tensor<2x3xi32>) -> tensor<2x3xi32> %12 = "xla_hlo.abs"(%arg1) : (tensor<2x3xi32>) -> tensor<2x3xi32> %13 = xla_hlo.divide %11, %12 : tensor<2x3xi32> %14 = "xla_hlo.select"(%4, %5, %13) : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> @@ -461,32 +461,32 @@ func @complex_abs(%arg0: tensor<2xcomplex>) -> tensor<2xf32> { } func @cos(%arg0: tensor<2xf32>) -> tensor<2xf32> { - %0 = "xla_hlo.cos"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + %0 = "xla_hlo.cosine"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> return %0 : tensor<2xf32> } func @cos_dynamic(%arg0: tensor) -> tensor { - %0 = "xla_hlo.cos"(%arg0) : (tensor) -> tensor + %0 = "xla_hlo.cosine"(%arg0) : (tensor) -> tensor return %0 : tensor } func @cos_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { - %0 = "xla_hlo.cos"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> + %0 = "xla_hlo.cosine"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } func @exp(%arg0: tensor<2xf32>) -> tensor<2xf32> { - %0 = "xla_hlo.exp"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + %0 = "xla_hlo.exponential"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> return %0 : tensor<2xf32> } func @exp_dynamic(%arg0: tensor) -> tensor { - %0 = "xla_hlo.exp"(%arg0) : (tensor) -> tensor + %0 = "xla_hlo.exponential"(%arg0) : (tensor) -> tensor return %0 : tensor } func @exp_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { - %0 = "xla_hlo.exp"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> + %0 = "xla_hlo.exponential"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } @@ -551,17 +551,17 @@ func @log1p_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { } func @neg(%arg0: tensor<2xf32>) -> tensor<2xf32> { - %0 = "xla_hlo.neg"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + %0 = "xla_hlo.negate"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> return %0 : tensor<2xf32> } func @neg_dynamic(%arg0: tensor) -> tensor { - %0 = "xla_hlo.neg"(%arg0) : (tensor) -> tensor + %0 = "xla_hlo.negate"(%arg0) : (tensor) -> tensor return %0 : tensor } func @neg_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { - %0 = "xla_hlo.neg"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> + %0 = "xla_hlo.negate"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } @@ -577,17 +577,17 @@ func @sigmoid(%arg0: tensor<2xf32>) -> tensor<2xf32> { } func @sin(%arg0: tensor<2xf32>) -> tensor<2xf32> { - %0 = "xla_hlo.sin"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + %0 = "xla_hlo.sine"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> return %0 : tensor<2xf32> } func @sin_dynamic(%arg0: tensor) -> tensor { - %0 = "xla_hlo.sin"(%arg0) : (tensor) -> tensor + %0 = "xla_hlo.sine"(%arg0) : (tensor) -> tensor return %0 : tensor } func @sin_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { - %0 = "xla_hlo.sin"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> + %0 = "xla_hlo.sine"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } @@ -677,6 +677,11 @@ func @size_rank_one_i64(%arg0: tensor) -> tensor { return %0 : tensor } +func @complex(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xcomplex> { + %0 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xcomplex> + return %0 : tensor<3xcomplex> +} + // NOTE: Assertions have been autogenerated by utils/generate-test-checks.py // CHECK-LABEL: func @biasAdd_NHWC( @@ -1481,3 +1486,10 @@ func @size_rank_one_i64(%arg0: tensor) -> tensor { // CHECK: [[VAL_366:%.*]] = "tf.Const"() {value = dense<1> : tensor} : () -> tensor // CHECK: return [[VAL_366]] : tensor // CHECK: } + +// CHECK-LABEL: func @complex( +// CHECK-SAME: [[VAL_367:%.*]]: tensor<3xf32>, [[VAL_368:%.*]]: tensor<3xf32>) -> tensor<3xcomplex> { +// CHECK: [[VAL_369:%.*]] = "tf.Complex"([[VAL_367]], [[VAL_368]]) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xcomplex> +// CHECK: return [[VAL_369]] : tensor<3xcomplex> +// CHECK: } + diff --git a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/BUILD b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/BUILD index 2451947a4a5..f22b537ece4 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/BUILD @@ -7,6 +7,7 @@ glob_lit_tests( driver = "@llvm-project//mlir:run_lit.sh", tags_override = { "preserve-entry-func-names.mlir": ["nomac"], # TODO(b/148403706): flaky on Mac, to be fixed. + "tf_add.mlir": ["nomac"], # TODO(b/148403706): flaky on Mac, to be fixed. }, test_file_exts = ["mlir"], ) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/tf_add.mlir b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/tf_add.mlir index 29f7f860f1c..beb7312543b 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/tf_add.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/tf_add.mlir @@ -1,4 +1,4 @@ -// RUN: tf-mlir-translate -mlir-to-graphdef %s -o - | FileCheck %s +// RUN: tf-mlir-translate -mlir-to-graphdef %s -o - | FileCheck %s --dump-input-on-failure func @main(%arg0: tensor<10xi32>, %arg1: tensor<10xi32>) -> tensor<10xi32> attributes {tf.entry_function = {inputs = "input0,input1", outputs = "Add"}} { diff --git a/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir b/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir index 5c8041e0436..544264177b3 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir @@ -116,8 +116,8 @@ func @multiple_blocks_one_return(%arg0: tensor) -> tensor<*xf32> { // CHECK-LABEL: func @shape_from_while_to_cond_body_functions func @shape_from_while_to_cond_body_functions(%arg0: tensor<4xf32>, %arg1: tensor>>, %arg2: tensor>>) -> tensor<4xf32> { - // CHECK "tf.While" - // CHECK-SAME (tensor<4xf32>, tensor>>, tensor>>) -> (tensor<4xf32>, tensor>>, tensor>>) + // CHECK: "tf.While" + // CHECK-SAME: (tensor<4xf32>, tensor>>, tensor>>) -> (tensor<4xf32>, tensor>>, tensor>>) %0:3 = "tf.While"(%arg0, %arg1, %arg2) {cond = @while_cond_func, body = @while_body_func, is_stateless = true} : (tensor<4xf32>, tensor>>, tensor>>) -> (tensor<4xf32>, tensor<*x!tf.resource>, tensor>>) return %0#0 : tensor<4xf32> } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir index 8031d5d3c28..e4d6b25fea4 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir @@ -1682,6 +1682,23 @@ func @testSlice_begin_out_of_bound(%arg0: tensor<4xi32>) -> tensor<2xi32> { // ----- +func @testSlice_unknown_begin_out_of_bounds(%arg0: tensor<4xi32>, %begins: tensor<1xi64>) -> tensor<3xi32> { + %sizes = "tf.Const"() {value = dense<[5]> : tensor<1xi64>} : () -> (tensor<1xi64>) + // expected-error @+1 {{requires size[i] <= Di, even if begin[i] is unknown at compile time}} + %0 = "tf.Slice"(%arg0, %begins, %sizes) : (tensor<4xi32>, tensor<1xi64>, tensor<1xi64>) -> tensor<3xi32> + return %0 : tensor<3xi32> +} + +// ----- + +func @testSlice_unknown_begin_in_bounds(%arg0: tensor<4xi32>, %begins: tensor<1xi64>) -> tensor<3xi32> { + %sizes = "tf.Const"() {value = dense<[4]> : tensor<1xi64>} : () -> (tensor<1xi64>) + %0 = "tf.Slice"(%arg0, %begins, %sizes) : (tensor<4xi32>, tensor<1xi64>, tensor<1xi64>) -> tensor<3xi32> + return %0 : tensor<3xi32> +} + +// ----- + // Valid StridedSlice operation. func @testStridedSlice(%input: tensor<4x8xf32>, %begin: tensor<2xi64>, %end: tensor<2xi64>, %strides: tensor<2xi64>) -> tensor { %0 = "tf.StridedSlice"(%input, %begin, %end, %strides) : (tensor<4x8xf32>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_device_ops_invalid.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_device_ops_invalid.mlir index 07863e3e806..0eb5f878c2a 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_device_ops_invalid.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_device_ops_invalid.mlir @@ -188,7 +188,7 @@ func @parallel_execute_single_region() { // Check that a parallel_execute op with empty regions are not allowed. func @parallel_execute_empty_region() { "tf_device.parallel_execute"() ( { -// expected-error@-1 {{'tf_device.parallel_execute' op regions must not be empty. Found an empty region (0).}} +// expected-error@-1 {{'tf_device.parallel_execute' op region #0 ('regions') failed to verify constraint: region with 1 blocks}} }, { tf_device.return diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_delete_unused_funcs.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_delete_unused_funcs.mlir index 090e80739e1..c59b2ebdd7a 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_delete_unused_funcs.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_delete_unused_funcs.mlir @@ -51,12 +51,12 @@ module attributes {tf_saved_model.semantics} { // Test case: Delete recursively dead cycle. - // CHECK-NOT func @recursively_dead0 + // CHECK-NOT: func @recursively_dead0 func @recursively_dead0() { "some_dialect.call"() { callee = @recursively_dead1 } : () -> () return } - // CHECK-NOT func @recursively_dead1 + // CHECK-NOT: func @recursively_dead1 func @recursively_dead1() { "some_dialect.call"() { callee = @recursively_dead0 } : () -> () return diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_freeze_global_tensors.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_freeze_global_tensors.mlir index ce44a562aca..38aa078358b 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_freeze_global_tensors.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_freeze_global_tensors.mlir @@ -86,3 +86,21 @@ module attributes {tf_saved_model.semantics} { } } +// ----- + +module attributes {tf_saved_model.semantics} { + + // CHECK-NOT: tf_saved_model.global_tensor + "tf_saved_model.global_tensor"() {sym_name = "v", type = tensor, value = dense<1.0> : tensor } : () -> () + "tf_saved_model.global_tensor"() {sym_name = "v2", type = tensor, value = dense<1.0> : tensor } : () -> () + + func @f(%arg1: tensor>> {tf_saved_model.bound_input = @"v"}, %arg2: tensor>> {tf_saved_model.bound_input = @"v2"}) + attributes {tf_saved_model.exported_names = ["f"]} { + // CHECK: "tf.Const"() + %0 = "tf.ReadVariableOp"(%arg1) {device = ""} : (tensor>>) -> tensor + + // CHECK: "tf.Const"() + %1 = "tf.ReadVariableOp"(%arg2) {device = ""} : (tensor>>) -> tensor + return + } +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu-dynamic-layout-pass.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu-dynamic-layout-pass.mlir index 7fdfd90c64a..abb32cd3bf3 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tpu-dynamic-layout-pass.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu-dynamic-layout-pass.mlir @@ -19,14 +19,14 @@ func @non_replicated(%arg0: tensor<*x!tf.resource> {tf.device = "/device:CPU:0"} // CHECK: %[[ITER:.*]]:2 = "tf.IteratorGetNext" %2:2 = "tf.IteratorGetNext"(%arg0) {device = "/device:CPU:0"} : (tensor<*x!tf.resource>) -> (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>) - // CHECK-DAG: %[[COPY0:.*]] = "tf.TPUCopyWithLayout"(%[[ITER]]#0, %[[LAYOUT0]]) {device = "/device:TPU:0"} - // CHECK-DAG: %[[COPY1:.*]] = "tf.TPUCopyWithLayout"(%[[ITER]]#1, %[[LAYOUT1]]) {device = "/device:TPU:0"} // CHECK: "tf_device.launch" // CHECK-NEXT: "tf.TPUCompileSucceededAssert" "tf_device.launch"() ( { "tf.TPUCompileSucceededAssert"(%compile#0) : (tensor) -> () tf_device.return }) {device = "/device:CPU:0"} : () -> () + // CHECK-DAG: %[[COPY0:.*]] = "tf.TPUCopyWithLayout"(%[[ITER]]#0, %[[LAYOUT0]]) {device = "/device:TPU:0"} + // CHECK-DAG: %[[COPY1:.*]] = "tf.TPUCopyWithLayout"(%[[ITER]]#1, %[[LAYOUT1]]) {device = "/device:TPU:0"} // CHECK: "tf_device.launch" // CHECK-NEXT: "tf.TPUExecute"(%[[COPY0]], %[[COPY1]], %[[COMPILE]]#1) %execute = "tf_device.launch"() ( { @@ -154,17 +154,17 @@ func @replicated(%arg0: tensor<*x!tf.resource> {tf.device = "/device:CPU:0"}) -> }) {device = "/device:CPU:0"} : () -> (tensor, tensor) // CHECK-DAG: %[[LAYOUT0:.*]] = "tf.TPUGetLayoutOp"(%[[COMPILE]]#1) {index = 0 : i64, is_output = false} // CHECK-DAG: %[[LAYOUT1:.*]] = "tf.TPUGetLayoutOp"(%[[COMPILE]]#1) {index = 1 : i64, is_output = false} - // CHECK-DAG: %[[COPY0:.*]] = "tf.TPUCopyWithLayout"(%[[ITER0]]#0, %[[LAYOUT0]]) {device = "/device:TPU:0"} - // CHECK-DAG: %[[COPY1:.*]] = "tf.TPUCopyWithLayout"(%[[ITER0]]#1, %[[LAYOUT1]]) {device = "/device:TPU:0"} // CHECK: %[[ITER1:.*]]:2 = "tf.IteratorGetNext" %3:2 = "tf.IteratorGetNext"(%arg0) {device = "/device:CPU:0"} : (tensor<*x!tf.resource>) -> (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>) - // CHECK-DAG: %[[COPY2:.*]] = "tf.TPUCopyWithLayout"(%[[ITER1]]#0, %[[LAYOUT0]]) {device = "/device:TPU:1"} - // CHECK-DAG: %[[COPY3:.*]] = "tf.TPUCopyWithLayout"(%[[ITER1]]#1, %[[LAYOUT1]]) {device = "/device:TPU:1"} "tf_device.launch"() ( { "tf.TPUCompileSucceededAssert"(%compile#0) : (tensor) -> () tf_device.return }) {device = "/device:CPU:0"} : () -> () + // CHECK-DAG: %[[COPY0:.*]] = "tf.TPUCopyWithLayout"(%[[ITER0]]#0, %[[LAYOUT0]]) {device = "/device:TPU:0"} + // CHECK-DAG: %[[COPY1:.*]] = "tf.TPUCopyWithLayout"(%[[ITER0]]#1, %[[LAYOUT1]]) {device = "/device:TPU:0"} + // CHECK-DAG: %[[COPY2:.*]] = "tf.TPUCopyWithLayout"(%[[ITER1]]#0, %[[LAYOUT0]]) {device = "/device:TPU:1"} + // CHECK-DAG: %[[COPY3:.*]] = "tf.TPUCopyWithLayout"(%[[ITER1]]#1, %[[LAYOUT1]]) {device = "/device:TPU:1"} // CHECK: tf_device.replicate([%[[COPY0]], %[[COPY2]]] as %[[R0:.*]]: tensor<3x3x1x32xf32>, [%[[COPY1]], %[[COPY3]]] as %[[R1:.*]]: tensor<3x3x1x32xf32>) %5:2 = tf_device.replicate([%2#0, %3#0] as %r0: tensor<3x3x1x32xf32>, [%2#1, %3#1] as %r1: tensor<3x3x1x32xf32>) {n = 2 : i32, devices = {TPU_REPLICATED_CORE_0 = ["/device:TPU:0", "/device:TPU:1"]}} { @@ -210,3 +210,137 @@ func @inside_replicated(%arg0: tensor<*x!tf.resource>, %arg1: tensor<*x!tf.resou } return %5#0 : tensor } + +// ----- + +// Tests that the pass can transform execution with model parallelism and no +// replication. +// +// The following TPUCompileMetadataProto is used: +// args { +// dtype: DT_FLOAT +// shape { +// dim { +// size: 128 +// } +// } +// } +// num_replicas: 1 +// num_cores_per_replica: 2 + +// CHECK-LABEL: func @parallel_execute +func @parallel_execute(%arg0: tensor<*x!tf.resource>) { + // CHECK: %[[COMPILE:.*]]:3 = "tf_device.launch" + // CHECK-NEXT: "tf._TPUCompileMlir"() + %compile:3 = "tf_device.launch"() ( { + %1:3 = "tf._TPUCompileMlir"() {NumDynamicShapes = 0 : i64, metadata = "\0A\09\08\01\12\05\12\03\08\80\01\18\01 \02", mlir_module = "..."} : () -> (tensor, tensor, tensor) + tf_device.return %1#0, %1#1, %1#2 : tensor, tensor, tensor + }) {device = "/device:CPU:0"} : () -> (tensor, tensor, tensor) + // CHECK-DAG: %[[LAYOUT0:.*]] = "tf.TPUGetLayoutOp"(%[[COMPILE]]#1) {index = 0 : i64, is_output = false} + // CHECK-DAG: %[[LAYOUT1:.*]] = "tf.TPUGetLayoutOp"(%[[COMPILE]]#2) {index = 0 : i64, is_output = false} + // CHECK: %[[ITER:.*]]:2 = "tf.IteratorGetNext" + %2:2 = "tf.IteratorGetNext"(%arg0) {device = "/device:CPU:0"} : (tensor<*x!tf.resource>) -> (tensor<128xf32>, tensor<128xf32>) + // CHECK: "tf.TPUCompileSucceededAssert"(%[[COMPILE]]#0) + "tf_device.launch"() ( { + "tf.TPUCompileSucceededAssert"(%compile#0) : (tensor) -> () + tf_device.return + }) {device = "/device:CPU:0"} : () -> () + // CHECK: "tf_device.parallel_execute" + "tf_device.parallel_execute"() ({ + // CHECK-NEXT: %[[COPY0:.*]] = "tf.TPUCopyWithLayout"(%[[ITER]]#0, %[[LAYOUT0]]) + // CHECK-SAME: device = "/device:TPU:0" + // CHECK-NEXT: "tf_device.launch" + // CHECK-NEXT: "tf.TPUExecute"(%[[COPY0]], %[[COMPILE]]#1) + // CHECK-NEXT: tf_device.return + // CHECK-NEXT: device = "/device:TPU:0" + "tf_device.launch"() ( { + "tf.TPUExecute"(%2#0, %compile#1) : (tensor<128xf32>, tensor) -> () + tf_device.return + }) {device = "/device:TPU:0"} : () -> () + tf_device.return + }, + { + // CHECK: %[[COPY1:.*]] = "tf.TPUCopyWithLayout"(%[[ITER]]#1, %[[LAYOUT1]]) + // CHECK-SAME: device = "/device:TPU:1" + // CHECK-NEXT: "tf_device.launch" + // CHECK-NEXT: "tf.TPUExecute"(%[[COPY1]], %[[COMPILE]]#2) + // CHECK-NEXT: tf_device.return + // CHECK-NEXT: device = "/device:TPU:1" + "tf_device.launch"() ( { + "tf.TPUExecute"(%2#1, %compile#2) : (tensor<128xf32>, tensor) -> () + tf_device.return + }) {device = "/device:TPU:1"} : () -> () + tf_device.return + }) {} : () -> () + return +} + +// ----- + +// Tests that the pass can transform execution with model parallelism and +// replication. +// +// The following TPUCompileMetadataProto is used: +// args { +// dtype: DT_FLOAT +// shape { +// dim { +// size: 128 +// } +// } +// } +// num_replicas: 2 +// num_cores_per_replica: 2 + +// CHECK-LABEL: func @replicated_parallel_execute +// CHECK-SAME: (%[[ARG0:[a-z0-9]+]]: tensor<*x!tf.resource>, %[[ARG1:[a-z0-9]+]]: tensor<*x!tf.resource>) +func @replicated_parallel_execute(%arg0: tensor<*x!tf.resource>, %arg1: tensor<*x!tf.resource>) { + // CHECK: %[[COMPILE:.*]]:3 = "tf_device.launch" + // CHECK-NEXT: "tf._TPUCompileMlir"() + %compile:3 = "tf_device.launch"() ( { + %1:3 = "tf._TPUCompileMlir"() {NumDynamicShapes = 0 : i64, metadata = "\0A\09\08\01\12\05\12\03\08\80\01\18\02 \02", mlir_module = "..."} : () -> (tensor, tensor, tensor) + tf_device.return %1#0, %1#1, %1#2 : tensor, tensor, tensor + }) {device = "/device:CPU:0"} : () -> (tensor, tensor, tensor) + // CHECK-DAG: %[[LAYOUT0:.*]] = "tf.TPUGetLayoutOp"(%[[COMPILE]]#1) {index = 0 : i64, is_output = false} + // CHECK-DAG: %[[LAYOUT1:.*]] = "tf.TPUGetLayoutOp"(%[[COMPILE]]#2) {index = 0 : i64, is_output = false} + // CHECK-DAG: %[[ITER0:.*]]:2 = "tf.IteratorGetNext"(%[[ARG0]]) + // CHECK-DAG: %[[ITER1:.*]]:2 = "tf.IteratorGetNext"(%[[ARG1]]) + %2:2 = "tf.IteratorGetNext"(%arg0) {device = "/device:CPU:0"} : (tensor<*x!tf.resource>) -> (tensor<128xf32>, tensor<128xf32>) + %3:2 = "tf.IteratorGetNext"(%arg1) {device = "/device:CPU:0"} : (tensor<*x!tf.resource>) -> (tensor<128xf32>, tensor<128xf32>) + // CHECK: "tf.TPUCompileSucceededAssert"(%[[COMPILE]]#0) + "tf_device.launch"() ( { + "tf.TPUCompileSucceededAssert"(%compile#0) : (tensor) -> () + tf_device.return + }) {device = "/device:CPU:0"} : () -> () + // CHECK-DAG: %[[COPY0:.*]] = "tf.TPUCopyWithLayout"(%[[ITER0]]#0, %[[LAYOUT0]]) {device = "/device:TPU:0"} + // CHECK-DAG: %[[COPY1:.*]] = "tf.TPUCopyWithLayout"(%[[ITER1]]#0, %[[LAYOUT0]]) {device = "/device:TPU:1"} + // CHECK-DAG: %[[COPY2:.*]] = "tf.TPUCopyWithLayout"(%[[ITER0]]#1, %[[LAYOUT1]]) {device = "/device:TPU:2"} + // CHECK-DAG: %[[COPY3:.*]] = "tf.TPUCopyWithLayout"(%[[ITER1]]#1, %[[LAYOUT1]]) {device = "/device:TPU:3"} + // CHECK-NEXT: tf_device.replicate + // CHECK-SAME: ([%[[COPY0]], %[[COPY1]]] as %[[R0:[a-z0-9]+]]: tensor<128xf32>, [%[[COPY2]], %[[COPY3]]] as %[[R1:[a-z0-9]+]]: tensor<128xf32>) + tf_device.replicate([%2#0, %3#0] as %r0: tensor<128xf32>, [%2#1, %3#1] as %r1: tensor<128xf32>) {n = 2 : i32, devices = {TPU_REPLICATED_CORE_0 = ["/device:TPU:0", "/device:TPU:1"], TPU_REPLICATED_CORE_1 = ["/device:TPU:2", "/device:TPU:3"]}} { + // CHECK: "tf_device.parallel_execute" + "tf_device.parallel_execute"() ({ + // CHECK: "tf.TPUExecute"(%[[R0]], %[[COMPILE]]#1) + // CHECK-NEXT: tf_device.return + // CHECK-NEXT: device = "TPU_REPLICATED_CORE_0" + "tf_device.launch"() ( { + "tf.TPUExecute"(%r0, %compile#1) : (tensor<128xf32>, tensor) -> () + tf_device.return + }) {device = "TPU_REPLICATED_CORE_0"} : () -> () + tf_device.return + }, + { + // CHECK: "tf.TPUExecute"(%[[R1]], %[[COMPILE]]#2) + // CHECK-NEXT: tf_device.return + // CHECK-NEXT: device = "TPU_REPLICATED_CORE_1" + "tf_device.launch"() ( { + "tf.TPUExecute"(%r1, %compile#2) : (tensor<128xf32>, tensor) -> () + tf_device.return + }) {device = "TPU_REPLICATED_CORE_1"} : () -> () + tf_device.return + }) {} : () -> () + tf_device.return + } + return +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_bridge_v1/end_to_end.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_bridge_v1/end_to_end.mlir index fe33624c1d7..b9bc0e17f2a 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_bridge_v1/end_to_end.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_bridge_v1/end_to_end.mlir @@ -5,6 +5,7 @@ module attributes {tf.devices = ["/job:localhost/replica:0/task:0/device:CPU:0", // CHECK: std.constant // CHECK: TPUCompile // CHECK: TPUExecute +// CHECK-NOT: func @_func tf_executor.graph { %outputs, %control = tf_executor.island wraps "std.constant"() {value = dense<2.000000e+00> : tensor} : () -> tensor %outputs_0, %control_1 = tf_executor.island wraps "std.constant"() {value = dense<3.000000e+00> : tensor} : () -> tensor diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/annotate_parameter_replication.cc b/tensorflow/compiler/mlir/tensorflow/transforms/annotate_parameter_replication.cc index 58b4901b548..01c30eabd35 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/annotate_parameter_replication.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/annotate_parameter_replication.cc @@ -39,7 +39,8 @@ constexpr char kMirroredVariableIndicesAttr[] = "_mirrored_variable_indices"; // Analyzes the inputs to LaunchFuncOps in the module, and annotates their // invoked functions whether each input has the same data across replicas. struct AnnotateParameterReplication - : public OperationPass { + : public PassWrapper> { void runOnOperation() override; }; @@ -90,7 +91,8 @@ void AnnotateParameterReplication::runOnOperation() { } // namespace -std::unique_ptr> CreateAnnotateParameterReplicationPass() { +std::unique_ptr> +CreateAnnotateParameterReplicationPass() { return std::make_unique(); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/batchmatmul_to_einsum.cc b/tensorflow/compiler/mlir/tensorflow/transforms/batchmatmul_to_einsum.cc index bcf08c6b3ed..5d842f53bd9 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/batchmatmul_to_einsum.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/batchmatmul_to_einsum.cc @@ -43,7 +43,8 @@ namespace TF { namespace { // Replace TF BatchMatMul by TF Einsum -struct BatchMatMulToEinsumPass : public FunctionPass { +struct BatchMatMulToEinsumPass + : public PassWrapper { void runOnFunction() override; }; @@ -117,7 +118,7 @@ static PassRegistration pass( "tf-batch-matmul-to-tf-einsum", "Replace TF BatchMatMul op by TF Einsum op."); -std::unique_ptr> CreateBatchMatMulToEinsumPass() { +std::unique_ptr> CreateBatchMatMulToEinsumPass() { return std::make_unique(); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc b/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc index 2e1201c10c5..73130640d1b 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc @@ -97,6 +97,7 @@ void CreateTPUBridgePipeline(OpPassManager &pm) { pm.addPass(CreateTPUShardingIdentificationPass()); pm.addPass(TFDevice::CreateAnnotateParameterReplicationPass()); pm.addPass(CreateTPURewritePass()); + pm.addPass(createSymbolDCEPass()); pm.addNestedPass(TFDevice::CreateReplicateInvariantOpHoistingPass()); pm.addNestedPass(CreateTPUDynamicLayoutPass()); pm.addNestedPass(CreateTPUMergeVariablesWithExecutePass()); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/cluster_formation.cc b/tensorflow/compiler/mlir/tensorflow/transforms/cluster_formation.cc index 48f25b50ef6..2b8ab85be38 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/cluster_formation.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/cluster_formation.cc @@ -37,7 +37,8 @@ namespace TFDevice { namespace { -struct ClusterFormationPass : public FunctionPass { +struct ClusterFormationPass + : public PassWrapper { void runOnFunction() override; }; @@ -229,7 +230,7 @@ void ClusterFormationPass::runOnFunction() { } // namespace -std::unique_ptr> CreateClusterFormationPass() { +std::unique_ptr> CreateClusterFormationPass() { return std::make_unique(); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/cluster_outlining.cc b/tensorflow/compiler/mlir/tensorflow/transforms/cluster_outlining.cc index aee6e72e7d6..aa4c071abdf 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/cluster_outlining.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/cluster_outlining.cc @@ -39,7 +39,7 @@ constexpr char kDeviceAttr[] = "device"; constexpr char kFuncAttr[] = "func"; struct ClusterOutliningPass - : public OperationPass { + : public PassWrapper> { void runOnOperation() override; }; @@ -66,6 +66,10 @@ FuncOp BuildFunction(StringRef device, llvm::ArrayRef live_ins, FuncOp outlined_func = FuncOp::create(launch_op.getLoc(), func_name_prefix, func_type); + // This function is not externally visible and marking it private would allow + // symbol-dce pass to remove it when it is not referenced anymore. + outlined_func.setVisibility(FuncOp::Visibility::Private); + // Create function body. Block* outlined_func_block = outlined_func.addEntryBlock(); @@ -132,7 +136,7 @@ void ClusterOutliningPass::runOnOperation() { } // namespace -std::unique_ptr> CreateClusterOutliningPass() { +std::unique_ptr> CreateClusterOutliningPass() { return std::make_unique(); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.cc b/tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.cc index d9715d11922..8951b49acb7 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.cc @@ -276,7 +276,7 @@ int64_t GetFirstIfIndicesAreContiguous(Value indices) { if (!const_op) return -1; int64_t last_index = -1; int64_t first_index = -1; - for (auto ind : const_op.value().getValues()) { + for (const auto& ind : const_op.value().getValues()) { if (last_index == -1) { last_index = ind.getSExtValue(); first_index = last_index; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/decode_constant.cc b/tensorflow/compiler/mlir/tensorflow/transforms/decode_constant.cc index 53129dbf703..44e7c5dab6a 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/decode_constant.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/decode_constant.cc @@ -52,7 +52,7 @@ bool DecodeOpaqueValueInConstantOp(Operation *op) { } // A pass to decode opaque constant values into readable ones. -struct DecodeConstant : public FunctionPass { +struct DecodeConstant : public PassWrapper { void runOnFunction() override { auto walk_result = getFunction().walk([](Operation *op) { return DecodeOpaqueValueInConstantOp(op) ? WalkResult::advance() @@ -64,7 +64,7 @@ struct DecodeConstant : public FunctionPass { } // namespace -std::unique_ptr> CreateDecodeConstantPass() { +std::unique_ptr> CreateDecodeConstantPass() { return std::make_unique(); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/decode_constant.h b/tensorflow/compiler/mlir/tensorflow/transforms/decode_constant.h index 1acbb2e3a55..ed82261b21b 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/decode_constant.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/decode_constant.h @@ -23,7 +23,7 @@ namespace TF { // Creates a pass to decode and reset opaque values in constant ops into // readable values. // Note that this pass assumes RaiseTFControlFlow pass has already been run. -std::unique_ptr> CreateDecodeConstantPass(); +std::unique_ptr> CreateDecodeConstantPass(); } // namespace TF } // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops_pass.cc index a439d7dcc45..7c8734fb695 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops_pass.cc @@ -38,7 +38,8 @@ namespace { // NOTE: This pass does not support `use_locking=true` for a lot of resource // operations. So decomposition may not be correct outside of backends like XLA, // which automatically locks all resource variables. -struct DecomposeResourceOps : public FunctionPass { +struct DecomposeResourceOps + : public PassWrapper { void runOnFunction() override { // Add lowering patterns to the list. OwningRewritePatternList patterns; @@ -50,7 +51,7 @@ struct DecomposeResourceOps : public FunctionPass { } // namespace -std::unique_ptr> CreateDecomposeResourceOpsPass() { +std::unique_ptr> CreateDecomposeResourceOpsPass() { return std::make_unique(); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/einsum.cc b/tensorflow/compiler/mlir/tensorflow/transforms/einsum.cc index 483da1c70f7..69d5de659fa 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/einsum.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/einsum.cc @@ -354,7 +354,8 @@ LogicalResult ConvertTFEinsumOp::matchAndRewrite( } // Transform Einsum to other TF Ops for the supported variants. -struct TransformEinsumPass : public FunctionPass { +struct TransformEinsumPass + : public PassWrapper { void runOnFunction() override; }; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/executor_island_coarsening.cc b/tensorflow/compiler/mlir/tensorflow/transforms/executor_island_coarsening.cc index f7569917b41..0d72a7638a3 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/executor_island_coarsening.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/executor_island_coarsening.cc @@ -57,7 +57,7 @@ struct IslandResult { }; struct ExecutorIslandCoarsening - : public FunctionPass { + : public PassWrapper { void runOnFunction() override; }; @@ -346,7 +346,7 @@ void ExecutorIslandCoarsening::runOnFunction() { } // namespace -std::unique_ptr> CreateTFExecutorIslandCoarseningPass() { +std::unique_ptr> CreateTFExecutorIslandCoarseningPass() { return std::make_unique(); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_inline_tpu_island.cc b/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_inline_tpu_island.cc index 71e5d291292..9a533798208 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_inline_tpu_island.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_inline_tpu_island.cc @@ -43,7 +43,8 @@ constexpr llvm::StringRef kNestedModule = "_tpu_v1_compat_outlined"; // Inlining the islands calling into the nested module that was outlined. // This is the end of the TPU bridge in V1 compatibility mode. struct TPUBridgeExecutorIslandInlining - : public OperationPass { + : public PassWrapper> { void runOnOperation() override; }; @@ -95,7 +96,7 @@ PassRegistration tpu_pass( } // namespace -std::unique_ptr> +std::unique_ptr> CreateTFExecutorTPUV1IslandInliningPass() { return std::make_unique(); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_island_coarsening.cc b/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_island_coarsening.cc index 452ac076ac9..97111efb2f0 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_island_coarsening.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_island_coarsening.cc @@ -59,7 +59,8 @@ constexpr llvm::StringRef kTpuStatusAttr = "_tpu_compilation_status"; // TPU-annotated operations and intended to preserve backward compatibility with // TFv1. struct TpuV1BridgeExecutorIslandCoarsening - : public OperationPass { + : public PassWrapper> { void runOnOperation() override; }; @@ -322,7 +323,7 @@ void TpuV1BridgeExecutorIslandCoarsening::runOnOperation() { } // namespace -std::unique_ptr> +std::unique_ptr> CreateTFExecutorTPUV1IslandCoarseningPass() { return std::make_unique(); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_outline_tpu_island.cc b/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_outline_tpu_island.cc index db13d6b3875..08645333d5d 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_outline_tpu_island.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_outline_tpu_island.cc @@ -44,7 +44,8 @@ constexpr llvm::StringRef kOutlinedFuncPrefix = "_tpu_v1_compat_outlined_func"; // This is only intended for V1 compatibility mode where the bridge runs without // feed/fetches on session create/extend. struct TPUBridgeExecutorIslandOutlining - : public OperationPass { + : public PassWrapper> { void runOnOperation() override; }; @@ -160,7 +161,7 @@ PassRegistration tpu_pass( } // namespace -std::unique_ptr> +std::unique_ptr> CreateTFExecutorTPUV1IslandOutliningPass() { return std::make_unique(); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/fold_switch.cc b/tensorflow/compiler/mlir/tensorflow/transforms/fold_switch.cc index ad404182658..fe9c10781fd 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/fold_switch.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/fold_switch.cc @@ -58,7 +58,7 @@ limitations under the License. namespace mlir { namespace { -class SwitchFoldPass : public mlir::FunctionPass { +class SwitchFoldPass : public mlir::PassWrapper { public: void runOnFunction() override; }; @@ -279,7 +279,7 @@ void SwitchFoldPass::runOnFunction() { } // namespace mlir namespace tf_executor { -std::unique_ptr> CreateSwitchFoldPass() { +std::unique_ptr> CreateSwitchFoldPass() { return std::make_unique(); } } // namespace tf_executor diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/freeze_global_tensors.cc b/tensorflow/compiler/mlir/tensorflow/transforms/freeze_global_tensors.cc index 088080c603b..d3b064f3efa 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/freeze_global_tensors.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/freeze_global_tensors.cc @@ -42,7 +42,7 @@ namespace { // support resources/variables . Further, this contract also ensures that this // pass lowers from saved model to pure TF. Hence it fails, if it cannot lower. struct FreezeGlobalTensorsPass - : public OperationPass { + : public PassWrapper> { void runOnOperation() override; }; @@ -85,6 +85,7 @@ void FreezeGlobalTensorsPass::runOnOperation() { } // Replace the arg with a tf.Const op in the function body. + builder.setInsertionPointToStart(&func.getBody().front()); auto const_op = builder.create(global_tensor.getLoc(), global_tensor.value()); args_to_erase.push_back(i); @@ -113,7 +114,7 @@ static PassRegistration pass( "tf-saved-model-freeze-global-tensors", "Freeze tf_saved_model.global_tensor's in func bodies."); -std::unique_ptr> CreateFreezeGlobalTensorsPass() { +std::unique_ptr> CreateFreezeGlobalTensorsPass() { return std::make_unique(); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_cfg.cc b/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_cfg.cc index b502b0ceb01..91bbac235e9 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_cfg.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_cfg.cc @@ -34,7 +34,7 @@ namespace TF { namespace { struct FunctionalControlFlowToCFG - : public FunctionPass { + : public PassWrapper { void runOnFunction() override; }; @@ -312,7 +312,7 @@ void FunctionalControlFlowToCFG::runOnFunction() { } // namespace -std::unique_ptr> CreateTFFunctionalControlFlowToCFG() { +std::unique_ptr> CreateTFFunctionalControlFlowToCFG() { return std::make_unique(); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/gpu_fusion.cc b/tensorflow/compiler/mlir/tensorflow/transforms/gpu_fusion.cc index a88ea2f387d..736af741842 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/gpu_fusion.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/gpu_fusion.cc @@ -35,7 +35,7 @@ namespace { // GpuOpFusionPass is a pass performing fusion specific to GPU targets. // This is an ad-hoc pass for now, but should be integrated with some notion // of "target" in the MLIR pipeline in the future. -class GpuOpFusionPass : public FunctionPass { +class GpuOpFusionPass : public PassWrapper { public: void runOnFunction() final; }; @@ -123,7 +123,7 @@ void GpuOpFusionPass::runOnFunction() { } // namespace -std::unique_ptr> CreateGpuOpFusionPass() { +std::unique_ptr> CreateGpuOpFusionPass() { return std::make_unique(); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/graph_pruning.cc b/tensorflow/compiler/mlir/tensorflow/transforms/graph_pruning.cc index 6e022a64262..498b9fa79a0 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/graph_pruning.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/graph_pruning.cc @@ -84,7 +84,7 @@ void PruneGraph(GraphOp graph) { namespace { // This transformation pass prunes a TF graph eliminating dead-nodes. -struct GraphPruning : public FunctionPass { +struct GraphPruning : public PassWrapper { void runOnFunction() override { getFunction().walk([](tf_executor::GraphOp graph) { // For TensorFlow V1.0 compatibility: when importing a graph without @@ -100,7 +100,7 @@ struct GraphPruning : public FunctionPass { } // namespace -std::unique_ptr> CreateTFExecutorGraphPruningPass() { +std::unique_ptr> CreateTFExecutorGraphPruningPass() { return std::make_unique(); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/launch_to_device_attribute.cc b/tensorflow/compiler/mlir/tensorflow/transforms/launch_to_device_attribute.cc index 9319e91064d..bce18c0b4b7 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/launch_to_device_attribute.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/launch_to_device_attribute.cc @@ -57,7 +57,7 @@ namespace { constexpr char kDeviceAttr[] = "device"; struct LaunchToDeviceAttributePass - : public FunctionPass { + : public PassWrapper { void runOnFunction() override; }; @@ -122,7 +122,7 @@ void LaunchToDeviceAttributePass::runOnFunction() { } // anonymous namespace -std::unique_ptr> CreateLaunchToDeviceAttributePass() { +std::unique_ptr> CreateLaunchToDeviceAttributePass() { return std::make_unique(); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/layout_optimization.cc b/tensorflow/compiler/mlir/tensorflow/transforms/layout_optimization.cc index 0b03c522596..e76a8da0b29 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/layout_optimization.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/layout_optimization.cc @@ -36,7 +36,8 @@ namespace { // LayoutAssignmentPass assigns optimal data layout (data format) for all // layout sensitive operations. -class LayoutAssignmentPass : public FunctionPass { +class LayoutAssignmentPass + : public PassWrapper { public: LayoutAssignmentPass() = default; explicit LayoutAssignmentPass(const std::string& force_data_format) { @@ -57,7 +58,8 @@ class LayoutAssignmentPass : public FunctionPass { // MoveTransposesPass moves all Transpose ops to the beginning or to the end of // the basic block where they are defined. This will allow canonicalzer to // delete redundant transposes. -class MoveTransposesPass : public FunctionPass { +class MoveTransposesPass + : public PassWrapper { public: enum class Direction { kBegin, kEnd }; @@ -316,7 +318,7 @@ void MoveTransposeAfter(Operation* op, SmallVector* work_list) { SmallVector permutation; auto attr = permutation_op.value().cast(); - for (auto value : attr.getIntValues()) + for (const auto& value : attr.getIntValues()) permutation.push_back(value.getSExtValue()); if (failed(fold_operands.FoldOperandsPermutation(permutation))) return; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc b/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc index 0ec30f44ce7..50f77cd9c3d 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc @@ -31,7 +31,7 @@ namespace mlir { namespace TF { namespace { -class LegalizeHloToTf : public FunctionPass { +class LegalizeHloToTf : public PassWrapper { public: LegalizeHloToTf() = default; LegalizeHloToTf(const LegalizeHloToTf &) {} @@ -76,7 +76,7 @@ static PassRegistration pass( } // end namespace -std::unique_ptr> CreateLegalizeHloToTfPass() { +std::unique_ptr> CreateLegalizeHloToTfPass() { return std::make_unique(); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo_patterns.td b/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo_patterns.td index 8a71005bf70..853fd806c5f 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo_patterns.td +++ b/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo_patterns.td @@ -64,6 +64,7 @@ def : Pat<(HLO_ShiftRightLogicalOp $l, $r, $_), (TF_RightShiftOp $l, $r), def : Pat<(HLO_FloorOp (HLO_DivOp $l, $r, $_)), (TF_FloorDivOp $l, $r), [(AreBroadcastCompatible $l, $r)]>; +def : Pat<(HLO_ComplexOp $r, $i), (TF_ComplexOp $r, $i)>; //===----------------------------------------------------------------------===// // Unary op patterns. //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf_pass.cc index ecd59442bf4..f6be97e51c9 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf_pass.cc @@ -23,7 +23,7 @@ namespace { // Lowers some of the TensorFlow operations that can be represented using other // TensorFlow operations. -struct LowerTF : public FunctionPass { +struct LowerTF : public PassWrapper { void runOnFunction() override { // Add lowering patterns to the list. OwningRewritePatternList patterns; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/mark_function_visibility.cc b/tensorflow/compiler/mlir/tensorflow/transforms/mark_function_visibility.cc index a42e7ea8f71..02e1c994986 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/mark_function_visibility.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/mark_function_visibility.cc @@ -74,8 +74,9 @@ LogicalResult MarkFunctionVisibilityUsingEntryFunctionSpecification( namespace { struct MarkFunctionVisibilityUsingEntryFunctionSpecificationPass - : public OperationPass< - MarkFunctionVisibilityUsingEntryFunctionSpecificationPass, ModuleOp> { + : public PassWrapper< + MarkFunctionVisibilityUsingEntryFunctionSpecificationPass, + OperationPass> { void runOnOperation() override { if (failed(MarkFunctionVisibilityUsingEntryFunctionSpecification( getOperation()))) { @@ -90,7 +91,7 @@ static PassRegistration< pass("tf-mark-func-visibility", "Use tf.entry_function to mark function visibility."); -std::unique_ptr> +std::unique_ptr> CreateMarkFunctionVisibilityUsingEntryFunctionSpecificationPass() { return std::make_unique< MarkFunctionVisibilityUsingEntryFunctionSpecificationPass>(); @@ -110,8 +111,8 @@ static LogicalResult MarkFunctionVisibilityUsingSavedModelLinkage( namespace { struct MarkFunctionVisibilityUsingSavedModelLinkagePass - : public OperationPass { + : public PassWrapper> { void runOnOperation() override { if (failed(MarkFunctionVisibilityUsingSavedModelLinkage(getOperation()))) { signalPassFailure(); @@ -124,7 +125,7 @@ static PassRegistration pass( "tf-saved-model-mark-func-visibility", "Use tf_saved_model linkage information to mark function visibility."); -std::unique_ptr> +std::unique_ptr> CreateMarkFunctionVisibilityUsingSavedModelLinkagePass() { return std::make_unique(); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/materialize_mlir_passthrough_op.cc b/tensorflow/compiler/mlir/tensorflow/transforms/materialize_mlir_passthrough_op.cc index c62b2a539b5..94fdfb310ac 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/materialize_mlir_passthrough_op.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/materialize_mlir_passthrough_op.cc @@ -35,7 +35,7 @@ namespace mlir { namespace { class MaterializePassthroughOpPass - : public FunctionPass { + : public PassWrapper { public: void runOnFunction() override; }; @@ -96,7 +96,7 @@ void MaterializePassthroughOpPass::runOnFunction() { } // namespace namespace TF { -std::unique_ptr> CreateMaterializePassthroughOpPass() { +std::unique_ptr> CreateMaterializePassthroughOpPass() { return std::make_unique(); } } // namespace TF diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/optimize.cc b/tensorflow/compiler/mlir/tensorflow/transforms/optimize.cc index df8d1aeed16..173015fa74f 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/optimize.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/optimize.cc @@ -33,7 +33,7 @@ namespace { #include "tensorflow/compiler/mlir/tensorflow/transforms/generated_optimize.inc" // Canonicalize operations in functions. -struct TFOptimizePass : public FunctionPass { +struct TFOptimizePass : public PassWrapper { void runOnFunction() override { OwningRewritePatternList patterns; auto func = getFunction(); @@ -71,7 +71,7 @@ void CreateTFStandardPipeline(OpPassManager &pm, pm.addNestedPass(createCSEPass()); } -std::unique_ptr> CreateTFOptimizePass() { +std::unique_ptr> CreateTFOptimizePass() { return std::make_unique(); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/optimize.td b/tensorflow/compiler/mlir/tensorflow/transforms/optimize.td index 0fb62cb064d..75d2bc06482 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/optimize.td +++ b/tensorflow/compiler/mlir/tensorflow/transforms/optimize.td @@ -27,23 +27,20 @@ def HasOneUse : Constraint>; // If we see a Conv2D op followed by Mul, then multiply the filter // with the value in Mul. -def FuseMulAndConv2D : Pat<(TF_MulOp (TF_Conv2DOp:$output $input, - (ConstantOp F32ElementsAttr:$filter), - $strides, - $use_cudnn, - $padding, $explicit_padding, - IsDataFormatNHWC:$data_format, - $dilations), - (ConstantOp F32ElementsAttr:$value)), -// TODO(karimnosseir): Add check for output is of rank 4. - (TF_Conv2DOp $input, - (TF_MulOp (ConstantOp $filter), - (ConstantOp $value)), - $strides, - $use_cudnn, - $padding, $explicit_padding, $data_format, - $dilations), - [(BroadcastableElements $filter, $value), (HasOneUse $output)]>; +def FuseMulAndConv2D : + Pat<(TF_MulOp:$mul (TF_Conv2DOp:$conv $input, + (ConstantOp:$filter F32ElementsAttr:$filter_value), + $strides, $use_cudnn, $padding, $explicit_padding, + IsDataFormatNHWC:$data_format, $dilations), + (ConstantOp:$multiplier F32ElementsAttr:$mul_value)), +// TODO(karimnosseir): Add check for $conv is of rank 4. + (TF_Conv2DOp $input, + (TF_MulOp (ConstantOp $filter_value, (location $filter)), + (ConstantOp $mul_value, (location $multiplier)), + (location $mul)), + $strides, $use_cudnn, $padding, $explicit_padding, $data_format, + $dilations, (location $conv)), + [(BroadcastableElements $filter_value, $mul_value), (HasOneUse $conv)]>; // This rule does the following pattern match and rewrite: // diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/optimize_global_tensors.cc b/tensorflow/compiler/mlir/tensorflow/transforms/optimize_global_tensors.cc index 74b9df3fe9f..550100c8ebf 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/optimize_global_tensors.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/optimize_global_tensors.cc @@ -41,7 +41,7 @@ namespace mlir { namespace tf_saved_model { namespace { struct OptimizeGlobalTensorsPass - : public OperationPass { + : public PassWrapper> { void runOnOperation() override; }; @@ -296,7 +296,7 @@ static PassRegistration pass( "tf-saved-model-optimize-global-tensors", "Optimize tf_saved_model.global_tensor's."); -std::unique_ptr> CreateOptimizeGlobalTensorsPass() { +std::unique_ptr> CreateOptimizeGlobalTensorsPass() { return std::make_unique(); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/parallel_execute_to_islands.cc b/tensorflow/compiler/mlir/tensorflow/transforms/parallel_execute_to_islands.cc index b5ecd5bd32b..693d6d964db 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/parallel_execute_to_islands.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/parallel_execute_to_islands.cc @@ -83,7 +83,7 @@ namespace TFDevice { namespace { struct ParallelExecuteToIslandsPass - : public FunctionPass { + : public PassWrapper { void runOnFunction() override; }; @@ -251,7 +251,7 @@ void ParallelExecuteToIslandsPass::runOnFunction() { } } // anonymous namespace -std::unique_ptr> CreateParallelExecuteToIslandsPass() { +std::unique_ptr> CreateParallelExecuteToIslandsPass() { return std::make_unique(); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h index 48c83ef8813..d6da961eb0e 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h @@ -24,36 +24,36 @@ namespace mlir { // Creates a pass that breaks up an island with multiple ops into multiple // islands, each with a single op. -std::unique_ptr> CreateBreakUpIslandsPass(); +std::unique_ptr> CreateBreakUpIslandsPass(); // Creates a pass that converts mlir functions consisting of mlir ops into a // tf_executor dialect as a single island. -std::unique_ptr> +std::unique_ptr> CreateFunctionalToExecutorDialectConversionPass(); namespace TF { // Transforms functional control flow operations in the standard TensorFlow // dialect to MLIR Control Flow Graph (CFG) form. -std::unique_ptr> CreateTFFunctionalControlFlowToCFG(); +std::unique_ptr> CreateTFFunctionalControlFlowToCFG(); // Materialize the MlirPassthroughOp by replacing it with the MLIR module // attached as an attribute. -std::unique_ptr> CreateMaterializePassthroughOpPass(); +std::unique_ptr> CreateMaterializePassthroughOpPass(); // Performs Shape Inference on the TensorFlow dialect using the global registry. -std::unique_ptr> CreateTFShapeInferencePass(); +std::unique_ptr> CreateTFShapeInferencePass(); // Optional pass which will unroll BatchMatMul and use only MatMul -std::unique_ptr> CreateUnrollBatchMatMulPassPass(); +std::unique_ptr> CreateUnrollBatchMatMulPassPass(); // Optional pass which will map TF BatchMatMul to TF Einsum -std::unique_ptr> CreateBatchMatMulToEinsumPass(); +std::unique_ptr> CreateBatchMatMulToEinsumPass(); // Optimizes Tensorflow graph. -std::unique_ptr> CreateTFOptimizePass(); +std::unique_ptr> CreateTFOptimizePass(); // Performs specific fusion for GPU targets. -std::unique_ptr> CreateGpuOpFusionPass(); +std::unique_ptr> CreateGpuOpFusionPass(); struct LayoutOptimizationPipelineOptions : public PassPipelineOptions { @@ -82,14 +82,14 @@ void CreateTFStandardPipeline(OpPassManager& pm, const StandardPipelineOptions& options); // Propagates device attributes of resources from callers to callees. -std::unique_ptr> CreateResourceDeviceInferencePass(); +std::unique_ptr> CreateResourceDeviceInferencePass(); // Creates a pass that promotes resource reads/writes in the main function to // inputs and outputs of the main function, assuming that resource operations // have already been decomposed and function calls have already been inlined. // The pass also annotates the input arguments for resources with the indices // of their aliasing output arguments. -std::unique_ptr> CreatePromoteResourcesToArgsPass(); +std::unique_ptr> CreatePromoteResourcesToArgsPass(); // Marks function visibility using tf.entry_function specification. That is, // functions with tf.entry_function attributes are marked with public @@ -98,11 +98,11 @@ LogicalResult MarkFunctionVisibilityUsingEntryFunctionSpecification( ModuleOp module); // Creates a pass that uses tf.entry_function specification to mark function // visibility. -std::unique_ptr> +std::unique_ptr> CreateMarkFunctionVisibilityUsingEntryFunctionSpecificationPass(); // Creates a simple device assignment pass on TF dialect for CoreRT use case. -std::unique_ptr> CreateSimpleTFDeviceAssignmentPass( +std::unique_ptr> CreateSimpleTFDeviceAssignmentPass( llvm::StringRef default_device); // Performs resource lifting on the function body to hoist resource variable @@ -112,25 +112,26 @@ LogicalResult ResourceLiftingForFunctionalControlFlow(FuncOp function); // Converts stack ops into operations on local variables, which can later be // removed by resource lifting. Requires known maximum sizes of stacks and // known element shapes of push ops. -std::unique_ptr> CreateStackOpsDecompositionPass(); +std::unique_ptr> CreateStackOpsDecompositionPass(); // Converts tensor list operations into operations on buffers and sizes. Needs // static shapes and known max element count. -std::unique_ptr> CreateTensorListOpsDecompositionPass(); +std::unique_ptr> CreateTensorListOpsDecompositionPass(); // Converts tensor array ops into operations on local variables, which can later // be removed by resource lifting. Requires known sizes and known element shapes // (either defined in TensorArrayV3 or implied in the first write). -std::unique_ptr> CreateTensorArrayOpsDecompositionPass(); +std::unique_ptr> +CreateTensorArrayOpsDecompositionPass(); // Create a pass that legalize HLO to TF dialect. -std::unique_ptr> CreateLegalizeHloToTfPass(); +std::unique_ptr> CreateLegalizeHloToTfPass(); } // namespace TF namespace TFControlFlow { // Raises from the "TensorFlow Control Flow" dialect to the standard TensorFlow // dialect. -std::unique_ptr> CreateRaiseTFControlFlowPass(); +std::unique_ptr> CreateRaiseTFControlFlowPass(); } // namespace TFControlFlow @@ -138,29 +139,30 @@ namespace tf_executor { class GraphOp; // Returns a pass that folds switch nodes with constant predicates. -std::unique_ptr> CreateSwitchFoldPass(); +std::unique_ptr> CreateSwitchFoldPass(); // Creates a pass to merge IslandOps from TFExecutor dialect. -std::unique_ptr> CreateTFExecutorIslandCoarseningPass(); +std::unique_ptr> CreateTFExecutorIslandCoarseningPass(); // Creates a pass to merge IslandOps for operation marked for execution on TPU. // This is a V1 backward compatibility. -std::unique_ptr> +std::unique_ptr> CreateTFExecutorTPUV1IslandCoarseningPass(); // Creates a pass to outlining TPU clusters from single IslandOp into a nested // module suitable for being processed as-if it was a V2 module. // This is a V1 backward compatibility. -std::unique_ptr> +std::unique_ptr> CreateTFExecutorTPUV1IslandOutliningPass(); // Creates a pass to inline calls to the nested TPU module, this reverses the // effect of the `TFExecutorTPUV1IslandOutlining` pass above. // This is a V1 backward compatibility. -std::unique_ptr> CreateTFExecutorTPUV1IslandInliningPass(); +std::unique_ptr> +CreateTFExecutorTPUV1IslandInliningPass(); // Creates a pass to prune tf_executor.graph from dead nodes. -std::unique_ptr> CreateTFExecutorGraphPruningPass(); +std::unique_ptr> CreateTFExecutorGraphPruningPass(); // Prunes unreachable operations of a tf_executor.graph operation. void PruneGraph(GraphOp graph); @@ -168,29 +170,29 @@ void PruneGraph(GraphOp graph); // Sink `tf.Const` operations in the LaunchOp region using them. This is // performed in order to limit the number of values implicitly captured in this // region before outlining. -std::unique_ptr> CreateTFExecutorConstantSinkingPass(); +std::unique_ptr> CreateTFExecutorConstantSinkingPass(); } // namespace tf_executor namespace TFDevice { // Creates a pass that forms clusters from instructions that are assigned to // same device. -std::unique_ptr> CreateClusterFormationPass(); +std::unique_ptr> CreateClusterFormationPass(); // Creates a pass that outlines regions of tf_device.launch operations. -std::unique_ptr> CreateClusterOutliningPass(); +std::unique_ptr> CreateClusterOutliningPass(); // A pass that decomposes composite resource operations into primitive ones like // ReadVariableOp, AssignVariableOp and other computations to facilitate // transformations like resource op lifting. -std::unique_ptr> CreateDecomposeResourceOpsPass(); +std::unique_ptr> CreateDecomposeResourceOpsPass(); // Creates a pass that lifts operations on external resource variables from // device computation nested in `tf_device::LaunchOp` out so that resource // variable load operations are all before device computation while resource // variable store operations are all after device computation. After this pass, // device computation no longer interacts with external resource variables. -std::unique_ptr> CreateResourceOpLiftingPass(); +std::unique_ptr> CreateResourceOpLiftingPass(); // Lifts resource operations from tf_device.launch_func ops nested in `op` // outside. Returns a failure if there are remaining resource-type values that @@ -198,55 +200,56 @@ std::unique_ptr> CreateResourceOpLiftingPass(); LogicalResult LiftResourceOps(Operation* op); // Creates a pass that hoists invariant operations in a `tf_device.replicate`. -std::unique_ptr> CreateReplicateInvariantOpHoistingPass(); +std::unique_ptr> CreateReplicateInvariantOpHoistingPass(); // Creates a pass that forms replica `tf_executor.island` from a single // `tf_device.replicate` island. -std::unique_ptr> CreateReplicateToIslandPass(); +std::unique_ptr> CreateReplicateToIslandPass(); // Creates a pass that creates `tf_executor.island` from a single // `tf_device.parallel_execute` island. -std::unique_ptr> CreateParallelExecuteToIslandsPass(); +std::unique_ptr> CreateParallelExecuteToIslandsPass(); // Creates a pass that annotates whether a LaunchFuncOp's parameters have the // same data across replicas. -std::unique_ptr> CreateAnnotateParameterReplicationPass(); +std::unique_ptr> +CreateAnnotateParameterReplicationPass(); // Creates a pass that hoists a `tf_device.launch` body and assigns a `device` // attribute to each TensorFlow dialect op in the body based on the `device` // attribute on the `tf_device.launch`. -std::unique_ptr> CreateLaunchToDeviceAttributePass(); +std::unique_ptr> CreateLaunchToDeviceAttributePass(); } // namespace TFDevice namespace TFTPU { // Creates a pass that forms clusters from operations of the same // `_tpu_replicate` attribute. -std::unique_ptr> CreateTPUClusterFormationPass(); +std::unique_ptr> CreateTPUClusterFormationPass(); // Creates a pass that allows TPU program inputs to have layouts determined at // run time. -std::unique_ptr> CreateTPUDynamicLayoutPass(); +std::unique_ptr> CreateTPUDynamicLayoutPass(); // Creates a pass that remaps and assigns padding map from a // `tf_device.launch_func` `padding_map` attribute to its encapsulated function. -std::unique_ptr> CreateTPUDynamicPaddingMapperPass(); +std::unique_ptr> CreateTPUDynamicPaddingMapperPass(); // Creates a pass that rewrites `tf_device.launch_func` on TPUs into TPU runtime // ops. -std::unique_ptr> CreateTPURewritePass(); +std::unique_ptr> CreateTPURewritePass(); // Creates a pass that identifies XLASharding ops in launch op for TPU // computation. -std::unique_ptr> CreateTPUShardingIdentificationPass(); +std::unique_ptr> CreateTPUShardingIdentificationPass(); // Creates a pass that merges device variable reads/updates into the surrounded // TPUExecute node. This allows the execute node to perform in-place variable // updates. -std::unique_ptr> CreateTPUMergeVariablesWithExecutePass(); +std::unique_ptr> CreateTPUMergeVariablesWithExecutePass(); // Creates a pass that adds ops which perform formatting on variables at // run-time according to compilation result. -std::unique_ptr> CreateTPUVariableReformattingPass(); +std::unique_ptr> CreateTPUVariableReformattingPass(); // Populates the supplied passmanager with the passes required to run the void CreateTPUBridgePipeline(OpPassManager& pm); @@ -260,16 +263,16 @@ void CreateTPUBridgePipelineV1(OpPassManager& pm); namespace tf_saved_model { // Creates a pass that optimizes tf_saved_model.global_tensor ops. -std::unique_ptr> CreateOptimizeGlobalTensorsPass(); +std::unique_ptr> CreateOptimizeGlobalTensorsPass(); // Creates a pass that freezes tf_saved_model.global_tensor ops. -std::unique_ptr> CreateFreezeGlobalTensorsPass(); +std::unique_ptr> CreateFreezeGlobalTensorsPass(); // Creates a pass that uses tf_saved_model dialect linkage information // to mark function visibility. That is, exported functions are marked with // public visibility while the other functions are marked with private // visibility. -std::unique_ptr> +std::unique_ptr> CreateMarkFunctionVisibilityUsingSavedModelLinkagePass(); } // namespace tf_saved_model diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/promote_resources_to_args.cc b/tensorflow/compiler/mlir/tensorflow/transforms/promote_resources_to_args.cc index 61644866886..d58fd55da64 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/promote_resources_to_args.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/promote_resources_to_args.cc @@ -258,7 +258,7 @@ LogicalResult PromoteResourcesToArguments(FuncOp function) { } class PromoteResourcesToArgsPass - : public OperationPass { + : public PassWrapper> { public: void runOnOperation() override; }; @@ -285,7 +285,7 @@ void PromoteResourcesToArgsPass::runOnOperation() { } // namespace -std::unique_ptr> CreatePromoteResourcesToArgsPass() { +std::unique_ptr> CreatePromoteResourcesToArgsPass() { return std::make_unique(); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/raise_control_flow.cc b/tensorflow/compiler/mlir/tensorflow/transforms/raise_control_flow.cc index e71b4a530b3..ca234818e10 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/raise_control_flow.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/raise_control_flow.cc @@ -32,7 +32,8 @@ namespace mlir { namespace TFControlFlow { namespace { -struct RaiseTFControlFlow : public FunctionPass { +struct RaiseTFControlFlow + : public PassWrapper { void runOnFunction() { // First start by recognizing loops and reconstructing a loop tree. buildLoopNests(); @@ -145,7 +146,7 @@ void RaiseTFControlFlow::rewriteOps() { } // namespace -std::unique_ptr> CreateRaiseTFControlFlowPass() { +std::unique_ptr> CreateRaiseTFControlFlowPass() { return std::make_unique(); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/replicate_invariant_op_hoisting.cc b/tensorflow/compiler/mlir/tensorflow/transforms/replicate_invariant_op_hoisting.cc index 5c21e1bffcc..031d57e99ba 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/replicate_invariant_op_hoisting.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/replicate_invariant_op_hoisting.cc @@ -37,7 +37,7 @@ namespace { constexpr char kDeviceAttr[] = "device"; struct ReplicateInvariantOpHoistingPass - : public FunctionPass { + : public PassWrapper { void runOnFunction() override; }; @@ -178,7 +178,8 @@ void ReplicateInvariantOpHoistingPass::runOnFunction() { } } // anonymous namespace -std::unique_ptr> CreateReplicateInvariantOpHoistingPass() { +std::unique_ptr> +CreateReplicateInvariantOpHoistingPass() { return std::make_unique(); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/replicate_to_island.cc b/tensorflow/compiler/mlir/tensorflow/transforms/replicate_to_island.cc index 0b41225e503..a781f054755 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/replicate_to_island.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/replicate_to_island.cc @@ -43,7 +43,8 @@ namespace TFDevice { namespace { constexpr char kDeviceAttr[] = "device"; -struct ReplicateToIslandPass : public FunctionPass { +struct ReplicateToIslandPass + : public PassWrapper { void runOnFunction() override; }; @@ -237,7 +238,7 @@ void ReplicateToIslandPass::runOnFunction() { } } // anonymous namespace -std::unique_ptr> CreateReplicateToIslandPass() { +std::unique_ptr> CreateReplicateToIslandPass() { return std::make_unique(); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/resource_device_inference.cc b/tensorflow/compiler/mlir/tensorflow/transforms/resource_device_inference.cc index 2ae62bfee10..d37dfd14590 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/resource_device_inference.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/resource_device_inference.cc @@ -54,7 +54,7 @@ constexpr char kFuncDeviceAttr[] = "tf.device"; // This pass changes the module by adding "tf.device" attribute to function // arguments and adding "device" attribute to TF ops. struct ResourceDeviceInference - : public OperationPass { + : public PassWrapper> { void runOnOperation() override; }; @@ -266,7 +266,7 @@ void ResourceDeviceInference::runOnOperation() { } // namespace -std::unique_ptr> CreateResourceDeviceInferencePass() { +std::unique_ptr> CreateResourceDeviceInferencePass() { return std::make_unique(); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc index 420367f72b8..43316faea34 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc @@ -132,7 +132,7 @@ namespace { // } // struct ResourceOpLiftingPass - : public OperationPass { + : public PassWrapper> { void runOnOperation() override; }; @@ -1071,7 +1071,8 @@ void ResourceOpLiftingPass::runOnOperation() { } struct ResourceOpLiftingForMainFunctionPass - : public OperationPass { + : public PassWrapper> { void runOnOperation() override; }; @@ -1100,7 +1101,7 @@ static PassRegistration pass( } // namespace namespace TFDevice { -std::unique_ptr> CreateResourceOpLiftingPass() { +std::unique_ptr> CreateResourceOpLiftingPass() { return std::make_unique(); } } // namespace TFDevice diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference_pass.cc index c90089ad9d5..48e4e77ce0f 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference_pass.cc @@ -47,7 +47,8 @@ namespace { // This transformation pass propagate shapes on the TensorFlow graph. // It is a ModulePass in order to be able to change function types. -struct ShapeInference : public OperationPass { +struct ShapeInference + : public PassWrapper> { void runOnOperation() override { auto module = getOperation(); auto producer_or = tensorflow::GetTfGraphProducerVersion(module); @@ -70,7 +71,7 @@ PassRegistration pass( } // namespace -std::unique_ptr> CreateTFShapeInferencePass() { +std::unique_ptr> CreateTFShapeInferencePass() { return std::make_unique(); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/sink_constant.cc b/tensorflow/compiler/mlir/tensorflow/transforms/sink_constant.cc index ff3dae278b3..0eafdea0964 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/sink_constant.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/sink_constant.cc @@ -39,7 +39,7 @@ namespace { using ::mlir::TF::ConstOp; class ExecutorConstantSinking - : public mlir::FunctionPass { + : public mlir::PassWrapper { void runOnFunction() override { getFunction().walk([](tf_device::LaunchOp launch) { LLVM_DEBUG(llvm::dbgs() << "Visit " << *launch.getOperation() << "\n"); @@ -89,7 +89,7 @@ static mlir::PassRegistration pass( } // anonymous namespace -std::unique_ptr> CreateTFExecutorConstantSinkingPass() { +std::unique_ptr> CreateTFExecutorConstantSinkingPass() { return std::make_unique(); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/stack_ops_decomposition.cc b/tensorflow/compiler/mlir/tensorflow/transforms/stack_ops_decomposition.cc index 6abf4893327..55b22ad8625 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/stack_ops_decomposition.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/stack_ops_decomposition.cc @@ -85,7 +85,7 @@ namespace cutil = TF::collection_ops_util; // // The pass also works across control flow and functional calls. struct StackOpsDecompositionPass - : public OperationPass { + : public PassWrapper> { void runOnOperation() override; }; @@ -568,7 +568,7 @@ static PassRegistration pass( } // namespace namespace TF { -std::unique_ptr> CreateStackOpsDecompositionPass() { +std::unique_ptr> CreateStackOpsDecompositionPass() { return std::make_unique(); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tensor_array_ops_decomposition.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tensor_array_ops_decomposition.cc index f97d9306a43..eda2e0e1bad 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tensor_array_ops_decomposition.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tensor_array_ops_decomposition.cc @@ -68,7 +68,8 @@ using std::string; // shape. // struct TensorArrayOpsDecompositionPass - : public OperationPass { + : public PassWrapper> { void runOnOperation() override; }; @@ -893,7 +894,8 @@ static PassRegistration pass( } // namespace namespace TF { -std::unique_ptr> CreateTensorArrayOpsDecompositionPass() { +std::unique_ptr> +CreateTensorArrayOpsDecompositionPass() { return std::make_unique(); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tensor_list_ops_decomposition.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tensor_list_ops_decomposition.cc index d8ae4fb534a..4eb078b7d2f 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tensor_list_ops_decomposition.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tensor_list_ops_decomposition.cc @@ -62,7 +62,8 @@ namespace cutil = TF::collection_ops_util; // // The pass also works across control flow and functional calls. struct TensorListOpsDecompositionPass - : public OperationPass { + : public PassWrapper> { void runOnOperation() override; }; @@ -383,7 +384,7 @@ LogicalResult GetConstShapeValue(Value shape_value, if (!shape_op) return failure(); auto shape_const_op = llvm::dyn_cast(shape_op); if (!shape_const_op) return failure(); - for (auto v : shape_const_op.value().getValues()) { + for (const auto& v : shape_const_op.value().getValues()) { shape->push_back(v.getSExtValue()); } return success(); @@ -728,7 +729,8 @@ static PassRegistration pass( } // namespace namespace TF { -std::unique_ptr> CreateTensorListOpsDecompositionPass() { +std::unique_ptr> +CreateTensorListOpsDecompositionPass() { return std::make_unique(); } } // namespace TF diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/test_side_effect_analysis.cc b/tensorflow/compiler/mlir/tensorflow/transforms/test_side_effect_analysis.cc index 38960eef411..6b284222526 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/test_side_effect_analysis.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/test_side_effect_analysis.cc @@ -39,7 +39,7 @@ namespace { // A pass that adds "Predecessors" and "Successors" remarks for each op based on // SideEffectAnalysis result. For testing purpose only. struct TestSideEffectAnalysis - : public mlir::FunctionPass { + : public mlir::PassWrapper { void runOnFunction() override { int64_t next_id = 0; llvm::SmallDenseMap ids; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tf_device_assignment.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tf_device_assignment.cc index 83451e130ba..2a770b2615d 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tf_device_assignment.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tf_device_assignment.cc @@ -24,7 +24,7 @@ namespace TF { namespace { class SimpleTFDeviceAssignmentPass - : public FunctionPass { + : public PassWrapper { public: SimpleTFDeviceAssignmentPass() = default; SimpleTFDeviceAssignmentPass(const SimpleTFDeviceAssignmentPass&) {} @@ -57,7 +57,7 @@ class SimpleTFDeviceAssignmentPass } // namespace -std::unique_ptr> CreateSimpleTFDeviceAssignmentPass( +std::unique_ptr> CreateSimpleTFDeviceAssignmentPass( llvm::StringRef default_device) { return std::make_unique(default_device); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.cc index cf45c8da5e9..500b879e697 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.cc @@ -40,7 +40,9 @@ namespace tensorflow { // Optimization Passes and convert back to MLIR. // Constraints: This pass expects that all operations in the MLIR module either // belong to 'tf' or '_tf' dialect. The output is in '_tf' dialect. -class GraphOptPass : public mlir::OperationPass { +class GraphOptPass + : public mlir::PassWrapper> { public: explicit GraphOptPass(std::vector passes) : passes_(std::move(passes)) {} @@ -166,13 +168,13 @@ class GraphOptByNamePass : public GraphOptPass { } // namespace tensorflow -std::unique_ptr> +std::unique_ptr> tensorflow::CreateTensorFlowGraphOptimizationPass( std::vector tf_passes) { return std::make_unique(std::move(tf_passes)); } -std::unique_ptr> +std::unique_ptr> tensorflow::CreateTensorFlowGraphOptimizationPass( const std::vector& pass_names) { return std::make_unique(pass_names); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.h b/tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.h index bea23f8face..18ec5320d80 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.h @@ -24,7 +24,7 @@ namespace tensorflow { // Create a module pass that will execute the given TF GraphOptimization passes // in sequence. // Pass requires that the module ran on is convertible to TF Graph. -std::unique_ptr> +std::unique_ptr> CreateTensorFlowGraphOptimizationPass( std::vector tf_passes); @@ -32,7 +32,7 @@ CreateTensorFlowGraphOptimizationPass( // passes are queried, if a TF graph optimization pass is not found in registry // then the pass fails. // Pass requires that the module ran on is convertible to TF Graph. -std::unique_ptr> +std::unique_ptr> CreateTensorFlowGraphOptimizationPass( const std::vector& pass_names); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_formation.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_formation.cc index fe11fee9f08..860d537c7ef 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_formation.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_formation.cc @@ -71,7 +71,8 @@ using MetadataMap = llvm::SmallDenseMap; using ClusterMap = llvm::SmallDenseMap, 8>; -struct TPUClusterFormation : public FunctionPass { +struct TPUClusterFormation + : public PassWrapper { void runOnFunction() override; }; @@ -502,7 +503,7 @@ void TPUClusterFormation::runOnFunction() { } } // anonymous namespace -std::unique_ptr> CreateTPUClusterFormationPass() { +std::unique_ptr> CreateTPUClusterFormationPass() { return std::make_unique(); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_dynamic_layout_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_dynamic_layout_pass.cc index 45fd3a5751d..6fb686995b4 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_dynamic_layout_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_dynamic_layout_pass.cc @@ -37,6 +37,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" #include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h" #include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h" #include "tensorflow/core/util/device_name_utils.h" @@ -73,7 +74,8 @@ constexpr char kDeviceAttr[] = "device"; // %copy_to_device. There will not be send/recv ops added by later passes, // because tf.TPUCopyWithLayout accepts a host input and produces a device // output. -struct TPUDynamicLayoutPass : public FunctionPass { +struct TPUDynamicLayoutPass + : public PassWrapper { void runOnFunction() override; }; @@ -91,63 +93,59 @@ bool IsSupportedInputOp(Operation* op) { return parsed_device.type == "CPU"; } +OpBuilder CreateBuilderAfterOp(Operation* op) { + return OpBuilder(op->getBlock(), ++Block::iterator(op)); +} + // Builds a TPUGetLayoutOp with the given compile op and input index. -TF::TPUGetLayoutOp BuildGetLayout(tf_device::LaunchOp compile_launch, - int64_t index, OpBuilder* builder) { - builder->setInsertionPointAfter(compile_launch); +TF::TPUGetLayoutOp BuildGetLayout(const int64_t execute_arg_index, + Value compilation_key, + tf_device::LaunchOp compile_launch, + OpBuilder* builder) { return builder->create( compile_launch.getLoc(), - llvm::ArrayRef{ - RankedTensorType::get({-1}, builder->getIntegerType(64))}, - llvm::ArrayRef{compile_launch.getResult(1)}, + llvm::ArrayRef{RankedTensorType::get({ShapedType::kDynamicSize}, + builder->getIntegerType(64))}, + llvm::ArrayRef{compilation_key}, llvm::ArrayRef{ - builder->getNamedAttr("index", builder->getI64IntegerAttr(index)), + builder->getNamedAttr("index", + builder->getI64IntegerAttr(execute_arg_index)), builder->getNamedAttr("is_output", builder->getBoolAttr(false))}); } // Builds a TPUCopyWithLayoutOp with the given get_layout op and input. -// walk_order for ops in the original IR is needed because we need to insert the -// ops after both get_layout and input, so we use the walk order to find which -// one comes later. -TF::TPUCopyWithLayoutOp BuildCopyWithLayout( - TF::TPUExecuteOp execute, tf_device::LaunchOp compile_launch, - TF::TPUGetLayoutOp get_layout, Value input, - const llvm::SmallDenseMap& walk_order, - OpBuilder* builder) { - auto input_op = input.getDefiningOp(); - int64_t compile_walk_order = walk_order.find(compile_launch)->getSecond(); - int64_t input_walk_order = walk_order.find(input_op)->getSecond(); - if (compile_walk_order > input_walk_order) { - builder->setInsertionPointAfter(get_layout); - } else { - builder->setInsertionPointAfter(input_op); - } +TF::TPUCopyWithLayoutOp BuildCopyWithLayout(tf_device::LaunchOp execute_launch, + tf_device::LaunchOp compile_launch, + TF::TPUGetLayoutOp get_layout, + Value input, OpBuilder* builder) { return builder->create( - execute.getLoc(), llvm::ArrayRef{input.getType()}, + execute_launch.getLoc(), llvm::ArrayRef{input.getType()}, llvm::ArrayRef{input, get_layout.layout()}, llvm::ArrayRef{}); } // Performs transformation for a non-replicated input. -void HandleInput(Value input, int64_t index, TF::TPUExecuteOp execute, - tf_device::LaunchOp compile_launch, - const llvm::SmallDenseMap& walk_order) { - OpBuilder builder(compile_launch.getContext()); - auto get_layout = BuildGetLayout(compile_launch, index, &builder); - auto copy_with_layout = BuildCopyWithLayout( - execute, compile_launch, get_layout, input, walk_order, &builder); - copy_with_layout.setAttr( - kDeviceAttr, - llvm::cast(execute.getParentOp()).deviceAttr()); - execute.setOperand(index, copy_with_layout); +void HandleInput(Value input, const int64_t execute_arg_index, + TF::TPUExecuteOp execute, tf_device::LaunchOp execute_launch, + tf_device::LaunchOp compile_launch) { + OpBuilder builder = CreateBuilderAfterOp(compile_launch); + auto get_layout = BuildGetLayout(execute_arg_index, execute.key(), + compile_launch, &builder); + builder.setInsertionPoint(execute_launch); + auto copy_with_layout = BuildCopyWithLayout(execute_launch, compile_launch, + get_layout, input, &builder); + copy_with_layout.setAttr(kDeviceAttr, execute_launch.deviceAttr()); + execute.setOperand(execute_arg_index, copy_with_layout); } // Performs transformation for replicated inputs. Returns true if this is a // supported case (thus transform happened). -bool HandleReplicatedInputs( - int64_t index, TF::TPUExecuteOp execute, tf_device::LaunchOp compile_launch, - int64_t replicate_arg_index, tf_device::ReplicateOp replicate, - const llvm::SmallDenseMap& walk_order) { +bool HandleReplicatedInputs(const int64_t execute_arg_index, + Value compilation_key, + tf_device::LaunchOp execute_launch, + tf_device::LaunchOp compile_launch, + const int64_t replicate_arg_index, + tf_device::ReplicateOp replicate) { // We need to know the devices to copy to. if (!replicate.devices()) return false; int64_t num_replicas = replicate.n().getZExtValue(); @@ -158,18 +156,17 @@ bool HandleReplicatedInputs( auto input_op = entry.value().getDefiningOp(); if (!input_op || !IsSupportedInputOp(input_op)) return false; } - OpBuilder builder(execute.getContext()); - auto get_layout = BuildGetLayout(compile_launch, index, &builder); + OpBuilder builder = CreateBuilderAfterOp(compile_launch); + auto get_layout = BuildGetLayout(execute_arg_index, compilation_key, + compile_launch, &builder); + builder.setInsertionPoint(replicate); for (auto entry : llvm::enumerate(inputs)) { - auto copy_with_layout = - BuildCopyWithLayout(execute, compile_launch, get_layout, entry.value(), - walk_order, &builder); + auto copy_with_layout = BuildCopyWithLayout( + execute_launch, compile_launch, get_layout, entry.value(), &builder); - // As model parallelism is not supported yet, assume that all ops are - // placed at logical core 0. auto device_list = replicate.devices() .getValue() - .get(tensorflow::GetDeviceAliasForLogicalCore(0)) + .get(execute_launch.getDevice()) .cast(); copy_with_layout.setAttr(kDeviceAttr, device_list.getValue()[entry.index()]); @@ -180,83 +177,94 @@ bool HandleReplicatedInputs( return true; } -// Performs transformation on a pair of execute and compile ops. The compile -// should not have other uses. -void HandleExecute(TF::TPUExecuteOp execute, tf_device::LaunchOp compile_launch, - const llvm::SmallDenseMap& walk_order) { - auto maybe_replicate = execute.getParentOfType(); - llvm::SmallVector unrestricted_input_indices; - for (auto input : llvm::enumerate(execute.args())) { - if (auto block_arg = input.value().dyn_cast()) { - // For a block argument, consider transforms only when it is a replicated - // input (defining ops will be outside the replicate node). - if (maybe_replicate != block_arg.getParentRegion()->getParentOp() || - !HandleReplicatedInputs(input.index(), execute, compile_launch, - block_arg.getArgNumber(), maybe_replicate, - walk_order)) { - continue; - } - } else { - // For an op output, consider transforms only when 1) there is no - // replicateion or 2) it is outside the replicate node that encloses the - // execute node. (Because if the op is inside replicate, it is probably - // not on the host.) - auto input_op = input.value().getDefiningOp(); - if (maybe_replicate && - maybe_replicate.body().isAncestor(input_op->getParentRegion())) { - continue; - } - if (!IsSupportedInputOp(input_op)) continue; - HandleInput(input.value(), input.index(), execute, compile_launch, - walk_order); - } - unrestricted_input_indices.push_back(input.index()); - } - if (unrestricted_input_indices.empty()) return; - - // Update the compilation metadata if we changed anything. - Operation& compile = compile_launch.GetBody().front(); - auto metadata_attr = compile.getAttrOfType("metadata"); - assert(metadata_attr && "Missing compilation metadata"); +// Performs transformation on a compile and associated execute(s) ops. The +// compile should not have other uses. +void HandleCompileAndExecutes( + tf_device::LaunchOp compile_launch, + llvm::MutableArrayRef execute_launches) { + auto compile = + llvm::cast(compile_launch.GetBody().front()); tensorflow::tpu::TPUCompileMetadataProto metadata; - metadata.ParseFromString(std::string(metadata_attr.getValue())); - for (int64_t input_index : unrestricted_input_indices) { - metadata.mutable_args(input_index)->set_unrestricted_layout(true); + metadata.ParseFromString(compile.metadata().str()); + llvm::SmallVector, 4> input_mappings = + tensorflow::GetMetadataArgumentMapping(metadata); + + bool metadata_updated = false; + auto maybe_replicate = + execute_launches.front().getParentOfType(); + + for (auto execute_and_input_mapping : + llvm::zip(execute_launches, input_mappings)) { + auto& execute_launch = std::get<0>(execute_and_input_mapping); + auto execute = + llvm::cast(execute_launch.GetBody().front()); + const auto& input_mapping = std::get<1>(execute_and_input_mapping); + + for (auto& input_and_idx : llvm::enumerate(execute.args())) { + Value input = input_and_idx.value(); + const int64_t execute_arg_index = input_and_idx.index(); + if (auto block_arg = input.dyn_cast()) { + // For a block argument, consider transforms only when it is a + // replicated input (defining ops will be outside the replicate node). + if (maybe_replicate != block_arg.getParentRegion()->getParentOp() || + !HandleReplicatedInputs( + execute_arg_index, execute.key(), execute_launch, + compile_launch, block_arg.getArgNumber(), maybe_replicate)) { + continue; + } + } else { + // For an op output, consider transforms only when 1) there is no + // replication or 2) it is outside the replicate node that encloses the + // execute node. (Because if the op is inside replicate, it is probably + // not on the host.) + auto* input_op = input.getDefiningOp(); + if (maybe_replicate && + maybe_replicate.body().isAncestor(input_op->getParentRegion())) { + continue; + } + if (!IsSupportedInputOp(input_op)) continue; + HandleInput(input, execute_arg_index, execute, execute_launch, + compile_launch); + } + + metadata.mutable_args(input_mapping[execute_arg_index]) + ->set_unrestricted_layout(true); + metadata_updated = true; + } } - compile.setAttr("metadata", StringAttr::get(metadata.SerializeAsString(), - compile.getContext())); + + if (metadata_updated) + compile.setAttr("metadata", StringAttr::get(metadata.SerializeAsString(), + compile.getContext())); } void TPUDynamicLayoutPass::runOnFunction() { - llvm::SmallVector, 4> - executes_and_compiles; - llvm::SmallDenseMap walk_order; - int64_t next_walk_order = 0; - getFunction().walk([&](Operation* op) { - walk_order[op] = next_walk_order++; - // Detect tf._TPUCompileMlir -> tf.TPUExecute - auto execute = llvm::dyn_cast(op); - if (!execute) return; - auto execute_launch = - llvm::dyn_cast_or_null(execute.getParentOp()); - if (!execute_launch || !execute_launch.WrapsSingleOp()) return; - auto compile = execute.key().getDefiningOp(); - if (!compile || !compile->getResult(1).hasOneUse()) return; - auto compile_launch = llvm::dyn_cast(compile); - if (!compile_launch || !compile_launch.WrapsSingleOp() || - !llvm::isa(compile_launch.GetBody().front())) - return; - executes_and_compiles.emplace_back(execute, compile_launch); + getFunction().walk([](TF::_TPUCompileMlirOp compile) { + // Detect tf._TPUCompileMlir -> tf.TPUExecute(s). + auto compile_launch = + llvm::dyn_cast(compile.getParentOp()); + if (!compile_launch || !compile_launch.WrapsSingleOp()) return; + + llvm::SmallVector execute_launches; + execute_launches.reserve(compile_launch.getNumResults() - 1); + for (Value program_result : llvm::drop_begin(compile_launch.results(), 1)) { + if (!program_result.hasOneUse()) return; + Operation* user = *program_result.user_begin(); + auto execute = llvm::dyn_cast(user); + if (!execute) return; + auto execute_launch = + llvm::dyn_cast(execute.getParentOp()); + if (!execute_launch || !execute_launch.WrapsSingleOp()) return; + execute_launches.push_back(execute_launch); + } + + HandleCompileAndExecutes(compile_launch, execute_launches); }); - for (auto execute_and_compile : executes_and_compiles) { - HandleExecute(execute_and_compile.first, execute_and_compile.second, - walk_order); - } } } // namespace -std::unique_ptr> CreateTPUDynamicLayoutPass() { +std::unique_ptr> CreateTPUDynamicLayoutPass() { return std::make_unique(); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_dynamic_padding_mapper.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_dynamic_padding_mapper.cc index a54826c8f8e..df2f1b3b326 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_dynamic_padding_mapper.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_dynamic_padding_mapper.cc @@ -49,7 +49,7 @@ constexpr char kPaddingMapAttr[] = "padding_map"; namespace { struct TPUDynamicPaddingMapper - : public OperationPass { + : public PassWrapper> { void runOnOperation() override; }; @@ -200,7 +200,7 @@ void TPUDynamicPaddingMapper::runOnOperation() { } } // anonymous namespace -std::unique_ptr> CreateTPUDynamicPaddingMapperPass() { +std::unique_ptr> CreateTPUDynamicPaddingMapperPass() { return std::make_unique(); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_merge_variables_with_execute.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_merge_variables_with_execute.cc index 46d22844457..3fd0dcd5a67 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_merge_variables_with_execute.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_merge_variables_with_execute.cc @@ -75,7 +75,7 @@ constexpr char kFuncDeviceAttr[] = "tf.device"; // the TPUExecute op. struct TPUMergeVariablesWithExecutePass - : public FunctionPass { + : public PassWrapper { void runOnFunction() override; }; @@ -531,7 +531,8 @@ void TPUMergeVariablesWithExecutePass::runOnFunction() { } // namespace -std::unique_ptr> CreateTPUMergeVariablesWithExecutePass() { +std::unique_ptr> +CreateTPUMergeVariablesWithExecutePass() { return std::make_unique(); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc index e735fa918bb..a635fdb9a1f 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc @@ -98,7 +98,8 @@ constexpr char kBadArrayAttrLengthMsg[] = // %4 = "tf.SomeOp"(%3) namespace { -struct TPURewritePass : public OperationPass { +struct TPURewritePass + : public PassWrapper> { void runOnOperation() override; }; @@ -223,7 +224,7 @@ LogicalResult SetMetadataProtoPaddingMap( if (!padding_map) return op.emitOpError(CreateMissingAttributeMsg(kPaddingMapAttr)); - for (const auto padding_and_idx : llvm::enumerate(padding_map)) { + for (const auto& padding_and_idx : llvm::enumerate(padding_map)) { auto& padding_attr = padding_and_idx.value(); auto padding_attr_str = padding_attr.dyn_cast(); if (!padding_attr_str) @@ -770,7 +771,7 @@ void TPURewritePass::runOnOperation() { } // namespace -std::unique_ptr> CreateTPURewritePass() { +std::unique_ptr> CreateTPURewritePass() { return std::make_unique(); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_sharding_identification_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_sharding_identification_pass.cc index 05c8e096f38..f0455cf010a 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_sharding_identification_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_sharding_identification_pass.cc @@ -40,7 +40,8 @@ namespace { constexpr char kShardingAttr[] = "xla_hlo.sharding"; struct TPUShardingIdentificationPass - : public OperationPass { + : public PassWrapper> { void runOnOperation() override; }; @@ -185,7 +186,7 @@ void TPUShardingIdentificationPass::runOnOperation() { } // anonymous namespace -std::unique_ptr> CreateTPUShardingIdentificationPass() { +std::unique_ptr> CreateTPUShardingIdentificationPass() { return std::make_unique(); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_variable_runtime_reformatting.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_variable_runtime_reformatting.cc index a58c28c50d1..a6ea26b1ebf 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_variable_runtime_reformatting.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_variable_runtime_reformatting.cc @@ -116,7 +116,8 @@ std::string GetRandomStateVariableName() { // tf.TPUReshardVariablesOp(%rvar, %default_format, %rstate) // } struct TPUVariableRuntimeReformattingPass - : public OperationPass { + : public PassWrapper> { void runOnOperation() override; }; @@ -575,7 +576,7 @@ void TPUVariableRuntimeReformattingPass::runOnOperation() { } // namespace -std::unique_ptr> CreateTPUVariableReformattingPass() { +std::unique_ptr> CreateTPUVariableReformattingPass() { return std::make_unique(); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/unroll_batch_matmul.cc b/tensorflow/compiler/mlir/tensorflow/transforms/unroll_batch_matmul.cc index d5603416d54..1856f9541f4 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/unroll_batch_matmul.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/unroll_batch_matmul.cc @@ -44,7 +44,8 @@ namespace { // Unrolls a BatchMatMul on the batch dimension. We need to slice each batch out // of the inputs, matmul them individually, then stack them all back together at // the end. -struct UnrollBatchMatMulPass : public FunctionPass { +struct UnrollBatchMatMulPass + : public PassWrapper { void runOnFunction() override; }; @@ -309,7 +310,7 @@ static PassRegistration pass( "tf-unroll-batch-matmul", "Unroll TF BatchMatMul op into Reshape, Slice, MatMul, Pack ops."); -std::unique_ptr> CreateUnrollBatchMatMulPassPass() { +std::unique_ptr> CreateUnrollBatchMatMulPassPass() { return std::make_unique(); } diff --git a/tensorflow/compiler/mlir/tensorflow/translate/breakup-islands.cc b/tensorflow/compiler/mlir/tensorflow/translate/breakup-islands.cc index d33dfba50ea..510337b54cd 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/breakup-islands.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/breakup-islands.cc @@ -42,7 +42,7 @@ namespace mlir { namespace { -struct BreakUpIslands : FunctionPass { +struct BreakUpIslands : PassWrapper { void runOnFunction() final; void BreakUpIsland(tf_executor::IslandOp island_op, @@ -325,7 +325,7 @@ void BreakUpIslands::BreakUpIsland( } // namespace -std::unique_ptr> CreateBreakUpIslandsPass() { +std::unique_ptr> CreateBreakUpIslandsPass() { return std::make_unique(); } diff --git a/tensorflow/compiler/mlir/tensorflow/translate/control_to_executor_dialect.cc b/tensorflow/compiler/mlir/tensorflow/translate/control_to_executor_dialect.cc index 9d66fc9d355..b5ebd45936a 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/control_to_executor_dialect.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/control_to_executor_dialect.cc @@ -45,7 +45,7 @@ namespace { // otherwise _tf operations are wrapped in an island and the _ prefix is // removed. Control dependencies are moved to be handled by the island itself. struct ControlToExecutorDialectConversion - : public FunctionPass { + : public PassWrapper { void runOnFunction() override; private: @@ -237,7 +237,7 @@ void ControlToExecutorDialectConversion::runOnFunction() { } } -OpPassBase *CreateTFControlToExecutorDialectConversion() { +OperationPass *CreateTFControlToExecutorDialectConversion() { return new ControlToExecutorDialectConversion(); } diff --git a/tensorflow/compiler/mlir/tensorflow/translate/executor_to_control_dialect.cc b/tensorflow/compiler/mlir/tensorflow/translate/executor_to_control_dialect.cc index 40a359808cf..7d0b75006a7 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/executor_to_control_dialect.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/executor_to_control_dialect.cc @@ -39,7 +39,7 @@ namespace mlir { namespace { struct ExecutorToControlDialectConversion - : public FunctionPass { + : public PassWrapper { void runOnFunction() override; }; } // end anonymous namespace @@ -230,7 +230,7 @@ void ExecutorToControlDialectConversion::runOnFunction() { graph.erase(); } -std::unique_ptr> +std::unique_ptr> CreateTFExecutorToControlDialectConversion() { return std::make_unique(); } diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc index e2b52da0f68..851eb03edac 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc @@ -2004,7 +2004,7 @@ StatusOr GraphDefImporter::InferMainFunctionType( llvm::SmallVector arg_types; arg_types.reserve(specs.inputs.size()); int i = 0; - for (auto it : specs.inputs) { + for (const auto& it : specs.inputs) { Node* arg_node = arg_nodes->at(i).node; if (arg_node == nullptr) { return errors::InvalidArgument("Input ", it.first, diff --git a/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.cc b/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.cc index b2cf906be0d..77da19d6853 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.cc @@ -81,7 +81,7 @@ Status ParseInputArrayInfo(const std::vector& node_names, // using the type from the graph. used_node_dtypes.resize(node_names.size(), DataType_Name(DT_INVALID)); } else if (node_names.size() == node_dtypes.size()) { - for (auto dtype : node_dtypes) { + for (const auto& dtype : node_dtypes) { if (dtype.empty()) { used_node_dtypes.push_back(DataType_Name(DT_INVALID)); } else if (dtype != DataType_Name(DT_INVALID)) { diff --git a/tensorflow/compiler/mlir/tensorflow/translate/tf_functional_to_executor.cc b/tensorflow/compiler/mlir/tensorflow/translate/tf_functional_to_executor.cc index 4a625b62857..29f98de6448 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/tf_functional_to_executor.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/tf_functional_to_executor.cc @@ -40,7 +40,7 @@ namespace { // return %graph_results#... // } struct FunctionalToExecutorDialectConversion - : public FunctionPass { + : public PassWrapper { void runOnFunction() override; }; } // end anonymous namespace @@ -95,7 +95,7 @@ void FunctionalToExecutorDialectConversion::runOnFunction() { } } -std::unique_ptr> +std::unique_ptr> CreateFunctionalToExecutorDialectConversionPass() { return std::make_unique(); } diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc index 3e250ec287b..7a627780f25 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc @@ -254,8 +254,9 @@ static void RegisterDialects() { } // namespace Status ConvertMLIRToXlaComputation( - mlir::ModuleOp module_op, xla::XlaComputation* xla_computation, - bool use_tuple_args, bool return_tuple, + mlir::ModuleOp module_op, llvm::StringRef device_type, + xla::XlaComputation* xla_computation, bool use_tuple_args, + bool return_tuple, const XlaCompiler::ShapeRepresentationFn shape_representation_fn) { mlir::PassManager tf2xla(module_op.getContext()); tf2xla.addNestedPass(mlir::createCanonicalizerPass()); @@ -268,6 +269,7 @@ Status ConvertMLIRToXlaComputation( // with a tuple argument which break the assumption of resource lifting // inside PromoteResourcesToArgs. tf2xla.addPass(mlir::xla_hlo::createLegalizeTFControlFlowPass()); + tf2xla.addPass(mlir::xla_hlo::createLegalizeTfWithTf2XlaPass(device_type)); // We need to run LegalizeTFPass 2 times because first // LegalizeTFPass(allow_partial_conversion=true) can expose more graph pruning // and canonicalization opportunities that are necessary for the second @@ -308,7 +310,7 @@ Status ConvertMLIRToXlaComputation( static Status CompileMlirToXlaHlo( mlir::ModuleOp module_op, llvm::ArrayRef arg_shapes, - bool use_tuple_args, + llvm::StringRef device_type, bool use_tuple_args, XlaCompiler::ShapeRepresentationFn shape_representation_fn, XlaCompiler::CompilationResult* compilation_result) { if (VLOG_IS_ON(1)) @@ -326,7 +328,8 @@ static Status CompileMlirToXlaHlo( // Convert MLIR module to XLA HLO proto contained in XlaComputation. compilation_result->computation = std::make_shared(); TF_RETURN_IF_ERROR(ConvertMLIRToXlaComputation( - module_op, compilation_result->computation.get(), use_tuple_args, + module_op, device_type, compilation_result->computation.get(), + use_tuple_args, /*return_tuple=*/true, shape_representation_fn)); // Construct mapping from XlaComputation's arg to input edges of execute @@ -355,7 +358,7 @@ static Status CompileMlirToXlaHlo( Status CompileSerializedMlirToXlaHlo( llvm::StringRef mlir_module_string, llvm::ArrayRef arg_shapes, - bool use_tuple_args, + llvm::StringRef device_type, bool use_tuple_args, const XlaCompiler::ShapeRepresentationFn shape_representation_fn, XlaCompiler::CompilationResult* compilation_result) { RegisterDialects(); @@ -364,14 +367,15 @@ Status CompileSerializedMlirToXlaHlo( TF_RETURN_IF_ERROR( ParseMlirModule(mlir_module_string, &mlir_context, &mlir_module)); - return CompileMlirToXlaHlo(mlir_module.get(), arg_shapes, use_tuple_args, - shape_representation_fn, compilation_result); + return CompileMlirToXlaHlo(mlir_module.get(), arg_shapes, device_type, + use_tuple_args, shape_representation_fn, + compilation_result); } Status CompileGraphToXlaHlo( const Graph& graph, llvm::ArrayRef arg_shapes, - bool use_tuple_args, const FunctionLibraryDefinition& flib_def, - const GraphDebugInfo& debug_info, + llvm::StringRef device_type, bool use_tuple_args, + const FunctionLibraryDefinition& flib_def, const GraphDebugInfo& debug_info, const XlaCompiler::ShapeRepresentationFn shape_representation_fn, XlaCompiler::CompilationResult* compilation_result) { RegisterDialects(); @@ -383,8 +387,8 @@ Status CompileGraphToXlaHlo( if (!module_or.ok()) return module_or.status(); return CompileMlirToXlaHlo(module_or.ValueOrDie().get(), arg_shapes, - use_tuple_args, shape_representation_fn, - compilation_result); + device_type, use_tuple_args, + shape_representation_fn, compilation_result); } } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h index 2ce0a31eb78..74c602a7afb 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h @@ -29,6 +29,8 @@ namespace tensorflow { // Lowers MLIR module to XLA HLO inside an XlaComputation. The input module // should only contain operations in tf dialect. If the input module contains // operation in the tf_executor dialect, for example, returns an error. +// Exception to this are tf_executor dialect ops that are optimized away through +// canonicalization. // // Operations in tf dialect are lowered to XLA HLO through the following steps: // . Legalizes control flow operations. @@ -39,6 +41,8 @@ namespace tensorflow { // . Legalizes the operations to XLA HLO operations. // . Canonicalizes the XLA HLO operations. // +// device_type: XLA JIT device to use for compilation such as "XLA_CPU_JIT", +// "XLA_GPU_JIT" or "XLA_TPU_JIT". // use_tuple_args: when this is true, always create a tuple argument for the // entry computation. // return_tuple: when this is true, always create a tuple result for the @@ -47,23 +51,24 @@ namespace tensorflow { // will be used to determine argument and result shapes. Otherwise the // original shape will be used as is. Status ConvertMLIRToXlaComputation( - mlir::ModuleOp module_op, xla::XlaComputation* xla_computation, - bool use_tuple_args, bool return_tuple, + mlir::ModuleOp module_op, llvm::StringRef device_type, + xla::XlaComputation* xla_computation, bool use_tuple_args, + bool return_tuple, const XlaCompiler::ShapeRepresentationFn shape_representation_fn = nullptr); // Compiles a serialized MLIR module into XLA HLO, generates all accompanying // metadata and stores them in CompilationResult. Status CompileSerializedMlirToXlaHlo( llvm::StringRef mlir_module_string, llvm::ArrayRef arg_shapes, - bool use_tuple_args, + llvm::StringRef device_type, bool use_tuple_args, const XlaCompiler::ShapeRepresentationFn shape_representation_fn, XlaCompiler::CompilationResult* compilation_result); // Same as the above but takes input as TensorFlow Graph. Status CompileGraphToXlaHlo( const Graph& graph, llvm::ArrayRef arg_shapes, - bool use_tuple_args, const FunctionLibraryDefinition& flib_def, - const GraphDebugInfo& debug_info, + llvm::StringRef device_type, bool use_tuple_args, + const FunctionLibraryDefinition& flib_def, const GraphDebugInfo& debug_info, const XlaCompiler::ShapeRepresentationFn shape_representation_fn, XlaCompiler::CompilationResult* compilation_result); diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc index d406934c520..26c50a24f58 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc @@ -46,7 +46,7 @@ TEST(CompileSerializedMlirToXlaHloTest, InvalidSerializedMlirModule) { XlaCompiler::CompilationResult compilation_result; Status s = CompileSerializedMlirToXlaHlo( - invalid_mlir_module, arg_shapes, + invalid_mlir_module, arg_shapes, "XLA_CPU_JIT", /*use_tuple_args=*/true, TestShapeRepresentation, &compilation_result); EXPECT_EQ(s.code(), tensorflow::errors::Code::INVALID_ARGUMENT); EXPECT_EQ(s.ToString(), @@ -68,7 +68,7 @@ TEST(CompileSerializedMlirToXlaHloTest, TupleArgs) { XlaCompiler::CompilationResult compilation_result; Status s = CompileSerializedMlirToXlaHlo( - kBinaryAddModule, arg_shapes, + kBinaryAddModule, arg_shapes, "XLA_CPU_JIT", /*use_tuple_args=*/true, TestShapeRepresentation, &compilation_result); TF_ASSERT_OK(s); @@ -126,7 +126,7 @@ TEST(CompileSerializedMlirToXlaHloTest, IndividualArgs) { XlaCompiler::CompilationResult compilation_result; Status s = CompileSerializedMlirToXlaHlo( - kBinaryAddModule, arg_shapes, + kBinaryAddModule, arg_shapes, "XLA_CPU_JIT", /*use_tuple_args=*/false, TestShapeRepresentation, &compilation_result); TF_ASSERT_OK(s); @@ -197,7 +197,7 @@ TEST(CompileSerializedMlirToXlaHloTest, CompileTimeConstantFoldedSuccess) { XlaCompiler::CompilationResult compilation_result; Status s = CompileSerializedMlirToXlaHlo( - mlir_module, arg_shapes, + mlir_module, arg_shapes, "XLA_CPU_JIT", /*use_tuple_args=*/true, TestShapeRepresentation, &compilation_result); TF_ASSERT_OK(s); @@ -236,7 +236,7 @@ TEST(CompileSerializedMlirToXlaHloTest, ShapeInference) { XlaCompiler::CompilationResult compilation_result; Status s = CompileSerializedMlirToXlaHlo( - mlir_module, arg_shapes, + mlir_module, arg_shapes, "XLA_CPU_JIT", /*use_tuple_args=*/true, TestShapeRepresentation, &compilation_result); TF_ASSERT_OK(s); @@ -267,7 +267,7 @@ module attributes {tf.versions = {producer = 179 : i32}} { XlaCompiler::CompilationResult compilation_result; Status s = CompileSerializedMlirToXlaHlo( - mlir_module, arg_shapes, + mlir_module, arg_shapes, "XLA_CPU_JIT", /*use_tuple_args=*/true, TestShapeRepresentation, &compilation_result); TF_ASSERT_OK(s); @@ -325,7 +325,7 @@ module attributes {tf.versions = {producer = 179 : i32}} { XlaCompiler::CompilationResult compilation_result; Status s = CompileSerializedMlirToXlaHlo( - mlir_module, arg_shapes, + mlir_module, arg_shapes, "XLA_CPU_JIT", /*use_tuple_args=*/true, TestShapeRepresentation, &compilation_result); TF_ASSERT_OK(s); @@ -362,7 +362,7 @@ module attributes {tf.versions = {producer = 179 : i32}} { XlaCompiler::CompilationResult compilation_result; Status s = CompileSerializedMlirToXlaHlo( - mlir_module, arg_shapes, + mlir_module, arg_shapes, "XLA_CPU_JIT", /*use_tuple_args=*/true, TestShapeRepresentation, &compilation_result); ASSERT_FALSE(s.ok()); EXPECT_EQ(s.error_message(), @@ -384,7 +384,7 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr XlaCompiler::CompilationResult compilation_result; Status s = CompileSerializedMlirToXlaHlo( - mlir_module, arg_shapes, + mlir_module, arg_shapes, "XLA_CPU_JIT", /*use_tuple_args=*/true, TestShapeRepresentation, &compilation_result); TF_ASSERT_OK(s); @@ -424,9 +424,10 @@ TEST(CompileGraphToXlaHlo, Basic) { test::graph::Retval(&graph, 0, arg); XlaCompiler::CompilationResult result; - TF_ASSERT_OK(CompileGraphToXlaHlo( - graph, /*arg_shapes=*/{TensorShape()}, /*use_tuple_args=*/false, flib_def, - GraphDebugInfo(), /*shape_representation_fn=*/nullptr, &result)); + TF_ASSERT_OK( + CompileGraphToXlaHlo(graph, /*arg_shapes=*/{TensorShape()}, "XLA_CPU_JIT", + /*use_tuple_args=*/false, flib_def, GraphDebugInfo(), + /*shape_representation_fn=*/nullptr, &result)); const xla::HloModuleConfig module_config( result.computation->GetProgramShape().ValueOrDie()); diff --git a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc index bdb4ebc5058..29de158ff3c 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc @@ -216,7 +216,7 @@ Status ConvertHalfElementsAttr(const ElementsAttr attr, output_tensor->add_half_val( (*elts.begin()).bitcastToAPInt().getSExtValue()); } else { - for (auto value : elts.getFloatValues()) + for (const auto& value : elts.getFloatValues()) output_tensor->add_half_val(value.bitcastToAPInt().getSExtValue()); } return Status::OK(); @@ -232,7 +232,8 @@ Status ConvertIntElementsAttr(const mlir::ElementsAttr attr, if (elts.isSplat()) { output_tensor->add_int_val((*elts.begin()).getSExtValue()); } else { - for (auto val : elts) output_tensor->add_int_val(val.getSExtValue()); + for (const auto& val : elts) + output_tensor->add_int_val(val.getSExtValue()); } return Status::OK(); } @@ -269,7 +270,8 @@ Status ConvertInt64ElementsAttr(const mlir::ElementsAttr attr, if (elts.isSplat()) { output_tensor->add_int64_val((*elts.begin()).getSExtValue()); } else { - for (auto val : elts) output_tensor->add_int64_val(val.getSExtValue()); + for (const auto& val : elts) + output_tensor->add_int64_val(val.getSExtValue()); } return Status::OK(); } @@ -281,7 +283,7 @@ Status ConvertInt64ElementsAttr(const mlir::ElementsAttr attr, Status ConvertBoolElementsAttr(const mlir::ElementsAttr attr, TensorProto* output_tensor) { if (auto elts = attr.dyn_cast()) { - for (auto val : elts) { + for (const auto& val : elts) { output_tensor->add_bool_val(val.getBoolValue()); } return Status::OK(); diff --git a/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc index aaff33bce3f..1853183c3b4 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h" +#include + #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/STLExtras.h" @@ -242,8 +244,8 @@ mlir::LogicalResult ExtractInputsForLogicalDevices( sharding.ParseFromString( sharding_attr.cast().getValue().str()); - const auto input_sharing_type = sharding.type(); - if (input_sharing_type == xla::OpSharding::OTHER) { + const auto input_sharding_type = sharding.type(); + if (input_sharding_type == xla::OpSharding::OTHER) { llvm::SmallVector tiled_inputs; auto result = HandleTileShardedInputs( launch_func.getLoc(), sharding, input_value, builder, &tiled_inputs); @@ -260,10 +262,10 @@ mlir::LogicalResult ExtractInputsForLogicalDevices( const int assigned_logical_device = sharding.tile_assignment_devices(i); (*input_list)[assigned_logical_device].emplace_back(tiled_inputs[i]); } - } else if (input_sharing_type == xla::OpSharding::REPLICATED) { + } else if (input_sharding_type == xla::OpSharding::REPLICATED) { for (auto& inputs : *input_list) inputs.emplace_back(input_value); } else { - assert(input_sharing_type == xla::OpSharding::MAXIMAL); + assert(input_sharding_type == xla::OpSharding::MAXIMAL); const int logical_device_id = sharding.tile_assignment_devices(0); (*input_list)[logical_device_id].emplace_back(input_value); } @@ -514,4 +516,34 @@ void RemapOutputsFromLogicalDevices( } } +llvm::SmallVector, 4> GetMetadataArgumentMapping( + const tpu::TPUCompileMetadataProto& metadata) { + llvm::SmallVector, 4> input_mappings( + metadata.num_cores_per_replica(), llvm::SmallVector()); + + if (metadata.num_cores_per_replica() == 1) { + input_mappings.front().resize(metadata.args_size()); + std::iota(input_mappings.front().begin(), input_mappings.front().end(), 0); + return input_mappings; + } + + for (const auto& arg_and_idx : llvm::enumerate(metadata.args())) { + const auto& sharding = arg_and_idx.value().sharding(); + const int64_t idx = arg_and_idx.index(); + + const auto& sharding_type = sharding.type(); + if (sharding_type == xla::OpSharding::OTHER) { + for (const auto& device : sharding.tile_assignment_devices()) + input_mappings[device].push_back(idx); + } else if (sharding_type == xla::OpSharding::REPLICATED) { + for (auto& input : input_mappings) input.push_back(idx); + } else { + assert(sharding_type == xla::OpSharding::MAXIMAL); + input_mappings[sharding.tile_assignment_devices(0)].push_back(idx); + } + } + + return input_mappings; +} + } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h b/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h index 2320bd44815..77bfd259cf6 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h @@ -25,6 +25,7 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h" namespace tensorflow { @@ -68,6 +69,12 @@ void RemapOutputsFromLogicalDevices( mlir::tf_device::ParallelExecuteOp parallel_execute, mlir::OpBuilder* builder); +// Determines each logical core argument to metadata argument index mapping, +// based on sharding. The return value is indexed first by logical core then by +// argument index. +llvm::SmallVector, 4> GetMetadataArgumentMapping( + const tpu::TPUCompileMetadataProto& metadata); + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_XLA_SHARDING_UTIL_H_ diff --git a/tensorflow/compiler/mlir/xla/BUILD b/tensorflow/compiler/mlir/xla/BUILD index 0feb633948d..e20f8543e61 100644 --- a/tensorflow/compiler/mlir/xla/BUILD +++ b/tensorflow/compiler/mlir/xla/BUILD @@ -160,8 +160,6 @@ cc_library( deps = [ ":hlo", ":mlir_hlo_builder", - "//tensorflow/compiler/jit:xla_cpu_device", - "//tensorflow/compiler/jit:xla_cpu_jit", "//tensorflow/compiler/mlir:op_or_arg_name_mapper", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:convert_type", diff --git a/tensorflow/compiler/mlir/xla/hlo_utils.cc b/tensorflow/compiler/mlir/xla/hlo_utils.cc index 7526248baca..dfed190ba1e 100644 --- a/tensorflow/compiler/mlir/xla/hlo_utils.cc +++ b/tensorflow/compiler/mlir/xla/hlo_utils.cc @@ -123,6 +123,8 @@ StatusOr ConvertPrimitiveTypeToMLIRType(PrimitiveType element_type, return builder.getI1Type(); case PrimitiveType::F16: return builder.getF16Type(); + case PrimitiveType::BF16: + return builder.getBF16Type(); case PrimitiveType::F32: return builder.getF32Type(); case PrimitiveType::F64: @@ -137,6 +139,8 @@ StatusOr ConvertPrimitiveTypeToMLIRType(PrimitiveType element_type, return builder.getIntegerType(64); case PrimitiveType::C64: return mlir::ComplexType::get(builder.getF32Type()); + case PrimitiveType::C128: + return mlir::ComplexType::get(builder.getF64Type()); // TODO(b/130356985): Support unsigned primitive types. default: return tensorflow::errors::Internal( diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc b/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc index 86e865a1657..a60ebd76d0e 100644 --- a/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc +++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc @@ -104,6 +104,53 @@ DenseIntElementsAttr BuildSliceLimits(DenseIntElementsAttr start_indices, return GetI64ElementsAttr(slice_limits, builder); } +// Returns the padding value of the given position. If padding_attr is a +// nullptr, returns 0. +static int64_t GetPaddingValue(DenseIntElementsAttr padding_attr, + ArrayRef index) { + if (!padding_attr) return 0; + return padding_attr.getValue(index); +} + +static bool IsOnlyPaddingSpatialDims(Value lhs, + ConvDimensionNumbers dimension_numbers, + DenseIntElementsAttr edge_padding_low, + DenseIntElementsAttr edge_padding_high) { + const int64_t batch_dim = dimension_numbers.input_batch_dimension().getInt(); + const int64_t feature_dim = + dimension_numbers.input_feature_dimension().getInt(); + if (edge_padding_low.getValue(batch_dim) || + edge_padding_high.getValue(batch_dim)) + return false; + if (edge_padding_low.getValue(feature_dim) || + edge_padding_high.getValue(feature_dim)) + return false; + return true; +} + +DenseIntElementsAttr BuildConvPaddingAttrs( + DenseIntElementsAttr edge_padding_low, + DenseIntElementsAttr edge_padding_high, DenseIntElementsAttr padding_attr, + ConvDimensionNumbers dimension_numbers, Builder* builder) { + SmallVector padding_low, padding_high; + for (const auto& dim : dimension_numbers.input_spatial_dimensions()) { + unsigned i = dim.getZExtValue(); + padding_low.push_back(edge_padding_low.getValue(i)); + padding_high.push_back(edge_padding_high.getValue(i)); + } + + int rank = padding_low.size(); + SmallVector padding; + for (unsigned i = 0; i < rank; ++i) { + padding.push_back(GetPaddingValue(padding_attr, {i, 0}) + padding_low[i]); + padding.push_back(GetPaddingValue(padding_attr, {i, 1}) + padding_high[i]); + } + // padding_attr.getType() doesn't work because it is an optional attribute, + // which can be a nullptr. + auto type = RankedTensorType::get({rank, 2}, builder->getIntegerType(64)); + return DenseIntElementsAttr::get(type, padding); +} + #include "tensorflow/compiler/mlir/xla/transforms/generated_canonicalize.inc" } // namespace @@ -1608,5 +1655,14 @@ LogicalResult deriveShapeFromFirstOperand( return success(); } +//===----------------------------------------------------------------------===// +// ConvOp +//===----------------------------------------------------------------------===// + +void ConvOp::getCanonicalizationPatterns(OwningRewritePatternList& results, + MLIRContext* context) { + results.insert(context); +} + } // namespace xla_hlo } // namespace mlir diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops.td b/tensorflow/compiler/mlir/xla/ir/hlo_ops.td index 00b43198c55..abfc42b20d9 100644 --- a/tensorflow/compiler/mlir/xla/ir/hlo_ops.td +++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops.td @@ -144,11 +144,11 @@ def HLO_ClzOp: HLO_UnaryElementwiseOp<"count_leading_zeros", [NoSideEffect, SameOperandsAndResultType], HLO_IntTensor>, BASE_HLO_ClzOp; -def HLO_CosOp: HLO_UnaryElementwiseOp<"cos", +def HLO_CosOp: HLO_UnaryElementwiseOp<"cosine", [NoSideEffect, SameOperandsAndResultType], HLO_FpOrComplexTensor>, BASE_HLO_CosOp; -def HLO_ExpOp: HLO_UnaryElementwiseOp<"exp", +def HLO_ExpOp: HLO_UnaryElementwiseOp<"exponential", [NoSideEffect, SameOperandsAndResultType], HLO_FpOrComplexTensor>, BASE_HLO_ExpOp; @@ -178,7 +178,7 @@ def HLO_NotOp: HLO_UnaryElementwiseOp<"not", [NoSideEffect, SameOperandsAndResultType], HLO_PredOrIntTensor>, BASE_HLO_NotOp; -def HLO_NegOp: HLO_UnaryElementwiseOp<"neg", +def HLO_NegOp: HLO_UnaryElementwiseOp<"negate", [NoSideEffect, SameOperandsAndResultType], HLO_IntFpOrComplexTensor>, BASE_HLO_NegOp; @@ -186,7 +186,7 @@ def HLO_PopulationCountOp: HLO_UnaryElementwiseOp<"popcnt", [NoSideEffect, SameOperandsAndResultType], HLO_IntTensor>, BASE_HLO_PopulationCountOp; -def HLO_RoundOp: HLO_UnaryElementwiseOp<"round", +def HLO_RoundOp: HLO_UnaryElementwiseOp<"round_nearest_afz", [NoSideEffect, SameOperandsAndResultType], HLO_FpTensor>, BASE_HLO_RoundOp; def HLO_RsqrtOp: HLO_UnaryElementwiseOp<"rsqrt", @@ -197,7 +197,7 @@ def HLO_SignOp: HLO_UnaryElementwiseOp<"sign", [NoSideEffect, SameOperandsAndResultType], HLO_IntFpOrComplexTensor>, BASE_HLO_SignOp; -def HLO_SinOp: HLO_UnaryElementwiseOp<"sin", +def HLO_SinOp: HLO_UnaryElementwiseOp<"sine", [NoSideEffect, SameOperandsAndResultType], HLO_FpOrComplexTensor>, BASE_HLO_SinOp; @@ -300,7 +300,7 @@ def HLO_MinOp : HLO_BinaryElementwiseOp<"minimum", def HLO_MulOp : HLO_BinaryElementwiseOp<"multiply", [Commutative, NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_MulOp; -def HLO_PowOp : HLO_BinaryElementwiseOp<"pow", +def HLO_PowOp : HLO_BinaryElementwiseOp<"power", [NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_PowOp; def HLO_RemOp : HLO_BinaryElementwiseOp<"remainder", @@ -872,7 +872,7 @@ def ConvDimensionNumbers : StructAttr<"ConvDimensionNumbers", HLO_Dialect, [ let description = "Structure of dimension information for conv op"; } -def HLO_ConvOp : HLO_Op<"conv", [NoSideEffect]>, BASE_HLO_ConvOp { +def HLO_ConvOp : HLO_Op<"convolution", [NoSideEffect]>, BASE_HLO_ConvOp { let arguments = (ins HLO_Tensor:$lhs, HLO_Tensor:$rhs, @@ -892,6 +892,7 @@ def HLO_ConvOp : HLO_Op<"conv", [NoSideEffect]>, BASE_HLO_ConvOp { let results = (outs HLO_Tensor); + let hasCanonicalizer = 1; } def HLO_CopyOp: HLO_Op<"copy", [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_CopyOp { diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td b/tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td index 8dee4d0eb69..f3de67a08c1 100644 --- a/tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td +++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td @@ -37,6 +37,11 @@ def HLO_Token : Type()">, "token">; // Any integer tensor types def HLO_IntTensor : TensorOf<[HLO_Int]>; +// Any integer tensor type with rank 0 (i.e. representing a single integer). +def HLO_ScalarIntTensor : ShapedContainerType< + [HLO_Int], And<[IsTensorTypePred, HasAnyRankOfPred<[0]>]>, + "a 0-dim integer tensor">; + // Any floating-point tensor types def HLO_FpTensor : TensorOf<[AnyFloat]>; diff --git a/tensorflow/compiler/mlir/xla/ir/lhlo_ops.td b/tensorflow/compiler/mlir/xla/ir/lhlo_ops.td index 92e084bf6a2..7613f1e0ffc 100644 --- a/tensorflow/compiler/mlir/xla/ir/lhlo_ops.td +++ b/tensorflow/compiler/mlir/xla/ir/lhlo_ops.td @@ -53,19 +53,20 @@ def LHLO_BufferOrTuple : AnyTypeOf<[LHLO_Buffer, LHLO_TupleBuffer]>; // XLA nullary op definitions. //===----------------------------------------------------------------------===// -class LHLO_Op traits> : Op; +class LHLO_Op traits> : + Op], traits)>; def LHLO_ConstOp : LHLO_Op<"constant", []>, BASE_HLO_ConstOp { let arguments = (ins ElementsAttr:$value, - LHLO_Buffer:$output + Arg:$output ); } def LHLO_IotaOp : LHLO_Op<"iota", []>, BASE_HLO_IotaOp { let arguments = (ins I64Attr:$iota_dimension, - LHLO_Buffer:$output); + Arg:$output); } //===----------------------------------------------------------------------===// @@ -75,8 +76,8 @@ def LHLO_IotaOp : LHLO_Op<"iota", []>, BASE_HLO_IotaOp { class LHLO_UnaryElementwiseOp : LHLO_Op { - let arguments = (ins LHLO_Buffer:$input, - LHLO_Buffer:$output); + let arguments = (ins Arg:$input, + Arg:$output); } def LHLO_AbsOp: LHLO_UnaryElementwiseOp<"abs">, BASE_HLO_AbsOp; @@ -84,17 +85,17 @@ def LHLO_AbsOp: LHLO_UnaryElementwiseOp<"abs">, BASE_HLO_AbsOp; def LHLO_CeilOp: LHLO_UnaryElementwiseOp<"ceil">, BASE_HLO_CeilOp; def LHLO_ConvertOp : LHLO_Op<"convert", [SameOperandsShape]>, BASE_HLO_ConvertOp { - let arguments = (ins LHLO_Buffer:$input, - LHLO_Buffer:$output); + let arguments = (ins Arg:$input, + Arg:$output); } -def LHLO_CosOp: LHLO_UnaryElementwiseOp<"cos">, BASE_HLO_CosOp; +def LHLO_CosOp: LHLO_UnaryElementwiseOp<"cosine">, BASE_HLO_CosOp; -def LHLO_ExpOp: LHLO_UnaryElementwiseOp<"exp">, BASE_HLO_ExpOp; +def LHLO_ExpOp: LHLO_UnaryElementwiseOp<"exponential">, BASE_HLO_ExpOp; def LHLO_LogOp: LHLO_UnaryElementwiseOp<"log">, BASE_HLO_LogOp; -def LHLO_NegOp: LHLO_UnaryElementwiseOp<"neg">, BASE_HLO_NegOp; +def LHLO_NegOp: LHLO_UnaryElementwiseOp<"negate">, BASE_HLO_NegOp; def LHLO_RsqrtOp: LHLO_UnaryElementwiseOp<"rsqrt">, BASE_HLO_RsqrtOp; @@ -111,9 +112,9 @@ def LHLO_TanhOp: LHLO_UnaryElementwiseOp<"tanh">, BASE_HLO_TanhOp; class LHLO_BinaryElementwiseOp traits> : LHLO_Op { let arguments = (ins - LHLO_Buffer:$lhs, - LHLO_Buffer:$rhs, - LHLO_Buffer:$out, + Arg:$lhs, + Arg:$rhs, + Arg:$out, OptionalAttr:$broadcast_dimensions ); } @@ -147,9 +148,9 @@ def LHLO_ReduceOp: LHLO_Op<"reduce", [ SingleBlockImplicitTerminator<"TerminatorOp"> ]>, BASE_HLO_ReduceOp { let arguments = (ins - Variadic:$operands, - Variadic:$init_values, - Variadic:$out, + Arg, "", [MemRead]>:$operands, + Arg, "", [MemRead]>:$init_values, + Arg, "", [MemWrite]>:$out, I64ElementsAttr:$dimensions ); @@ -157,14 +158,13 @@ def LHLO_ReduceOp: LHLO_Op<"reduce", [ } def LHLO_ReduceWindowOp: LHLO_Op<"reduce_window", [ - NoSideEffect, SingleBlockImplicitTerminator<"TerminatorOp"> ]>, BASE_HLO_ReduceWindowOp { let arguments = (ins - LHLO_Buffer:$operand, - LHLO_Buffer:$init_value, - LHLO_Buffer:$out, + Arg:$operand, + Arg:$init_value, + Arg:$out, I64ElementsAttr:$window_dimensions, // If strides or dilations attributes are missing then the default value is // one for each of the input dimensions. Similarly, padding values are zero @@ -184,23 +184,23 @@ def LHLO_ReduceWindowOp: LHLO_Op<"reduce_window", [ def LHLO_GetTupleElementOp: LHLO_Op<"get_tuple_element", []>, BASE_HLO_GetTupleElementOp { let arguments = (ins - LHLO_TupleBuffer:$input, - LHLO_BufferOrTuple:$out, + Arg:$input, + Arg:$out, I32Attr:$index ); } def LHLO_TupleOp : LHLO_Op<"tuple", []>, BASE_HLO_TupleOp { let arguments = (ins - Variadic:$val, - LHLO_TupleBuffer:$out); + Arg, "", [MemRead]>:$val, + Arg:$out); } def LHLO_CompareOp: LHLO_Op<"compare", []>, BASE_HLO_CompareOp { let arguments = (ins - LHLO_Buffer:$lhs, - LHLO_Buffer:$rhs, - LHLO_PredBuffer:$out, + Arg:$lhs, + Arg:$rhs, + Arg:$out, OptionalAttr:$broadcast_dimensions, HLO_ComparisonDirectionAttr:$comparison_direction ); @@ -214,8 +214,8 @@ def LHLO_SliceOp: LHLO_Op< "slice", [AllTypesMatch<["start_indices", "limit_indices", "strides"]>]> { let arguments = (ins - LHLO_Buffer:$operand, - LHLO_Buffer:$output, + Arg:$operand, + Arg:$output, I64ElementsAttr:$start_indices, I64ElementsAttr:$limit_indices, I64ElementsAttr:$strides @@ -224,10 +224,10 @@ def LHLO_SliceOp: LHLO_Op< def HLO_DynamicUpdateSliceOp: LHLO_Op<"dynamic-update-slice", []> { let arguments = (ins - LHLO_Buffer:$operand, - LHLO_Buffer:$update, - LHLO_Buffer:$output, - Variadic:$start_indices + Arg:$operand, + Arg:$update, + Arg:$output, + Arg, "", [MemRead]>:$start_indices ); } @@ -239,12 +239,12 @@ def HLO_BatchNormInferenceOp : LHLO_Op<"batch_norm_inference", []>, BASE_HLO_BatchNormInferenceOp { let arguments = (ins - LHLO_Buffer:$operand, - LHLO_Buffer:$scale, - LHLO_Buffer:$offset, - LHLO_Buffer:$mean, - LHLO_Buffer:$variance, - LHLO_Buffer:$output, + Arg:$operand, + Arg:$scale, + Arg:$offset, + Arg:$mean, + Arg:$variance, + Arg:$output, F32Attr:$epsilon, I64Attr:$feature_index ); @@ -253,8 +253,8 @@ def HLO_BatchNormInferenceOp : LHLO_Op<"batch_norm_inference", []>, def LHLO_BroadcastOp : LHLO_Op<"broadcast", []>, BASE_HLO_BroadcastOp { let arguments = (ins - LHLO_Buffer:$operand, - LHLO_Buffer:$output, + Arg:$operand, + Arg:$output, I64ElementsAttr:$broadcast_sizes ); } @@ -262,90 +262,90 @@ def LHLO_BroadcastOp : LHLO_Op<"broadcast", def LHLO_BroadcastInDimOp : LHLO_Op<"broadcast_in_dim", []>, BASE_HLO_BroadcastInDimOp { let arguments = (ins - LHLO_Buffer:$operand, - LHLO_Buffer:$output, + Arg:$operand, + Arg:$output, BroadcastDimAttr:$broadcast_dimensions ); } def LHLO_ClampOp : LHLO_Op<"clamp", []>, BASE_HLO_ClampOp { let arguments = (ins - LHLO_Buffer:$min, - LHLO_Buffer:$operand, - LHLO_Buffer:$max, - LHLO_Buffer:$output + Arg:$min, + Arg:$operand, + Arg:$max, + Arg:$output ); } def LHLO_ConcatenateOp : LHLO_Op<"concatenate", []>, BASE_HLO_ConcatenateOp { let arguments = (ins - Variadic:$val, - LHLO_Buffer:$output, + Arg, "", [MemRead]>:$val, + Arg:$output, I64Attr:$dimension ); } -def LHLO_ConvOp : LHLO_Op<"conv", []>, BASE_HLO_ConvOp { +def LHLO_ConvOp : LHLO_Op<"convolution", []>, BASE_HLO_ConvOp { let arguments = (ins - LHLO_Buffer:$lhs, - LHLO_Buffer:$rhs, - LHLO_Buffer:$output + Arg:$lhs, + Arg:$rhs, + Arg:$output ); } def LHLO_CopyOp: LHLO_Op<"copy", []>, BASE_HLO_CopyOp { let arguments = (ins - LHLO_Buffer:$operand, - LHLO_Buffer:$output + Arg:$operand, + Arg:$output ); } def LHLO_DotOp: LHLO_Op<"dot", []>, BASE_HLO_DotOp { let arguments = (ins - LHLO_Buffer:$lhs, - LHLO_Buffer:$rhs, + Arg:$lhs, + Arg:$rhs, HLO_PrecisionConfigAttr:$precision_config, - LHLO_Buffer:$output + Arg:$output ); } def LHLO_GatherOp: LHLO_Op<"gather", []>, BASE_HLO_GatherOp { let arguments = (ins - LHLO_Buffer:$operand, - LHLO_IntBuffer:$start_indices, + Arg:$operand, + Arg:$start_indices, I64Attr:$index_vector_dim, I64ElementsAttr:$offset_dims, I64ElementsAttr:$slice_sizes, I64ElementsAttr:$collapsed_slice_dims, I64ElementsAttr:$start_index_map, - LHLO_Buffer:$output + Arg:$output ); } def LHLO_ReshapeOp: LHLO_Op<"reshape", []>, BASE_HLO_ReshapeOp { let arguments = (ins - LHLO_Buffer:$operand, - LHLO_Buffer:$output + Arg:$operand, + Arg:$output ); } def LHLO_SelectOp: LHLO_Op<"select", []>, BASE_HLO_SelectOp { let arguments = (ins - LHLO_PredBuffer:$pred, - LHLO_Buffer:$on_true, - LHLO_Buffer:$on_false, - LHLO_Buffer:$output + Arg:$pred, + Arg:$on_true, + Arg:$on_false, + Arg:$output ); } -def LHLO_SelectAndScatterOp: LHLO_Op<"select_and_scatter", - [NoSideEffect]>, BASE_HLO_SelectAndScatterOp { +def LHLO_SelectAndScatterOp: LHLO_Op<"select_and_scatter", []>, + BASE_HLO_SelectAndScatterOp { let arguments = (ins - LHLO_Buffer:$operand, - LHLO_Buffer:$source, - LHLO_Buffer:$init_value, - LHLO_Buffer:$out, + Arg:$operand, + Arg:$source, + Arg:$init_value, + Arg:$out, OptionalAttr:$window_dimensions, OptionalAttr:$window_strides, OptionalAttr:$padding @@ -356,28 +356,28 @@ def LHLO_SelectAndScatterOp: LHLO_Op<"select_and_scatter", def LHLO_ReverseOp: LHLO_Op<"reverse", []>, BASE_HLO_ReverseOp { let arguments = (ins - LHLO_Buffer:$operand, + Arg:$operand, I64ElementsAttr:$dimensions, - LHLO_Buffer:$output + Arg:$output ); } def LHLO_PadOp: LHLO_Op<"pad", []>, BASE_HLO_PadOp { let arguments = (ins - LHLO_Buffer:$operand, - LHLO_Buffer:$padding_value, + Arg:$operand, + Arg:$padding_value, I64ElementsAttr:$edge_padding_low, I64ElementsAttr:$edge_padding_high, I64ElementsAttr:$interior_padding, - LHLO_Buffer:$output + Arg:$output ); } def LHLO_TransposeOp: LHLO_Op<"transpose", []>, BASE_HLO_TransposeOp { let arguments = (ins - LHLO_Buffer:$operand, + Arg:$operand, I64ElementsAttr:$permutation, - LHLO_Buffer:$output + Arg:$output ); } diff --git a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc index 8922cc131c6..6d87dc8e603 100644 --- a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc +++ b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc @@ -261,18 +261,18 @@ static xla::DotDimensionNumbers Convert_dot_dimension_numbers( dot_dimension_numbers_attr.lhs_batching_dimensions() .cast(); - for (auto val : rhs_contracting_dimensions) { + for (const auto& val : rhs_contracting_dimensions) { dot_dimension_numbers.add_rhs_contracting_dimensions(val.getSExtValue()); } - for (auto val : lhs_contracting_dimensions) { + for (const auto& val : lhs_contracting_dimensions) { dot_dimension_numbers.add_lhs_contracting_dimensions(val.getSExtValue()); } - for (auto val : rhs_batch_dimensions) { + for (const auto& val : rhs_batch_dimensions) { dot_dimension_numbers.add_rhs_batch_dimensions(val.getSExtValue()); } - for (auto val : lhs_batch_dimensions) { + for (const auto& val : lhs_batch_dimensions) { dot_dimension_numbers.add_lhs_batch_dimensions(val.getSExtValue()); } diff --git a/tensorflow/compiler/mlir/xla/tests/buffer-assignment.mlir b/tensorflow/compiler/mlir/xla/tests/buffer-assignment.mlir index 2a1975384e5..7bcf477d45e 100644 --- a/tensorflow/compiler/mlir/xla/tests/buffer-assignment.mlir +++ b/tensorflow/compiler/mlir/xla/tests/buffer-assignment.mlir @@ -7,7 +7,7 @@ func @condBranch(%cond : i1, %arg0 : tensor<2xf32>) -> tensor<2xf32>{ ^bb1: br ^exit(%arg0 : tensor<2xf32>) ^bb2: - %1 = "xla_hlo.exp"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + %1 = "xla_hlo.exponential"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> br ^exit(%1 : tensor<2xf32>) ^exit(%arg1: tensor<2xf32>): return %arg1 : tensor<2xf32> @@ -21,7 +21,7 @@ func @criticalEdge(%cond : i1, %arg0 : tensor<2xf32>) -> tensor<2xf32>{ // CHECK: Alloc: cond_br cond_br %cond, ^bb1, ^exit(%arg0 : tensor<2xf32>) ^bb1: - %0 = "xla_hlo.exp"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + %0 = "xla_hlo.exponential"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> br ^exit(%0 : tensor<2xf32>) ^exit(%arg1: tensor<2xf32>): return %arg1 : tensor<2xf32> @@ -32,8 +32,8 @@ func @criticalEdge(%cond : i1, %arg0 : tensor<2xf32>) -> tensor<2xf32>{ // CHECK-LABEL: Testing : invCriticalEdge func @invCriticalEdge(%cond : i1, %arg0 : tensor<2xf32>) -> tensor<2xf32>{ - // CHECK: Alloc: %0 = "xla_hlo.exp" - %0 = "xla_hlo.exp"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + // CHECK: Alloc: %0 = "xla_hlo.exponential" + %0 = "xla_hlo.exponential"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> cond_br %cond, ^bb1, ^exit(%arg0 : tensor<2xf32>) ^bb1: br ^exit(%0 : tensor<2xf32>) @@ -46,18 +46,18 @@ func @invCriticalEdge(%cond : i1, %arg0 : tensor<2xf32>) -> tensor<2xf32>{ // CHECK-LABEL: Testing : ifElse func @ifElse(%cond : i1, %arg0 : tensor<2xf32>) -> tensor<2xf32>{ - // CHECK: Alloc: %0 = "xla_hlo.exp"(%arg1) - %0 = "xla_hlo.exp"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + // CHECK: Alloc: %0 = "xla_hlo.exponential"(%arg1) + %0 = "xla_hlo.exponential"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> cond_br %cond, ^bb1(%arg0, %0: tensor<2xf32>, tensor<2xf32>), ^bb2(%0, %arg0: tensor<2xf32>, tensor<2xf32>) ^bb1(%arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>): br ^exit(%arg1, %arg2 : tensor<2xf32>, tensor<2xf32>) ^bb2(%arg3 : tensor<2xf32>, %arg4 : tensor<2xf32>): br ^exit(%arg3, %arg4 : tensor<2xf32>, tensor<2xf32>) ^exit(%arg5 : tensor<2xf32>, %arg6 : tensor<2xf32>): - // CHECK-NEXT: Dealloc: %7 = "xla_hlo.exp"(%5) - // CHECK: Alloc: %7 = "xla_hlo.exp"(%5) + // CHECK-NEXT: Dealloc: %7 = "xla_hlo.exponential"(%5) + // CHECK: Alloc: %7 = "xla_hlo.exponential"(%5) // CHECK-NEXT: Dealloc: return - %1 = "xla_hlo.exp"(%arg5) : (tensor<2xf32>) -> tensor<2xf32> + %1 = "xla_hlo.exponential"(%arg5) : (tensor<2xf32>) -> tensor<2xf32> return %1 : tensor<2xf32> } @@ -65,8 +65,8 @@ func @ifElse(%cond : i1, %arg0 : tensor<2xf32>) -> tensor<2xf32>{ // CHECK-LABEL: Testing : ifElseNoUsers func @ifElseNoUsers(%cond : i1, %arg0 : tensor<2xf32>) -> tensor<2xf32>{ - // CHECK: Alloc: %0 = "xla_hlo.exp"(%arg1) - %0 = "xla_hlo.exp"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + // CHECK: Alloc: %0 = "xla_hlo.exponential"(%arg1) + %0 = "xla_hlo.exponential"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> cond_br %cond, ^bb1(%arg0, %0: tensor<2xf32>, tensor<2xf32>), ^bb2(%0, %arg0: tensor<2xf32>, tensor<2xf32>) ^bb1(%arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>): br ^exit(%arg1, %arg2 : tensor<2xf32>, tensor<2xf32>) @@ -81,8 +81,8 @@ func @ifElseNoUsers(%cond : i1, %arg0 : tensor<2xf32>) -> tensor<2xf32>{ // CHECK-LABEL: Testing : ifElseNested func @ifElseNested(%cond : i1, %arg0 : tensor<2xf32>) -> tensor<2xf32>{ - // CHECK: Alloc: %0 = "xla_hlo.exp"(%arg1) - %0 = "xla_hlo.exp"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + // CHECK: Alloc: %0 = "xla_hlo.exponential"(%arg1) + %0 = "xla_hlo.exponential"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> cond_br %cond, ^bb1(%arg0, %0: tensor<2xf32>, tensor<2xf32>), ^bb2(%0, %arg0: tensor<2xf32>, tensor<2xf32>) ^bb1(%arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>): br ^exit(%arg1, %arg2 : tensor<2xf32>, tensor<2xf32>) @@ -93,10 +93,10 @@ func @ifElseNested(%cond : i1, %arg0 : tensor<2xf32>) -> tensor<2xf32>{ ^bb4(%arg8 : tensor<2xf32>): br ^exit(%arg3, %arg8 : tensor<2xf32>, tensor<2xf32>) ^exit(%arg5 : tensor<2xf32>, %arg6 : tensor<2xf32>): - // CHECK-NEXT: Dealloc: %9 = "xla_hlo.exp"(%7) - // CHECK: Alloc: %9 = "xla_hlo.exp"(%7) + // CHECK-NEXT: Dealloc: %9 = "xla_hlo.exponential"(%7) + // CHECK: Alloc: %9 = "xla_hlo.exponential"(%7) // CHECK-NEXT: Dealloc: return - %1 = "xla_hlo.exp"(%arg5) : (tensor<2xf32>) -> tensor<2xf32> + %1 = "xla_hlo.exponential"(%arg5) : (tensor<2xf32>) -> tensor<2xf32> return %1 : tensor<2xf32> } diff --git a/tensorflow/compiler/mlir/xla/tests/canonicalize.mlir b/tensorflow/compiler/mlir/xla/tests/canonicalize.mlir index 1b7d879ca03..a045e1f9d07 100644 --- a/tensorflow/compiler/mlir/xla/tests/canonicalize.mlir +++ b/tensorflow/compiler/mlir/xla/tests/canonicalize.mlir @@ -90,3 +90,68 @@ func @fold_copy(%arg : tensor<1x4xf32>) -> tensor<1x4xf32> { %0 = "xla_hlo.copy"(%arg) : (tensor<1x4xf32>) -> tensor<1x4xf32> return %0 : tensor<1x4xf32> } + +// CHECK-LABEL: func @fold_pad_into_conv_f32 +func @fold_pad_into_conv_f32(%arg0 : tensor<1x32x32x3xf32>, + %arg1 : tensor<7x7x3x64xf32>) + -> tensor<1x16x16x64xf32> { + // CHECK-NOT: xla_hlo.pad + // CHECK: xla_hlo.convolution + // CHECK-SAME: padding = dense<3> : tensor<2x2xi64> + %0 = xla_hlo.constant dense<0.000000e+00> : tensor + %1 = "xla_hlo.pad"(%arg0, %0) { + edge_padding_high = dense<[0, 3, 3, 0]> : tensor<4xi64>, + edge_padding_low = dense<[0, 3, 3, 0]> : tensor<4xi64>, + interior_padding = dense<0> : tensor<4xi64> + } : (tensor<1x32x32x3xf32>, tensor) -> tensor<1x38x38x3xf32> + %2 = "xla_hlo.convolution"(%1, %arg1) { + batch_group_count = 1 : i64, + dimension_numbers = { + input_batch_dimension = 0 : i64, + input_feature_dimension = 3 : i64, + input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, + kernel_input_feature_dimension = 2 : i64, + kernel_output_feature_dimension = 3 : i64, + kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, + output_batch_dimension = 0 : i64, + output_feature_dimension = 3 : i64, + output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64> + }, + feature_group_count = 1 : i64, + padding = dense<0> : tensor<2x2xi64>, + window_strides = dense<2> : tensor<2xi64> + } : (tensor<1x38x38x3xf32>, tensor<7x7x3x64xf32>) -> tensor<1x16x16x64xf32> + return %2 : tensor<1x16x16x64xf32> +} + +// CHECK-LABEL: func @fold_pad_into_conv_i32 +func @fold_pad_into_conv_i32(%arg0 : tensor<1x32x32x3xi32>, + %arg1 : tensor<7x7x3x64xi32>) + -> tensor<1x16x16x64xi32> { + // CHECK-NOT: xla_hlo.pad + // CHECK: xla_hlo.convolution + // CHECK-SAME: padding = dense<3> : tensor<2x2xi64> + %0 = xla_hlo.constant dense<0> : tensor + %1 = "xla_hlo.pad"(%arg0, %0) { + edge_padding_high = dense<[0, 3, 3, 0]> : tensor<4xi64>, + edge_padding_low = dense<[0, 3, 3, 0]> : tensor<4xi64>, + interior_padding = dense<0> : tensor<4xi64> + } : (tensor<1x32x32x3xi32>, tensor) -> tensor<1x38x38x3xi32> + %2 = "xla_hlo.convolution"(%1, %arg1) { + batch_group_count = 1 : i64, + dimension_numbers = { + input_batch_dimension = 0 : i64, + input_feature_dimension = 3 : i64, + input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, + kernel_input_feature_dimension = 2 : i64, + kernel_output_feature_dimension = 3 : i64, + kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, + output_batch_dimension = 0 : i64, + output_feature_dimension = 3 : i64, + output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64> + }, + feature_group_count = 1 : i64, + window_strides = dense<2> : tensor<2xi64> + } : (tensor<1x38x38x3xi32>, tensor<7x7x3x64xi32>) -> tensor<1x16x16x64xi32> + return %2 : tensor<1x16x16x64xi32> +} diff --git a/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-lhlo.mlir b/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-lhlo.mlir index 2858c6d9978..5ef76352acc 100644 --- a/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-lhlo.mlir +++ b/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-lhlo.mlir @@ -3,10 +3,10 @@ // CHECK-LABEL: func @attrs func @attrs_copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xf32> - %tensor_result = "xla_hlo.exp"(%tensor_operand) + %tensor_result = "xla_hlo.exponential"(%tensor_operand) {some_attr_1 = "exp.1", some_attr_2 = dense<1> : tensor<1xi64>} : (tensor<2x2xf32>) -> tensor<2x2xf32> - // CHECK: "xla_lhlo.exp"(%{{.*}}, %{{.*}}) {some_attr_1 = "exp.1", some_attr_2 = dense<1> : tensor<1xi64>} + // CHECK: "xla_lhlo.exponential"(%{{.*}}, %{{.*}}) {some_attr_1 = "exp.1", some_attr_2 = dense<1> : tensor<1xi64>} tensor_store %tensor_result, %result : memref<2x2xf32> return } @@ -83,9 +83,9 @@ func @copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { // CHECK-LABEL: func @exp func @exp(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xf32> - %tensor_result = "xla_hlo.exp"(%tensor_operand) + %tensor_result = "xla_hlo.exponential"(%tensor_operand) : (tensor<2x2xf32>) -> tensor<2x2xf32> - // CHECK: "xla_lhlo.exp"(%{{.*}}, %{{.*}}) + // CHECK: "xla_lhlo.exponential"(%{{.*}}, %{{.*}}) tensor_store %tensor_result, %result : memref<2x2xf32> return } @@ -222,9 +222,9 @@ func @convert(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { // CHECK-LABEL: func @cos func @cos(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xf32> - %tensor_result = "xla_hlo.cos"(%tensor_operand) + %tensor_result = "xla_hlo.cosine"(%tensor_operand) : (tensor<2x2xf32>) -> tensor<2x2xf32> - // CHECK: "xla_lhlo.cos"(%{{.*}}, %{{.*}}) + // CHECK: "xla_lhlo.cosine"(%{{.*}}, %{{.*}}) tensor_store %tensor_result, %result : memref<2x2xf32> return } @@ -234,9 +234,9 @@ func @cos(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { // CHECK-LABEL: func @neg func @neg(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xf32> - %tensor_result = "xla_hlo.neg"(%tensor_operand) + %tensor_result = "xla_hlo.negate"(%tensor_operand) : (tensor<2x2xf32>) -> tensor<2x2xf32> - // CHECK: "xla_lhlo.neg"(%{{.*}}, %{{.*}}) + // CHECK: "xla_lhlo.negate"(%{{.*}}, %{{.*}}) tensor_store %tensor_result, %result : memref<2x2xf32> return } diff --git a/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-linalg.mlir b/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-linalg.mlir index e1e11d4d37d..67c59ba10c5 100644 --- a/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-linalg.mlir +++ b/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-linalg.mlir @@ -126,7 +126,7 @@ func @float_abs(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { func @float_exp(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { // CHECK: linalg.generic // CHECK: exp - %0 = "xla_hlo.exp"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32> + %0 = "xla_hlo.exponential"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32> return %0 : tensor<2x2xf32> } @@ -156,7 +156,7 @@ func @float_ceil(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { func @float_neg(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { // CHECK: linalg.generic // CHECK: negf - %0 = "xla_hlo.neg"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32> + %0 = "xla_hlo.negate"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32> return %0 : tensor<2x2xf32> } @@ -216,7 +216,7 @@ func @int_cmp(%lhs: tensor<2x2xi32>, func @float_cos(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { // CHECK: linalg.generic // CHECK: cos - %0 = "xla_hlo.cos"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32> + %0 = "xla_hlo.cosine"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32> return %0 : tensor<2x2xf32> } diff --git a/tensorflow/compiler/mlir/xla/tests/inlining.mlir b/tensorflow/compiler/mlir/xla/tests/inlining.mlir index 02df235575c..3e447f7ff11 100644 --- a/tensorflow/compiler/mlir/xla/tests/inlining.mlir +++ b/tensorflow/compiler/mlir/xla/tests/inlining.mlir @@ -5,7 +5,7 @@ // CHECK-LABEL: func @caller // CHECK: "xla_hlo.while"{{.*}}( { // CHECK: }, { -// CHECK: "xla_hlo.exp" +// CHECK: "xla_hlo.exponential" // CHECK: }) // CHECK-LABEL: func @callee @@ -23,6 +23,6 @@ func @caller(%arg0: tensor, %pred: tensor) -> tensor { func @callee(%arg0: tensor) -> tensor { - %0 = "xla_hlo.exp"(%arg0) : (tensor) -> tensor + %0 = "xla_hlo.exponential"(%arg0) : (tensor) -> tensor return %0 : tensor } diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-control-flow.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-control-flow.mlir index 10de9bfd2ca..e611b4419c9 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-control-flow.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-control-flow.mlir @@ -47,9 +47,9 @@ func @conditional(%arg0: tensor) -> tensor { ^bb0(%arg1: tensor): // CHECK: ^bb2([[VAL4:%.+]]: tensor): - // CHECK: [[VAL5:%.+]] = "xla_hlo.exp"([[VAL4]]) : (tensor) -> tensor + // CHECK: [[VAL5:%.+]] = "xla_hlo.exponential"([[VAL4]]) : (tensor) -> tensor // CHECK: br ^bb3([[VAL5]] : tensor) - %2 = "xla_hlo.exp"(%arg1) : (tensor) -> tensor + %2 = "xla_hlo.exponential"(%arg1) : (tensor) -> tensor "xla_hlo.return"(%2) : (tensor) -> () }) : (tensor, tensor, tensor) -> tensor @@ -126,7 +126,7 @@ func @conditional_with_multiple_blocks(%arg0: tensor, %arg1: tensor, % // CHECK: %3 = "xla_hlo.log"(%2) : (tensor) -> tensor // CHECK: br ^[[EXIT:.+]](%3 : tensor) // CHECK: ^[[ELSE_ENTRY]](%4: tensor): - // CHECK: %5 = "xla_hlo.exp"(%4) : (tensor) -> tensor + // CHECK: %5 = "xla_hlo.exponential"(%4) : (tensor) -> tensor // CHECK: br ^[[EXIT]](%5 : tensor) // CHECK: ^[[EXIT]](%6: tensor): // CHECK: return %6 : tensor @@ -139,7 +139,7 @@ func @conditional_with_multiple_blocks(%arg0: tensor, %arg1: tensor, % "xla_hlo.return"(%2) : (tensor) -> () }, { ^else_entry(%arg2: tensor): - %2 = "xla_hlo.exp"(%arg2) : (tensor) -> tensor + %2 = "xla_hlo.exponential"(%arg2) : (tensor) -> tensor "xla_hlo.return"(%2) : (tensor) -> () }) : (tensor, tensor, tensor) -> tensor return %1 : tensor diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf-control-flow.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf-control-flow.mlir index da736876259..808d0053416 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf-control-flow.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf-control-flow.mlir @@ -30,7 +30,7 @@ attributes {tf._input_shapes = ["tfshape$", "tfshape$"]} { func @cond_false(%arg0: tensor, %arg1: tensor) -> tensor attributes {tf._input_shapes = ["tfshape$", "tfshape$"]} { - %0 = "xla_hlo.exp"(%arg1) : (tensor) -> tensor + %0 = "xla_hlo.exponential"(%arg1) : (tensor) -> tensor return %0 : tensor } diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir index e271340f247..c1ee15f19de 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir @@ -34,8 +34,8 @@ func @dynamic_operand(%arg0: tensor) -> tensor { // CHECK-LABEL: multiple_dialect_ops func @multiple_dialect_ops(%arg0: tensor<2xf32>) -> tensor<2xf32> { - // CHECK: xla_hlo.neg - %0 = "xla_hlo.neg"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + // CHECK: xla_hlo.negate + %0 = "xla_hlo.negate"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> // CHECK: xla_hlo.abs %1 = "tf.Abs"(%0) : (tensor<2xf32>) -> tensor<2xf32> diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir index 7c7f6f306cf..805b6711abd 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir @@ -629,14 +629,14 @@ func @bitwise_and_dynamic(%arg0: tensor, %arg1: tensor<1xi32>) -> tensor< // CHECK-LABEL: func @pow func @pow(%arg0: tensor<2xf32>) -> tensor<2xf32> { - // CHECK-NEXT: xla_hlo.pow + // CHECK-NEXT: xla_hlo.power %0 = "tf.Pow"(%arg0, %arg0) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> return %0: tensor<2xf32> } // CHECK-LABEL: func @pow_dynamic func @pow_dynamic(%arg0: tensor) -> tensor { - // CHECK-NEXT: xla_hlo.pow + // CHECK-NEXT: xla_hlo.power %0 = "tf.Pow"(%arg0, %arg0) : (tensor, tensor) -> tensor return %0: tensor } @@ -668,7 +668,7 @@ func @floordiv_broadcast_i32(%arg0: tensor<2x3xi32>, %arg1: tensor<3xi32>) -> te // CHECK-DAG: [[ZEROS3:%.+]] = xla_hlo.constant dense<1> // CHECK-DAG: [[SUB:%.+]] = xla_hlo.subtract [[ABS2]], [[ZEROS3]] // CHECK-DAG: [[ADD:%.+]] = "xla_hlo.add"([[ABS1]], [[SUB]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} - // CHECK-DAG: [[NEG:%.+]] = "xla_hlo.neg"([[ADD]]) + // CHECK-DAG: [[NEG:%.+]] = "xla_hlo.negate"([[ADD]]) // CHECK-DAG: [[ABS3:%.+]] = "xla_hlo.abs"(%arg1) // CHECK-DAG: [[DIV2:%.+]] = "xla_hlo.divide"([[NEG]], [[ABS3]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} // CHECK-DAG: [[SELECT:%.+]] = "xla_hlo.select"([[CMP3]], [[DIV1]], [[DIV2]]) @@ -690,7 +690,7 @@ func @floordiv_reverse_broadcast_i32(%arg0: tensor<3xi32>, %arg1: tensor<2x3xi32 // CHECK-DAG: [[ZEROS3:%.+]] = xla_hlo.constant dense<1> // CHECK-DAG: [[SUB:%.+]] = xla_hlo.subtract [[ABS2]], [[ZEROS3]] // CHECK-DAG: [[ADD:%.+]] = "xla_hlo.add"([[ABS1]], [[SUB]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} - // CHECK-DAG: [[NEG:%.+]] = "xla_hlo.neg"([[ADD]]) + // CHECK-DAG: [[NEG:%.+]] = "xla_hlo.negate"([[ADD]]) // CHECK-DAG: [[ABS3:%.+]] = "xla_hlo.abs"(%arg1) // CHECK-DAG: [[DIV2:%.+]] = xla_hlo.divide [[NEG]], [[ABS3]] // CHECK-DAG: [[SELECT:%.+]] = "xla_hlo.select"([[CMP3]], [[DIV1]], [[DIV2]]) @@ -1015,7 +1015,7 @@ func @real(%arg0: tensor<3xcomplex>) -> tensor<3xf32> { func @conj(%arg0: tensor<3xcomplex>) -> tensor<3xcomplex> { // CHECK-DAG: [[R1:%.*]] = "xla_hlo.real"(%arg0) // CHECK-DAG: [[R2:%.*]] = "xla_hlo.imag"(%arg0) - // CHECK-DAG: [[R3:%.*]] = "xla_hlo.neg"([[R2]]) + // CHECK-DAG: [[R3:%.*]] = "xla_hlo.negate"([[R2]]) // CHECK: [[R4:%.*]] = "xla_hlo.complex"([[R1]], [[R3]]) %1 = "tf.Conj"(%arg0) : (tensor<3xcomplex>) -> tensor<3xcomplex> return %1 : tensor<3xcomplex> @@ -1247,7 +1247,7 @@ func @matrix_band_part(%arg0: tensor<64x64xbf16>, %arg1: tensor, %arg2: ten // CHECK: %[[D:.*]] = "xla_hlo.select"(%[[C]], %[[N]], %[[UPPER]]) : (tensor, tensor, tensor) -> tensor // CHECK: %[[E:.*]] = "xla_hlo.convert"(%[[B]]) : (tensor) -> tensor - // CHECK: %[[F:.*]] = "xla_hlo.neg"(%[[E]]) : (tensor) -> tensor + // CHECK: %[[F:.*]] = "xla_hlo.negate"(%[[E]]) : (tensor) -> tensor // CHECK: %[[X:.*]] = "xla_hlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<64x64xbf16> // CHECK: %[[Y:.*]] = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<64x64xbf16> @@ -1651,7 +1651,7 @@ func @simple_softmax(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { // CHECK: %[[CASTED_MAX:.*]] = "xla_hlo.convert"(%[[MAX]]) : (tensor<2xf32>) -> tensor<2xf32> // CHECK: %[[SHIFTED_INP:.*]] = "xla_hlo.subtract"(%[[ARG0]], %[[CASTED_MAX]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} - // CHECK: %[[EXP:.*]] = "xla_hlo.exp"(%[[SHIFTED_INP]]) + // CHECK: %[[EXP:.*]] = "xla_hlo.exponential"(%[[SHIFTED_INP]]) // Verify reduce op for summation and its body. // CHECK-DAG: %[[CASTED_EXP:.*]] = "xla_hlo.convert"(%[[EXP]]) : (tensor<2x3xf32>) -> tensor<2x3xf32> @@ -1721,7 +1721,7 @@ func @simple_logsoftmax(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { // CHECK: %[[CASTED_MAX:.*]] = "xla_hlo.convert"(%[[MAX]]) : (tensor<2xf32>) -> tensor<2xf32> // CHECK: %[[SHIFTED_INP:.*]] = "xla_hlo.subtract"(%[[ARG0]], %[[CASTED_MAX]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} - // CHECK: %[[EXP:.*]] = "xla_hlo.exp"(%[[SHIFTED_INP]]) + // CHECK: %[[EXP:.*]] = "xla_hlo.exponential"(%[[SHIFTED_INP]]) // Verify reduce op for summation and its body. // CHECK-DAG: %[[CASTED_EXP:.*]] = "xla_hlo.convert"(%[[EXP]]) : (tensor<2x3xf32>) -> tensor<2x3xf32> @@ -1881,42 +1881,42 @@ func @complex_abs(%arg0: tensor<2xcomplex>) -> tensor<2xf32> { // CHECK-LABEL: @cos func @cos(%arg0: tensor<2xf32>) -> tensor<2xf32> { - // CHECK: "xla_hlo.cos"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + // CHECK: "xla_hlo.cosine"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> %0 = "tf.Cos"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> return %0 : tensor<2xf32> } // CHECK-LABEL: func @cos_dynamic func @cos_dynamic(%arg0: tensor) -> tensor { - // CHECK: "xla_hlo.cos"(%arg0) : (tensor) -> tensor + // CHECK: "xla_hlo.cosine"(%arg0) : (tensor) -> tensor %0 = "tf.Cos"(%arg0) : (tensor) -> tensor return %0 : tensor } // CHECK-LABEL: func @cos_unranked func @cos_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { - // CHECK: "xla_hlo.cos"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> + // CHECK: "xla_hlo.cosine"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> %0 = "tf.Cos"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } // CHECK-LABEL: @exp func @exp(%arg0: tensor<2xf32>) -> tensor<2xf32> { - // CHECK: "xla_hlo.exp"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + // CHECK: "xla_hlo.exponential"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> %0 = "tf.Exp"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> return %0 : tensor<2xf32> } // CHECK-LABEL: func @exp_dynamic func @exp_dynamic(%arg0: tensor) -> tensor { - // CHECK: "xla_hlo.exp"(%arg0) : (tensor) -> tensor + // CHECK: "xla_hlo.exponential"(%arg0) : (tensor) -> tensor %0 = "tf.Exp"(%arg0) : (tensor) -> tensor return %0 : tensor } // CHECK-LABEL: func @exp_unranked func @exp_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { - // CHECK: "xla_hlo.exp"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> + // CHECK: "xla_hlo.exponential"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> %0 = "tf.Exp"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } @@ -2014,21 +2014,21 @@ func @not_op_unranked(%arg0: tensor<*xi1>) -> tensor<*xi1> { // CHECK-LABEL: @neg func @neg(%arg0: tensor<2xf32>) -> tensor<2xf32> { - // CHECK: "xla_hlo.neg"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + // CHECK: "xla_hlo.negate"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> %0 = "tf.Neg"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> return %0 : tensor<2xf32> } // CHECK-LABEL: func @neg_dynamic func @neg_dynamic(%arg0: tensor) -> tensor { - // CHECK: "xla_hlo.neg"(%arg0) : (tensor) -> tensor + // CHECK: "xla_hlo.negate"(%arg0) : (tensor) -> tensor %0 = "tf.Neg"(%arg0) : (tensor) -> tensor return %0 : tensor } // CHECK-LABEL: func @neg_unranked func @neg_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { - // CHECK: "xla_hlo.neg"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> + // CHECK: "xla_hlo.negate"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> %0 = "tf.Neg"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } @@ -2047,21 +2047,21 @@ func @sigmoid(%arg0: tensor<2xf32>) -> tensor<2xf32> { // CHECK-LABEL: @sin func @sin(%arg0: tensor<2xf32>) -> tensor<2xf32> { - // CHECK: "xla_hlo.sin"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + // CHECK: "xla_hlo.sine"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> %0 = "tf.Sin"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> return %0 : tensor<2xf32> } // CHECK-LABEL: func @sin_dynamic func @sin_dynamic(%arg0: tensor) -> tensor { - // CHECK: "xla_hlo.sin"(%arg0) : (tensor) -> tensor + // CHECK: "xla_hlo.sine"(%arg0) : (tensor) -> tensor %0 = "tf.Sin"(%arg0) : (tensor) -> tensor return %0 : tensor } // CHECK-LABEL: func @sin_unranked func @sin_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { - // CHECK: "xla_hlo.sin"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> + // CHECK: "xla_hlo.sine"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> %0 = "tf.Sin"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } @@ -2875,7 +2875,7 @@ func @linspace_invalid_num(%arg0: tensor, %arg1: tensor) -> tensor, %arg1: tensor<3x3x3x16xf32>) -> tensor<256x30x30x16xf32> { - // CHECK: "xla_hlo.conv"(%arg0, %arg1) + // CHECK: "xla_hlo.convolution"(%arg0, %arg1) // Default attributes // CHECK-NOT: lhs_dilation @@ -2906,7 +2906,7 @@ func @conv_simple(%arg0: tensor<256x32x32x6xf32>, %arg1: tensor<3x3x3x16xf32>) - // CHECK-LABEL: depthwiseconv_simple func @depthwiseconv_simple(%arg0: tensor<2x4x5x3xf32>, %arg1: tensor<2x2x3x3xf32>) -> tensor<2x3x4x9xf32> { // CHECK: %[[RESHAPED_FILTER:.*]] = "xla_hlo.reshape"(%arg1) : (tensor<2x2x3x3xf32>) -> tensor<2x2x1x9xf32> - // CHECK: "xla_hlo.conv"(%arg0, %[[RESHAPED_FILTER]]) + // CHECK: "xla_hlo.convolution"(%arg0, %[[RESHAPED_FILTER]]) // CHECK: feature_group_count = 3 %0 = "tf.DepthwiseConv2dNative"(%arg0, %arg1) { data_format = "NHWC", @@ -2921,7 +2921,7 @@ func @depthwiseconv_simple(%arg0: tensor<2x4x5x3xf32>, %arg1: tensor<2x2x3x3xf32 // CHECK-LABEL: conv_valid_padding func @conv_valid_padding(%arg0: tensor<1x4x5x1xf32>, %arg1: tensor<3x3x1x1xf32>) -> tensor<1x2x3x1xf32> { - // CHECK: "xla_hlo.conv"(%arg0, %arg1) + // CHECK: "xla_hlo.convolution"(%arg0, %arg1) %0 = "tf.Conv2D"(%arg0, %arg1) {data_format = "NHWC", dilations = [1, 1, 1, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x4x5x1xf32>, tensor<3x3x1x1xf32>) -> tensor<1x2x3x1xf32> return %0 : tensor<1x2x3x1xf32> @@ -2930,7 +2930,7 @@ func @conv_valid_padding(%arg0: tensor<1x4x5x1xf32>, %arg1: tensor<3x3x1x1xf32>) // CHECK-LABEL: conv_explicit_paddings func @conv_explicit_paddings(%arg0: tensor<256x32x32x6xf32>, %arg1: tensor<3x3x3x16xf32>) -> tensor<256x32x32x16xf32> { - // CHECK: "xla_hlo.conv"(%arg0, %arg1) + // CHECK: "xla_hlo.convolution"(%arg0, %arg1) // CHECK-SAME: padding = dense<{{\[\[}}6, 0], [3, 3]]> %0 = "tf.Conv2D"(%arg0, %arg1) {data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "EXPLICIT", explicit_paddings = [0, 0, 6, 0, 3, 3, 0, 0], strides = [1, 4, 5, 1]} : (tensor<256x32x32x6xf32>, tensor<3x3x3x16xf32>) -> tensor<256x32x32x16xf32> @@ -2943,7 +2943,7 @@ func @conv2d_backprop_input( %out_backprop: tensor<100x26x26x32xf32> ) -> tensor<100x28x28x1xf32> { // CHECK: %[[REV_FILTER:.*]] = "xla_hlo.reverse"(%arg0) {dimensions = dense<[0, 1]> : tensor<2xi64>} - // CHECK: %[[RESULT:.*]] = "xla_hlo.conv"(%arg1, %[[REV_FILTER]]) { + // CHECK: %[[RESULT:.*]] = "xla_hlo.convolution"(%arg1, %[[REV_FILTER]]) { // CHECK-SAME: batch_group_count = 1 : i64, // CHECK-SAME: dimension_numbers = { // CHECK-SAME: input_batch_dimension = 0 : i64, @@ -2979,7 +2979,7 @@ func @conv2d_backprop_filter( %input: tensor<100x28x28x1xf32>, %out_backprop: tensor<100x26x26x32xf32> ) -> tensor<100x28x28x1xf32> { - // CHECK: %[[RESULT:.*]] = "xla_hlo.conv"(%arg0, %arg1) { + // CHECK: %[[RESULT:.*]] = "xla_hlo.convolution"(%arg0, %arg1) { // CHECK-SAME: batch_group_count = 1 : i64, // CHECK-SAME: dimension_numbers = { // CHECK-SAME: input_batch_dimension = 3 : i64, @@ -3802,11 +3802,11 @@ func @batchmatmulv2_adj_real(%arg0: tensor<5x2xf32>, %arg1: tensor<2x4xf32>) -> func @batchmatmulv2_adj_complex(%arg0: tensor<5x2xcomplex>, %arg1: tensor<2x4xcomplex>) -> tensor<5x4xcomplex> { // CHECK: [[LHSRE:%.+]] = "xla_hlo.real"(%arg0) : (tensor<5x2xcomplex>) -> tensor<5x2xf32> // CHECK: [[LHSIM:%.+]] = "xla_hlo.imag"(%arg0) : (tensor<5x2xcomplex>) -> tensor<5x2xf32> - // CHECK: [[LHSIMNEG:%.+]] = "xla_hlo.neg"([[LHSIM]]) : (tensor<5x2xf32>) -> tensor<5x2xf32> + // CHECK: [[LHSIMNEG:%.+]] = "xla_hlo.negate"([[LHSIM]]) : (tensor<5x2xf32>) -> tensor<5x2xf32> // CHECK: [[LHSCONJ:%.+]] = "xla_hlo.complex"([[LHSRE]], [[LHSIMNEG]]) : (tensor<5x2xf32>, tensor<5x2xf32>) -> tensor<5x2xcomplex> // CHECK: [[RHSRE:%.+]] = "xla_hlo.real"(%arg1) : (tensor<2x4xcomplex>) -> tensor<2x4xf32> // CHECK: [[RHSIM:%.+]] = "xla_hlo.imag"(%arg1) : (tensor<2x4xcomplex>) -> tensor<2x4xf32> - // CHECK: [[RHSIMNEG:%.+]] = "xla_hlo.neg"([[RHSIM]]) : (tensor<2x4xf32>) -> tensor<2x4xf32> + // CHECK: [[RHSIMNEG:%.+]] = "xla_hlo.negate"([[RHSIM]]) : (tensor<2x4xf32>) -> tensor<2x4xf32> // CHECK: [[RHSCONJ:%.+]] = "xla_hlo.complex"([[RHSRE]], [[RHSIMNEG]]) : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xcomplex> // CHECK: [[BLHS:%.+]] = "xla_hlo.dynamic_broadcast_in_dim"([[LHSCONJ]], {{.*}}) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<5x2xcomplex>, {{.*}}) -> tensor<5x2xcomplex> // CHECK: [[BRHS:%.+]] = "xla_hlo.dynamic_broadcast_in_dim"([[RHSCONJ]], {{.*}}) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<2x4xcomplex>, {{.*}}) -> tensor<2x4xcomplex> diff --git a/tensorflow/compiler/mlir/xla/tests/lhlo-copy-removal.mlir b/tensorflow/compiler/mlir/xla/tests/lhlo-copy-removal.mlir index 35546594ccb..fab1389262d 100644 --- a/tensorflow/compiler/mlir/xla/tests/lhlo-copy-removal.mlir +++ b/tensorflow/compiler/mlir/xla/tests/lhlo-copy-removal.mlir @@ -24,8 +24,8 @@ func @remove_without_dealloc(%arg0: memref<2x2xf32>) { // CHECK-LABEL: func @replace_dependency func @replace_dependency(%arg0: memref<2x2xf32>, %arg1: memref<2x2xf32>) { %0 = alloc() {temp = true} : memref<2x2xf32> - "xla_lhlo.exp"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> () - // CHECK-NEXT: "xla_lhlo.exp"(%arg0, %arg1) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "xla_lhlo.exponential"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> () + // CHECK-NEXT: "xla_lhlo.exponential"(%arg0, %arg1) : (memref<2x2xf32>, memref<2x2xf32>) -> () "xla_lhlo.copy"(%0, %arg1) : (memref<2x2xf32>, memref<2x2xf32>) -> () dealloc %0 : memref<2x2xf32> // CHECK-NEXT: "xla_lhlo.terminator"() : () -> () @@ -50,10 +50,10 @@ func @must_not_be_removed(%arg0: memref<2x2xf32>, %arg2: memref<2x2xf32>) { // CHECK-NEXT: %[[ALLOC:.*]] = alloc() {temp = true} : memref<2x2xf32> %0 = alloc() {temp = true} : memref<2x2xf32> - // CHECK-NEXT: "xla_lhlo.exp"(%arg0, %[[ALLOC]]) : (memref<2x2xf32>, memref<2x2xf32>) -> () - "xla_lhlo.exp"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> () - // CHECK-NEXT: "xla_lhlo.exp"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () - "xla_lhlo.exp"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () + // CHECK-NEXT: "xla_lhlo.exponential"(%arg0, %[[ALLOC]]) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "xla_lhlo.exponential"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> () + // CHECK-NEXT: "xla_lhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "xla_lhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () // CHECK-NEXT: "xla_lhlo.copy"(%[[ALLOC]], %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () "xla_lhlo.copy"(%0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () dealloc %0 : memref<2x2xf32> @@ -67,10 +67,10 @@ func @must_be_removed_first(%arg0: memref<2x2xf32>, %arg1: memref<2x2xf32>, %arg2: memref<2x2xf32>) { %0 = alloc() {temp = true} : memref<2x2xf32> - // CHECK-NEXT: "xla_lhlo.exp"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () - "xla_lhlo.exp"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () - // CHECK-NEXT: "xla_lhlo.exp"(%arg0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () - "xla_lhlo.exp"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> () + // CHECK-NEXT: "xla_lhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "xla_lhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () + // CHECK-NEXT: "xla_lhlo.exponential"(%arg0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "xla_lhlo.exponential"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> () "xla_lhlo.copy"(%0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () dealloc %0 : memref<2x2xf32> "xla_lhlo.terminator"() : () -> () @@ -83,11 +83,11 @@ func @must_be_removed_second(%arg0: memref<2x2xf32>, %arg1: memref<2x2xf32>, %arg2: memref<2x2xf32>) { %0 = alloc() {temp = true} : memref<2x2xf32> - // CHECK-NEXT: "xla_lhlo.exp"(%arg0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () - "xla_lhlo.exp"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> () + // CHECK-NEXT: "xla_lhlo.exponential"(%arg0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "xla_lhlo.exponential"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> () "xla_lhlo.copy"(%0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () - // CHECK-NEXT: "xla_lhlo.exp"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () - "xla_lhlo.exp"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () + // CHECK-NEXT: "xla_lhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "xla_lhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () dealloc %0 : memref<2x2xf32> "xla_lhlo.terminator"() : () -> () } diff --git a/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-linalg.mlir b/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-linalg.mlir index 5e4a7fd719f..10e5baa53c8 100644 --- a/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-linalg.mlir +++ b/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-linalg.mlir @@ -91,7 +91,7 @@ func @and(%lhs: memref<2x2xi32>, %rhs: memref<2x2xi32>, // CHECK-LABEL: func @exp func @exp(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { - "xla_lhlo.exp"(%input, %result) + "xla_lhlo.exponential"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () return } @@ -376,7 +376,7 @@ func @convert_f32_to_f32(%input: memref<2x2xf32>, // CHECK-LABEL: func @cos func @cos(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { - "xla_lhlo.cos"(%input, %result) + "xla_lhlo.cosine"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () return } @@ -390,7 +390,7 @@ func @cos(%input: memref<2x2xf32>, // CHECK-LABEL: func @neg func @neg(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { - "xla_lhlo.neg"(%input, %result) + "xla_lhlo.negate"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () return } diff --git a/tensorflow/compiler/mlir/xla/tests/lhlo_ops.mlir b/tensorflow/compiler/mlir/xla/tests/lhlo_ops.mlir index 04d9d23fe8b..7e831eadc2f 100644 --- a/tensorflow/compiler/mlir/xla/tests/lhlo_ops.mlir +++ b/tensorflow/compiler/mlir/xla/tests/lhlo_ops.mlir @@ -34,7 +34,7 @@ func @convert_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () { // CHECK-LABEL: func @exp_memref func @exp_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () { - "xla_lhlo.exp"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> () + "xla_lhlo.exponential"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> () return } @@ -50,7 +50,7 @@ func @log_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () { // CHECK-LABEL: func @neg_memref func @neg_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () { - "xla_lhlo.neg"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> () + "xla_lhlo.negate"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> () return } diff --git a/tensorflow/compiler/mlir/xla/tests/lower-complex.mlir b/tensorflow/compiler/mlir/xla/tests/lower-complex.mlir index 915771923d0..74d175109d3 100644 --- a/tensorflow/compiler/mlir/xla/tests/lower-complex.mlir +++ b/tensorflow/compiler/mlir/xla/tests/lower-complex.mlir @@ -152,7 +152,7 @@ func @div(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, % %2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex>) %3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex>) - // CHECK-DAG: [[VAL0:%.+]] = "xla_hlo.neg"(%arg3) + // CHECK-DAG: [[VAL0:%.+]] = "xla_hlo.negate"(%arg3) // Compute the numerator's real component: // numerator.real = lhs.real * rhs.real lhs.imag * rhs.imag @@ -191,7 +191,7 @@ func @div_broadcast(%arg0 : tensor<1x2xf32>, %arg1 : tensor<1x2xf32>, %arg2 : te %2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<1x2xf32>, tensor<1x2xf32>) -> (tensor<1x2xcomplex>) %3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex>) - // CHECK-DAG: [[VAL0:%.+]] = "xla_hlo.neg"(%arg3) + // CHECK-DAG: [[VAL0:%.+]] = "xla_hlo.negate"(%arg3) // Compute the numerator's real component: // numerator.real = lhs.real * rhs.real lhs.imag * rhs.imag @@ -230,7 +230,7 @@ func @div_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor< %2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex>) %3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex>) - // CHECK-DAG: [[VAL0:%.+]] = "xla_hlo.neg"(%arg3) + // CHECK-DAG: [[VAL0:%.+]] = "xla_hlo.negate"(%arg3) // Compute the numerator's real component: // numerator.real = lhs.real * rhs.real lhs.imag * rhs.imag @@ -281,12 +281,12 @@ func @abs(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>) -> (tensor<2xf32>) { func @exp(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) { %0 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex>) - // CHECK-DAG: [[VAL0:%.+]] = "xla_hlo.exp"(%arg0) - // CHECK-DAG: [[VAL1:%.+]] = "xla_hlo.cos"(%arg1) - // CHECK-DAG: [[VAL2:%.+]] = "xla_hlo.sin"(%arg1) + // CHECK-DAG: [[VAL0:%.+]] = "xla_hlo.exponential"(%arg0) + // CHECK-DAG: [[VAL1:%.+]] = "xla_hlo.cosine"(%arg1) + // CHECK-DAG: [[VAL2:%.+]] = "xla_hlo.sine"(%arg1) // CHECK-DAG: [[VAL3:%.+]] = xla_hlo.multiply [[VAL0]], [[VAL1]] // CHECK-DAG: [[VAL4:%.+]] = xla_hlo.multiply [[VAL0]], [[VAL2]] - %1 = "xla_hlo.exp"(%0) : (tensor<2xcomplex>) -> (tensor<2xcomplex>) + %1 = "xla_hlo.exponential"(%0) : (tensor<2xcomplex>) -> (tensor<2xcomplex>) %2 = "xla_hlo.real"(%1) : (tensor<2xcomplex>) -> (tensor<2xf32>) %3 = "xla_hlo.imag"(%1) : (tensor<2xcomplex>) -> (tensor<2xf32>) @@ -298,12 +298,12 @@ func @exp(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>) -> (tensor<2xf32>, tenso func @exp_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) { %0 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex>) - // CHECK-DAG: [[VAL0:%.+]] = "xla_hlo.exp"(%arg0) - // CHECK-DAG: [[VAL1:%.+]] = "xla_hlo.cos"(%arg1) - // CHECK-DAG: [[VAL2:%.+]] = "xla_hlo.sin"(%arg1) + // CHECK-DAG: [[VAL0:%.+]] = "xla_hlo.exponential"(%arg0) + // CHECK-DAG: [[VAL1:%.+]] = "xla_hlo.cosine"(%arg1) + // CHECK-DAG: [[VAL2:%.+]] = "xla_hlo.sine"(%arg1) // CHECK-DAG: [[VAL3:%.+]] = xla_hlo.multiply [[VAL0]], [[VAL1]] // CHECK-DAG: [[VAL4:%.+]] = xla_hlo.multiply [[VAL0]], [[VAL2]] - %1 = "xla_hlo.exp"(%0) : (tensor<*xcomplex>) -> (tensor<*xcomplex>) + %1 = "xla_hlo.exponential"(%0) : (tensor<*xcomplex>) -> (tensor<*xcomplex>) %2 = "xla_hlo.real"(%1) : (tensor<*xcomplex>) -> (tensor<*xf32>) %3 = "xla_hlo.imag"(%1) : (tensor<*xcomplex>) -> (tensor<*xf32>) diff --git a/tensorflow/compiler/mlir/xla/tests/materialize-broadcasts.mlir b/tensorflow/compiler/mlir/xla/tests/materialize-broadcasts.mlir index 91eb7493648..07ff6d17091 100644 --- a/tensorflow/compiler/mlir/xla/tests/materialize-broadcasts.mlir +++ b/tensorflow/compiler/mlir/xla/tests/materialize-broadcasts.mlir @@ -132,8 +132,8 @@ func @mulBroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x func @powBroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> { // CHECK-NEXT: %[[BROADCAST0:.*]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x4xf32>) -> tensor<1x4xf32> // CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<1x4xf32> - // CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.pow %[[BROADCAST0]], %[[BROADCAST1]] : tensor<1x4xf32> - %0 = "xla_hlo.pow"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32> + // CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.power %[[BROADCAST0]], %[[BROADCAST1]] : tensor<1x4xf32> + %0 = "xla_hlo.power"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32> return %0 : tensor<1x4xf32> } diff --git a/tensorflow/compiler/mlir/xla/tests/translate/conditional.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/conditional.hlotxt index 722469eff1a..00f6ec2d308 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/conditional.hlotxt +++ b/tensorflow/compiler/mlir/xla/tests/translate/conditional.hlotxt @@ -38,7 +38,7 @@ ENTRY %tfcompile.20 { // CHECK: }, { // CHECK: ^bb0([[A1:%.+]]: tuple>): // CHECK: [[R7:%.+]] = "xla_hlo.get_tuple_element"([[A1]]) - // CHECK: [[R8:%.+]] = "xla_hlo.exp"([[R7]]) + // CHECK: [[R8:%.+]] = "xla_hlo.exponential"([[R7]]) // CHECK: [[R9:%.+]] = "xla_hlo.tuple"([[R8]]) // CHECK: "xla_hlo.return"([[R9]]) // CHECK: }) diff --git a/tensorflow/compiler/mlir/xla/tests/translate/conditional.mlir b/tensorflow/compiler/mlir/xla/tests/translate/conditional.mlir index e69d677a8cc..e510a2aa35f 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/conditional.mlir +++ b/tensorflow/compiler/mlir/xla/tests/translate/conditional.mlir @@ -21,7 +21,7 @@ func @else_branch(%arg0: tuple>) -> tuple> { %0 = "xla_hlo.get_tuple_element"(%arg0) {index = 0 : i32} : (tuple>) -> tensor // CHECK: %[[VAL1:.+]] = f32[] exponential(f32[] %[[VAL0]]) - %1 = "xla_hlo.exp"(%0) : (tensor) -> tensor + %1 = "xla_hlo.exponential"(%0) : (tensor) -> tensor // CHECK: ROOT %[[VAL2:.+]] = (f32[]) tuple(f32[] %[[VAL1]]) %2 = "xla_hlo.tuple"(%1) : (tensor) -> tuple> @@ -50,7 +50,7 @@ func @main(%arg0: tensor) -> tuple> { }, { ^bb0(%arg1: tuple>): %6 = "xla_hlo.get_tuple_element"(%arg1) {index = 0 : i32} : (tuple>) -> tensor - %7 = "xla_hlo.exp"(%6) : (tensor) -> tensor + %7 = "xla_hlo.exponential"(%6) : (tensor) -> tensor %8 = "xla_hlo.tuple"(%7) : (tensor) -> tuple> "xla_hlo.return"(%8) : (tuple>) -> () }) : (tensor, tuple>, tuple>) -> tuple> diff --git a/tensorflow/compiler/mlir/xla/tests/translate/export.mlir b/tensorflow/compiler/mlir/xla/tests/translate/export.mlir index 2436ef32c07..8953516c5fc 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/export.mlir +++ b/tensorflow/compiler/mlir/xla/tests/translate/export.mlir @@ -295,7 +295,7 @@ func @main() -> tensor<2x2x1x1xf32> { // CHECK: HloModule func @main(%arg0 : tensor<100x26x26x32xf32>, %arg1 : tensor<3x3x1x32xf32>) -> tensor<100x28x28x1xf32> { - %result = "xla_hlo.conv"(%arg0, %arg1) { + %result = "xla_hlo.convolution"(%arg0, %arg1) { batch_group_count = 1 : i64, dimension_numbers = { input_batch_dimension = 0 : i64, @@ -985,3 +985,19 @@ func @main(%arg0: tensor<16x16xf32>, %arg1: tensor<16x16xf32> {tf_device.is_same // CHECK-NOT: parameter_replication={true} // CHECK: %[[ARG1:.*]] = f32[16,16] parameter(1), parameter_replication={true} // CHECK: ROOT %[[RESULT:.*]] = f32[16,16] add(f32[16,16] %[[ARG0]], f32[16,16] %[[ARG1]]) + +// ----- + +// CHECK: HloModule +func @main(%arg0: tensor<2xcomplex>, %arg1: tensor<2xcomplex>) -> (tensor<2xf32>, tensor<2xf64>) { + %0 = "xla_hlo.abs"(%arg0) : (tensor<2xcomplex>) -> (tensor<2xf32>) + %1 = "xla_hlo.abs"(%arg1) : (tensor<2xcomplex>) -> (tensor<2xf64>) + return %0, %1 : tensor<2xf32>, tensor<2xf64> +} + +// CHECK: ENTRY +// CHECK: %[[ARG0:.*]] = c64[2] parameter(0) +// CHECK: %[[ABS0:.*]] = f32[2] abs(c64[2] %[[ARG0]]) +// CHECK: %[[ARG1:.*]] = c128[2] parameter(1) +// CHECK: %[[ABS1:.*]] = f64[2] abs(c128[2] %[[ARG1]]) +// CHECK: ROOT %[[RESULT:.*]] = (f32[2], f64[2]) tuple(f32[2] %[[ABS0]], f64[2] %[[ABS1]]) diff --git a/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt index c8ef751b450..89a34dfa68a 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt +++ b/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt @@ -224,7 +224,7 @@ add { // CHECK-NEXT: %cst = constant {name = "{{.*}}"} dense<{{\[\[\[\[}}5.000000e-01]], {{\[\[}}-6.000000e-01]]], {{\[\[\[}}3.000000e-01]], {{\[\[}}-1.000000e-01]]]]> : tensor<2x2x1x1xf32> %constant.3 = f32[2,2,1,1]{3,2,1,0} constant({{{{0.5}}, {{-0.6}}}, {{{0.3}}, {{-0.1}}}}), metadata={op_type="Conv2D" op_name="embedded_inference/conv_model/conv_0/Conv2D"} - // CHECK-NEXT: %2 = "xla_hlo.conv"(%1, %cst) { + // CHECK-NEXT: %2 = "xla_hlo.convolution"(%1, %cst) { // CHECK-SAME: batch_group_count = 1 : i64 // CHECK-SAME: dimension_numbers = { // CHECK-SAME: input_batch_dimension = 0 : i64 @@ -260,7 +260,7 @@ add { %test_convolve1D_padding (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,5,1] { %input = f32[1,2,1] parameter(0) %filter = f32[1,1,1] parameter(1) - // CHECK: "xla_hlo.conv" + // CHECK: "xla_hlo.convolution" // CHECK-SAME: padding = dense<{{\[\[}}1, 2]]> : tensor<1x2xi64> ROOT %convolution = f32[1,5,1] convolution(f32[1,2,1] %input, f32[1,1,1] %filter), feature_group_count=1, dim_labels=b0f_0io->b0f, window={pad=1_2 size=1} } @@ -284,7 +284,7 @@ add { %test_cosine (arg0.1: f32[1,16,16,3]) -> f32[1,16,16,3] { %arg0.1 = f32[1,16,16,3]{3,2,1,0} parameter(0), metadata={op_name="HLO_Args"} - // CHECK-NEXT: "xla_hlo.cos"(%arg0) {name = "{{.*}}"} : (tensor<1x16x16x3xf32>) -> tensor<1x16x16x3xf32> + // CHECK-NEXT: "xla_hlo.cosine"(%arg0) {name = "{{.*}}"} : (tensor<1x16x16x3xf32>) -> tensor<1x16x16x3xf32> ROOT %cosine.3 = f32[1,16,16,3]{3,2,1,0} cosine(f32[1,16,16,3]{3,2,1,0} %arg0.1) } @@ -381,7 +381,7 @@ add { %test_exponential (arg0.1: f32[16]) -> f32[16] { %arg0.1 = f32[16] parameter(0) - // CHECK-NEXT: "xla_hlo.exp"(%arg0) {name = "{{.*}}"} : (tensor<16xf32>) -> tensor<16xf32> + // CHECK-NEXT: "xla_hlo.exponential"(%arg0) {name = "{{.*}}"} : (tensor<16xf32>) -> tensor<16xf32> ROOT %exp.2 = f32[16] exponential(f32[16] %arg0.1) } @@ -538,7 +538,7 @@ add { %test_negate (arg0.1: f32[16]) -> f32[16] { %arg0.1 = f32[16] parameter(0) - // CHECK-NEXT: "xla_hlo.neg"(%arg0) {name = "{{.*}}"} : (tensor<16xf32>) -> tensor<16xf32> + // CHECK-NEXT: "xla_hlo.negate"(%arg0) {name = "{{.*}}"} : (tensor<16xf32>) -> tensor<16xf32> ROOT %negate.2 = f32[16] negate(f32[16] %arg0.1) } @@ -609,7 +609,7 @@ add { %Arg_0.1 = f32[4] parameter(0) %Arg_1.2 = f32[4] parameter(1) - // CHECK-NEXT: xla_hlo.pow %arg0, %arg1 {name = "{{.*}}"} : tensor<4xf32> + // CHECK-NEXT: xla_hlo.power %arg0, %arg1 {name = "{{.*}}"} : tensor<4xf32> ROOT %power.3 = f32[4] power(f32[4] %Arg_0.1, f32[4] %Arg_1.2) } @@ -848,7 +848,7 @@ add { %test_sine (arg0.1: f32[1,16,16,3]) -> f32[1,16,16,3] { %arg0.1 = f32[1,16,16,3]{3,2,1,0} parameter(0), metadata={op_name="HLO_Args"} - // CHECK-NEXT: "xla_hlo.sin"(%arg0) {name = "{{.*}}"} : (tensor<1x16x16x3xf32>) -> tensor<1x16x16x3xf32> + // CHECK-NEXT: "xla_hlo.sine"(%arg0) {name = "{{.*}}"} : (tensor<1x16x16x3xf32>) -> tensor<1x16x16x3xf32> ROOT %sine.3 = f32[1,16,16,3]{3,2,1,0} sine(f32[1,16,16,3]{3,2,1,0} %arg0.1) } @@ -988,3 +988,16 @@ add { // CHECK: xla_hlo.shift_right_logical [[VAL_0]], [[VAL_1]] ROOT %shiftright.logical = s32[4] shift-right-logical(s32[4] %Arg_0.1, s32[4] %Arg_1.2) } + +// CHECK-LABEL: func @complex_type +// CHECK-SAME: (%[[ARG0:.*]]: tensor<2xcomplex>, %[[ARG1:.*]]: tensor<2xcomplex>) -> tuple, tensor<2xf64>> +%complex_type (Arg_0.1: c64[2], Arg_1.2: c128[2]) -> (f32[2], f64[2]) { + %Arg_0.1 = c64[2] parameter(0) + %abs.3 = f32[2] abs(c64[2] %Arg_0.1) + %Arg_1.2 = c128[2] parameter(1) + %abs.4 = f64[2] abs(c128[2] %Arg_1.2) + + // CHECK: "xla_hlo.abs"(%[[ARG0]]) {name = "{{.*}}"} : (tensor<2xcomplex>) -> tensor<2xf32> + // CHECK: "xla_hlo.abs"(%[[ARG1]]) {name = "{{.*}}"} : (tensor<2xcomplex>) -> tensor<2xf64> + ROOT %tuple.5 = (f32[2], f64[2]) tuple(f32[2] %abs.3, f64[2] %abs.4) +} diff --git a/tensorflow/compiler/mlir/xla/transforms/buffer_assignment.cc b/tensorflow/compiler/mlir/xla/transforms/buffer_assignment.cc index 3b40f4c8326..540c9ab486d 100644 --- a/tensorflow/compiler/mlir/xla/transforms/buffer_assignment.cc +++ b/tensorflow/compiler/mlir/xla/transforms/buffer_assignment.cc @@ -343,7 +343,8 @@ class BufferAssignmentAnalysis { /// the right positions. It uses the algorithm described at the top of the file. // TODO(dfki): create a templated version that allows to match dialect-specific // alloc/dealloc nodes and to insert dialect-specific dealloc node. -struct BufferAssignmentPass : mlir::FunctionPass { +struct BufferAssignmentPass + : mlir::PassWrapper { void runOnFunction() override { // Get required analysis information first. auto& analysis = getAnalysis(); @@ -471,7 +472,7 @@ void FunctionAndBlockSignatureConverter::addDynamicallyLegalFuncOp( // Buffer assignment pass registrations //===----------------------------------------------------------------------===// -std::unique_ptr> createBufferAssignmentPass() { +std::unique_ptr> createBufferAssignmentPass() { return absl::make_unique(); } @@ -482,14 +483,15 @@ static PassRegistration buffer_assignment_pass( /// A simple pass to print debug/test information for the buffer assignment /// analysis. -struct BufferAssignmentTestPass : mlir::FunctionPass { +struct BufferAssignmentTestPass + : mlir::PassWrapper { void runOnFunction() override { llvm::outs() << "Testing : " << getFunction().getName() << "\n"; getAnalysis().print(llvm::outs()); }; }; -std::unique_ptr> createBufferAssignmentTestPass() { +std::unique_ptr> createBufferAssignmentTestPass() { return absl::make_unique(); } diff --git a/tensorflow/compiler/mlir/xla/transforms/canonicalize.td b/tensorflow/compiler/mlir/xla/transforms/canonicalize.td index df9be382f11..65f81aae9f2 100644 --- a/tensorflow/compiler/mlir/xla/transforms/canonicalize.td +++ b/tensorflow/compiler/mlir/xla/transforms/canonicalize.td @@ -47,3 +47,54 @@ def UnaryEinsumToEinsum : Pat< (HLO_UnaryEinsumOp $operand, $equation), (HLO_EinsumOp (HLO_ConstOp (GetScalarOfType<1> $operand)), $operand, (UnaryToBinaryEinsumEq $equation))>; + +//===----------------------------------------------------------------------===// +// Conv op patterns. +//===----------------------------------------------------------------------===// + +def IsZero : Attr() &&" + "$_self.cast().isSplat() &&" + "$_self.cast().getSplatValue()" + ".getValue().isZero()) ||" + "($_self.isa() &&" + "$_self.cast().isSplat() &&" + "$_self.cast().getSplatValue()" + ".getInt() == 0)">>; + +def IsOnlyPaddingSpatialDims + : Constraint>; + +def BuildConvPaddingAttrs : NativeCodeCall< + "BuildConvPaddingAttrs($0, $1, $2, $3, &$_builder)">; + +def FoldPadIntoConv : Pat< + (HLO_ConvOp + (HLO_PadOp $lhs, + (HLO_ConstOp IsZero:$padding_value), + $edge_padding_low, + $edge_padding_high, + IsZero:$interior_padding), + $rhs, + $window_strides, + $padding, + $lhs_dilation, + $rhs_dilation, + $dimension_numbers, + $feature_group_count, + $batch_group_count, + $precision_config), + (HLO_ConvOp + $lhs, + $rhs, + $window_strides, + (BuildConvPaddingAttrs $edge_padding_low, $edge_padding_high, $padding, + $dimension_numbers), + $lhs_dilation, + $rhs_dilation, + $dimension_numbers, + $feature_group_count, + $batch_group_count, + $precision_config), + [(IsOnlyPaddingSpatialDims $lhs, $dimension_numbers, $edge_padding_low, + $edge_padding_high)]>; diff --git a/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc b/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc index 7215ffef6d3..d3fb832d542 100644 --- a/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc +++ b/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc @@ -324,7 +324,8 @@ class HloToLhloTensorStoreOpConverter : public ConversionPattern { // "xla_lhlo.terminator"() : () -> () // } -struct HloLegalizeToLhlo : public OperationPass { +struct HloLegalizeToLhlo + : public PassWrapper> { void runOnOperation() override { OwningRewritePatternList patterns; auto& context = getContext(); @@ -473,7 +474,7 @@ void populateHLOToLHLOConversionPattern(MLIRContext* context, // clang-format on } -std::unique_ptr> createLegalizeToLhloPass() { +std::unique_ptr> createLegalizeToLhloPass() { return absl::make_unique(); } diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_control_flow.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_control_flow.cc index 3633b32f847..129a24600a2 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_control_flow.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_control_flow.cc @@ -37,7 +37,8 @@ using mlir::PassRegistration; namespace mlir { namespace xla_hlo { namespace { -struct LegalizeControlFlow : public mlir::FunctionPass { +struct LegalizeControlFlow + : public mlir::PassWrapper { // Perform the lowering to MLIR control flow. void runOnFunction() override; }; @@ -227,7 +228,7 @@ void LegalizeControlFlow::runOnFunction() { } // namespace xla_hlo } // namespace mlir -std::unique_ptr> +std::unique_ptr> mlir::xla_hlo::createLegalizeControlFlowPass() { return std::make_unique(); } diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc index aa6ac85b4af..0ccfdaa3d89 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc @@ -55,7 +55,7 @@ namespace mlir { namespace xla_hlo { namespace { -class LegalizeTF : public FunctionPass { +class LegalizeTF : public PassWrapper { public: LegalizeTF() = default; LegalizeTF(const LegalizeTF &) {} @@ -633,37 +633,24 @@ static bool ArgTypesMatchCallee(mlir::Operation *op, OperandRange args, static bool CanBeTranslatedToDynamicSlice(Value input, Value start_indices, DenseIntElementsAttr slice_sizes) { auto input_ty = input.getType().dyn_cast(); + if (!input_ty) return false; + auto start_indices_ty = start_indices.getType().dyn_cast(); + if (!start_indices_ty) return false; + int64_t input_rank = input_ty.getRank(); ArrayRef input_shape = input_ty.getShape(); DenseIntElementsAttr constant_start_indices; - if (!matchPattern(start_indices, m_Constant(&constant_start_indices))) { - for (int64_t i = 0; i < input_rank; ++i) { - int64_t slice_size = slice_sizes.getValue(i).getInt(); - int64_t input_size = input_shape[i]; - if (slice_size < 0 || (input_size != -1 && slice_size > input_size)) { - return false; - } - } - return true; - } + bool is_constant_start = + matchPattern(start_indices, m_Constant(&constant_start_indices)); for (int64_t i = 0; i < input_rank; ++i) { int64_t input_size = input_shape[i]; - int64_t start_index = - constant_start_indices.getValue(i).getInt(); int64_t slice_size = slice_sizes.getValue(i).getInt(); - if (start_index < 0) return false; // A slice_size of -1 means "all elements from start_index to the end". - // We can't support this semantics for dynamic shapes. - if (slice_size == -1) { - if (input_size == -1) return false; - slice_size = input_size - start_index; - } - if (input_size != -1 && start_index + slice_size > input_size) { - return false; - } + // In order to support these semantics, we need to know both the start index + // and the shape of the input dimension. + if (slice_size < 0 && (!is_constant_start || input_size < 0)) return false; } - return true; } @@ -767,7 +754,7 @@ NamedAttribute GetConvDimensionNumbersAttr( // // Sample result for Conv2D: // -// %conv = "xla_hlo.conv"(%input, %filter) { +// %conv = "xla_hlo.convolution"(%input, %filter) { // strides = [1, 2], // paddings = [[1, 0], [1, 1]], // ... @@ -1525,7 +1512,7 @@ class ConvertSigmoidOp : public OpRewritePattern { // %sub = "xla_hlo.subtract"(%inp, %max) {broadcast_dimensions = 0} // : (tensor, tensor) -> tensor // -// %exp = "xla_hlo.exp"(%sub) : (tensor) -> tensor +// %exp = "xla_hlo.exponential"(%sub) : (tensor) -> tensor // %sum = "tf.Sum"(%exp, %reduce_dim) // : (tensor, tensor<1xi64>) -> tensor // @@ -2221,7 +2208,7 @@ class GenericConvertReductionOp : public OpRewritePattern { // to mark the reduced dimensions. SmallVector reduced_dimensions_bitmap(input_shape.size(), false); SmallVector xla_dimensions; - for (APInt index_raw : dimensions.getValues()) { + for (const APInt &index_raw : dimensions.getValues()) { int64_t index = index_raw.getSExtValue(); int64_t rank = input_shape.size(); if ((index < -rank || index >= rank)) return failure(); @@ -2660,7 +2647,7 @@ class ConvertMaxPoolGradOp : public OpRewritePattern { // Converts hlo.Conv2DBackpropInputOp into: // %rev_filter = "xla_hlo.reverse"(%filter) -// %result = "xla_hlo.conv"(%out_backprop, %rev_filter) +// %result = "xla_hlo.convolution"(%out_backprop, %rev_filter) class ConvertConv2DBackpropInputOp : public OpRewritePattern { public: @@ -2803,7 +2790,7 @@ class ConvertConv2DBackpropInputOp }; // Converts tf.Conv2DBackpropFilterOp into: -// %result = "xla_hlo.conv"(%input, %out_backprop) +// %result = "xla_hlo.convolution"(%input, %out_backprop) class ConvertConv2DBackpropFilterOp : public OpRewritePattern { public: @@ -3829,7 +3816,7 @@ static PassRegistration pass( } // end namespace -std::unique_ptr> createLegalizeTFPass( +std::unique_ptr> createLegalizeTFPass( bool allow_partial_conversion) { return std::make_unique(allow_partial_conversion); } diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_control_flow.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_control_flow.cc index 053deddcdfe..86927fe0e07 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_control_flow.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_control_flow.cc @@ -52,13 +52,13 @@ namespace mlir { namespace xla_hlo { namespace { class LegalizeTFControlFlow - : public OperationPass { + : public PassWrapper> { public: void runOnOperation() override; }; } // namespace -std::unique_ptr> +std::unique_ptr> createLegalizeTFControlFlowPass() { return std::make_unique(); } diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc index 07a74f5cd6a..7ae18eb0d34 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc @@ -333,10 +333,14 @@ LogicalResult FuncLegalizer::LegalizeOp(Operation* op) { return success(); } -class LegalizeTF : public FunctionPass { +class LegalizeTF : public PassWrapper { public: LegalizeTF() = default; + explicit LegalizeTF(llvm::StringRef device_type) { + device_type_ = device_type.str(); + } + LegalizeTF(const LegalizeTF&) {} void runOnFunction() override { @@ -359,5 +363,10 @@ static PassRegistration pass( } // end namespace +std::unique_ptr> createLegalizeTfWithTf2XlaPass( + llvm::StringRef device_type) { + return std::make_unique(device_type); +} + } // end namespace xla_hlo } // end namespace mlir diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_to_standard.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_to_standard.cc index 0e3c59e06cd..604054bd094 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_to_standard.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_to_standard.cc @@ -177,13 +177,14 @@ class ConvertIotaOp : public OpRewritePattern { } // end anonymous namespace namespace { -struct LegalizeToStandard : public FunctionPass { +struct LegalizeToStandard + : public PassWrapper { /// Perform the lowering to Standard dialect. void runOnFunction() override; }; } // end anonymous namespace -std::unique_ptr> createLegalizeToStdPass() { +std::unique_ptr> createLegalizeToStdPass() { return std::make_unique(); } diff --git a/tensorflow/compiler/mlir/xla/transforms/lhlo_copy_removal.cc b/tensorflow/compiler/mlir/xla/transforms/lhlo_copy_removal.cc index 97341879759..636cdd38ece 100644 --- a/tensorflow/compiler/mlir/xla/transforms/lhlo_copy_removal.cc +++ b/tensorflow/compiler/mlir/xla/transforms/lhlo_copy_removal.cc @@ -30,7 +30,7 @@ namespace { // arguments. All uses of each buffer are replaced with the corresponding block // argument and the buffer is freed. Note that this pass only works in regions // with a single block. -struct LhloCopyRemoval : mlir::OperationPass { +struct LhloCopyRemoval : mlir::PassWrapper> { void runOnOperation() override { llvm::SmallVector eraseList; auto operation = getOperation(); diff --git a/tensorflow/compiler/mlir/xla/transforms/lhlo_fuse_linalg.cc b/tensorflow/compiler/mlir/xla/transforms/lhlo_fuse_linalg.cc index fbe8d800306..bdee1b77cff 100644 --- a/tensorflow/compiler/mlir/xla/transforms/lhlo_fuse_linalg.cc +++ b/tensorflow/compiler/mlir/xla/transforms/lhlo_fuse_linalg.cc @@ -30,12 +30,12 @@ namespace { using linalg::LinalgOp; -class LhloFuseLinalg : public FunctionPass { +class LhloFuseLinalg : public PassWrapper { public: LhloFuseLinalg() = default; LhloFuseLinalg(const LhloFuseLinalg&) {} LhloFuseLinalg(bool use_parallel_loops, llvm::ArrayRef tile_sizes) { - tile_sizes_->assign(tile_sizes.begin(), tile_sizes.end()); + tile_sizes_ = tile_sizes; use_parallel_loops_.setValue(use_parallel_loops); } @@ -123,7 +123,7 @@ class LhloFuseLinalg : public FunctionPass { } // namespace -std::unique_ptr> createLhloFuseLinalg( +std::unique_ptr> createLhloFuseLinalg( bool use_parallel_loops, ArrayRef tile_sizes) { return absl::make_unique(use_parallel_loops, tile_sizes); } diff --git a/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_affine.cc b/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_affine.cc index 15b91edbd8d..164e4dc93d8 100644 --- a/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_affine.cc +++ b/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_affine.cc @@ -81,7 +81,8 @@ void populateLHLOToAffineConversionPattern(MLIRContext* context, // clang-format on } -struct LhloLegalizeToAffine : public FunctionPass { +struct LhloLegalizeToAffine + : public PassWrapper { void runOnFunction() override { OwningRewritePatternList patterns; auto func = getFunction(); @@ -92,7 +93,7 @@ struct LhloLegalizeToAffine : public FunctionPass { } // namespace -std::unique_ptr> createLegalizeToAffinePass() { +std::unique_ptr> createLegalizeToAffinePass() { return absl::make_unique(); } diff --git a/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_gpu.cc b/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_gpu.cc index 0ea29393744..e6f3ac02d4f 100644 --- a/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_gpu.cc +++ b/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_gpu.cc @@ -168,7 +168,7 @@ class LhloReduceToGPULaunchConverter : public OpConversionPattern { }; }; -struct LhloLegalizeToGpu : public FunctionPass { +struct LhloLegalizeToGpu : public PassWrapper { void runOnFunction() override { OwningRewritePatternList patterns; ConversionTarget target(getContext()); @@ -186,7 +186,7 @@ struct LhloLegalizeToGpu : public FunctionPass { } // namespace -std::unique_ptr> createLegalizeToGpuPass() { +std::unique_ptr> createLegalizeToGpuPass() { return absl::make_unique(); } diff --git a/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_parallel_loops.cc b/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_parallel_loops.cc index 8a66a8853fb..489285e02d1 100644 --- a/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_parallel_loops.cc +++ b/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_parallel_loops.cc @@ -229,7 +229,7 @@ class ReduceOpConverter : public OpConversionPattern { ConversionPatternRewriter* rewriter) const { auto loc = xla_reduce_op.getLoc(); DenseSet reducing_dims; - for (auto rdim : xla_reduce_op.dimensions().getIntValues()) { + for (const auto& rdim : xla_reduce_op.dimensions().getIntValues()) { reducing_dims.insert(rdim.getSExtValue()); } @@ -452,7 +452,7 @@ class ReduceWindowOpConverter }; struct LhloLegalizeToParallelLoops - : public FunctionPass { + : public PassWrapper { void runOnFunction() override { auto func = getFunction(); @@ -478,7 +478,7 @@ struct LhloLegalizeToParallelLoops } // namespace -std::unique_ptr> createLegalizeLhloToParallelLoopsPass() { +std::unique_ptr> createLegalizeLhloToParallelLoopsPass() { return absl::make_unique(); } diff --git a/tensorflow/compiler/mlir/xla/transforms/lower_complex.cc b/tensorflow/compiler/mlir/xla/transforms/lower_complex.cc index 479f865626b..cef3138daf0 100644 --- a/tensorflow/compiler/mlir/xla/transforms/lower_complex.cc +++ b/tensorflow/compiler/mlir/xla/transforms/lower_complex.cc @@ -38,11 +38,12 @@ limitations under the License. using mlir::FunctionPass; using mlir::OwningRewritePatternList; using mlir::PassRegistration; +using mlir::PassWrapper; namespace { -class LowerComplex : public FunctionPass { +class LowerComplex : public PassWrapper { public: - explicit LowerComplex() : FunctionPass() {} + explicit LowerComplex() : PassWrapper() {} /// Performs the lowering to XLA dialect. void runOnFunction() override; diff --git a/tensorflow/compiler/mlir/xla/transforms/lower_general_dot.cc b/tensorflow/compiler/mlir/xla/transforms/lower_general_dot.cc index 28cbbf9f6e3..026e88b3671 100644 --- a/tensorflow/compiler/mlir/xla/transforms/lower_general_dot.cc +++ b/tensorflow/compiler/mlir/xla/transforms/lower_general_dot.cc @@ -39,6 +39,7 @@ using mlir::MLIRContext; using mlir::OpRewritePattern; using mlir::OwningRewritePatternList; using mlir::PassRegistration; +using mlir::PassWrapper; using mlir::PatternRewriter; using mlir::RankedTensorType; using mlir::success; @@ -170,7 +171,8 @@ struct GeneralDotConvert } }; -struct LegalizeGeneralDot : public FunctionPass { +struct LegalizeGeneralDot + : public PassWrapper { /// Lower all general dots that can be represented as a non-batched matmul. void runOnFunction() override { OwningRewritePatternList patterns; diff --git a/tensorflow/compiler/mlir/xla/transforms/materialize_broadcasts_pass.cc b/tensorflow/compiler/mlir/xla/transforms/materialize_broadcasts_pass.cc index 5622892c684..4c02fd5cadc 100644 --- a/tensorflow/compiler/mlir/xla/transforms/materialize_broadcasts_pass.cc +++ b/tensorflow/compiler/mlir/xla/transforms/materialize_broadcasts_pass.cc @@ -28,7 +28,7 @@ namespace xla_hlo { namespace { struct TestMaterializeBroadcastsPass - : public FunctionPass { + : public PassWrapper { void runOnFunction() override { ConversionTarget conversionTarget(getContext()); OwningRewritePatternList conversionPatterns; diff --git a/tensorflow/compiler/mlir/xla/transforms/passes.h b/tensorflow/compiler/mlir/xla/transforms/passes.h index e49ef435de9..2d0164981a3 100644 --- a/tensorflow/compiler/mlir/xla/transforms/passes.h +++ b/tensorflow/compiler/mlir/xla/transforms/passes.h @@ -28,18 +28,23 @@ class FuncOp; class ModuleOp; class Operation; template -class OpPassBase; +class OperationPass; class Pass; namespace xla_hlo { /// Lowers from TF dialect to HLO dialect. When allow_partial_conversion is /// false, emits an error if there is any operation that can't be legalized. -std::unique_ptr> createLegalizeTFPass( +std::unique_ptr> createLegalizeTFPass( bool allow_partial_conversion = false); +/// Lowers from TF dialect to HLO dialect using tf2xla op kernels for the +/// specified device type. +std::unique_ptr> createLegalizeTfWithTf2XlaPass( + llvm::StringRef device_type); + /// Lowers from TF dialect's control flow to HLO dialect's control flow. -std::unique_ptr> createLegalizeTFControlFlowPass(); +std::unique_ptr> createLegalizeTFControlFlowPass(); /// Converts the provided Operation as well as all nested operations into HLO /// dialect using the conversion patterns registered by the HLO dialect. When @@ -48,30 +53,30 @@ std::unique_ptr> createLegalizeTFControlFlowPass(); LogicalResult legalizeTF(Operation* op, bool allow_partial_conversion = false); /// Lowers HLO control flow ops to the Standard dialect. -std::unique_ptr> createLegalizeControlFlowPass(); +std::unique_ptr> createLegalizeControlFlowPass(); /// Lowers from HLO dialect to Standard dialect. -std::unique_ptr> createLegalizeToStdPass(); +std::unique_ptr> createLegalizeToStdPass(); // Lowers from HLO dialect to LHLO dialect allocating/deallocating temporary // buffers if necessary. -std::unique_ptr> createLegalizeToLhloPass(); +std::unique_ptr> createLegalizeToLhloPass(); // Lowers from HLO dialect to Linalg dialect. -std::unique_ptr> createLegalizeHloToLinalgPass(); +std::unique_ptr> createLegalizeHloToLinalgPass(); } // namespace xla_hlo namespace xla_lhlo { // Lowers from LHLO dialect to Affine dialect. -std::unique_ptr> createLegalizeToAffinePass(); +std::unique_ptr> createLegalizeToAffinePass(); // Lowers from LHLO dialect to Linalg dialect. -std::unique_ptr> createLegalizeLhloToLinalgPass(); +std::unique_ptr> createLegalizeLhloToLinalgPass(); // Lowers from LHLO dialect to GPU dialect. -std::unique_ptr> createLegalizeToGpuPass(); +std::unique_ptr> createLegalizeToGpuPass(); // Fuses linalg ops obtained after LHLO lowering. To enable fusion, // operations are first tiled. @@ -82,7 +87,7 @@ std::unique_ptr> createLegalizeToGpuPass(); // 'tile_sizes' provides the tile sizes to use for tiling. If the linalg // operation has more dimensions than tile sizes provided, 1 is used as // default. -std::unique_ptr> createLhloFuseLinalg( +std::unique_ptr> createLhloFuseLinalg( bool use_parallel_loops = false, ArrayRef tile_sizes = {}); // Removes unnecessary LHLO copies which copy from the allocated buffers to the @@ -92,7 +97,7 @@ std::unique_ptr> createLhloFuseLinalg( std::unique_ptr createLhloCopyRemovalPass(); // Lowers from LHLO dialect to parallel loops. -std::unique_ptr> createLegalizeLhloToParallelLoopsPass(); +std::unique_ptr> createLegalizeLhloToParallelLoopsPass(); } // namespace xla_lhlo @@ -110,7 +115,7 @@ namespace xla { /// 3) Note that the current implementation does not support loops. /// Refer to the class mlir::xla::BufferAssignmentLegalizer for more /// information. -std::unique_ptr> createBufferAssignmentPass(); +std::unique_ptr> createBufferAssignmentPass(); } // namespace xla } // namespace mlir diff --git a/tensorflow/compiler/mlir/xla/transforms/unfuse_batch_norm_pass.cc b/tensorflow/compiler/mlir/xla/transforms/unfuse_batch_norm_pass.cc index 51973191dc7..600c7ece217 100644 --- a/tensorflow/compiler/mlir/xla/transforms/unfuse_batch_norm_pass.cc +++ b/tensorflow/compiler/mlir/xla/transforms/unfuse_batch_norm_pass.cc @@ -27,7 +27,8 @@ namespace xla_hlo { namespace { -struct TestUnfuseBatchNormPass : public OperationPass { +struct TestUnfuseBatchNormPass + : public PassWrapper> { void runOnOperation() override { OwningRewritePatternList patterns; PopulateUnfuseBatchNormPatterns(&getContext(), &patterns); diff --git a/tensorflow/compiler/mlir/xla/transforms/xla_legalize_to_linalg.cc b/tensorflow/compiler/mlir/xla/transforms/xla_legalize_to_linalg.cc index d4f90ade2a2..f9c041f2e28 100644 --- a/tensorflow/compiler/mlir/xla/transforms/xla_legalize_to_linalg.cc +++ b/tensorflow/compiler/mlir/xla/transforms/xla_legalize_to_linalg.cc @@ -531,7 +531,8 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context, // iterator_types = ["parallel", "parallel"], // } : (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> () // } -struct LhloLegalizeToLinalg : public FunctionPass { +struct LhloLegalizeToLinalg + : public PassWrapper { void runOnFunction() override { OwningRewritePatternList patterns; ConversionTarget target(getContext()); @@ -545,7 +546,8 @@ struct LhloLegalizeToLinalg : public FunctionPass { } }; -struct HloLegalizeToLinalg : public FunctionPass { +struct HloLegalizeToLinalg + : public PassWrapper { void runOnFunction() override { OwningRewritePatternList patterns; ConversionTarget target(getContext()); @@ -562,7 +564,7 @@ struct HloLegalizeToLinalg : public FunctionPass { } // namespace namespace xla_lhlo { -std::unique_ptr> createLegalizeLhloToLinalgPass() { +std::unique_ptr> createLegalizeLhloToLinalgPass() { return absl::make_unique(); } @@ -599,7 +601,7 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context, PointwiseToLinalgConverter>(context); } -std::unique_ptr> createLegalizeHloToLinalgPass() { +std::unique_ptr> createLegalizeHloToLinalgPass() { return absl::make_unique(); } diff --git a/tensorflow/compiler/mlir/xla/type_to_shape.cc b/tensorflow/compiler/mlir/xla/type_to_shape.cc index 8250976eb00..3b1ae934c48 100644 --- a/tensorflow/compiler/mlir/xla/type_to_shape.cc +++ b/tensorflow/compiler/mlir/xla/type_to_shape.cc @@ -45,6 +45,17 @@ PrimitiveType TypeToPrimitiveType(mlir::Type type) { switch (type.getKind()) { case mlir::StandardTypes::BF16: return PrimitiveType::BF16; + case mlir::StandardTypes::Complex: { + mlir::Type element_ty = type.cast().getElementType(); + switch (element_ty.getKind()) { + case mlir::StandardTypes::F32: + return PrimitiveType::C64; + case mlir::StandardTypes::F64: + return PrimitiveType::C128; + default: + return PrimitiveType::PRIMITIVE_TYPE_INVALID; + } + } case mlir::StandardTypes::F16: return PrimitiveType::F16; case mlir::StandardTypes::F32: diff --git a/tensorflow/compiler/tests/unary_mlir_ops_test.py b/tensorflow/compiler/tests/unary_mlir_ops_test.py index 2b3dec3d5a7..92be8e04b71 100644 --- a/tensorflow/compiler/tests/unary_mlir_ops_test.py +++ b/tensorflow/compiler/tests/unary_mlir_ops_test.py @@ -67,9 +67,7 @@ class UnaryOpsTest(xla_test.XLATestCase): equality_test(result, expected, rtol=rtol, atol=atol) def testNumericOps(self): - # TODO(hinsu): Enable complex types after fixing the failure in export to - # HLOModule. - for dtype in self.numeric_types - {np.int8, np.uint8} - self.complex_types: + for dtype in self.numeric_types - {np.int8, np.uint8}: self._assertOpOutputMatchesExpected( math_ops.abs, np.array([[2, -1]], dtype=dtype), diff --git a/tensorflow/compiler/tf2tensorrt/BUILD b/tensorflow/compiler/tf2tensorrt/BUILD index 6291ea6cbda..af1877a2394 100644 --- a/tensorflow/compiler/tf2tensorrt/BUILD +++ b/tensorflow/compiler/tf2tensorrt/BUILD @@ -90,6 +90,7 @@ cc_library( deps = [ ":trt_allocator", ":trt_conversion", + ":trt_engine_utils", ":trt_logging", ":trt_plugins", ":trt_resources", @@ -215,9 +216,24 @@ cc_library( deps = [ ":get_calibration_data_op_op_lib", ":trt_engine_op_op_lib", + ":trt_engine_utils", ], ) +tf_cuda_library( + name = "trt_engine_utils", + srcs = ["utils/trt_engine_utils.cc"], + hdrs = ["utils/trt_engine_utils.h"], + deps = [ + ":trt_logging", + ":utils", + "@com_google_absl//absl/strings", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core/platform:status", + ] + if_tensorrt([":tensorrt_lib"]), +) + tf_cuda_library( name = "trt_logging", srcs = ["utils/trt_logger.cc"], @@ -435,6 +451,8 @@ tf_cuda_cc_test( ":trt_logging", ":trt_conversion", ":trt_plugins", + ":trt_engine_utils", + ":utils", "@com_google_googletest//:gtest", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc index 400c53614f9..e9e3333ea38 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc @@ -30,6 +30,8 @@ limitations under the License. #include "tensorflow/cc/framework/scope.h" #include "tensorflow/cc/ops/nn_ops_internal.h" #include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/compiler/tf2tensorrt/convert/utils.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_engine_utils.h" #include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h" #include "tensorflow/core/framework/node_def.pb.h" // NOLINT #include "tensorflow/core/framework/tensor.h" @@ -1213,26 +1215,12 @@ TEST_F(ConvertGraphDefToEngineTest, IdentityGraph) { TF_EXPECT_OK(RunConvertGraphDefToEngine(&s)); } -// Input/output data format for OpConverterTest::BuildAndRun(). -struct InputOutputData { - void* Buffer() const { - return const_cast(tensor.tensor_data().data()); - } - - size_t TotalBytes() const { return tensor.TotalBytes(); } - - string name; - Tensor tensor; -}; - template Tensor ConstructTensor(int data_size, const T& value = T()) { std::vector values(data_size, value); return test::AsTensor(values); } -using DataVec = std::vector; - template inline absl::Span GetSpanForData(const InputOutputData& data) { const auto& tensor_map = data.tensor.flat(); @@ -1308,10 +1296,31 @@ class OpConverterTest : public ::testing::Test { CheckDataTypeMatches(input_data); CheckDataTypeMatches(*output_data); - // Execute the TRT engine. const int num_bindings = input_data.size() + output_data->size(); std::vector buffers(num_bindings); + ASSERT_EQ(engine_->getNbBindings(), num_bindings); + TrtUniquePtrType execution_context( + engine_->createExecutionContext()); + + // Prepare input bindings. + TF_ASSERT_OK(SetTrtEngineInputs(engine_.get(), execution_context.get(), 0, + buffers, converter_->use_implicit_batch(), + batch_size, nullptr, &input_data)); + + // Prepare output bindings. + TF_ASSERT_OK(SetTrtEngineOutputs(engine_.get(), execution_context.get(), 0, + buffers, converter_->use_implicit_batch(), + batch_size, nullptr, output_data)); + + // Allocate buffers on GPU and copy data there. This is necessary because + // the test tensors are allocated in host memory, so the pointers that + // SetTrtEngin(In|Out)puts placed into buffers[] cannot be used on the GPU. + // We allocate the GPU buffers, copy the data there, and overwrite the + // addresses in the buffers array. + // + // TODO(tfeher): This step can be avoided if we allocate the Tensors in + // unified memory. for (const auto& data : input_data) { const int input_index = engine_->getBindingIndex(data.name.c_str()); ASSERT_NE(-1, input_index); @@ -1334,10 +1343,9 @@ class OpConverterTest : public ::testing::Test { ASSERT_EQ(0, cudaMalloc(&buffers[output_index], data.TotalBytes())); } - ASSERT_EQ(engine_->getNbBindings(), num_bindings); - TrtUniquePtrType execution_context( - engine_->createExecutionContext()); - execution_context->enqueue(batch_size, buffers.data(), stream_, nullptr); + // Execute the TRT engine. + TF_ASSERT_OK(TrtEnqueue(execution_context.get(), buffers, stream_, + converter_->use_implicit_batch(), batch_size)); for (int i = 0; i < output_infos.size(); ++i) { const auto& output_info = output_infos[i]; diff --git a/tensorflow/compiler/tf2tensorrt/convert/utils.cc b/tensorflow/compiler/tf2tensorrt/convert/utils.cc index c83b84998fa..fb3ae6943d3 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/utils.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/utils.cc @@ -163,6 +163,28 @@ bool AreShapesCompatible(const std::vector& actual_shapes, return true; } +Status TrtDimsToTensorShape(const std::vector& trt_dims, + bool use_implicit_batch, int batch_size, + TensorShape& shape) { + TF_RETURN_IF_ERROR( + TensorShapeUtils::MakeShape(trt_dims.data(), trt_dims.size(), &shape)); + if (use_implicit_batch) { + shape.InsertDim(0, batch_size); + } + return Status::OK(); +} + +Status TrtDimsToTensorShape(const nvinfer1::Dims trt_dims, + bool use_implicit_batch, int batch_size, + TensorShape& shape) { + TF_RETURN_IF_ERROR( + TensorShapeUtils::MakeShape(trt_dims.d, trt_dims.nbDims, &shape)); + if (use_implicit_batch) { + shape.InsertDim(0, batch_size); + } + return Status::OK(); +} + int GetNumberOfEngineInputs(const nvinfer1::ICudaEngine* engine) { int n_bindings = engine->getNbBindings(); int n_input = 0; diff --git a/tensorflow/compiler/tf2tensorrt/convert/utils.h b/tensorflow/compiler/tf2tensorrt/convert/utils.h index 139984616f0..5d4cf1bb851 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/utils.h +++ b/tensorflow/compiler/tf2tensorrt/convert/utils.h @@ -98,6 +98,14 @@ inline nvinfer1::Dims TensorShapeToTrtDims(const TensorShapeType& shape, return trt_dims; } +Status TrtDimsToTensorShape(const std::vector& trt_dims, + bool use_implicit_batch, int batch_size, + TensorShape& shape); + +Status TrtDimsToTensorShape(const nvinfer1::Dims trt_dims, + bool use_implicit_batch, int batch_size, + TensorShape& shape); + // Returns a string that includes compile time TensorRT library version // information {Maj, Min, Patch}. string GetLinkedTensorRTVersion(); diff --git a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc index a0524f4a90e..66a1a96d96d 100644 --- a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc +++ b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h" #include "tensorflow/compiler/tf2tensorrt/convert/utils.h" #include "tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_engine_utils.h" #include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h" #include "tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h" #include "tensorflow/compiler/tf2tensorrt/utils/trt_shape_optimization_profiles.h" @@ -109,8 +110,8 @@ class TRTEngineOp : public AsyncOpKernel { // Executes the tensorrt engine. Returns whether we need to retry by running // the native segment. - bool ExecuteTrtEngine(OpKernelContext* ctx, EngineContext* engine_context, - int trt_context_idx); + Status ExecuteTrtEngine(OpKernelContext* ctx, EngineContext* engine_context, + int trt_context_idx); // Allocates necessary resources for calibration. Status AllocateCalibrationResources(OpKernelContext* ctx, @@ -602,11 +603,10 @@ void TRTEngineOp::ComputeAsync(OpKernelContext* ctx, ExecuteNativeSegment(ctx, helper); return; } - - const bool retry = ExecuteTrtEngine(ctx, engine_context, trt_context_idx); - if (retry) { - LOG(WARNING) << "Failed to execute engine, " - << "retrying with native segment for " << name(); + Status stat = ExecuteTrtEngine(ctx, engine_context, trt_context_idx); + if (!stat.ok()) { + LOG(WARNING) << "Failed to execute engine: " << stat + << " Retrying with native segment for " << name(); // Release any outputs that are allocated, ExecuteNativeSegment will // re-allocate them and fail if they are currently allocated. for (int i = 0; i < ctx->num_outputs(); i++) { @@ -617,42 +617,9 @@ void TRTEngineOp::ComputeAsync(OpKernelContext* ctx, } } -// Gets the binding index of a tensor in an engine. -// -// The binding index is looked up using the tensor's name and the profile index. -// Profile index should be set to zero, if we do not have optimization profiles. -Status GetTrtBindingIndex(const char* tensor_name, int profile_index, - const nvinfer1::ICudaEngine* cuda_engine, - int* binding_index) { - // If the engine has been built for K profiles, the first getNbBindings() / K - // bindings are used by profile number 0, the following getNbBindings() / K - // bindings are used by profile number 1 etc. - // - // GetBindingIndex(tensor_name) returns the binding index for the progile 0. - // We can also consider it as a "binding_index_within_profile". - *binding_index = cuda_engine->getBindingIndex(tensor_name); - if (*binding_index == -1) { - const string msg = StrCat("Input node ", tensor_name, " not found"); - LOG(ERROR) << msg; - return errors::NotFound(msg); - } -#if IS_TRT_VERSION_GE(6, 0, 0, 0) - int n_profiles = cuda_engine->getNbOptimizationProfiles(); -#else - int n_profiles = 1; -#endif - // If we have more then one optimization profile, then we need to shift the - // binding index according to the following formula: - // binding_index_within_engine = binding_index_within_profile + - // profile_index * bindings_per_profile - const int bindings_per_profile = cuda_engine->getNbBindings() / n_profiles; - *binding_index = *binding_index + profile_index * bindings_per_profile; - return Status::OK(); -} - -bool TRTEngineOp::ExecuteTrtEngine(OpKernelContext* ctx, - EngineContext* engine_context, - int trt_context_idx) { +Status TRTEngineOp::ExecuteTrtEngine(OpKernelContext* ctx, + EngineContext* engine_context, + int trt_context_idx) { VLOG(1) << "Executing TRT engine: " << name(); auto& cuda_engine = engine_context->cuda_engine; @@ -677,163 +644,24 @@ bool TRTEngineOp::ExecuteTrtEngine(OpKernelContext* ctx, const int num_binding = cuda_engine->getNbBindings(); std::vector buffers(num_binding); + // nvinfer1::IExecutionContext::enqueue is not thread safe and we need a mutex + // for it. mutex_lock lock(engine_context->mu); nvinfer1::IExecutionContext* execution_context; - Status status = - engine_context->GetExecutionContext(trt_context_idx, &execution_context); - const bool kRetry = true; - if (!status.ok()) { - // TODO(Tamas) let ExecuteTrtEngine return a status, and do the logging at - // the call site - LOG(ERROR) << status; - return kRetry; - } + TF_RETURN_IF_ERROR( + engine_context->GetExecutionContext(trt_context_idx, &execution_context)); - // Setup engine inputs. - for (int i = 0; i < ctx->num_inputs(); i++) { - const string input_name = StrCat(IONamePrefixes::kInputPHName, i); - int binding_index; - auto status = GetTrtBindingIndex(input_name.c_str(), trt_context_idx, - cuda_engine.get(), &binding_index); - if (!status.ok()) { - ctx->SetStatus(status); - return !kRetry; - } + const int num_batch = + use_implicit_batch_ ? ctx->input(0).shape().dim_size(0) : 0; - const Tensor& input_tensor = ctx->input(i); - const TensorShape& input_shape = input_tensor.shape(); + TF_RETURN_IF_ERROR(SetTrtEngineInputs(cuda_engine.get(), execution_context, + trt_context_idx, buffers, + use_implicit_batch_, num_batch, ctx)); - if (use_implicit_batch_) { - // Ensure all inputs have the same batch size - const int num_batch = ctx->input(0).shape().dim_size(0); - if (num_batch != input_shape.dim_size(0)) { - LOG(ERROR) << "Input data has inconsistent batch size: " << num_batch - << " vs " << input_shape.dim_size(0); - return kRetry; - } - } -#if IS_TRT_VERSION_GE(6, 0, 0, 0) - // Set known input dimensions. This is necessary because TRT network - // could be made with dynamic dimensions. - if (!use_implicit_batch_) { - nvinfer1::Dims trt_dims; - trt_dims.nbDims = input_shape.dims(); - for (int k = 0; k < input_shape.dims(); k++) { - trt_dims.d[k] = input_shape.dim_size(k); - } - execution_context->setBindingDimensions(binding_index, trt_dims); - } -#endif - // Setup input bindings. - auto dtype = cuda_engine->getBindingDataType(binding_index); - switch (dtype) { - case nvinfer1::DataType::kFLOAT: - buffers[binding_index] = - const_cast(input_tensor.flat().data()); - break; - case nvinfer1::DataType::kHALF: - buffers[binding_index] = - const_cast(input_tensor.flat().data()); - break; - case nvinfer1::DataType::kINT8: - LOG(ERROR) << "INT8 inputs are not supported yet!"; - return kRetry; - case nvinfer1::DataType::kINT32: - buffers[binding_index] = - const_cast(input_tensor.flat().data()); - break; - default: - LOG(ERROR) << "Unknown TRT data type: " << static_cast(dtype); - return kRetry; - } - } + TF_RETURN_IF_ERROR(SetTrtEngineOutputs(cuda_engine.get(), execution_context, + trt_context_idx, buffers, + use_implicit_batch_, num_batch, ctx)); -#if IS_TRT_VERSION_GE(6, 0, 0, 0) - // Ensure all network dynamic dimensions (if any) are set in execution - // context. - if (!execution_context->allInputDimensionsSpecified()) { - LOG(WARNING) << "Failed to set dimensions for all dynamic input tensors."; - return kRetry; - } - if (!execution_context->allInputShapesSpecified()) { - LOG(WARNING) << "Failed to set dimensions for all shape input tensors."; - return kRetry; - } -#endif - - // Setup engine outputs. - for (int i = 0; i < ctx->num_outputs(); i++) { - const string output_name = StrCat(IONamePrefixes::kOutputPHName, i); - int binding_index; - auto status = GetTrtBindingIndex(output_name.c_str(), trt_context_idx, - cuda_engine.get(), &binding_index); - if (!status.ok()) { - ctx->SetStatus(status); - return !kRetry; - } - // Get TRT output shapes for allocating output memory. - std::vector trt_shape; - if (!use_implicit_batch_) { - // Explicit batch mode just copy output dims to trt_shape -#if IS_TRT_VERSION_GE(6, 0, 0, 0) - // Get dims from context instead of engine in explicit batch mode - // because engine might have dynamic shapes. - auto dims = execution_context->getBindingDimensions(binding_index); - for (int j = 0; j < dims.nbDims; j++) { - trt_shape.push_back(dims.d[j]); - } -#else - LOG(ERROR) - << "Explicit batch mode is only supported with TensorRT 6 and above."; - return kRetry; -#endif - } else { - // Implicit batch mode, it's assumed that first dimension of all inputs - // and outputs is batch size. We prepend the batch dim to trt_shape. - auto dims = cuda_engine->getBindingDimensions(binding_index); - trt_shape.push_back(ctx->input(0).shape().dim_size(0)); - for (int j = 0; j < dims.nbDims; j++) { - trt_shape.push_back(dims.d[j]); - } - } - // Allocate output tensor of TRTEngineOp. - Tensor* output_tensor = nullptr; - TensorShape output_shape; - status = TensorShapeUtils::MakeShape(trt_shape.data(), trt_shape.size(), - &output_shape); - if (!status.ok()) { - LOG(ERROR) << "Failed to get output shape: " << status; - return kRetry; - } - status = ctx->allocate_output(i, output_shape, &output_tensor); - if (!status.ok()) { - LOG(ERROR) << "Allocating output failed with " << status; - ctx->SetStatus(status); - return kRetry; - } - // Setup output bindings. - auto dtype = cuda_engine->getBindingDataType(binding_index); - switch (dtype) { - case nvinfer1::DataType::kFLOAT: - buffers[binding_index] = - const_cast(output_tensor->flat().data()); - break; - case nvinfer1::DataType::kHALF: - buffers[binding_index] = - const_cast(output_tensor->flat().data()); - break; - case nvinfer1::DataType::kINT8: - LOG(WARNING) << "int8 is not supported yet!"; - return kRetry; - case nvinfer1::DataType::kINT32: - buffers[binding_index] = - const_cast(output_tensor->flat().data()); - break; - default: - LOG(WARNING) << "Unknown TRT data type: " << static_cast(dtype); - return kRetry; - } - } // Copied from cuda_kernel_helper since it seems only valid in *.cu.cc files const cudaStream_t* stream = CHECK_NOTNULL( reinterpret_cast(ctx->op_device_context() @@ -841,29 +669,9 @@ bool TRTEngineOp::ExecuteTrtEngine(OpKernelContext* ctx, ->implementation() ->GpuStreamMemberHack())); - // nvinfer1::IExecutionContext::enqueue is not thread safe and we need a mutex - // for it. - bool ret = false; - if (use_implicit_batch_) { - const int num_batch = ctx->input(0).shape().dim_size(0); - ret = execution_context->enqueue(num_batch, &buffers[0], *stream, nullptr); - VLOG(1) << "Called IExecutionContext::enqueue"; - } else { -#if IS_TRT_VERSION_GE(6, 0, 0, 0) - ret = execution_context->enqueueV2(&buffers[0], *stream, nullptr); - VLOG(1) << "Called IExecutionContext::enqueueV2"; -#else - LOG(ERROR) - << "Explicit batch mode is only supported with TensorRT 6 and above."; - return kRetry; -#endif - } - if (!ret) { - LOG(WARNING) << "Failed to enqueue batch for TRT engine: " << name(); - return kRetry; - } - // Synchronization will be done by TF. - return !kRetry; + TF_RETURN_IF_ERROR(TrtEnqueue(execution_context, buffers, *stream, + use_implicit_batch_, num_batch)); + return Status::OK(); } Status TRTEngineOp::GetEngineCacheResource(OpKernelContext* ctx, diff --git a/tensorflow/compiler/tf2tensorrt/utils/py_utils_wrapper.cc b/tensorflow/compiler/tf2tensorrt/utils/py_utils_wrapper.cc index 0d7819931b1..03f77c6bd5f 100644 --- a/tensorflow/compiler/tf2tensorrt/utils/py_utils_wrapper.cc +++ b/tensorflow/compiler/tf2tensorrt/utils/py_utils_wrapper.cc @@ -15,7 +15,7 @@ limitations under the License. #include -#include "include/pybind11/pybind11.h" +#include "pybind11/pybind11.h" #include "tensorflow/compiler/tf2tensorrt/utils/py_utils.h" std::tuple get_linked_tensorrt_version() { diff --git a/tensorflow/compiler/tf2tensorrt/utils/trt_engine_utils.cc b/tensorflow/compiler/tf2tensorrt/utils/trt_engine_utils.cc new file mode 100644 index 00000000000..213c1732e59 --- /dev/null +++ b/tensorflow/compiler/tf2tensorrt/utils/trt_engine_utils.cc @@ -0,0 +1,253 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2tensorrt/utils/trt_engine_utils.h" + +#include +#include + +#include "absl/strings/str_cat.h" +#include "tensorflow/compiler/tf2tensorrt/convert/utils.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/errors.h" + +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT +#include "third_party/tensorrt/NvInfer.h" + +namespace tensorflow { +namespace tensorrt { + +using absl::StrCat; + +Status GetTrtBindingShape(const nvinfer1::ICudaEngine* cuda_engine, + const nvinfer1::IExecutionContext* execution_context, + int binding_index, bool use_implicit_batch, + int batch_size, TensorShape& shape) { + nvinfer1::Dims dims; + if (use_implicit_batch) { + dims = cuda_engine->getBindingDimensions(binding_index); + } else { +#if IS_TRT_VERSION_GE(6, 0, 0, 0) + // Get dims from context instead of engine in explicit batch mode because + // the engine might have dynamic shapes. + dims = execution_context->getBindingDimensions(binding_index); +#else + return errors::Internal( + "Explicit batch mode is only supported with TensorRT 6 and above."); +#endif + } + TF_RETURN_IF_ERROR( + TrtDimsToTensorShape(dims, use_implicit_batch, batch_size, shape)); + return Status::OK(); +} + +Status GetTrtBindingIndex(const char* tensor_name, int profile_index, + const nvinfer1::ICudaEngine* cuda_engine, + int* binding_index) { + // If the engine has been built for K profiles, the first getNbBindings() / K + // bindings are used by profile number 0, the following getNbBindings() / K + // bindings are used by profile number 1 etc. + // + // GetBindingIndex(tensor_name) returns the binding index for the progile 0. + // We can also consider it as a "binding_index_within_profile". + *binding_index = cuda_engine->getBindingIndex(tensor_name); + if (*binding_index == -1) { + const string msg = StrCat("Input node ", tensor_name, " not found"); + LOG(ERROR) << msg; + return errors::NotFound(msg); + } +#if IS_TRT_VERSION_GE(6, 0, 0, 0) + int n_profiles = cuda_engine->getNbOptimizationProfiles(); +#else + int n_profiles = 1; +#endif + // If we have more then one optimization profile, then we need to shift the + // binding index according to the following formula: + // binding_index_within_engine = binding_index_within_profile + + // profile_index * bindings_per_profile + const int bindings_per_profile = cuda_engine->getNbBindings() / n_profiles; + *binding_index = *binding_index + profile_index * bindings_per_profile; + return Status::OK(); +} + +Status SetTrtEngineInputs(nvinfer1::ICudaEngine* cuda_engine, + nvinfer1::IExecutionContext* execution_context, + const int trt_profile_idx, + std::vector& buffers, bool use_implicit_batch, + int num_batch, OpKernelContext* ctx, + const DataVec* input_vec) { + int n_inputs = ctx ? ctx->num_inputs() : (input_vec ? input_vec->size() : 0); + // Setup engine inputs. + for (int i = 0; i < n_inputs; i++) { + const string input_name = + ctx ? StrCat(IONamePrefixes::kInputPHName, i) : input_vec->at(i).name; + int binding_index; + TF_RETURN_IF_ERROR(GetTrtBindingIndex(input_name.c_str(), trt_profile_idx, + cuda_engine, &binding_index)); + const Tensor& input_tensor = ctx ? ctx->input(i) : input_vec->at(i).tensor; + const TensorShape& input_shape = input_tensor.shape(); + + if (use_implicit_batch && ctx) { + // Ensure all inputs have the same batch size + if (num_batch != input_shape.dim_size(0)) { + const string msg = + StrCat("Input data has inconsistent batch size: ", num_batch, + " vs ", input_shape.dim_size(0)); + return errors::NotFound(msg); + } + } +#if IS_TRT_VERSION_GE(6, 0, 0, 0) + // Set known input dimensions. This is necessary because TRT network + // could be made with dynamic dimensions. + if (!use_implicit_batch) { + nvinfer1::Dims trt_dims; + trt_dims.nbDims = input_shape.dims(); + for (int k = 0; k < input_shape.dims(); k++) { + trt_dims.d[k] = input_shape.dim_size(k); + } + execution_context->setBindingDimensions(binding_index, trt_dims); + } +#endif + // Setup input bindings. + auto dtype = cuda_engine->getBindingDataType(binding_index); + switch (dtype) { + case nvinfer1::DataType::kFLOAT: + buffers[binding_index] = + const_cast(input_tensor.flat().data()); + break; + case nvinfer1::DataType::kHALF: + buffers[binding_index] = + const_cast(input_tensor.flat().data()); + break; + case nvinfer1::DataType::kINT8: + return errors::Internal("INT8 inputs are not supported yet!"); + case nvinfer1::DataType::kINT32: + buffers[binding_index] = + const_cast(input_tensor.flat().data()); + break; + default: + return errors::Internal("Unknown TRT data type: ", + static_cast(dtype)); + } + } + +#if IS_TRT_VERSION_GE(6, 0, 0, 0) + // Ensure all network dynamic dimensions (if any) are set in execution + // context. + if (!execution_context->allInputDimensionsSpecified()) { + return errors::Internal( + "Failed to set dimensions for all dynamic input tensors"); + } + if (!execution_context->allInputShapesSpecified()) { + return errors::Internal( + "Failed to set dimensions for all shape input tensors."); + } +#endif + return Status::OK(); +} + +Status SetTrtEngineOutputs(nvinfer1::ICudaEngine* cuda_engine, + nvinfer1::IExecutionContext* execution_context, + int trt_profile_idx, std::vector& buffers, + bool use_implicit_batch, int batch_size, + OpKernelContext* ctx, DataVec* outputs) { + // Either one of ctx or outpus should be specified + int n_outputs = ctx ? ctx->num_outputs() : (outputs ? outputs->size() : 0); + for (int i = 0; i < n_outputs; i++) { + const string output_name = + ctx ? StrCat(IONamePrefixes::kOutputPHName, i) : outputs->at(i).name; + int binding_index; + TF_RETURN_IF_ERROR(GetTrtBindingIndex(output_name.c_str(), trt_profile_idx, + cuda_engine, &binding_index)); + + // Get TRT output shapes for allocating output memory. + TensorShape output_shape; + TF_RETURN_IF_ERROR(GetTrtBindingShape(cuda_engine, execution_context, + binding_index, use_implicit_batch, + batch_size, output_shape)); + + // Allocate output tensor of TRTEngineOp. + Tensor* output_tensor = nullptr; + if (ctx) { + TF_RETURN_IF_ERROR(ctx->allocate_output(i, output_shape, &output_tensor)); + } else { + // This path is used for unit tests. The tensor is already allocated. + // Its shape is not necessarily set correctly, we fix that. + VLOG(2) << "Applying shape " << output_shape.DebugString() + << " on output."; + output_tensor = &(outputs->at(i).tensor); + bool status = output_tensor->CopyFrom(*output_tensor, output_shape); + if (!status) { + return errors::Internal( + "Buffer size do not match while reshaping output tensors"); + } + } + + // Setup output bindings. + auto dtype = cuda_engine->getBindingDataType(binding_index); + switch (dtype) { + case nvinfer1::DataType::kFLOAT: + buffers[binding_index] = + const_cast(output_tensor->flat().data()); + break; + case nvinfer1::DataType::kHALF: + buffers[binding_index] = + const_cast(output_tensor->flat().data()); + break; + case nvinfer1::DataType::kINT8: + return errors::Internal("int8 is not supported yet!"); + case nvinfer1::DataType::kINT32: + buffers[binding_index] = + const_cast(output_tensor->flat().data()); + break; + default: + return errors::Internal("Unknown TRT data type: ", + static_cast(dtype)); + } + } + return Status::OK(); +} + +Status TrtEnqueue(nvinfer1::IExecutionContext* execution_context, + std::vector& buffers, cudaStream_t stream, + bool use_implicit_batch, int batch_size) { + bool ret = false; + if (use_implicit_batch) { + ret = execution_context->enqueue(batch_size, &buffers[0], stream, nullptr); + VLOG(1) << "Called IExecutionContext::enqueue"; + } else { +#if IS_TRT_VERSION_GE(6, 0, 0, 0) + ret = execution_context->enqueueV2(&buffers[0], stream, nullptr); + VLOG(1) << "Called IExecutionContext::enqueueV2"; +#else + return errors::Internal( + "Explicit batch mode is only supported with TensorRT 6 and above."); +#endif + } + if (!ret) { + return errors::Internal("Failed to enqueue batch for TRT engine"); + } + // Synchronization will be done by TF. + return Status::OK(); +} + +} // namespace tensorrt +} // namespace tensorflow + +#endif // GOOGLE_TENSORRT +#endif // GOOGLE_CUDA diff --git a/tensorflow/compiler/tf2tensorrt/utils/trt_engine_utils.h b/tensorflow/compiler/tf2tensorrt/utils/trt_engine_utils.h new file mode 100644 index 00000000000..a471749877a --- /dev/null +++ b/tensorflow/compiler/tf2tensorrt/utils/trt_engine_utils.h @@ -0,0 +1,97 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_ENGINE_UTILS_H_ +#define TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_ENGINE_UTILS_H_ + +#include +#include + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/lib/core/status.h" + +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT +#include "third_party/tensorrt/NvInfer.h" + +namespace tensorflow { +namespace tensorrt { + +// Input/output data format for OpConverterTest::BuildAndRun(). +struct InputOutputData { + void* Buffer() const { + return const_cast(tensor.tensor_data().data()); + } + + size_t TotalBytes() const { return tensor.TotalBytes(); } + + string name; + Tensor tensor; +}; + +using DataVec = std::vector; + +// Gets the binding index of a tensor in an engine. +// +// The binding index is looked up using the tensor's name and the profile index. +// Profile index should be set to zero, if we do not have optimization profiles. +Status GetTrtBindingIndex(const char* tensor_name, int profile_index, + const nvinfer1::ICudaEngine* cuda_engine, + int* binding_index); + +// Sets input buffers for TRT from a list of input tensors. The input tensors +// are either defined by ctx or by input_vec. +Status SetTrtEngineInputs(nvinfer1::ICudaEngine* cuda_engine, + nvinfer1::IExecutionContext* execution_context, + const int trt_profile_idx, + std::vector& buffers, bool use_implicit_batch, + int num_batch, OpKernelContext* ctx = nullptr, + const DataVec* input_vec = nullptr); + +// Returns the shape of a binding from TensorRT. +// +// The binding is identified by its binding_index. The batch_size argument is +// ignored if use_implicit_batch==false. The shape is returned in the last +// argument. +Status GetTrtBindingShape(const nvinfer1::ICudaEngine* cuda_engine, + const nvinfer1::IExecutionContext* execution_context, + int binding_index, bool use_implicit_batch, + int batch_size, TensorShape& shape); + +// Defines output buffers for TRT. The buffers are allocated by ctx, if ctx is +// not null. Otherwise it is expected that the outputs DataVec is not null, and +// the Tensors in outputs are already allocated. +Status SetTrtEngineOutputs(nvinfer1::ICudaEngine* cuda_engine, + nvinfer1::IExecutionContext* execution_context, + int trt_profile_idx, std::vector& buffers, + bool use_implicit_batch, int batch_size = 0, + OpKernelContext* ctx = nullptr, + DataVec* outputs = nullptr); + +// Enqueues TensorRT inference job. The batch_size argument is only relevant in +// implicit batch mode. +Status TrtEnqueue(nvinfer1::IExecutionContext* execution_context, + std::vector& buffers, cudaStream_t stream, + bool use_implicit_batch, int batch_size = 1); + +} // namespace tensorrt +} // namespace tensorflow + +#endif // GOOGLE_TENSORRT +#endif // GOOGLE_CUDA + +#endif // TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_ENGINE_UTILS_H_ diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index 1c5867a1312..a5332385994 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -167,6 +167,7 @@ cc_library( ":tf2xla_proto_cc", ":tf2xla_util", ":xla_compiler", + "//tensorflow/compiler/jit", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:compile_mlir_util", "//tensorflow/compiler/mlir/tensorflow:convert_graphdef", diff --git a/tensorflow/compiler/tf2xla/mlir_tf2xla.cc b/tensorflow/compiler/tf2xla/mlir_tf2xla.cc index 43404bc2267..daf261fa5d8 100644 --- a/tensorflow/compiler/tf2xla/mlir_tf2xla.cc +++ b/tensorflow/compiler/tf2xla/mlir_tf2xla.cc @@ -174,7 +174,8 @@ Status ConvertGraphDefToXlaViaMlir( // Convert the MLIR module to XLA computation. If the input graph can't be // lowered down to a single graph node with a single island by the previous // step, this step will return an error. - return ConvertMLIRToXlaComputation(*module, computation, + return ConvertMLIRToXlaComputation(*module, /*device_type=*/"XLA_CPU_JIT", + computation, /*use_tuple_args=*/false, /*always_return_tuple=*/true); } diff --git a/tensorflow/compiler/xla/client/lib/BUILD b/tensorflow/compiler/xla/client/lib/BUILD index 6fcdef46f29..38ddbd5abf7 100644 --- a/tensorflow/compiler/xla/client/lib/BUILD +++ b/tensorflow/compiler/xla/client/lib/BUILD @@ -41,6 +41,7 @@ xla_test( deps = [ ":arithmetic", "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto_cc", diff --git a/tensorflow/compiler/xla/client/lib/arithmetic.cc b/tensorflow/compiler/xla/client/lib/arithmetic.cc index de573429fdc..a24f110fd7a 100644 --- a/tensorflow/compiler/xla/client/lib/arithmetic.cc +++ b/tensorflow/compiler/xla/client/lib/arithmetic.cc @@ -126,7 +126,7 @@ XlaComputation CreateMinMaxComputation(XlaBuilder* outer_builder, XlaOp rhs_index = Parameter(b, 3, ShapeUtil::MakeShape(index_type, {}), "rhs_index"); - auto cmp = is_min ? Lt(lhs_value, rhs_value) : Gt(lhs_value, rhs_value); + auto cmp = is_min ? Le(lhs_value, rhs_value) : Ge(lhs_value, rhs_value); XlaOp max = Select(cmp, lhs_value, rhs_value); XlaOp arg_max = Select(cmp, lhs_index, rhs_index); Tuple(b, {max, arg_max}); @@ -178,36 +178,22 @@ XlaOp ArgMinMaxTwoPass(XlaOp input, PrimitiveType output_type, int axis, reducer = CreateScalarMaxComputation(input_shape.element_type(), builder); } + XlaOp iota = Iota( + builder, ShapeUtil::ChangeElementType(input_shape, output_type), axis); XlaOp input_max = Reduce(input, init_value, reducer, /*dimensions_to_reduce=*/{axis}); std::vector broadcast_dims(input_shape.rank() - 1); std::iota(broadcast_dims.begin(), broadcast_dims.begin() + axis, 0); std::iota(broadcast_dims.begin() + axis, broadcast_dims.end(), axis + 1); - // Compute a mask that has 1s for elements equal to the maximum. - XlaOp partial_mask = - ConvertElementType(Eq(input, input_max, broadcast_dims), output_type); - // In order to make identity elements for a bitwise And, we: - // Left shift the 1 to the leftmost bit, yielding 0x10...0 - // Arithmetic right shift the 1 back to the rightmost bit, yielding - // 0xFF...F - int32 bits_in_type = - ShapeUtil::ByteSizeOfPrimitiveType(output_type) * 8 - 1; - XlaOp shift_amount = ConstantR0WithType(builder, output_type, bits_in_type); - XlaOp full_mask = ShiftRightArithmetic( - ShiftLeft(partial_mask, shift_amount), shift_amount); + XlaOp max_idx = MaxValue(builder, output_type); + XlaOp select_mask = Select(Eq(input, input_max, broadcast_dims), + /*on_true=*/iota, + /*on_false=*/ + max_idx); - // And with the vector [0, 1, 2, ...] to convert each 0xFF...F into its - // index. - - const int64 axis_size = ShapeUtil::GetDimension(input_shape, axis); - XlaOp iota = Iota(builder, output_type, axis_size); - XlaOp product = And(full_mask, iota, /*broadcast_dimensions=*/{axis}); - - // If there are multiple maximum elements, choose the one with the highest - // index. - return Reduce(product, MinValue(builder, output_type), - CreateScalarMaxComputation(output_type, builder), + return Reduce(select_mask, max_idx, + CreateScalarMinComputation(output_type, builder), /*dimensions_to_reduce=*/{axis}); }); } diff --git a/tensorflow/compiler/xla/client/lib/arithmetic_test.cc b/tensorflow/compiler/xla/client/lib/arithmetic_test.cc index a13839f9db8..d3ff14d8a9b 100644 --- a/tensorflow/compiler/xla/client/lib/arithmetic_test.cc +++ b/tensorflow/compiler/xla/client/lib/arithmetic_test.cc @@ -14,8 +14,12 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/client/lib/arithmetic.h" + +#include + #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/test_macros.h" @@ -25,42 +29,65 @@ limitations under the License. namespace xla { namespace { -using ArithmeticTest = ClientLibraryTestBase; +class ArithmeticTest : public ClientLibraryTestBase { + public: + template + void TestArgMin(std::initializer_list> input, + absl::Span expected_output, int axis) { + return TestArgMinMax(input, expected_output, axis, /*is_min=*/true); + } + + template + void TestArgMax(std::initializer_list> input, + absl::Span expected_output, int axis) { + return TestArgMinMax(input, expected_output, axis, /*is_min=*/false); + } + + private: + // Test ArgMin/ArgMax implementation, both single- and two- pass. + template + void TestArgMinMax( + std::initializer_list> input, + absl::Span expected_output, int axis, bool is_min) { + if (is_min) { + TestArgMinMaxImpl(input, expected_output, axis, &ArgMin); + TestArgMinMaxImpl(input, expected_output, axis, &ArgMinTwoPass); + } else { + TestArgMinMaxImpl(input, expected_output, axis, &ArgMax); + TestArgMinMaxImpl(input, expected_output, axis, &ArgMaxTwoPass); + } + } + + template + void TestArgMinMaxImpl( + std::initializer_list> input, + absl::Span expected_output, int axis, + std::function MinMaxImpl) { + XlaBuilder builder(TestName()); + XlaOp x = ConstantR2(&builder, input); + MinMaxImpl(x, primitive_util::NativeToPrimitiveType(), axis); + ComputeAndCompareR1(&builder, expected_output, {}); + } +}; XLA_TEST_F(ArithmeticTest, ArgMinR2Axis0) { - XlaBuilder builder(TestName()); - auto x = ConstantR2(&builder, {{1, 7, 4}, {6, 3, 5}, {8, 3, 3}}); - ArgMin(x, S32, /*axis=*/0); - - std::vector expected = {0, 2, 2}; - ComputeAndCompareR1(&builder, expected, {}); + TestArgMin({{1, 7, 4}, {6, 3, 5}, {8, 3, 3}}, {0, 1, 2}, + /*axis=*/0); } XLA_TEST_F(ArithmeticTest, ArgMinR2Axis1) { - XlaBuilder builder(TestName()); - auto x = ConstantR2(&builder, {{1, 7, 4}, {6, 3, 5}, {8, 3, 3}}); - ArgMin(x, S32, /*axis=*/1); - - std::vector expected = {0, 1, 2}; - ComputeAndCompareR1(&builder, expected, {}); + TestArgMin({{1, 7, 4}, {6, 3, 5}, {8, 3, 3}}, {0, 1, 1}, + /*axis=*/1); } XLA_TEST_F(ArithmeticTest, ArgMaxR2Axis0) { - XlaBuilder builder(TestName()); - auto x = ConstantR2(&builder, {{1, 7, 4}, {6, 3, 5}, {8, 3, 3}}); - ArgMax(x, S32, /*axis=*/0); - - std::vector expected = {2, 0, 1}; - ComputeAndCompareR1(&builder, expected, {}); + TestArgMax({{1, 7, 4}, {6, 3, 5}, {8, 3, 3}}, {2, 0, 1}, + /*axis=*/0); } XLA_TEST_F(ArithmeticTest, ArgMaxR2Axis1) { - XlaBuilder builder(TestName()); - auto x = ConstantR2(&builder, {{1, 7, 4}, {6, 3, 5}, {8, 3, 3}}); - ArgMax(x, S32, /*axis=*/1); - - std::vector expected = {1, 0, 0}; - ComputeAndCompareR1(&builder, expected, {}); + TestArgMax({{1, 7, 4}, {6, 3, 5}, {8, 3, 3}}, {1, 0, 0}, + /*axis=*/1); } } // namespace diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index 888db7536e4..3dd2b935109 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -1568,6 +1568,51 @@ XlaOp XlaBuilder::CustomCall( }); } +XlaOp XlaBuilder::CustomCall( + const string& call_target_name, absl::Span operands, + const XlaComputation& computation, const Shape& shape, const string& opaque, + absl::optional> operand_shapes_with_layout) { + return ReportErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + if (absl::StartsWith(call_target_name, "$")) { + return InvalidArgument( + "Invalid custom_call_target \"%s\": Call targets that start with '$' " + "are reserved for internal use.", + call_target_name); + } + *instr.mutable_shape() = shape.ToProto(); + instr.set_custom_call_target(call_target_name); + instr.set_backend_config(opaque); + if (operand_shapes_with_layout.has_value()) { + if (!LayoutUtil::HasLayout(shape)) { + return InvalidArgument( + "Result shape must have layout for custom call with constrained " + "layout."); + } + if (operands.size() != operand_shapes_with_layout->size()) { + return InvalidArgument( + "Must specify a shape with layout for each operand for custom call " + "with constrained layout; given %d shapes, expected %d", + operand_shapes_with_layout->size(), operands.size()); + } + instr.set_constrain_layout(true); + int64 operand_num = 0; + for (const Shape& operand_shape : *operand_shapes_with_layout) { + if (!LayoutUtil::HasLayout(operand_shape)) { + return InvalidArgument( + "No layout specified for operand %d for custom call with " + "constrained layout.", + operand_num); + } + *instr.add_operand_shapes_with_layout() = operand_shape.ToProto(); + ++operand_num; + } + } + AddCalledComputation(computation, &instr); + return AddInstruction(std::move(instr), HloOpcode::kCustomCall, operands); + }); +} + XlaOp XlaBuilder::Transpose(XlaOp operand, absl::Span permutation) { return ReportErrorOrReturn([&]() -> StatusOr { @@ -3173,6 +3218,16 @@ XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name, /*operand_shapes_with_layout=*/absl::nullopt); } +XlaOp CustomCallWithComputation(XlaBuilder* builder, + const string& call_target_name, + absl::Span operands, + const XlaComputation& computation, + const Shape& shape, const string& opaque) { + return builder->CustomCall(call_target_name, operands, computation, shape, + opaque, + /*operand_shapes_with_layout=*/absl::nullopt); +} + XlaOp CustomCallWithLayout(XlaBuilder* builder, const string& call_target_name, absl::Span operands, const Shape& shape, absl::Span operand_shapes_with_layout, diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h index 75975baba91..15411edb051 100644 --- a/tensorflow/compiler/xla/client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -493,6 +493,12 @@ class XlaBuilder { const Shape& shape_with_layout, const string& opaque, absl::optional> operand_shapes_with_layout); + XlaOp CustomCall( + const string& call_target_name, absl::Span operands, + const XlaComputation& computation, const Shape& shape_with_layout, + const string& opaque, + absl::optional> operand_shapes_with_layout); + XlaOp Reduce(XlaOp operand, XlaOp init_value, const XlaComputation& computation, absl::Span dimensions_to_reduce); @@ -885,6 +891,12 @@ class XlaBuilder { friend XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name, absl::Span operands, const Shape& shape, const string& opaque); + friend XlaOp CustomCallWithComputation(XlaBuilder* builder, + const string& call_target_name, + absl::Span operands, + const XlaComputation& computation, + const Shape& shape, + const string& opaque); friend XlaOp CustomCallWithLayout( XlaBuilder* builder, const string& call_target_name, absl::Span operands, const Shape& shape_with_layout, @@ -1580,6 +1592,13 @@ XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name, absl::Span operands, const Shape& shape, const string& opaque = ""); +// Overload which constructs a custom call that applies an Xla computation. +XlaOp CustomCallWithComputation(XlaBuilder* builder, + const string& call_target_name, + absl::Span operands, + const XlaComputation& computation, + const Shape& shape, const string& opaque = ""); + // Overload which constructs a custom call with fixed layouts. The operands will // have the layouts specified by |operand_shapes_with_layout| when provided to // external code, and the external code is expected to produce a result with the diff --git a/tensorflow/compiler/xla/python/bfloat16.h b/tensorflow/compiler/xla/python/bfloat16.h index 02fd10e04e8..9e52d086919 100644 --- a/tensorflow/compiler/xla/python/bfloat16.h +++ b/tensorflow/compiler/xla/python/bfloat16.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_PYTHON_BFLOAT16_H_ #define TENSORFLOW_COMPILER_XLA_PYTHON_BFLOAT16_H_ -#include "include/pybind11/pybind11.h" +#include "pybind11/pybind11.h" #include "tensorflow/compiler/xla/statusor.h" namespace xla { diff --git a/tensorflow/compiler/xla/python/dlpack.h b/tensorflow/compiler/xla/python/dlpack.h index 09700841ab4..88548ba5b2a 100644 --- a/tensorflow/compiler/xla/python/dlpack.h +++ b/tensorflow/compiler/xla/python/dlpack.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_PYTHON_DLPACK_H_ #define TENSORFLOW_COMPILER_XLA_PYTHON_DLPACK_H_ -#include "include/pybind11/pybind11.h" +#include "pybind11/pybind11.h" #include "tensorflow/compiler/xla/python/local_client.h" namespace xla { diff --git a/tensorflow/compiler/xla/python/local_client.cc b/tensorflow/compiler/xla/python/local_client.cc index dc6e8c5b500..4d7a6335c3f 100644 --- a/tensorflow/compiler/xla/python/local_client.cc +++ b/tensorflow/compiler/xla/python/local_client.cc @@ -1413,6 +1413,24 @@ PyLocalExecutable::Execute(absl::Span argument_handles, RunId(), options); } +StatusOr>> +PyLocalExecutable::ExecuteOnLocalDevice( + absl::Span argument_handles, Device* device, + const ExecuteOptions& options) const { + for (int i = 0; i < local_devices_.size(); ++i) { + if (local_devices_[i] == device) { + VLOG(1) << "Executing computation " << name(); + return ExecuteHelper(argument_handles, + /*replica=*/local_logical_device_ids_[i].first, + /*partition=*/local_logical_device_ids_[i].second, + RunId(), options); + } + } + return InvalidArgument( + "Attempted to execute on device id %d which is not a local device", + device->id()); +} + StatusOr>>> PyLocalExecutable::ExecuteOnLocalDevices( absl::Span> argument_handles, @@ -1435,8 +1453,8 @@ PyLocalExecutable::ExecuteOnLocalDevices( VLOG(1) << "Executing computation " << name() << "; num_replicas=" << num_replicas() - << " num_partitions=" << num_partitions() << " num_local_devices=8" - << num_local_devices; + << " num_partitions=" << num_partitions() + << " num_local_devices=" << num_local_devices; std::vector>>> results( num_local_devices); if (num_local_devices == 1) { diff --git a/tensorflow/compiler/xla/python/local_client.h b/tensorflow/compiler/xla/python/local_client.h index 63786042955..cd7d1be0b0a 100644 --- a/tensorflow/compiler/xla/python/local_client.h +++ b/tensorflow/compiler/xla/python/local_client.h @@ -329,6 +329,9 @@ class PyLocalBuffer { std::shared_ptr buffer_reference, PyLocalClient* client, Device* device); + // Note that literal must remain in scope until the transfer has completed, so + // the caller should, for example, wait for BlockHostUntilReady() completes on + // the return value before letting literal go out of scope. static StatusOr> FromHostLiteral( const LiteralSlice& literal, PyLocalClient* client, Device* device); @@ -558,6 +561,10 @@ class PyLocalExecutable { absl::Span argument_handles, const ExecuteOptions& options) const; + StatusOr>> ExecuteOnLocalDevice( + absl::Span argument_handles, Device* device, + const ExecuteOptions& options) const; + // Execute on local devices. Takes a sequence of argument lists (one argument // list per local device) and returns a tuple of results (one result per local // device). The number of argument lists must be equal to the local device diff --git a/tensorflow/compiler/xla/python/python_ref_manager.h b/tensorflow/compiler/xla/python/python_ref_manager.h index 2c6ea16c7f7..0ad533c695f 100644 --- a/tensorflow/compiler/xla/python/python_ref_manager.h +++ b/tensorflow/compiler/xla/python/python_ref_manager.h @@ -22,7 +22,7 @@ limitations under the License. #include "absl/container/inlined_vector.h" #include "absl/synchronization/mutex.h" #include "absl/types/span.h" -#include "include/pybind11/pybind11.h" +#include "pybind11/pybind11.h" namespace xla { diff --git a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client_extension.cc b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client_extension.cc index 88d17cb8e2a..83a3e5b3db9 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client_extension.cc +++ b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client_extension.cc @@ -15,7 +15,7 @@ limitations under the License. #include -#include "include/pybind11/pybind11.h" +#include "pybind11/pybind11.h" #include "tensorflow/compiler/xla/python/python_ref_manager.h" #include "tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.h" #include "tensorflow/compiler/xla/python/types.h" diff --git a/tensorflow/compiler/xla/python/types.h b/tensorflow/compiler/xla/python/types.h index d555661d49d..7a29f9dca28 100644 --- a/tensorflow/compiler/xla/python/types.h +++ b/tensorflow/compiler/xla/python/types.h @@ -22,9 +22,9 @@ limitations under the License. #include "numpy/arrayobject.h" #include "absl/container/inlined_vector.h" #include "absl/types/optional.h" -#include "include/pybind11/numpy.h" -#include "include/pybind11/pybind11.h" -#include "include/pybind11/stl.h" +#include "pybind11/numpy.h" +#include "pybind11/pybind11.h" +#include "pybind11/stl.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/python/local_client.h" #include "tensorflow/compiler/xla/shape.h" diff --git a/tensorflow/compiler/xla/python/xla.cc b/tensorflow/compiler/xla/python/xla.cc index 5510c4be056..1cdff854f21 100644 --- a/tensorflow/compiler/xla/python/xla.cc +++ b/tensorflow/compiler/xla/python/xla.cc @@ -24,10 +24,10 @@ limitations under the License. #include "absl/synchronization/mutex.h" #include "absl/types/optional.h" #include "absl/types/span.h" -#include "include/pybind11/cast.h" -#include "include/pybind11/numpy.h" -#include "include/pybind11/pybind11.h" -#include "include/pybind11/pytypes.h" +#include "pybind11/cast.h" +#include "pybind11/numpy.h" +#include "pybind11/pybind11.h" +#include "pybind11/pytypes.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/lib/comparators.h" #include "tensorflow/compiler/xla/client/lib/math.h" diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index c8da3d3ccbe..f6055393493 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -706,6 +706,146 @@ Status AlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add) { add->shape(), HloOpcode::kAdd, a, b)), c)); } + + if (options_.is_layout_sensitive()) { + return Status::OK(); + } + + HloInstruction* lhs_scatter_operand = nullptr; + HloInstruction* rhs_scatter_operand = nullptr; + HloInstruction* lhs_scatter_update = nullptr; + HloInstruction* rhs_scatter_update = nullptr; + HloInstruction* lhs_scatter_index = nullptr; + HloInstruction* rhs_scatter_index = nullptr; + bool lhs_scatter = Match(lhs, m::Scatter(m::Op(&lhs_scatter_operand), + m::Op(&lhs_scatter_index), + m::Op(&lhs_scatter_update)) + .WithOneUse()) && + Match(lhs->to_apply()->root_instruction(), + m::Add(m::Parameter(), m::Parameter())); + bool rhs_scatter = Match(rhs, m::Scatter(m::Op(&rhs_scatter_operand), + m::Op(&rhs_scatter_index), + m::Op(&rhs_scatter_update)) + .WithOneUse()) && + Match(rhs->to_apply()->root_instruction(), + m::Add(m::Parameter(), m::Parameter())); + if (rhs_scatter && lhs_scatter) { + const auto& lhs_dnums = lhs->scatter_dimension_numbers(); + const auto& rhs_dnums = rhs->scatter_dimension_numbers(); + absl::optional index_concat_dimension; + absl::optional update_concat_dimension; + // Don't try to combine scatters of different ranks. + if (lhs_scatter_index->shape().rank() != + rhs_scatter_index->shape().rank()) { + return Status::OK(); + } + + int64 first_index_dim = lhs_scatter_index->shape().rank(); + int64 first_update_dim = lhs_scatter_update->shape().rank(); + // Find a dimension where it is possible to concatenate the indices and + // updates. This is the first and only non-equal dimension or the first + // equally sized dimension. + for (int64 d = lhs_scatter_index->shape().rank() - 1, + update_dim = lhs_scatter_update->shape().rank() - 1; + d >= 0; --d) { + if (d == lhs_dnums.index_vector_dim()) { + continue; + } + while ( + absl::c_linear_search(lhs_dnums.update_window_dims(), update_dim)) { + --update_dim; + } + if (lhs_scatter_index->shape().dimensions(d) == + rhs_scatter_index->shape().dimensions(d)) { + first_index_dim = d; + first_update_dim = update_dim--; + continue; + } + // More than one dimension of unequal size was found, bail out. + if (index_concat_dimension) { + return Status::OK(); + } + index_concat_dimension = d; + update_concat_dimension = update_dim--; + } + if (!index_concat_dimension) { + index_concat_dimension = first_index_dim; + update_concat_dimension = first_update_dim; + } + + // A scalar scatter will require additional reshapes of the index and + // update. + if (*index_concat_dimension == lhs_scatter_index->shape().rank()) { + return Status::OK(); + } + const bool update_concat_is_cheap = + ShapeUtil::ElementsIn(rhs_scatter_update->shape()) + + ShapeUtil::ElementsIn(lhs_scatter_update->shape()) < + ShapeUtil::ElementsIn(lhs->shape()); + if (!update_concat_is_cheap) { + return Status::OK(); + } + const bool same_dimension_numbers = + lhs_dnums.index_vector_dim() == rhs_dnums.index_vector_dim() && + absl::c_equal(lhs_dnums.scatter_dims_to_operand_dims(), + rhs_dnums.scatter_dims_to_operand_dims()) && + absl::c_equal(lhs_dnums.inserted_window_dims(), + rhs_dnums.inserted_window_dims()) && + absl::c_equal(lhs_dnums.update_window_dims(), + rhs_dnums.update_window_dims()); + const bool index_concat_is_safe = + !lhs->unique_indices() && !rhs->unique_indices() && + !DynCast(lhs)->indices_are_sorted() && + !DynCast(rhs)->indices_are_sorted(); + + Shape lhs_update_window = ShapeUtil::FilterDimensions( + [&](int64 dim) { + return absl::c_linear_search(lhs_dnums.update_window_dims(), dim); + }, + lhs_scatter_update->shape()); + Shape rhs_update_window = ShapeUtil::FilterDimensions( + [&](int64 dim) { + return absl::c_linear_search(rhs_dnums.update_window_dims(), dim); + }, + rhs_scatter_update->shape()); + // Concatenate the indices and updates + if (index_concat_is_safe && same_dimension_numbers && + index_concat_dimension && + ShapeUtil::SameDimensions(lhs_update_window, rhs_update_window)) { + TF_ASSIGN_OR_RETURN(HloInstruction * new_operand, + MakeBinaryHlo(HloOpcode::kAdd, lhs_scatter_operand, + rhs_scatter_operand)); + TF_ASSIGN_OR_RETURN(HloInstruction * new_index, + MakeConcatHlo({lhs_scatter_index, rhs_scatter_index}, + *index_concat_dimension)); + TF_ASSIGN_OR_RETURN( + HloInstruction * new_update, + MakeConcatHlo({lhs_scatter_update, rhs_scatter_update}, + *update_concat_dimension)); + return ReplaceWithNewInstruction( + add, HloInstruction::CreateScatter( + add->shape(), new_operand, new_index, new_update, + lhs->to_apply(), lhs_dnums, false, false)); + } + TF_ASSIGN_OR_RETURN(HloInstruction * new_operand, + MakeBinaryHlo(HloOpcode::kAdd, lhs_scatter_operand, + rhs_scatter_operand)); + TF_RETURN_IF_ERROR(rhs->ReplaceOperandWith(0, new_operand)); + TF_RETURN_IF_ERROR(lhs->ReplaceOperandWith(0, rhs)); + return ReplaceInstruction(add, lhs); + } else if (rhs_scatter) { + TF_ASSIGN_OR_RETURN( + HloInstruction * new_operand, + MakeBinaryHlo(HloOpcode::kAdd, lhs, rhs_scatter_operand)); + TF_RETURN_IF_ERROR(rhs->ReplaceOperandWith(0, new_operand)); + return ReplaceInstruction(add, rhs); + } else if (lhs_scatter) { + TF_ASSIGN_OR_RETURN( + HloInstruction * new_operand, + MakeBinaryHlo(HloOpcode::kAdd, lhs_scatter_operand, rhs)); + TF_RETURN_IF_ERROR(lhs->ReplaceOperandWith(0, new_operand)); + return ReplaceInstruction(add, lhs); + } return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index 255edf78345..9bbf692b4f9 100755 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -6217,5 +6217,214 @@ TEST_F(AlgebraicSimplifierTest, AbsEliminationPower2) { GmockMatch(m::Multiply(m::Parameter(0), m::Parameter(0)))); } +TEST_F(AlgebraicSimplifierTest, ScatterAddCombined) { + const char* hlo_string = R"( + HloModule m + apply { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT c = f32[] add(a, b) + } + test { + z = f32[] constant(0) + init = f32[100,4] broadcast(z), dimensions={} + shared = f32[100,4] parameter(0) + index0 = s32[20] parameter(1) + index1 = s32[10] parameter(2) + update0 = f32[20,4] parameter(3) + update1 = f32[10,4] parameter(4) + scatter.0 = f32[100,4] scatter(init, index0, update0), + to_apply=apply, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1 + scatter.1 = f32[100,4] scatter(init, index1, update1), + to_apply=apply, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1 + add.0 = f32[100,4] add(shared, scatter.0) + ROOT add.1 = f32[100,4] add(add.0, scatter.1) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string)); + // Combine Scatters + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + // Optimize Add with 0 + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + EXPECT_THAT( + m->entry_computation()->root_instruction(), + GmockMatch(m::Scatter(m::Parameter(0), + m::Concatenate(m::Parameter(1), m::Parameter(2)), + m::Concatenate(m::Parameter(3), m::Parameter(4))))); +} + +TEST_F(AlgebraicSimplifierTest, ScatterAddCombinedSwapped) { + const char* hlo_string = R"( + HloModule m + apply { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT c = f32[] add(a, b) + } + test { + z = f32[] constant(0) + init = f32[100,4] broadcast(z), dimensions={} + shared = f32[100,4] parameter(0) + index0 = s32[20] parameter(1) + index1 = s32[10] parameter(2) + update0 = f32[20,4] parameter(3) + update1 = f32[10,4] parameter(4) + scatter.0 = f32[100,4] scatter(init, index0, update0), + to_apply=apply, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1 + scatter.1 = f32[100,4] scatter(init, index1, update1), + to_apply=apply, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1 + add.0 = f32[100,4] add(shared, scatter.0) + ROOT add.1 = f32[100,4] add(scatter.1, add.0) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string)); + // Combine Scatters + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + // Optimize Add with 0 + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + EXPECT_THAT( + m->entry_computation()->root_instruction(), + GmockMatch(m::Scatter(m::Parameter(0), + m::Concatenate(m::Parameter(2), m::Parameter(1)), + m::Concatenate(m::Parameter(4), m::Parameter(3))))); +} + +TEST_F(AlgebraicSimplifierTest, ScatterAddCombinedWeirdDnums) { + const char* hlo_string = R"( + HloModule m + apply { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT c = f32[] add(a, b) + } + test { + z = f32[] constant(0) + init = f32[100,4] broadcast(z), dimensions={} + shared = f32[100,4] parameter(0) + index0 = s32[1,4,5] parameter(1) + index1 = s32[1,2,5] parameter(2) + update0 = f32[4,4,5] parameter(3) + update1 = f32[2,4,5] parameter(4) + scatter.0 = f32[100,4] scatter(init, index0, update0), + to_apply=apply, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=0 + scatter.1 = f32[100,4] scatter(init, index1, update1), + to_apply=apply, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=0 + ROOT add.1 = f32[100,4] add(scatter.0, scatter.1) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string)); + // Combine Scatters + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + // Simplify Add + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + EXPECT_THAT( + m->entry_computation()->root_instruction(), + GmockMatch(m::Scatter(m::Broadcast(), + m::Concatenate(m::Parameter(1), m::Parameter(2)), + m::Concatenate(m::Parameter(3), m::Parameter(4))))); +} + +TEST_F(AlgebraicSimplifierTest, ScatterAddCombinedWeirdDnums2) { + const char* hlo_string = R"( + HloModule m + apply { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT c = f32[] add(a, b) + } + test { + z = f32[] constant(0) + init = f32[100,4] broadcast(z), dimensions={} + shared = f32[100,4] parameter(0) + index0 = s32[4,3,1] parameter(1) + index1 = s32[4,5,1] parameter(2) + update0 = f32[4,4,3] parameter(3) + update1 = f32[4,4,5] parameter(4) + scatter.0 = f32[100,4] scatter(init, index0, update0), + to_apply=apply, + update_window_dims={0}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=2 + scatter.1 = f32[100,4] scatter(init, index1, update1), + to_apply=apply, + update_window_dims={0}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=2 + ROOT add.1 = f32[100,4] add(scatter.0, scatter.1) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string)); + // Combine Scatters + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + // Simplify Add + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + EXPECT_THAT( + m->entry_computation()->root_instruction(), + GmockMatch(m::Scatter(m::Broadcast(), + m::Concatenate(m::Parameter(1), m::Parameter(2)), + m::Concatenate(m::Parameter(3), m::Parameter(4))))); +} + +TEST_F(AlgebraicSimplifierTest, ScalarScatter) { + const char* hlo_string = R"( + HloModule m + apply { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT c = f32[] add(a, b) + } + test { + z = f32[] constant(0) + init = f32[100,4,20] broadcast(z), dimensions={} + shared = f32[100,4,20] parameter(0) + index0 = s32[1] parameter(1) + index1 = s32[1] parameter(2) + update0 = f32[4,20] parameter(3) + update1 = f32[4,20] parameter(4) + scatter.0 = f32[100,4,20] scatter(init, index0, update0), + to_apply=apply, + update_window_dims={0, 1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=0 + scatter.1 = f32[100,4,20] scatter(init, index1, update1), + to_apply=apply, + update_window_dims={0, 1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=0 + ROOT add.1 = f32[100,4,20] add(scatter.0, scatter.1) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string)); + // Combine Scatters + ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); +} } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc index 3e9daa96150..6cbd33053fc 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment.cc @@ -99,6 +99,9 @@ std::vector ColorInterferenceGraph( bool HloBufferIsReadOnly(const HloBuffer& buffer) { for (const HloValue* value : buffer.values()) { const HloInstruction* instruction = value->instruction(); + if (instruction->opcode() == HloOpcode::kConstant) { + return true; + } const HloModule* module = instruction->parent()->parent(); const bool is_entry_parameter = instruction->opcode() == HloOpcode::kParameter && diff --git a/tensorflow/compiler/xla/service/convolution_group_converter.cc b/tensorflow/compiler/xla/service/convolution_group_converter.cc index ab959cb0087..323bf44dcd3 100644 --- a/tensorflow/compiler/xla/service/convolution_group_converter.cc +++ b/tensorflow/compiler/xla/service/convolution_group_converter.cc @@ -225,10 +225,12 @@ Status ConvolutionVisitor::HandleBatchGroupCount(HloInstruction* convolution) { const int64 kernel_output_feature_dimension = dim_numbers.kernel_output_feature_dimension(); + const int64 input_batch = + activation->shape().dimensions(input_batch_dimension); const int64 output_feature = filter->shape().dimensions(kernel_output_feature_dimension); - if (output_feature != batch_group_count) { + if (output_feature != batch_group_count || input_batch != batch_group_count) { // Insert a spatial dimension to the activation before the input batch // dimension to represent the batch group. std::vector input_sizes(activation->shape().dimensions().begin(), diff --git a/tensorflow/compiler/xla/service/convolution_group_converter_test.cc b/tensorflow/compiler/xla/service/convolution_group_converter_test.cc index fea37130c6d..143e071dc3c 100644 --- a/tensorflow/compiler/xla/service/convolution_group_converter_test.cc +++ b/tensorflow/compiler/xla/service/convolution_group_converter_test.cc @@ -119,5 +119,28 @@ ENTRY %Convolve1D1Window_0.v3 (input: f32[16,19,19,512]{3,2,1,0}, filter: f32[16 EXPECT_EQ(root->operand(0)->opcode(), HloOpcode::kReduceWindow); } +TEST_F(ConvolutionGroupConverterTest, + ConvertBatchGroupCountNotEqualToInputBatchDim) { + string hlo_string = R"(HloModule m + ENTRY main { + %input = f32[1,1,1,4] parameter(0) + %filter = f32[1,1,1,2] parameter(1) + ROOT %convolution = f32[1,1,2,2] convolution(%input,%filter), + window={size=1x1}, dim_labels=f01b_i01o->01fb, batch_group_count=2 + })"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + + auto computation = module->entry_computation(); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kConvolution); + auto cost_model = [](HloInstruction* conv) { return false; }; + ConvolutionGroupConverter converter(cost_model, /*convert_batch_groups_only=*/ + true); + // Make sure that batch group count is rewritten even if + // batch_group_count == output_feature but not input_batch + ASSERT_TRUE(converter.Run(module.get()).ValueOrDie()); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/dynamic_padder.cc b/tensorflow/compiler/xla/service/dynamic_padder.cc index 0b176031e8d..7c6cab6b738 100644 --- a/tensorflow/compiler/xla/service/dynamic_padder.cc +++ b/tensorflow/compiler/xla/service/dynamic_padder.cc @@ -209,38 +209,35 @@ HloInstruction* PadWithScalar(HloInstruction* inst, int64 dim, // [[a,b,P] // [c,d,P]] // -// The way we do this is by a 6-steps double-sorting algorithm: +// The way we do this is by a 5-steps cumsum-gather algorithm: // // 1.First we use the output shape to generate a binary 0-1 masking, which masks // out the padded area of the output: -// [[0,0,1] -// [0,0,1]] +// [[1,1,0] +// [1,1,0]] // // 2.Then we do an inverse reshape to reshape it from output shape back to input // shape [2,3]->[6]: -// [0,0,1,0,0,1] +// [1,1,0,1,1,0] // -// 3.We then generate an iota mask using the input shape: -// [0,1,2,3,4,5] +// 3.We then do a cumsum with the mask: +// [1,2,2,3,4,4] and subtract it with 1: +// [0,1,1,2,3,3] // -// 4.Stable sort the iota mask using the binary mask as key: -// key [0,0,1,0,0,1] -// value[0,1,2,3,4,5] -// | Sort by key +// 4.Use the the result of cumsum as gather indicies to rearrange the original +// data. Feed the original input [a,b,c,d,P,P] and indices into gather. +// +// operand [a,b,c,d,P,P], indices [0,1,1,2,3,3] +// | | +// Gather-----------------+ +// | // v -// key [0,0,0,0,1,1] -// value[0,1,3,4,2,5] +// value[a,b,b,c,d,d], which is equivalent to [a,b,P,c,d,P] as padding value +// doesn't matter. // -// 5.Sort the original input [a,b,c,d,P,P] using the sorted iota mask: -// key [0,1,3,4,2,5] -// value[a,b,c,d,P,P] -// | Sort by key -// v -// key [0,1,2,3,4,5] -// value[a,b,P,c,d,P] // -// 6.Feed the sorted input to original reshape[6]->[2,3], we can get the correct -// reshape: +// 5.Feed the sorted input to original reshape[6]->[2,3], we can now get the +// correct result: // [[a,b,P] // [c,d,P]] // @@ -248,27 +245,37 @@ Status RewriteDynamicReshapeSplitInput( HloInstruction* reshape, int64 input_dim, absl::Span output_dims, DynamicDimensionInference* dynamic_dimension_inference) { + VLOG(1) << "Reshaping input dim " << input_dim << "to " + << VectorString(output_dims); const Shape operand_shape = reshape->operand(0)->shape(); TF_RET_CHECK(output_dims.size() > 1); HloComputation* comp = reshape->parent(); const Shape mask_input_shape = - ShapeUtil::ChangeElementType(operand_shape, xla::S32); + ShapeUtil::MakeShape(xla::S32, {operand_shape.dimensions(input_dim)}); + + std::vector reshaped_dims; + for (int64 output_dim : output_dims) { + reshaped_dims.push_back(reshape->shape().dimensions(output_dim)); + } + const Shape mask_reshaped_shape = - ShapeUtil::ChangeElementType(reshape->shape(), xla::S32); + ShapeUtil::MakeShape(xla::S32, reshaped_dims); HloInstruction* zero = comp->AddInstruction( HloInstruction::CreateConstant(LiteralUtil::Zero(S32))); HloInstruction* one = comp->AddInstruction( HloInstruction::CreateConstant(LiteralUtil::One(S32))); // Step 1 -- generate binary mask. - // Mask starts with all zero, each dynamic dimension sets one dimension of the - // mask to partially one. + // Mask starts with all one, each dynamic dimension sets that dimension of the + // mask to partially zero in the end. HloInstruction* binary_mask = comp->AddInstruction( - HloInstruction::CreateBroadcast(mask_reshaped_shape, zero, {})); + HloInstruction::CreateBroadcast(mask_reshaped_shape, one, {})); bool need_rewrite = false; + // Pad the effective dimension with 1. + // // Index starts from 1 since there is no need to rewrite a major output // dimension. for (int64 i = 1; i < output_dims.size(); ++i) { @@ -278,10 +285,10 @@ Status RewriteDynamicReshapeSplitInput( if (dynamic_size == nullptr) { continue; } - // If there is dynamic dimension in the output, need rewrite the input. + // If there is dynamic dimension in the output, need to rewrite the input. need_rewrite = true; - binary_mask = PadWithScalar(binary_mask, output_dim, dynamic_size, one); + binary_mask = PadWithScalar(binary_mask, i, dynamic_size, zero); } if (!need_rewrite) { return Status::OK(); @@ -292,90 +299,77 @@ Status RewriteDynamicReshapeSplitInput( HloInstruction* input_shape_binary_mask = comp->AddInstruction( HloInstruction::CreateReshape(mask_input_shape, binary_mask)); - // Step 3. Generate iota mask. - HloInstruction* iota_mask = comp->AddInstruction( - HloInstruction::CreateIota(mask_input_shape, input_dim)); - - // Step 4. Sort iota. - // Use binary mark to sort iota mask, then use iota mask to reshape input. - HloComputation::Builder comp_builder("compare_binary_iota"); + // Step 3. Do a cumsum on the binary mask. + auto embedded_builder = HloComputation::Builder("add"); { - HloInstruction* lhs_key = - comp_builder.AddInstruction(HloInstruction::CreateParameter( - 0, ShapeUtil::MakeShape(S32, {}), "lhs_key_binary")); - HloInstruction* rhs_key = - comp_builder.AddInstruction(HloInstruction::CreateParameter( - 1, ShapeUtil::MakeShape(S32, {}), "rhs_key_binary")); - - // Values for lhs and rhs - comp_builder.AddInstruction(HloInstruction::CreateParameter( - 2, ShapeUtil::MakeShape(S32, {}), "lhs_iota")); - comp_builder.AddInstruction(HloInstruction::CreateParameter( - 3, ShapeUtil::MakeShape(S32, {}), "rhs_iota")); - comp_builder.AddInstruction( - HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), lhs_key, - rhs_key, ComparisonDirection::kLt)); + auto lhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(operand_shape.element_type(), {}), "lhs")); + auto rhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter( + 1, ShapeUtil::MakeShape(operand_shape.element_type(), {}), "rhs")); + embedded_builder.AddInstruction( + HloInstruction::CreateBinary(lhs->shape(), HloOpcode::kAdd, lhs, rhs)); } - HloComputation* compare_binary_iota = - comp->parent()->AddEmbeddedComputation(comp_builder.Build()); + HloComputation* add = + reshape->GetModule()->AddEmbeddedComputation(embedded_builder.Build()); + Window cumsum_window; + // First dimension is unchanged. + WindowDimension* dim = cumsum_window.add_dimensions(); + dim->set_size(operand_shape.dimensions(input_dim)); + dim->set_stride(1); + dim->set_padding_low(operand_shape.dimensions(input_dim) - 1); + dim->set_padding_high(0); + dim->set_window_dilation(1); + dim->set_base_dilation(1); + HloInstruction* cumsum = + comp->AddInstruction(HloInstruction::CreateReduceWindow( + mask_input_shape, input_shape_binary_mask, zero, cumsum_window, add)); - HloInstruction* sorted_binary_iota = - comp->AddInstruction(HloInstruction::CreateSort( - ShapeUtil::MakeTupleShape({mask_input_shape, mask_input_shape}), - input_dim, {input_shape_binary_mask, iota_mask}, compare_binary_iota, - /*is_stable=*/true)); - HloInstruction* sorted_iota_mask = - comp->AddInstruction(HloInstruction::CreateGetTupleElement( - mask_input_shape, sorted_binary_iota, 1)); + HloInstruction* broadcast_ones = comp->AddInstruction( + HloInstruction::CreateBroadcast(mask_input_shape, one, {})); + cumsum = comp->AddInstruction(HloInstruction::CreateBinary( + mask_input_shape, HloOpcode::kSubtract, cumsum, broadcast_ones)); - // Step 5. Sort original input using iota mask as key. - HloComputation::Builder comp_builder_iota("compare_binary_iota"); - { - HloInstruction* lhs_key = - comp_builder_iota.AddInstruction(HloInstruction::CreateParameter( - 0, ShapeUtil::MakeShape(S32, {}), "lhs_key_iota")); - HloInstruction* rhs_key = - comp_builder_iota.AddInstruction(HloInstruction::CreateParameter( - 1, ShapeUtil::MakeShape(S32, {}), "rhs_key_iota")); - - // Values for lhs and rhs - comp_builder_iota.AddInstruction(HloInstruction::CreateParameter( - 2, ShapeUtil::MakeShape(operand_shape.element_type(), {}), - "lhs_value")); - comp_builder_iota.AddInstruction(HloInstruction::CreateParameter( - 3, ShapeUtil::MakeShape(operand_shape.element_type(), {}), - "rhs_value")); - comp_builder_iota.AddInstruction( - HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), lhs_key, - rhs_key, ComparisonDirection::kLt)); + GatherDimensionNumbers gather_dim_numbers; + // We use gather to rearrange the input dim dimension. However the current + // semantic of gather doesn't allow us to collapse dimension in this case so + // we keep it, which make the gather from shape [..., input_dim, ...] to + // [..., 1, input_dim, ...] + for (int64 i = 0; i < operand_shape.dimensions_size(); ++i) { + // Offset dim is every dimension including newly added size 1 dim, except + // for input_dim, which acts as a batch_dim. + if (i != input_dim) { + gather_dim_numbers.add_offset_dims(i); + } } + // The dimension to rewrite is the index dim. + gather_dim_numbers.add_start_index_map(input_dim); + gather_dim_numbers.set_index_vector_dim(1); + gather_dim_numbers.add_collapsed_slice_dims(input_dim); - HloComputation* compare_iota_value = - comp->parent()->AddEmbeddedComputation(comp_builder_iota.Build()); + // Step 4. Gather. - // Temporarily removes dynamic dimension before entering sort -- we want the - // sort to ignore dynamic dimension. + // Temporarily removes dynamic dimension before entering gather -- we want the + // gather to ignore dynamic dimension. HloInstruction* operand_static_dim_size = comp->AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR0(operand_shape.dimensions(input_dim)))); - HloInstruction* operand_static = comp->AddInstruction(HloInstruction::CreateSetDimensionSize( operand_shape, reshape->mutable_operand(0), operand_static_dim_size, input_dim)); - HloInstruction* sorted_iota_value = - comp->AddInstruction(HloInstruction::CreateSort( - ShapeUtil::MakeTupleShape({mask_input_shape, operand_shape}), - input_dim, {sorted_iota_mask, operand_static}, compare_iota_value, - /*is_stable=*/true)); - // Step 6: Feed sorted input to original reshape. - HloInstruction* sorted_operand = - comp->AddInstruction(HloInstruction::CreateGetTupleElement( - operand_shape, sorted_iota_value, 1)); + std::vector slice_sizes(operand_shape.dimensions().begin(), + operand_shape.dimensions().end()); + slice_sizes[input_dim] = 1; + HloInstruction* gather = comp->AddInstruction(HloInstruction::CreateGather( + ShapeUtil::MakeShape(operand_shape.element_type(), + operand_shape.dimensions()), + operand_static, cumsum, gather_dim_numbers, slice_sizes, true)); - TF_RETURN_IF_ERROR(reshape->ReplaceOperandWith(0, sorted_operand)); + // Step 6: Feed gather input to original reshape. + + TF_RETURN_IF_ERROR(reshape->ReplaceOperandWith(0, gather)); HloInstruction* reshape_dynamic = reshape; diff --git a/tensorflow/compiler/xla/service/dynamic_padder_test.cc b/tensorflow/compiler/xla/service/dynamic_padder_test.cc index e669bc4dbe2..d3b68266b4f 100644 --- a/tensorflow/compiler/xla/service/dynamic_padder_test.cc +++ b/tensorflow/compiler/xla/service/dynamic_padder_test.cc @@ -682,6 +682,98 @@ ENTRY main { EXPECT_EQ(result, expected); } +XLA_TEST_F(ExecutionTest, OutputMinorDimensionReshapeWithUnchangedDimMajor) { + const string hlo_text = R"( +HloModule TensorFlowScatterV1 + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + rhs = s32[] parameter(1) + ROOT add = s32[] add(lhs, rhs) +} + +ENTRY main { + param = s32[2, 6] parameter(0) + const = s32[] constant(4) + param_padded = s32[2, 6] set-dimension-size(param, const), dimensions={1} + // Third dimension is dynamic. + reshaped = s32[2, 2, 3] reshape(param_padded), inferred_dimension=2 + init = s32[] constant(0) + ROOT reduce = s32[2, 2] reduce(reshaped, init), + dimensions={2}, + to_apply=update_s32 +} +)"; + + // The third dimension has upper bound of 5, dynamic dimension is 3. + Literal operand = + LiteralUtil::CreateR2({{0, 1, 2, 3, 4, 5}, {6, 7, 8, 9, 10, 11}}); + auto module = GetHloModule(hlo_text); + + Literal result = PadAndExecute(std::move(module), {&operand}); + + // After padding and reshape we have + // + // [[[0, 1, P], + // [2, 3, P]], + // [[6, 7, P], + // [8, 9, P]]] + // Reducing on the third dimension gives us + // [0+1, 2+3] + // [6+7, 8+9] + // + Literal expected = LiteralUtil::CreateR2({{1, 5}, {13, 17}}); + + EXPECT_EQ(result, expected); +} + +XLA_TEST_F(ExecutionTest, OutputMinorDimensionReshapeWithUnchangedDimMinor) { + const string hlo_text = R"( +HloModule TensorFlowScatterV1 + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + rhs = s32[] parameter(1) + ROOT add = s32[] add(lhs, rhs) +} + +ENTRY main { + param = s32[6, 2] parameter(0) + const = s32[] constant(4) + param_padded = s32[6, 2] set-dimension-size(param, const), dimensions={0} + // Second dimension is dynamic. + reshaped = s32[2, 3, 2] reshape(param_padded), inferred_dimension=1 + init = s32[] constant(0) + ROOT reduce = s32[2, 2] reduce(reshaped, init), + dimensions={1}, + to_apply=update_s32 +} +)"; + + // The third dimension has upper bound of 5, dynamic dimension is 3. + Literal operand = LiteralUtil::CreateR2( + {{0, 1}, {2, 3}, {4, 5}, {6, 7}, {8, 9}, {10, 11}}); + auto module = GetHloModule(hlo_text); + + Literal result = PadAndExecute(std::move(module), {&operand}); + + // After padding and reshape we have + // + // [[[0, 1], + // [2, 3] + // [P, P]], + // [[4, 5], + // [6, 7], + // [P, P]]] + // Reducing on the second dimension gives us + // [0+2, 1+3] + // [4+6, 5+7] + // + Literal expected = LiteralUtil::CreateR2({{2, 4}, {10, 12}}); + + EXPECT_EQ(result, expected); +} + XLA_TEST_F(ExecutionTest, DynamicDimensionReshapeUnchanged) { const string hlo_text = R"( HloModule TensorFlowScatterV1 diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index f73a15e1533..767c34b3a99 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -152,9 +152,11 @@ Status GpuCompiler::OptimizeHloModule( pipeline.AddPass(); - auto cost_model = [](HloInstruction*) { - // We need a cost model for GPUs. Currently, do nothing. - return false; + auto cost_model = [](HloInstruction* conv) { + auto operand = conv->operand(0); + return operand->shape().dimensions(conv->convolution_dimension_numbers() + .input_batch_dimension()) == + conv->batch_group_count(); }; pipeline.AddPass(cost_model); diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc index 6a0b9e5dfb8..4894e566393 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc @@ -1150,6 +1150,9 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser( HloInstruction* user, const ShapeIndex& user_index) const { CHECK(user->IsUserOf(operand)) << "user: " << user->ToString() << " operand: " << operand->ToString(); + if (operand->opcode() == HloOpcode::kConstant) { + return false; + } const Shape& operand_subshape = ShapeUtil::GetSubshape(operand->shape(), operand_index); const Shape& user_subshape = diff --git a/tensorflow/compiler/xla/service/hlo_matchers.h b/tensorflow/compiler/xla/service/hlo_matchers.h index 8b0f2db13bb..ec048bef9e8 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers.h +++ b/tensorflow/compiler/xla/service/hlo_matchers.h @@ -231,6 +231,7 @@ HLO_MATCHER(Fft); HLO_MATCHER(Floor); HLO_MATCHER(Fusion); HLO_MATCHER(Gather); +HLO_MATCHER(GetDimensionSize); HLO_MATCHER(Infeed); HLO_MATCHER(Iota); HLO_MATCHER(IsFinite); @@ -261,6 +262,7 @@ HLO_MATCHER(Select); HLO_MATCHER(SelectAndScatter); HLO_MATCHER(Send); HLO_MATCHER(SendDone); +HLO_MATCHER(SetDimensionSize); HLO_MATCHER(ShiftLeft); HLO_MATCHER(ShiftRightArithmetic); HLO_MATCHER(ShiftRightLogical); diff --git a/tensorflow/compiler/xla/service/mlir_gpu/inject_errors_pass.cc b/tensorflow/compiler/xla/service/mlir_gpu/inject_errors_pass.cc index bd64c18680c..7445ab5221a 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/inject_errors_pass.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/inject_errors_pass.cc @@ -21,7 +21,7 @@ namespace mlir { namespace { struct InjectErrorsForTestingPass - : public FunctionPass { + : public PassWrapper { void runOnFunction() override { getFunction().getBody().walk([&](Operation *op) { op->emitError() << "failed for testing: " << op->getName(); @@ -31,7 +31,7 @@ struct InjectErrorsForTestingPass } // namespace -std::unique_ptr> createInjectErrorsForTestingPass() { +std::unique_ptr> createInjectErrorsForTestingPass() { return std::make_unique(); } diff --git a/tensorflow/compiler/xla/service/mlir_gpu/inject_errors_pass.h b/tensorflow/compiler/xla/service/mlir_gpu/inject_errors_pass.h index dd19fbe35cb..9f0612c8868 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/inject_errors_pass.h +++ b/tensorflow/compiler/xla/service/mlir_gpu/inject_errors_pass.h @@ -22,7 +22,7 @@ namespace mlir { // Returns a function pass that emits errors from all operations inside the // function. -std::unique_ptr> createInjectErrorsForTestingPass(); +std::unique_ptr> createInjectErrorsForTestingPass(); } // namespace mlir diff --git a/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc b/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc index 7482c9df1fe..5b684c075bb 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc @@ -63,7 +63,8 @@ using ::mlir::xla_lhlo::FusionOp; // This is needed, as these ops are not closed from above and hence nested pass // managers can not be applied. struct NestedHloRegionsConverter - : public mlir::FunctionPass { + : public mlir::PassWrapper { void runOnFunction() override { auto& ctx = getContext(); mlir::OwningRewritePatternList patterns; @@ -83,7 +84,8 @@ struct NestedHloRegionsConverter }; // Replaces a FusionOp by the operations contained in its region. -struct FusionOpRemover : public mlir::FunctionPass { +struct FusionOpRemover + : public mlir::PassWrapper { void runOnFunction() override { getFunction().walk([&](FusionOp op) { mlir::OpBuilder builder(op); @@ -105,7 +107,7 @@ struct FusionOpRemover : public mlir::FunctionPass { // Rewrite the single-trip loops we get out of linalg into just their bodies. // TODO(herhut): Make this a general pattern. struct SingleTripLoopRemoval - : public mlir::FunctionPass { + : public mlir::PassWrapper { void runOnFunction() override { auto getConstantValue = [](mlir::Value value) -> llvm::Optional { auto definingOp = value.getDefiningOp(); @@ -142,7 +144,8 @@ struct SingleTripLoopRemoval // Simple pass that replaces a load that immediately follows a store to the // same address with the stored value. This needs generalization. -struct StoreForwardingPass : mlir::FunctionPass { +struct StoreForwardingPass + : mlir::PassWrapper { void runOnFunction() override { llvm::DenseMap memrefToAllocOp; @@ -208,7 +211,8 @@ struct StoreForwardingPass : mlir::FunctionPass { // never read from or that are read but the read value is not used. // Needs an analysis that proves that loads and stores are side-effect free // (in bounds, no aliasing, etc.). -struct DeadTempBufferRemoval : mlir::FunctionPass { +struct DeadTempBufferRemoval + : mlir::PassWrapper { bool operationConsideredDead(mlir::Operation* op) { for (auto result : op->getResults()) { if (!llvm::all_of(result.getUsers(), [&](mlir::Operation* op) { @@ -325,7 +329,8 @@ namespace { /// A pass that does the final lowering to NVVM. It collects all the patterns /// that are currently required, currently mixing std, linalg and gpu. class LowerToNVVMPass - : public ::mlir::OperationPass { + : public ::mlir::PassWrapper< + LowerToNVVMPass, ::mlir::OperationPass<::mlir::gpu::GPUModuleOp>> { public: void runOnOperation() override { ::mlir::gpu::GPUModuleOp m = getOperation(); diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/cos.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/cos.hlo index 9abc2dad0aa..e10b8e72f34 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/cos.hlo +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/cos.hlo @@ -5,5 +5,5 @@ ENTRY %Cos (val: f32[2,2]) -> f32[2,2] { } // CHECK: func @cosine(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]]) { -// CHECK: "xla_lhlo.cos"(%[[ARG0]], %[[ARG1]]) : ([[TYPE]], [[TYPE]]) -> () +// CHECK: "xla_lhlo.cosine"(%[[ARG0]], %[[ARG1]]) : ([[TYPE]], [[TYPE]]) -> () // CHECK: } diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/exp.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/exp.hlo index 9af0de99d42..5eec5d98b22 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/exp.hlo +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/exp.hlo @@ -6,6 +6,6 @@ ENTRY %Exp (x: f32[2,2]) -> f32[2,2] { } // CHECK: func @exponential(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]]) { -// CHECK: "xla_lhlo.exp"(%[[ARG0]], %[[ARG1]]) : ([[TYPE]], [[TYPE]]) -> () +// CHECK: "xla_lhlo.exponential"(%[[ARG0]], %[[ARG1]]) : ([[TYPE]], [[TYPE]]) -> () // CHECK: } diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/neg.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/neg.hlo index caead37c995..e0b42c4da12 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/neg.hlo +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/neg.hlo @@ -5,5 +5,5 @@ ENTRY %Neg (val: f32[2,2]) -> f32[2,2] { } // CHECK: func @negate(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]]) { -// CHECK: "xla_lhlo.neg"(%[[ARG0]], %[[ARG1]]) : ([[TYPE]], [[TYPE]]) -> () +// CHECK: "xla_lhlo.negate"(%[[ARG0]], %[[ARG1]]) : ([[TYPE]], [[TYPE]]) -> () // CHECK: } diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index 4f9e3a4d083..2d692183338 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -1348,6 +1348,7 @@ xla_test( srcs = ["dynamic_ops_test.cc"], deps = [ ":client_library_test_base", + ":hlo_test_base", ":literal_test_util", ":test_macros_header", ":xla_internal_test_main", diff --git a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc index 555dfc48d9e..0974d37779e 100644 --- a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc +++ b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/transfer_manager.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -731,6 +732,24 @@ XLA_TEST_F(DynamicUpdateSliceTest, DISABLED_ON_GPU(R3ContiguousLargerBF16)) { RunR3Contiguous(operand_shape, /*index=*/7, /*size=*/1); } +// This test that buffer assignment does not alias constants with the output of +// dynamic update slice. +XLA_TEST_F(HloTestBase, AddOfDUS) { + const char* hlo_string = R"( + HloModule m + test { + o = s32[6] constant({2,3,4,5,6,7}) + i = s32[] parameter(0) + u = s32[2] parameter(1) + dus = s32[6] dynamic-update-slice(o,u,i) + a = s32[6] add(dus, dus) + j = s32[] parameter(2) + ROOT ds = s32[2] dynamic-slice(a, j), dynamic_slice_sizes={2} + } + )"; + EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0, 0})); +} + void BM_DynamicSlice(int num_iters) { tensorflow::testing::StopTiming(); diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index b309a1f2e24..c36664c70fc 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -238,6 +238,7 @@ FRAMEWORK_PROTO_SRCS = [ PROFILER_PROTO_SRCS = [ "//tensorflow/core/profiler/protobuf:xplane.proto", + "//tensorflow/core/profiler:profiler_options.proto", ] ERROR_CODES_PROTO_SRCS = [ @@ -2162,6 +2163,7 @@ tf_proto_library( "//tensorflow/core/framework:protos_all", "//tensorflow/core/lib/core:error_codes_proto", "//tensorflow/core/profiler/protobuf:xplane_proto", + "//tensorflow/core/profiler:profiler_options_proto", "//tensorflow/core/util:protos_all", "//tensorflow/core/util:test_log_proto_impl", ], diff --git a/tensorflow/core/common_runtime/BUILD b/tensorflow/core/common_runtime/BUILD index b64e14212c9..bf5b1cfaba3 100644 --- a/tensorflow/core/common_runtime/BUILD +++ b/tensorflow/core/common_runtime/BUILD @@ -429,6 +429,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":shared_counter", + "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", diff --git a/tensorflow/core/common_runtime/base_collective_executor.cc b/tensorflow/core/common_runtime/base_collective_executor.cc index 8d37b1889bf..de2dc28c979 100644 --- a/tensorflow/core/common_runtime/base_collective_executor.cc +++ b/tensorflow/core/common_runtime/base_collective_executor.cc @@ -142,8 +142,8 @@ class CollectiveAdapterImpl : public CollectiveAdapter { Tensor TempChunk(int i) const override { AllocationAttributes empty; - auto op_annotation = - ScopedMemoryDebugAnnotation("CollectiveAdapterImpl::TempChunk", 0); + ScopedMemoryDebugAnnotation op_annotation( + "CollectiveAdapterImpl::TempChunk"); return Tensor(allocator_, dt_, {ChunkElts(i)}, empty); } diff --git a/tensorflow/core/common_runtime/bfc_allocator.cc b/tensorflow/core/common_runtime/bfc_allocator.cc index 5ed9a3e67b7..f65cfcf8f00 100644 --- a/tensorflow/core/common_runtime/bfc_allocator.cc +++ b/tensorflow/core/common_runtime/bfc_allocator.cc @@ -29,6 +29,7 @@ limitations under the License. #ifdef TENSORFLOW_MEM_DEBUG #include "tensorflow/core/platform/stacktrace.h" #endif +#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/profiler/lib/traceme.h" #include "tensorflow/core/protobuf/bfc_memory_map.pb.h" @@ -427,11 +428,13 @@ void* BFCAllocator::AllocateRawInternal(size_t unused_alignment, // Dump the memory log for analysis. MaybeWriteMemoryMap(); if (dump_log_on_failure) { - LOG(WARNING) << "Allocator (" << Name() << ") ran out of memory trying " - << "to allocate " << strings::HumanReadableNumBytes(num_bytes) - << " (rounded to " << rounded_bytes << ")" - << "requested by op " << pending_op_name - << "\nCurrent allocation summary follows."; + LOG(WARNING) + << "Allocator (" << Name() << ") ran out of memory trying " + << "to allocate " << strings::HumanReadableNumBytes(num_bytes) + << " (rounded to " << rounded_bytes << ")" + << "requested by op " + << ScopedMemoryDebugAnnotation::CurrentAnnotation().pending_op_name + << "\nCurrent allocation summary follows."; DumpMemoryLog(rounded_bytes); LOG(WARNING) << RenderOccupancy(); } @@ -453,6 +456,11 @@ void BFCAllocator::AddTraceMe(absl::string_view traceme_name, memory_limit_ - stats.bytes_reserved - stats.bytes_in_use; BFCAllocator::Chunk* chunk = ChunkFromHandle(region_manager_.get_handle(chunk_ptr)); + const auto& annotation = + ScopedMemoryDebugAnnotation::CurrentAnnotation(); + std::string tensor_shape = annotation.pending_shape + ? annotation.pending_shape->DebugString() + : ""; return absl::StrCat(traceme_name, "#allocator_name=", name_, ",bytes_reserved=", stats.bytes_reserved, @@ -462,8 +470,11 @@ void BFCAllocator::AddTraceMe(absl::string_view traceme_name, ",requested_bytes=", chunk->requested_size, ",allocation_bytes=", chunk->size, ",addr=", reinterpret_cast(chunk_ptr), - ",tf_op=", pending_op_name, ",id=", pending_step_id, - "#"); + ",tf_op=", annotation.pending_op_name, + ",id=", annotation.pending_step_id, + ",region_type=", annotation.pending_region_type, + ",data_type=", annotation.pending_data_type, + ",shape=", tensor_shape, "#"); }, traceme_level); } @@ -516,17 +527,20 @@ void* BFCAllocator::FindChunkPtr(BinNum bin_num, size_t rounded_bytes, #ifdef TENSORFLOW_MEM_DEBUG if (ShouldRecordOpName()) { - if (pending_op_name != nullptr) { - chunk->op_name = pending_op_name; + const auto& annotation = + ScopedMemoryDebugAnnotation::CurrentAnnotation(); + if (annotation.pending_op_name != nullptr) { + chunk->op_name = annotation.pending_op_name; } else { LOG(INFO) << "missing pending_op_name for " << Name() << " reading addr " - << static_cast(&pending_op_name) << "\n" + << static_cast(&annotation.pending_op_name) + << "\n" << CurrentStackTrace(); chunk->op_name = nullptr; } chunk->action_count = ++action_counter_; - chunk->step_id = pending_step_id; + chunk->step_id = annotation.pending_step_id; int slot = chunk->action_count % MEM_DEBUG_SIZE_HISTORY_SIZE; size_history_[slot] = stats_.bytes_in_use; } diff --git a/tensorflow/core/common_runtime/eager/context.h b/tensorflow/core/common_runtime/eager/context.h index 877d8072008..1670345efd5 100644 --- a/tensorflow/core/common_runtime/eager/context.h +++ b/tensorflow/core/common_runtime/eager/context.h @@ -125,6 +125,11 @@ class CustomDevice { int* num_retvals) = 0; }; +// Custom devices do many of the same things as physical Devices, but have a +// much more restricted interface. We pass around ambiguous pointers since +// TensorHandles may be placed either on custom or physical devices. +using VariantDevice = absl::variant; + class EagerContext : public AbstractContextInterface, public core::RefCounted { public: static const uint64 kInvalidContextId = 0; diff --git a/tensorflow/core/common_runtime/eager/copy_to_device_node.h b/tensorflow/core/common_runtime/eager/copy_to_device_node.h index 793513c5c5f..b3f38c2843e 100644 --- a/tensorflow/core/common_runtime/eager/copy_to_device_node.h +++ b/tensorflow/core/common_runtime/eager/copy_to_device_node.h @@ -50,8 +50,8 @@ class CopyToDeviceNode : public EagerNode { Status Run() override { tensorflow::Tensor tensor; - auto op_annotation = ScopedMemoryDebugAnnotation( - pending_op_name ? pending_op_name : "eager::CopyToDeviceNode"); + ScopedMemoryDebugAnnotation op_annotation( + "eager::CopyToDeviceNode", "dynamic", tensor.dtype(), &tensor.shape()); TF_RETURN_IF_ERROR(src_->CopyToDevice(ctx_, dstd_, &tensor)); if (!async_ && mirror_) { return dst_->AddLocalMirror(std::move(tensor), dstd_); diff --git a/tensorflow/core/common_runtime/eager/core.cc b/tensorflow/core/common_runtime/eager/core.cc index de7e7475a1c..cfb188bdd77 100644 --- a/tensorflow/core/common_runtime/eager/core.cc +++ b/tensorflow/core/common_runtime/eager/core.cc @@ -21,8 +21,7 @@ limitations under the License. namespace { -bool IsCPU( - absl::variant variant) { +bool IsCPU(tensorflow::VariantDevice variant) { if (VariantDeviceIsCustom(variant)) { return false; } diff --git a/tensorflow/core/common_runtime/eager/eager_operation.cc b/tensorflow/core/common_runtime/eager/eager_operation.cc index 3804f5164d4..d3a31278326 100644 --- a/tensorflow/core/common_runtime/eager/eager_operation.cc +++ b/tensorflow/core/common_runtime/eager/eager_operation.cc @@ -37,7 +37,7 @@ void EagerOperation::Clear() { } const string& EagerOperation::DeviceName() const { - absl::variant variant_device = + VariantDevice variant_device = (Device() == kVariantDeviceNull) ? EagerContext().HostCPU() : Device(); return absl::visit([](auto* d) -> const string& { return d->name(); }, variant_device); diff --git a/tensorflow/core/common_runtime/eager/eager_operation.h b/tensorflow/core/common_runtime/eager/eager_operation.h index 550881c571b..d1128977ace 100644 --- a/tensorflow/core/common_runtime/eager/eager_operation.h +++ b/tensorflow/core/common_runtime/eager/eager_operation.h @@ -119,9 +119,7 @@ class EagerOperation : public AbstractOperationInterface { // Like TensorHandles, EagerOperations may be placed either on a virtual // CustomDevice or on a physical Device. - absl::variant Device() const { - return device_; - } + VariantDevice Device() const { return device_; } void SetDevice(tensorflow::Device* device) { device_ = device; @@ -185,7 +183,7 @@ class EagerOperation : public AbstractOperationInterface { AttrBuilder attrs_; const AttrTypeMap* attr_types_; absl::InlinedVector inputs_; - absl::variant device_; + VariantDevice device_; string raw_device_name_; string device_name_; DeviceNameUtils::ParsedName device_parsed_name_; diff --git a/tensorflow/core/common_runtime/eager/execute.cc b/tensorflow/core/common_runtime/eager/execute.cc index f1c90119bda..8c602b0f498 100644 --- a/tensorflow/core/common_runtime/eager/execute.cc +++ b/tensorflow/core/common_runtime/eager/execute.cc @@ -80,8 +80,7 @@ const string& DeviceNameOrUnspecified(Device* device) { return (device == nullptr) ? *unspecified_string : device->name(); } -const string& DeviceNameOrUnspecified( - absl::variant device) { +const string& DeviceNameOrUnspecified(VariantDevice device) { if (VariantDeviceIsCustom(device)) { return absl::get(device)->name(); } else { @@ -374,7 +373,10 @@ Status MustCompileWithXLA(const EagerOperation* op, const EagerContext& ctx, // running without an explicitly requested device. Status EagerLocalExecute(EagerOperation* op, TensorHandle** retvals, int* num_retvals) { - auto op_annotation = ScopedMemoryDebugAnnotation(op->op_name()); + ScopedMemoryDebugAnnotation op_annotation( + op->op_name(), op->remote_func_params().has_value() + ? op->remote_func_params().value().step_id.value_or(0) + : 0); profiler::TraceMe activity( [&] { return absl::StrCat("EagerLocalExecute: ", op->Name()); }, profiler::TraceMeLevel::kInfo); diff --git a/tensorflow/core/common_runtime/eager/execute_node.cc b/tensorflow/core/common_runtime/eager/execute_node.cc index 5ced006fb9e..f2528081877 100644 --- a/tensorflow/core/common_runtime/eager/execute_node.cc +++ b/tensorflow/core/common_runtime/eager/execute_node.cc @@ -53,8 +53,7 @@ Status ExecuteNodeArgs::Init( serialize_remote_handle_ = [ctx, &op_inputs](const int i, eager::RemoteTensorHandle* handle) -> Status { - absl::variant variant_device = - op_inputs[i]->device(); + VariantDevice variant_device = op_inputs[i]->device(); if (VariantDeviceIsCustom(variant_device)) { return errors::Internal( "Custom devices and remote execution are currently not supported " diff --git a/tensorflow/core/common_runtime/eager/tensor_handle.cc b/tensorflow/core/common_runtime/eager/tensor_handle.cc index 2cbb978b5ee..858d0a338ae 100644 --- a/tensorflow/core/common_runtime/eager/tensor_handle.cc +++ b/tensorflow/core/common_runtime/eager/tensor_handle.cc @@ -326,8 +326,7 @@ Status TensorHandle::TensorValue(const Device* d, tensorflow::TensorValue* t) { return mirror.TensorValue(t); } -TensorHandle::VariantDevice TensorHandle::DeviceOrHostCPU( - const EagerContext& ctx) const { +VariantDevice TensorHandle::DeviceOrHostCPU(const EagerContext& ctx) const { if (VariantDeviceIsCustom(device_)) { return device_; } else { @@ -788,16 +787,15 @@ Status TensorHandle::CopyToDevice(const EagerContext& ctx, return status; } -bool VariantDeviceIsCustom( - absl::variant variant_device) { +bool VariantDeviceIsCustom(VariantDevice variant_device) { return variant_device.index() != 0; } -string VariantDeviceName(absl::variant device) { +string VariantDeviceName(VariantDevice device) { return absl::visit([](auto* device) { return device->name(); }, device); } -string VariantDeviceDebugString(absl::variant device) { +string VariantDeviceDebugString(VariantDevice device) { if (device == kVariantDeviceNull) { return "[]"; } else if (VariantDeviceIsCustom(device)) { diff --git a/tensorflow/core/common_runtime/eager/tensor_handle.h b/tensorflow/core/common_runtime/eager/tensor_handle.h index 9309b4fcccd..0b39161af73 100644 --- a/tensorflow/core/common_runtime/eager/tensor_handle.h +++ b/tensorflow/core/common_runtime/eager/tensor_handle.h @@ -55,11 +55,6 @@ class EagerContext; // (unrelated to python TensorHandle). class TensorHandle : public AbstractTensorHandleInterface, public core::RefCounted { - // Custom devices do many of the same things as physical Devices, but have a - // much more restricted interface. We pass around ambiguous pointers since - // TensorHandles may be placed either on custom or physical devices. - using VariantDevice = absl::variant; - // TensorHandle for dtype != DT_RESOURCE TensorHandle(tensorflow::Tensor&& t, Device* d, Device* op_device, Device* resource_device, EagerContext* ctx); @@ -291,18 +286,17 @@ class TensorHandle : public AbstractTensorHandleInterface, }; // Checks whether a VariantDevice contains a custom device. -bool VariantDeviceIsCustom(absl::variant device); +bool VariantDeviceIsCustom(VariantDevice device); // Wraps device->name() or CustomDevice->name(). -string VariantDeviceName(absl::variant device); +string VariantDeviceName(VariantDevice device); // Wraps device->DebugString() or CustomDevice->name(). -string VariantDeviceDebugString(absl::variant device); +string VariantDeviceDebugString(VariantDevice device); // Indicates either HostCPU or an unset physical device. We never set a null // CustomDevice*. -const absl::variant kVariantDeviceNull = - static_cast(nullptr); +const VariantDevice kVariantDeviceNull = static_cast(nullptr); // Returns the device backing the resource. Else, returns nullptr. Device* GetResourceDevice(const ResourceHandle& handle, EagerContext* ctx); diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc index f1177e8cba4..e744cdf1c4d 100644 --- a/tensorflow/core/common_runtime/executor.cc +++ b/tensorflow/core/common_runtime/executor.cc @@ -375,10 +375,6 @@ class ExecutorState { mutex mu_; Status status_ TF_GUARDED_BY(mu_); - - // A flag that is set on error after the propagator state has been - // dumped for diagnostic purposes. - bool dumped_on_error_ TF_GUARDED_BY(mu_) = false; }; template @@ -932,11 +928,6 @@ Status ExecutorState::ProcessOutputs( // add better optional debugging support. if (vlog_ && VLOG_IS_ON(1)) { LOG(WARNING) << this << " Compute status: " << s; - mutex_lock l(mu_); - if (!dumped_on_error_) { - propagator_.DumpState(); - dumped_on_error_ = true; - } } if (s.code() == error::RESOURCE_EXHAUSTED) { if (stats_collector_) { @@ -1189,6 +1180,12 @@ void ExecutorState::Finish() { CHECK(done_cb != nullptr); Device* device = immutable_state_.params().device; + if (vlog_ && !status.ok() && VLOG_IS_ON(1)) { + // Logs verbose information about the current state of active and pending + // nodes in the propagator. + propagator_.DumpState(); + } + // There are several potential race conditions below. To name a few: // 1. Even if the device's status is OK at the precise moment when // num_deferred_ops_ reaches 0, it could go bad before device->RefreshStatus() diff --git a/tensorflow/core/common_runtime/gpu/gpu_device.cc b/tensorflow/core/common_runtime/gpu/gpu_device.cc index da6a2eadea2..00e237ad253 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_device.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_device.cc @@ -340,7 +340,7 @@ Status BaseGPUDevice::InitScratchBuffers() { if (!scratch_) { DCHECK(stream_); size_t scratch_buffer_size = Eigen::kGpuScratchSize + sizeof(unsigned int); - auto op_annotation = ScopedMemoryDebugAnnotation("ScratchBuffer"); + ScopedMemoryDebugAnnotation op_annotation("ScratchBuffer"); void* scratch_buffer = gpu_allocator_->AllocateRaw( Allocator::kAllocatorAlignment, scratch_buffer_size); if (scratch_buffer == nullptr) { @@ -498,8 +498,8 @@ void BaseGPUDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) { } } ScopedActivateExecutorContext scoped_activation{stream->parent()}; - auto op_annotation = ScopedMemoryDebugAnnotation( - op_kernel->name_view().data(), context->step_id()); + ScopedMemoryDebugAnnotation op_annotation(op_kernel->name_view().data(), + context->step_id()); op_kernel->Compute(context); if (context->status().ok()) { if (sync_every_op_) { @@ -612,8 +612,6 @@ Status BaseGPUDevice::MaybeCopyTensorToGPU( Status BaseGPUDevice::MakeTensorFromProto(const TensorProto& tensor_proto, const AllocatorAttributes alloc_attrs, Tensor* tensor) { - auto op_annotation = ScopedMemoryDebugAnnotation( - (pending_op_name != nullptr ? pending_op_name : "MakeTensorFromProto")); AllocatorAttributes attr; attr.set_on_host(true); attr.set_gpu_compatible(true); @@ -624,6 +622,8 @@ Status BaseGPUDevice::MakeTensorFromProto(const TensorProto& tensor_proto, tensor_proto.DebugString()); } + ScopedMemoryDebugAnnotation op_annotation("MakeTensorFromProto", "dynamic", + parsed.dtype(), &parsed.shape()); if (parsed.dtype() == DT_VARIANT) { const Variant* from = parsed.flat().data(); int numa_node = attributes().locality().numa_node(); diff --git a/tensorflow/core/common_runtime/hierarchical_tree_broadcaster.cc b/tensorflow/core/common_runtime/hierarchical_tree_broadcaster.cc index df4cd4bffbb..e5097923f14 100644 --- a/tensorflow/core/common_runtime/hierarchical_tree_broadcaster.cc +++ b/tensorflow/core/common_runtime/hierarchical_tree_broadcaster.cc @@ -409,8 +409,9 @@ void HierarchicalTreeBroadcaster::DispatchSend(int subdiv, int dst_rank, int src_rank, const Tensor* src_tensor, const StatusCallback& done) { - auto op_annotation = ScopedMemoryDebugAnnotation( - col_ctx_->op_ctx->op_kernel().name_view().data()); + ScopedMemoryDebugAnnotation op_annotation( + col_ctx_->op_ctx->op_kernel().name_view().data(), col_ctx_->step_id, + "dynamic", src_tensor->dtype(), &src_tensor->shape()); string send_buf_key = BroadcastBufKey(col_ctx_->exec_key, subdiv, src_rank, dst_rank); int dst_idx = diff --git a/tensorflow/core/common_runtime/rendezvous_mgr.cc b/tensorflow/core/common_runtime/rendezvous_mgr.cc index 3df124e934b..cf9f609bbd0 100644 --- a/tensorflow/core/common_runtime/rendezvous_mgr.cc +++ b/tensorflow/core/common_runtime/rendezvous_mgr.cc @@ -74,7 +74,8 @@ void SameWorkerRecvDone(const DeviceMgr* device_mgr, return; } - auto op_annotation = ScopedMemoryDebugAnnotation("SameWorkerRecvDone"); + ScopedMemoryDebugAnnotation op_annotation("SameWorkerRecvDone", 0, "dynamic", + in.dtype(), &in.shape()); AllocatorAttributes attr = recv_args.alloc_attrs; attr.set_gpu_compatible(send_args.alloc_attrs.gpu_compatible() || recv_args.alloc_attrs.gpu_compatible()); @@ -112,7 +113,7 @@ void IntraProcessRecvAsyncImpl(const DeviceMgr* device_mgr, RendezvousInterface::DoneCallback done) { VLOG(1) << "IntraProcessRendezvous Recv " << local << " " << parsed.FullKey(); - auto op_annotation = ScopedMemoryDebugAnnotation("RecvAsync"); + ScopedMemoryDebugAnnotation op_annotation("RecvAsync"); // Recv the tensor from local_. local->RecvAsync( parsed, recv_args, diff --git a/tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc b/tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc index d921e9c2cf1..ff67ee6f0b2 100644 --- a/tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc +++ b/tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc @@ -272,9 +272,8 @@ void BaseRemoteRendezvous::SameWorkerRecvDone( return; } - // Note that it would be nice to cache the step_id here, but it's not - // available. - auto op_annotation = ScopedMemoryDebugAnnotation("SameWorkerRecvDone", 0); + ScopedMemoryDebugAnnotation op_annotation("SameWorkerRecvDone", step_id_, + "dynamic", in.dtype(), &in.shape()); AllocatorAttributes attr = recv_args.alloc_attrs; attr.set_gpu_compatible(send_args.alloc_attrs.gpu_compatible() || recv_args.alloc_attrs.gpu_compatible()); @@ -323,7 +322,7 @@ void BaseRemoteRendezvous::RecvAsync(const ParsedKey& parsed, DCHECK(is_initialized()) << "RecvAsync called when uninitialized (key: " << parsed.FullKey() << ")."; - auto op_annotation = ScopedMemoryDebugAnnotation("RecvAsync", 0); + ScopedMemoryDebugAnnotation op_annotation("RecvAsync", step_id_); // Are src and dst in the same worker? if (IsSameWorker(parsed.src, parsed.dst)) { // Recv the tensor from local_. diff --git a/tensorflow/core/distributed_runtime/collective_rma_distributed.cc b/tensorflow/core/distributed_runtime/collective_rma_distributed.cc index 5afd679dc9f..73fb9d9cf64 100644 --- a/tensorflow/core/distributed_runtime/collective_rma_distributed.cc +++ b/tensorflow/core/distributed_runtime/collective_rma_distributed.cc @@ -129,9 +129,10 @@ void CollectiveRemoteAccessDistributed::RecvFromPeer( } AllocatorAttributes cpu_attr; cpu_attr.set_gpu_compatible(true); - auto op_annotation = ScopedMemoryDebugAnnotation( + ScopedMemoryDebugAnnotation op_annotation( "CollectiveRemoteAccessDistributed::RecvFromPeer" - "::recv_buf_callback"); + "::recv_buf_callback", + step_id_, "dynamic", to_tensor->dtype(), &to_tensor->shape()); Tensor* cpu_tensor = new Tensor(cpu_dev->GetAllocator(cpu_attr), to_tensor->dtype(), to_tensor->shape()); PopulateTensorFromExtra(extra, cpu_tensor); diff --git a/tensorflow/core/distributed_runtime/eager/remote_mgr.cc b/tensorflow/core/distributed_runtime/eager/remote_mgr.cc index ef3d42de037..c120a28032c 100644 --- a/tensorflow/core/distributed_runtime/eager/remote_mgr.cc +++ b/tensorflow/core/distributed_runtime/eager/remote_mgr.cc @@ -76,7 +76,7 @@ Status RemoteMgr::GetMirroredResourceShape( Status RemoteMgr::GetRemoteTensorHandle(const tensorflow::TensorHandle* handle, int64* op_id, int32* output_num) { // TODO(allenl): Consider supporting remote handles on custom devices. - absl::variant device = handle->device(); + VariantDevice device = handle->device(); if (VariantDeviceIsCustom(device)) { return errors::Unimplemented( "Custom devices and remote execution are currently not supported " diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc index 7ad05008a3b..10bd8936a74 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc @@ -669,8 +669,9 @@ void GrpcWorker::RecvBufAsync(CallOptions* opts, const RecvBufRequest* request, AllocatorAttributes cpu_attr; cpu_attr.set_gpu_compatible(true); cpu_attr.set_nic_compatible(true); - auto op_annotation = ScopedMemoryDebugAnnotation( - "GrpcWorker::RecvBufAsync::consumer_callback"); + ScopedMemoryDebugAnnotation op_annotation( + "GrpcWorker::RecvBufAsync::consumer_callback", request->step_id(), + "dynamic", hook->prod_value->dtype(), &hook->prod_value->shape()); Tensor* cpu_tensor = new Tensor(cpu_dev->GetAllocator(cpu_attr), hook->prod_value->dtype(), hook->prod_value->shape()); diff --git a/tensorflow/core/framework/allocator.cc b/tensorflow/core/framework/allocator.cc index 51cc27426b1..8cc8a29fe48 100644 --- a/tensorflow/core/framework/allocator.cc +++ b/tensorflow/core/framework/allocator.cc @@ -27,8 +27,7 @@ limitations under the License. namespace tensorflow { -thread_local const char* pending_op_name = nullptr; -thread_local int64 pending_step_id = 0; +thread_local MemoryDebugAnnotation ScopedMemoryDebugAnnotation::annotation_; string AllocatorStats::DebugString() const { return strings::Printf( diff --git a/tensorflow/core/framework/allocator.h b/tensorflow/core/framework/allocator.h index 46cb8a6cae1..087505f8cd5 100644 --- a/tensorflow/core/framework/allocator.h +++ b/tensorflow/core/framework/allocator.h @@ -32,6 +32,8 @@ limitations under the License. namespace tensorflow { +class TensorShape; + // Attributes for a single allocation call. Different calls to the same // allocator could potentially have different allocation attributes. struct AllocationAttributes { @@ -62,31 +64,80 @@ struct AllocationAttributes { TF_DISALLOW_COPY_AND_ASSIGN(AllocationAttributes); }; -// The runtime will cache Op names in thread-local memory and some allocators -// will try to tag allocations with the requesting Op. -extern thread_local const char* pending_op_name; -extern thread_local int64 pending_step_id; +// Annotations for memory profiling and debugging purpose. The runtime will +// cache the annotations in thread-local memory, and some allocators will try to +// tag allocations with the annotations. +struct MemoryDebugAnnotation { + const char* pending_op_name = nullptr; + int64 pending_step_id = 0; + const char* pending_region_type = nullptr; + int32 pending_data_type = 0; + const TensorShape* pending_shape = nullptr; +}; -// Wrapper class of pending_op_name and pending_step_id for RAII. +// Wrapper class of MemoryDebugAnnotation for RAII. class ScopedMemoryDebugAnnotation { public: + static const MemoryDebugAnnotation& CurrentAnnotation() { + return annotation_; + } + explicit ScopedMemoryDebugAnnotation(const char* op_name) { - last_op_name_ = pending_op_name; - pending_op_name = op_name; + last_annotation_ = annotation_; + CleanupAnnotation(); + annotation_.pending_op_name = op_name; } explicit ScopedMemoryDebugAnnotation(const char* op_name, int64 step_id) { - last_op_name_ = pending_op_name; - pending_op_name = op_name; - pending_step_id = step_id; + last_annotation_ = annotation_; + CleanupAnnotation(); + annotation_.pending_op_name = op_name; + annotation_.pending_step_id = step_id; } - ~ScopedMemoryDebugAnnotation() { pending_op_name = last_op_name_; } + // This constructor keeps the pending_op_name and pending_step_id from parent + // (if any). Otherwise it overwrites with op_name. + explicit ScopedMemoryDebugAnnotation(const char* op_name, + const char* region_type, int32 data_type, + const TensorShape* shape) { + last_annotation_ = annotation_; + if (!annotation_.pending_op_name) { + annotation_.pending_op_name = op_name; + } + annotation_.pending_region_type = region_type; + annotation_.pending_data_type = data_type; + annotation_.pending_shape = shape; + } + + explicit ScopedMemoryDebugAnnotation(const char* op_name, int64 step_id, + const char* region_type, int32 data_type, + const TensorShape* shape) { + last_annotation_ = annotation_; + annotation_.pending_op_name = op_name; + annotation_.pending_step_id = step_id; + annotation_.pending_region_type = region_type; + annotation_.pending_data_type = data_type; + annotation_.pending_shape = shape; + } + + ~ScopedMemoryDebugAnnotation() { annotation_ = last_annotation_; } private: - // Stores the previous value of pending_op_name in case the annotations are - // nested. - const char* last_op_name_ = nullptr; + void CleanupAnnotation() { + annotation_.pending_op_name = nullptr; + annotation_.pending_step_id = 0; + annotation_.pending_region_type = nullptr; + annotation_.pending_data_type = 0; + annotation_.pending_shape = nullptr; + } + + // Stores the current annotations. + static thread_local MemoryDebugAnnotation annotation_; + + // Stores the previous values in case the annotations are nested. + MemoryDebugAnnotation last_annotation_; + + TF_DISALLOW_COPY_AND_ASSIGN(ScopedMemoryDebugAnnotation); }; // Runtime statistics collected by an allocator. Exactly the same as diff --git a/tensorflow/core/framework/dataset.h b/tensorflow/core/framework/dataset.h index 10ea8cb738e..7a49e6e2561 100644 --- a/tensorflow/core/framework/dataset.h +++ b/tensorflow/core/framework/dataset.h @@ -337,7 +337,9 @@ class IteratorContext { if (thread_pool) { runner_threadpool_size = thread_pool->NumThreads(); } else { - runner_threadpool_size = port::MaxParallelism(); + static const int32 kDefaultRunnerThreadpoolSize = + port::MaxParallelism(); + runner_threadpool_size = kDefaultRunnerThreadpoolSize; } // NOTE: Wrap every runner invocation in a call to Runner()->Run(), so diff --git a/tensorflow/core/framework/op_kernel.cc b/tensorflow/core/framework/op_kernel.cc index c365716f209..be34afc5105 100644 --- a/tensorflow/core/framework/op_kernel.cc +++ b/tensorflow/core/framework/op_kernel.cc @@ -703,8 +703,6 @@ Status OpKernelContext::allocate_tensor( DataType type, const TensorShape& shape, Tensor* out_tensor, AllocatorAttributes attr, const AllocationAttributes& allocation_attr) { Allocator* a = get_allocator(attr); - auto op_annotation = - ScopedMemoryDebugAnnotation(op_kernel().name_view().data(), step_id()); Tensor new_tensor(a, type, shape, AllocationAttributes(allocation_attr.no_retry_on_failure, /* allocation_will_be_logged= */ true, @@ -758,6 +756,8 @@ Status OpKernelContext::allocate_output(int index, const TensorShape& shape, " more than once. Try turning off the ScopedAllocator optimizer."); } } + ScopedMemoryDebugAnnotation op_annotation(op_kernel().name_view().data(), + step_id(), "output", type, &shape); auto output_tensor = MakeUnique(); Status s = allocate_tensor(type, shape, output_tensor.get(), attr); if (s.ok()) { @@ -787,6 +787,8 @@ Status OpKernelContext::allocate_temp( << ". Switch to allocate_output to avoid performance penalty."; allocator_attr.scope_id = -1; } + ScopedMemoryDebugAnnotation op_annotation(op_kernel().name_view().data(), + step_id(), "temp", type, &shape); Status s = allocate_tensor(type, shape, out_temp, allocator_attr, allocation_attr); if (track_allocations() && s.ok() && out_temp->TotalBytes() > 0) { @@ -815,6 +817,8 @@ Status OpKernelContext::allocate_persistent(DataType type, return errors::Internal( "Unexpected call to allocate_persistent with scope_id ", attr.scope_id); } + ScopedMemoryDebugAnnotation op_annotation(op_kernel().name_view().data(), + step_id(), "persist", type, &shape); Tensor persistent; Status s = allocate_tensor(type, shape, &persistent, attr); if (s.ok()) { @@ -921,6 +925,9 @@ bool OpKernelContext::maybe_set_output_by_allocate_and_copy( << " params_->forward_from_array[index] " << params_->forward_from_array[index] << " alloc_attr.scope_id " << output_alloc_attr(index).scope_id; + ScopedMemoryDebugAnnotation op_annotation(op_kernel().name_view().data(), + step_id(), "output", + tensor.dtype(), &tensor.shape()); auto new_tensor = MakeUnique(); Status s = allocate_tensor(tensor.dtype(), tensor.shape(), new_tensor.get(), output_alloc_attr(index)); diff --git a/tensorflow/core/kernels/constant_op.cc b/tensorflow/core/kernels/constant_op.cc index 4c1fa5983f2..f787d879ed6 100644 --- a/tensorflow/core/kernels/constant_op.cc +++ b/tensorflow/core/kernels/constant_op.cc @@ -73,7 +73,7 @@ ConstantOp::ConstantOp(OpKernelConstruction* ctx) : OpKernel(ctx, StripTensorDataFromNodeDef(ctx), false), tensor_(ctx->output_type(0)) { const TensorProto* proto = nullptr; - auto op_annotation = ScopedMemoryDebugAnnotation(name_view().data()); + ScopedMemoryDebugAnnotation op_annotation(name_view().data()); OP_REQUIRES_OK(ctx, ctx->GetAttr("value", &proto)); OP_REQUIRES_OK(ctx, ctx->device()->MakeTensorFromProto( *proto, AllocatorAttributes(), &tensor_)); diff --git a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc index 7920d84519f..e1677b95959 100644 --- a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc +++ b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc @@ -76,6 +76,8 @@ namespace data { ParallelInterleaveDatasetOp::kDeterministic; /* static */ constexpr const char* const ParallelInterleaveDatasetOp::kSloppy; +namespace { + constexpr char kTfDataParallelInterleaveWorkerPool[] = "tf_data_parallel_interleave_worker_pool"; constexpr char kParallelism[] = "parallelism"; @@ -113,7 +115,10 @@ constexpr double kDefaultPerIteratorPrefetchFactor = 2.0L; // Period between reporting dataset statistics. constexpr int kStatsReportingPeriodMillis = 1000; -namespace { +inline int64 CeilDiv(int64 numerator, int64 denominator) { + return (numerator + denominator - 1) / denominator; +} + int64 ComputeBufferOutputElements(int64 configured_buffer_output_elements, int64 block_length) { if (configured_buffer_output_elements != model::kAutotune) { @@ -140,6 +145,7 @@ int64 OpVersionFromOpName(absl::string_view op_name) { return 4; } } + } // namespace // The motivation for creating an alternative implementation of parallel @@ -1522,14 +1528,6 @@ ParallelInterleaveDatasetOp::ParallelInterleaveDatasetOp( void ParallelInterleaveDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input, DatasetBase** output) { - int64 cycle_length = 0; - OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kCycleLength, &cycle_length)); - if (cycle_length == model::kAutotune) { - cycle_length = port::NumSchedulableCPUs(); - } - OP_REQUIRES(ctx, cycle_length > 0, - errors::InvalidArgument("`cycle_length` must be > 0")); - int64 block_length = 0; OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kBlockLength, &block_length)); OP_REQUIRES(ctx, block_length > 0, @@ -1561,6 +1559,24 @@ void ParallelInterleaveDatasetOp::MakeDataset(OpKernelContext* ctx, OP_REQUIRES( ctx, num_parallel_calls > 0 || num_parallel_calls == model::kAutotune, errors::InvalidArgument("num_parallel_calls must be greater than zero.")); + int64 cycle_length = 0; + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kCycleLength, &cycle_length)); + if (cycle_length == model::kAutotune) { + if (num_parallel_calls != model::kAutotune) { + cycle_length = std::min(num_parallel_calls, + static_cast(port::MaxParallelism())); + } else { + // If parallelism is to be autotuned, we set the cycle length so that + // the number of thread created for the current and future cycle elements + // roughly matches the number of schedulable cores. + const int num_threads_per_cycle_length = kDefaultCyclePrefetchFactor + 1; + cycle_length = + CeilDiv(port::MaxParallelism(), num_threads_per_cycle_length); + } + } + OP_REQUIRES(ctx, cycle_length > 0, + errors::InvalidArgument("`cycle_length` must be > 0")); + OP_REQUIRES( ctx, num_parallel_calls <= cycle_length, errors::InvalidArgument( diff --git a/tensorflow/core/kernels/deserialize_sparse_string_op.cc b/tensorflow/core/kernels/deserialize_sparse_string_op.cc index 8b8819b3cd0..2e1510785b3 100644 --- a/tensorflow/core/kernels/deserialize_sparse_string_op.cc +++ b/tensorflow/core/kernels/deserialize_sparse_string_op.cc @@ -204,8 +204,9 @@ class DeserializeSparseOp : public OpKernel { target_shape.vec()(i + ndims - 1) = output.shape().data()[i + 1]; } - Reshape(context, output.indices(), input_shape, target_shape, - 0 /* output indices index */, 2 /* output shape index */); + ReshapeSparseTensor(context, output.indices(), input_shape, target_shape, + 0 /* output indices index */, + 2 /* output shape index */); context->set_output(1, output.values()); } diff --git a/tensorflow/core/kernels/reshape_util.cc b/tensorflow/core/kernels/reshape_util.cc index 3b49181f77c..1fce80f7970 100644 --- a/tensorflow/core/kernels/reshape_util.cc +++ b/tensorflow/core/kernels/reshape_util.cc @@ -31,9 +31,11 @@ limitations under the License. namespace tensorflow { -void Reshape(OpKernelContext *context, const Tensor &input_indices_in, - const Tensor &input_shape_in, const Tensor &target_shape_in, - int output_indices_idx, int output_shape_idx) { +void ReshapeSparseTensor(OpKernelContext *context, + const Tensor &input_indices_in, + const Tensor &input_shape_in, + const Tensor &target_shape_in, int output_indices_idx, + int output_shape_idx) { OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input_indices_in.shape()), errors::InvalidArgument( "Input indices should be a matrix but received shape ", diff --git a/tensorflow/core/kernels/reshape_util.h b/tensorflow/core/kernels/reshape_util.h index 90cd30869c8..7e1809e8ca8 100644 --- a/tensorflow/core/kernels/reshape_util.h +++ b/tensorflow/core/kernels/reshape_util.h @@ -16,17 +16,17 @@ limitations under the License. #ifndef TENSORFLOW_CORE_KERNELS_RESHAPE_UTIL_H_ #define TENSORFLOW_CORE_KERNELS_RESHAPE_UTIL_H_ -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/lib/core/status.h" - namespace tensorflow { class OpKernelContext; +class Tensor; // Reshapes the input indices and input shape to the target shape. -void Reshape(OpKernelContext *context, const Tensor &input_indices_in, - const Tensor &input_shape_in, const Tensor &target_shape_in, - int output_indices_idx, int output_shape_idx); +void ReshapeSparseTensor(OpKernelContext *context, + const Tensor &input_indices_in, + const Tensor &input_shape_in, + const Tensor &target_shape_in, int output_indices_idx, + int output_shape_idx); } // namespace tensorflow diff --git a/tensorflow/core/kernels/sparse_reshape_op.cc b/tensorflow/core/kernels/sparse_reshape_op.cc index 059519a913b..6eb5f0af635 100644 --- a/tensorflow/core/kernels/sparse_reshape_op.cc +++ b/tensorflow/core/kernels/sparse_reshape_op.cc @@ -34,8 +34,9 @@ class SparseReshapeOp : public OpKernel { explicit SparseReshapeOp(OpKernelConstruction* context) : OpKernel(context) {} void Compute(OpKernelContext* context) override { - Reshape(context, context->input(0), context->input(1), context->input(2), - 0 /* output indices index */, 1 /* output shape index */); + ReshapeSparseTensor(context, context->input(0), context->input(1), + context->input(2), 0 /* output indices index */, + 1 /* output shape index */); } }; diff --git a/tensorflow/core/platform/default/build_config.bzl b/tensorflow/core/platform/default/build_config.bzl index dbb6ffdbc6e..391de3a4649 100644 --- a/tensorflow/core/platform/default/build_config.bzl +++ b/tensorflow/core/platform/default/build_config.bzl @@ -589,7 +589,10 @@ def tf_protos_all(): ) def tf_protos_profiler_impl(): - return [clean_dep("//tensorflow/core/profiler/protobuf:xplane_proto_cc_impl")] + return [ + clean_dep("//tensorflow/core/profiler/protobuf:xplane_proto_cc_impl"), + clean_dep("//tensorflow/core/profiler:profiler_options_proto_cc_impl"), + ] def tf_protos_grappler_impl(): return [clean_dep("//tensorflow/core/grappler/costs:op_performance_data_cc_impl")] diff --git a/tensorflow/core/profiler/BUILD b/tensorflow/core/profiler/BUILD index bfb9a893765..618ac2f6010 100644 --- a/tensorflow/core/profiler/BUILD +++ b/tensorflow/core/profiler/BUILD @@ -28,6 +28,20 @@ tf_proto_library( visibility = ["//visibility:public"], ) +tf_proto_library( + name = "profiler_options_proto", + srcs = ["profiler_options.proto"], + cc_api_version = 2, + make_default_target_header_only = True, + visibility = ["//visibility:public"], +) + +# This is needed because of how tf_android_core_proto_sources parses proto paths. +exports_files( + srcs = ["profiler_options.proto"], + visibility = ["//tensorflow/core:__pkg__"], +) + tf_proto_library( name = "profiler_service_proto", srcs = ["profiler_service.proto"], @@ -35,6 +49,7 @@ tf_proto_library( cc_api_version = 2, cc_grpc_version = 1, protodeps = [ + ":profiler_options_proto", ":profiler_service_monitor_result_proto", ], use_grpc_namespace = True, diff --git a/tensorflow/core/profiler/internal/BUILD b/tensorflow/core/profiler/internal/BUILD index 3c9d662d1da..9fab42cd54a 100644 --- a/tensorflow/core/profiler/internal/BUILD +++ b/tensorflow/core/profiler/internal/BUILD @@ -434,6 +434,7 @@ cc_library( deps = [ "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "//tensorflow/core/profiler:profiler_options_proto_cc", "//tensorflow/core/profiler/protobuf:xplane_proto_cc", ], ) diff --git a/tensorflow/core/profiler/internal/cpu/BUILD b/tensorflow/core/profiler/internal/cpu/BUILD index a7005a6738c..8dffc5d42e4 100644 --- a/tensorflow/core/profiler/internal/cpu/BUILD +++ b/tensorflow/core/profiler/internal/cpu/BUILD @@ -50,6 +50,7 @@ tf_cc_test( "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core/profiler/internal:profiler_interface", + "//tensorflow/core/profiler/lib:profiler_session", "//tensorflow/core/profiler/lib:traceme", "//tensorflow/core/profiler/protobuf:xplane_proto_cc", "//tensorflow/core/profiler/utils:xplane_schema", diff --git a/tensorflow/core/profiler/internal/cpu/host_tracer.cc b/tensorflow/core/profiler/internal/cpu/host_tracer.cc index 9dd8256331e..753d8c53b9c 100644 --- a/tensorflow/core/profiler/internal/cpu/host_tracer.cc +++ b/tensorflow/core/profiler/internal/cpu/host_tracer.cc @@ -55,8 +55,6 @@ class HostTracer : public ProfilerInterface { Status CollectData(XSpace* space) override; - DeviceType GetDeviceType() override { return DeviceType::kCpu; } - private: // Level of host tracing. const int host_trace_level_; @@ -154,9 +152,9 @@ Status HostTracer::CollectData(XSpace* space) { // Not in anonymous namespace for testing purposes. std::unique_ptr CreateHostTracer( - const profiler::ProfilerOptions& options) { - if (options.host_tracer_level == 0) return nullptr; - return absl::make_unique(options.host_tracer_level); + const ProfileOptions& options) { + if (options.host_tracer_level() == 0) return nullptr; + return absl::make_unique(options.host_tracer_level()); } auto register_host_tracer_factory = [] { diff --git a/tensorflow/core/profiler/internal/cpu/host_tracer_test.cc b/tensorflow/core/profiler/internal/cpu/host_tracer_test.cc index 6c1ef024b33..31cc8509dc6 100644 --- a/tensorflow/core/profiler/internal/cpu/host_tracer_test.cc +++ b/tensorflow/core/profiler/internal/cpu/host_tracer_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/profiler/internal/profiler_interface.h" +#include "tensorflow/core/profiler/lib/profiler_session.h" #include "tensorflow/core/profiler/lib/traceme.h" #include "tensorflow/core/profiler/protobuf/xplane.pb.h" #include "tensorflow/core/profiler/utils/xplane_schema.h" @@ -31,7 +32,7 @@ namespace tensorflow { namespace profiler { std::unique_ptr CreateHostTracer( - const ProfilerOptions& options); + const ProfileOptions& options); namespace { @@ -77,7 +78,7 @@ inline ::testing::PolymorphicMatcher EqualsNodeStats( TEST(HostTracerTest, CollectsTraceMeEventsAsRunMetadata) { uint32 thread_id = Env::Default()->GetCurrentThreadId(); - auto tracer = CreateHostTracer(ProfilerOptions()); + auto tracer = CreateHostTracer(ProfilerSession::DefaultOptions()); TF_ASSERT_OK(tracer->Start()); { TraceMe traceme("hello"); } @@ -122,7 +123,7 @@ TEST(HostTracerTest, CollectsTraceMeEventsAsXSpace) { ASSERT_TRUE(Env::Default()->GetCurrentThreadName(&thread_name)); thread_id = Env::Default()->GetCurrentThreadId(); - auto tracer = CreateHostTracer(ProfilerOptions()); + auto tracer = CreateHostTracer(ProfilerSession::DefaultOptions()); TF_ASSERT_OK(tracer->Start()); { TraceMe traceme("hello"); } diff --git a/tensorflow/core/profiler/internal/cpu/metadata_collector.cc b/tensorflow/core/profiler/internal/cpu/metadata_collector.cc index fbcfaa26e73..c6aa7840920 100644 --- a/tensorflow/core/profiler/internal/cpu/metadata_collector.cc +++ b/tensorflow/core/profiler/internal/cpu/metadata_collector.cc @@ -74,8 +74,6 @@ class MetadataCollector : public ProfilerInterface { return Status::OK(); } - DeviceType GetDeviceType() override { return DeviceType::kCpu; } - private: std::vector debug_info_; bool trace_active_ = false; @@ -84,9 +82,9 @@ class MetadataCollector : public ProfilerInterface { }; std::unique_ptr CreatMetadataCollector( - const profiler::ProfilerOptions& options) { - return options.enable_hlo_proto ? absl::make_unique() - : nullptr; + const ProfileOptions& options) { + return options.enable_hlo_proto() ? absl::make_unique() + : nullptr; } } // namespace diff --git a/tensorflow/core/profiler/internal/cpu/python_tracer.cc b/tensorflow/core/profiler/internal/cpu/python_tracer.cc index e6a910ccc69..aa259f53cfa 100644 --- a/tensorflow/core/profiler/internal/cpu/python_tracer.cc +++ b/tensorflow/core/profiler/internal/cpu/python_tracer.cc @@ -50,8 +50,6 @@ class PythonTracer : public ProfilerInterface { Status CollectData(XSpace* space) override; - DeviceType GetDeviceType() override { return DeviceType::kCpu; } - private: bool recording_ = false; @@ -107,10 +105,10 @@ Status PythonTracer::CollectData(XSpace* space) { // Not in anonymous namespace for testing purposes. std::unique_ptr CreatePythonTracer( - const profiler::ProfilerOptions& options) { - if (!options.enable_python_tracer) return nullptr; + const ProfileOptions& options) { + if (options.python_tracer_level() == 0) return nullptr; // This ProfilerInterface rely on TraceMeRecorder to be active. - if (options.host_tracer_level == 0) return nullptr; + if (options.host_tracer_level() == 0) return nullptr; return absl::make_unique(); } diff --git a/tensorflow/core/profiler/internal/gpu/BUILD b/tensorflow/core/profiler/internal/gpu/BUILD index bfe855ef417..24568091e88 100644 --- a/tensorflow/core/profiler/internal/gpu/BUILD +++ b/tensorflow/core/profiler/internal/gpu/BUILD @@ -76,6 +76,7 @@ tf_cc_test_gpu( "//tensorflow/core:testlib", "//tensorflow/core/kernels:ops_util", "//tensorflow/core/profiler/internal:profiler_interface", + "//tensorflow/core/profiler/lib:profiler_session", "//tensorflow/core/profiler/utils:tf_xplane_visitor", "//tensorflow/core/profiler/utils:xplane_schema", "//tensorflow/core/profiler/utils:xplane_utils", diff --git a/tensorflow/core/profiler/internal/gpu/device_tracer.cc b/tensorflow/core/profiler/internal/gpu/device_tracer.cc index 726cf4600e8..534a1d53752 100644 --- a/tensorflow/core/profiler/internal/gpu/device_tracer.cc +++ b/tensorflow/core/profiler/internal/gpu/device_tracer.cc @@ -513,9 +513,6 @@ class GpuTracer : public profiler::ProfilerInterface { Status Stop() override; Status CollectData(RunMetadata* run_metadata) override; Status CollectData(XSpace* space) override; - profiler::DeviceType GetDeviceType() override { - return profiler::DeviceType::kGpu; - } private: Status DoStart(); @@ -679,9 +676,9 @@ Status GpuTracer::CollectData(XSpace* space) { // Not in anonymous namespace for testing purposes. std::unique_ptr CreateGpuTracer( - const profiler::ProfilerOptions& options) { - if (options.device_type != profiler::DeviceType::kGpu && - options.device_type != profiler::DeviceType::kUnspecified) + const ProfileOptions& options) { + if (options.device_type() != ProfileOptions::GPU && + options.device_type() != ProfileOptions::UNSPECIFIED) return nullptr; profiler::CuptiTracer* cupti_tracer = profiler::CuptiTracer::GetCuptiTracerSingleton(); diff --git a/tensorflow/core/profiler/internal/gpu/device_tracer_test.cc b/tensorflow/core/profiler/internal/gpu/device_tracer_test.cc index 2e422160a59..e6aacb66b89 100644 --- a/tensorflow/core/profiler/internal/gpu/device_tracer_test.cc +++ b/tensorflow/core/profiler/internal/gpu/device_tracer_test.cc @@ -35,6 +35,7 @@ limitations under the License. #include "tensorflow/core/platform/strcat.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/profiler/internal/profiler_interface.h" +#include "tensorflow/core/profiler/lib/profiler_session.h" #include "tensorflow/core/profiler/utils/tf_xplane_visitor.h" #include "tensorflow/core/profiler/utils/xplane_schema.h" #include "tensorflow/core/profiler/utils/xplane_utils.h" @@ -45,14 +46,15 @@ namespace tensorflow { namespace profiler { #if GOOGLE_CUDA -std::unique_ptr CreateGpuTracer( - const ProfilerOptions& options); +extern std::unique_ptr CreateGpuTracer( + const ProfileOptions& options); +std::unique_ptr CreateGpuTracer() { + ProfileOptions options = ProfilerSession::DefaultOptions(); + return CreateGpuTracer(options); +} #else // We don't have device tracer for non-cuda case. -std::unique_ptr CreateGpuTracer( - const ProfilerOptions& options) { - return nullptr; -} +std::unique_ptr CreateGpuTracer() { return nullptr; } #endif namespace { @@ -111,24 +113,21 @@ class DeviceTracerTest : public ::testing::Test { }; TEST_F(DeviceTracerTest, StartStop) { - profiler::ProfilerOptions options; - auto tracer = CreateGpuTracer(options); + auto tracer = CreateGpuTracer(); if (!tracer) return; TF_EXPECT_OK(tracer->Start()); TF_EXPECT_OK(tracer->Stop()); } TEST_F(DeviceTracerTest, StopBeforeStart) { - profiler::ProfilerOptions options; - auto tracer = CreateGpuTracer(options); + auto tracer = CreateGpuTracer(); if (!tracer) return; TF_EXPECT_OK(tracer->Stop()); TF_EXPECT_OK(tracer->Stop()); } TEST_F(DeviceTracerTest, CollectBeforeStart) { - profiler::ProfilerOptions options; - auto tracer = CreateGpuTracer(options); + auto tracer = CreateGpuTracer(); if (!tracer) return; RunMetadata run_metadata; TF_EXPECT_OK(tracer->CollectData(&run_metadata)); @@ -136,8 +135,7 @@ TEST_F(DeviceTracerTest, CollectBeforeStart) { } TEST_F(DeviceTracerTest, CollectBeforeStop) { - profiler::ProfilerOptions options; - auto tracer = CreateGpuTracer(options); + auto tracer = CreateGpuTracer(); if (!tracer) return; TF_EXPECT_OK(tracer->Start()); RunMetadata run_metadata; @@ -147,9 +145,8 @@ TEST_F(DeviceTracerTest, CollectBeforeStop) { } TEST_F(DeviceTracerTest, StartTwoTracers) { - profiler::ProfilerOptions options; - auto tracer1 = CreateGpuTracer(options); - auto tracer2 = CreateGpuTracer(options); + auto tracer1 = CreateGpuTracer(); + auto tracer2 = CreateGpuTracer(); if (!tracer1 || !tracer2) return; TF_EXPECT_OK(tracer1->Start()); @@ -162,8 +159,7 @@ TEST_F(DeviceTracerTest, StartTwoTracers) { TEST_F(DeviceTracerTest, RunWithTracer) { // On non-GPU platforms, we may not support DeviceTracer. - profiler::ProfilerOptions options; - auto tracer = CreateGpuTracer(options); + auto tracer = CreateGpuTracer(); if (!tracer) return; Initialize({3, 2, -1, 0}); @@ -190,8 +186,7 @@ TEST_F(DeviceTracerTest, RunWithTracer) { } TEST_F(DeviceTracerTest, TraceToStepStatsCollector) { - profiler::ProfilerOptions options; - auto tracer = CreateGpuTracer(options); + auto tracer = CreateGpuTracer(); if (!tracer) return; Initialize({3, 2, -1, 0}); @@ -244,8 +239,7 @@ TEST_F(DeviceTracerTest, RunWithTraceOption) { } TEST_F(DeviceTracerTest, TraceToXSpace) { - profiler::ProfilerOptions options; - auto tracer = CreateGpuTracer(options); + auto tracer = CreateGpuTracer(); if (!tracer) return; Initialize({3, 2, -1, 0}); diff --git a/tensorflow/core/profiler/internal/profiler_factory.cc b/tensorflow/core/profiler/internal/profiler_factory.cc index 9fd90c2dc77..e2bae59b892 100644 --- a/tensorflow/core/profiler/internal/profiler_factory.cc +++ b/tensorflow/core/profiler/internal/profiler_factory.cc @@ -36,7 +36,7 @@ void RegisterProfilerFactory(ProfilerFactory factory) { } void CreateProfilers( - const profiler::ProfilerOptions& options, + const ProfileOptions& options, std::vector>* result) { mutex_lock lock(mu); for (auto factory : *GetFactories()) { diff --git a/tensorflow/core/profiler/internal/profiler_factory.h b/tensorflow/core/profiler/internal/profiler_factory.h index 4473e21699e..6bcdcf28c3c 100644 --- a/tensorflow/core/profiler/internal/profiler_factory.h +++ b/tensorflow/core/profiler/internal/profiler_factory.h @@ -24,11 +24,11 @@ namespace tensorflow { namespace profiler { using ProfilerFactory = - std::unique_ptr (*)(const ProfilerOptions&); + std::unique_ptr (*)(const ProfileOptions&); void RegisterProfilerFactory(ProfilerFactory factory); -void CreateProfilers(const ProfilerOptions& options, +void CreateProfilers(const ProfileOptions& options, std::vector>* result); } // namespace profiler diff --git a/tensorflow/core/profiler/internal/profiler_interface.h b/tensorflow/core/profiler/internal/profiler_interface.h index c42c278f847..2605e834f09 100644 --- a/tensorflow/core/profiler/internal/profiler_interface.h +++ b/tensorflow/core/profiler/internal/profiler_interface.h @@ -16,50 +16,13 @@ limitations under the License. #define TENSORFLOW_CORE_PROFILER_INTERNAL_PROFILER_INTERFACE_H_ #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/profiler/profiler_options.pb.h" #include "tensorflow/core/profiler/protobuf/xplane.pb.h" #include "tensorflow/core/protobuf/config.pb.h" namespace tensorflow { namespace profiler { -enum class DeviceType { - kUnspecified, - kCpu, - kGpu, - kTpu, -}; - -struct ProfilerOptions { - // DeviceType::kUnspecified: All registered device profiler will be enabled. - // DeviceType::kCpu: only CPU will be profiled. - // DeviceType::kGpu: only CPU/GPU will be profiled. - // DeviceType::kTpu: only CPU/TPU will be profiled. - DeviceType device_type = DeviceType::kUnspecified; - - // Levels of host tracing: - // - Level 0 is used to disable host traces. - // - Level 1 enables tracing of only user instrumented (or default) TraceMe. - // - Level 2 enables tracing of all level 1 TraceMe(s) and instrumented high - // level program execution details (expensive TF ops, XLA ops, etc). - // This is the default. - // - Level 3 enables tracing of all level 2 TraceMe(s) and more verbose - // (low-level) program execution details (cheap TF ops, etc). - uint32 host_tracer_level = 2; - - // Levels of device tracing: - // - Level 0 is used to disable device traces. - // - Level 1 is used to enable device traces. - // - More levels might be defined for specific device for controlling the - // verbosity of the trace. - uint32 device_tracer_level = 1; - - // Whether to enable python function calls tracer. - bool enable_python_tracer = false; - - // Whether to capture HLO protos from XLA runtime. - bool enable_hlo_proto = true; -}; - // Interface for tensorflow profiler plugins. // // ProfileSession calls each of these methods at most once per instance, and @@ -87,9 +50,6 @@ class ProfilerInterface { // After this or the overload above are called once, subsequent calls might // return empty data. virtual Status CollectData(XSpace* space) = 0; - - // Which device this ProfilerInterface is used for. - virtual DeviceType GetDeviceType() = 0; }; } // namespace profiler diff --git a/tensorflow/core/profiler/lib/BUILD b/tensorflow/core/profiler/lib/BUILD index 4bb1d92c0cb..18ffc3f1e5c 100644 --- a/tensorflow/core/profiler/lib/BUILD +++ b/tensorflow/core/profiler/lib/BUILD @@ -26,6 +26,7 @@ cc_library( deps = [ "//tensorflow/core/profiler/internal:profiler_interface", "//tensorflow/core/profiler/protobuf:xplane_proto_cc", + "//tensorflow/core/profiler:profiler_options_proto_cc", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", ] + if_static([ @@ -50,6 +51,7 @@ cc_library( "//tensorflow/core/platform", "//tensorflow/core/profiler/internal:profiler_interface", "//tensorflow/core/profiler/protobuf:xplane_proto_cc", + "//tensorflow/core/profiler:profiler_options_proto_cc", "//tensorflow/core/util:ptr_util", ] + if_not_android([ ":profiler_utils", diff --git a/tensorflow/core/profiler/lib/profiler_session.cc b/tensorflow/core/profiler/lib/profiler_session.cc index 2cdf5570034..b907f74179c 100644 --- a/tensorflow/core/profiler/lib/profiler_session.cc +++ b/tensorflow/core/profiler/lib/profiler_session.cc @@ -33,8 +33,17 @@ limitations under the License. namespace tensorflow { +namespace { +ProfileOptions GetOptions(const ProfileOptions& opts) { + if (opts.version()) return opts; + ProfileOptions options = ProfilerSession::DefaultOptions(); + options.set_include_dataset_ops(opts.include_dataset_ops()); + return options; +} +}; // namespace + /*static*/ std::unique_ptr ProfilerSession::Create( - const profiler::ProfilerOptions& options) { + const ProfileOptions& options) { return WrapUnique(new ProfilerSession(options)); } @@ -45,12 +54,12 @@ namespace tensorflow { if (!s.ok()) { LOG(WARNING) << "ProfilerSession: " << s.error_message(); } - profiler::ProfilerOptions options; - options.host_tracer_level = host_tracer_level; + ProfileOptions options = DefaultOptions(); + options.set_host_tracer_level(host_tracer_level); return Create(options); } -Status ProfilerSession::Status() { +tensorflow::Status ProfilerSession::Status() { mutex_lock l(mutex_); return status_; } @@ -122,14 +131,14 @@ Status ProfilerSession::CollectData(RunMetadata* run_metadata) { return Status::OK(); } - -ProfilerSession::ProfilerSession(const profiler::ProfilerOptions& options) +ProfilerSession::ProfilerSession(const ProfileOptions& options) #if !defined(IS_MOBILE_PLATFORM) : active_(profiler::AcquireProfilerLock()), #else : active_(false), #endif - start_time_ns_(EnvTime::NowNanos()) { + start_time_ns_(EnvTime::NowNanos()), + options_(GetOptions(options)) { if (!active_) { #if !defined(IS_MOBILE_PLATFORM) status_ = tensorflow::Status(error::UNAVAILABLE, @@ -145,7 +154,7 @@ ProfilerSession::ProfilerSession(const profiler::ProfilerOptions& options) LOG(INFO) << "Profiler session started."; #if !defined(IS_MOBILE_PLATFORM) - CreateProfilers(options, &profilers_); + CreateProfilers(options_, &profilers_); #endif status_ = Status::OK(); diff --git a/tensorflow/core/profiler/lib/profiler_session.h b/tensorflow/core/profiler/lib/profiler_session.h index ba977d72567..1c20876d9d0 100644 --- a/tensorflow/core/profiler/lib/profiler_session.h +++ b/tensorflow/core/profiler/lib/profiler_session.h @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/thread_annotations.h" #include "tensorflow/core/profiler/internal/profiler_interface.h" +#include "tensorflow/core/profiler/profiler_options.pb.h" #include "tensorflow/core/profiler/protobuf/xplane.pb.h" namespace tensorflow { @@ -36,10 +37,21 @@ namespace tensorflow { class ProfilerSession { public: // Creates and ProfilerSession and starts profiling. - static std::unique_ptr Create( - const profiler::ProfilerOptions& options); + static std::unique_ptr Create(const ProfileOptions& options); static std::unique_ptr Create(); + static ProfileOptions DefaultOptions() { + ProfileOptions options; + options.set_version(1); + options.set_device_tracer_level(1); + options.set_host_tracer_level(2); + options.set_device_type(ProfileOptions::UNSPECIFIED); + options.set_python_tracer_level(0); + options.set_enable_hlo_proto(false); + options.set_include_dataset_ops(true); + return options; + } + // Deletes an existing Profiler and enables starting a new one. ~ProfilerSession(); @@ -53,7 +65,7 @@ class ProfilerSession { private: // Constructs an instance of the class and starts profiling - explicit ProfilerSession(const profiler::ProfilerOptions& options); + explicit ProfilerSession(const ProfileOptions& options); // ProfilerSession is neither copyable or movable. ProfilerSession(const ProfilerSession&) = delete; @@ -68,6 +80,7 @@ class ProfilerSession { tensorflow::Status status_ TF_GUARDED_BY(mutex_); const uint64 start_time_ns_; mutex mutex_; + ProfileOptions options_; }; } // namespace tensorflow diff --git a/tensorflow/core/profiler/lib/traceme.h b/tensorflow/core/profiler/lib/traceme.h index a0ea97bd8e0..8b42f187850 100644 --- a/tensorflow/core/profiler/lib/traceme.h +++ b/tensorflow/core/profiler/lib/traceme.h @@ -15,7 +15,9 @@ limitations under the License. #ifndef TENSORFLOW_CORE_PROFILER_LIB_TRACEME_H_ #define TENSORFLOW_CORE_PROFILER_LIB_TRACEME_H_ +#include "absl/strings/match.h" #include "absl/strings/string_view.h" +#include "absl/strings/strip.h" #include "tensorflow/core/platform/env_time.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" @@ -160,6 +162,26 @@ class TraceMe { #endif } + // Sets new_metadata in the metadata part of no_init_.name. + void SetMetadata(absl::string_view new_metadata) { +#if !defined(IS_MOBILE_PLATFORM) + if (TF_PREDICT_FALSE(start_time_ != kUntracedActivity)) { + if (TF_PREDICT_TRUE(TraceMeRecorder::Active())) { + absl::string_view orig = no_init_.name; + if (absl::EndsWith(orig, "#")) { + // orig does have metadata. + absl::ConsumeSuffix(&orig, "#"); + absl::ConsumePrefix(&new_metadata, "#"); + no_init_.name = absl::StrCat(orig, ",", new_metadata); + } else { + // orig does not have metadata. + absl::StrAppend(&no_init_.name, new_metadata); + } + } + } +#endif + } + ~TraceMe() { Stop(); } // Static API, for use when scoped objects are inconvenient. diff --git a/tensorflow/core/profiler/profiler_options.proto b/tensorflow/core/profiler/profiler_options.proto new file mode 100644 index 00000000000..8b4fc3de6fc --- /dev/null +++ b/tensorflow/core/profiler/profiler_options.proto @@ -0,0 +1,54 @@ +syntax = "proto3"; + +package tensorflow; + +message ProfileOptions { + // Some default value of option are not proto3 default value. Use this version + // to determine if we should use default option value instead of proto3 + // default value. + uint32 version = 5; + + enum DeviceType { + UNSPECIFIED = 0; + CPU = 1; + GPU = 2; + TPU = 3; + } + + // Device type to profile/trace: (version >= 1) + // DeviceType::UNSPECIFIED: All registered device profiler will be enabled. + // DeviceType::CPU: only CPU will be profiled. + // DeviceType::GPU: only CPU/GPU will be profiled. + // DeviceType::TPU: only CPU/TPU will be profiled. + DeviceType device_type = 6; + + // We don't collect the dataset ops by default for better trace-viewer + // scalability. The caller can mannually set this field to include the ops. + bool include_dataset_ops = 1; + + // Levels of host tracing: (version >= 1) + // - Level 0 is used to disable host traces. + // - Level 1 enables tracing of only user instrumented (or default) TraceMe. + // - Level 2 enables tracing of all level 1 TraceMe(s) and instrumented high + // level program execution details (expensive TF ops, XLA ops, etc). + // This is the default. + // - Level 3 enables tracing of all level 2 TraceMe(s) and more verbose + // (low-level) program execution details (cheap TF ops, etc). + uint32 host_tracer_level = 2; + + // Levels of device tracing: (version >= 1) + // - Level 0 is used to disable device traces. + // - Level 1 is used to enable device traces. + // - More levels might be defined for specific device for controlling the + // verbosity of the trace. + uint32 device_tracer_level = 3; + + // Whether enable python function calls tracing. Runtime overhead ensues if + // enabled. Default off. (version >= 1) + uint32 python_tracer_level = 4; + + // Whether serialize hlo_proto when XLA is used. (version >= 1) + bool enable_hlo_proto = 7; + + // next-field: 8 +} diff --git a/tensorflow/core/profiler/profiler_service.proto b/tensorflow/core/profiler/profiler_service.proto index 007b68e9482..37ca4084e42 100644 --- a/tensorflow/core/profiler/profiler_service.proto +++ b/tensorflow/core/profiler/profiler_service.proto @@ -2,6 +2,7 @@ syntax = "proto3"; package tensorflow; +import "tensorflow/core/profiler/profiler_options.proto"; import "tensorflow/core/profiler/profiler_service_monitor_result.proto"; // The ProfilerService service retrieves performance information about @@ -13,40 +14,6 @@ service ProfilerService { rpc Monitor(MonitorRequest) returns (MonitorResponse) {} } -message ProfileOptions { - // Some default value of option are not proto3 default value. Use this version - // to determine if we should use default option value instead of proto3 - // default value. - uint32 version = 5; - - // We don't collect the dataset ops by default for better trace-viewer - // scalability. The caller can mannually set this field to include the ops. - bool include_dataset_ops = 1; - - // Levels of host tracing: (version >= 1) - // - Level 0 is used to disable host traces. - // - Level 1 enables tracing of only user instrumented (or default) TraceMe. - // - Level 2 enables tracing of all level 1 TraceMe(s) and instrumented high - // level program execution details (expensive TF ops, XLA ops, etc). - // This is the default. - // - Level 3 enables tracing of all level 2 TraceMe(s) and more verbose - // (low-level) program execution details (cheap TF ops, etc). - uint32 host_tracer_level = 2; - - // Levels of device tracing: (version >= 1) - // - Level 0 is used to disable device traces. - // - Level 1 is used to enable device traces. - // - More levels might be defined for specific device for controlling the - // verbosity of the trace. - uint32 device_tracer_level = 3; - - // Whether enable python function calls tracing. Runtime overhead ensues if - // enabled. Default off. (version >= 1) - uint32 python_tracer_level = 4; - - // next-field: 6 -} - message ToolRequestOptions { // Required formats for the tool, it should be one of "json", "proto", "raw" // etc. If not specified (backward compatible), use default format, i.e. most diff --git a/tensorflow/core/profiler/rpc/profiler_service_impl.cc b/tensorflow/core/profiler/rpc/profiler_service_impl.cc index 407cd0ae0a6..8f1be23594a 100644 --- a/tensorflow/core/profiler/rpc/profiler_service_impl.cc +++ b/tensorflow/core/profiler/rpc/profiler_service_impl.cc @@ -53,7 +53,7 @@ class ProfilerServiceImpl : public grpc::ProfilerService::Service { ProfileResponse* response) override { VLOG(1) << "Received a profile request: " << req->DebugString(); std::unique_ptr profiler = - ProfilerSession::Create(GetOptions(req->opts())); + ProfilerSession::Create(req->opts()); Status status = profiler->Status(); if (!status.ok()) { return ::grpc::Status(::grpc::StatusCode::INTERNAL, @@ -76,19 +76,6 @@ class ProfilerServiceImpl : public grpc::ProfilerService::Service { return ::grpc::Status::OK; } - - private: - profiler::ProfilerOptions GetOptions(const tensorflow::ProfileOptions& opts) { - profiler::ProfilerOptions options; - if (opts.version()) { - options.host_tracer_level = opts.host_tracer_level(); - options.device_tracer_level = opts.device_tracer_level(); - options.enable_python_tracer = opts.python_tracer_level() > 0; - } else { - // use default options value; - } - return options; - } }; } // namespace diff --git a/tensorflow/core/profiler/utils/xplane_schema.cc b/tensorflow/core/profiler/utils/xplane_schema.cc index da0ba034dbe..47e4a3e1e81 100644 --- a/tensorflow/core/profiler/utils/xplane_schema.cc +++ b/tensorflow/core/profiler/utils/xplane_schema.cc @@ -126,6 +126,8 @@ const StatTypeMap& GetStatTypeMap() { {"requested_bytes", kRequestedBytes}, {"allocation_bytes", kAllocationBytes}, {"addr", kAddress}, + {"region_type", kRegionType}, + {"data_type", kDataType}, {"shape", kTensorShapes}, // Device trace arguments. {"device_id", kDeviceId}, diff --git a/tensorflow/core/profiler/utils/xplane_schema.h b/tensorflow/core/profiler/utils/xplane_schema.h index bce6c5ecc8f..e85808253de 100644 --- a/tensorflow/core/profiler/utils/xplane_schema.h +++ b/tensorflow/core/profiler/utils/xplane_schema.h @@ -119,6 +119,8 @@ enum StatType { kRequestedBytes, kAllocationBytes, kAddress, + kRegionType, + kDataType, kTensorShapes, // Device trace arguments. kDeviceId, diff --git a/tensorflow/lite/BUILD b/tensorflow/lite/BUILD index fa9e62186fa..1560a35fe17 100644 --- a/tensorflow/lite/BUILD +++ b/tensorflow/lite/BUILD @@ -83,6 +83,7 @@ FRAMEWORK_LIB_HDRS = [ cc_library( name = "version", hdrs = ["version.h"], + build_for_embedded = True, copts = TFLITE_DEFAULT_COPTS, # Note that we only use the header defines from :version_lib. deps = ["//tensorflow/core:version_lib"], @@ -135,6 +136,7 @@ cc_library( name = "external_cpu_backend_context", srcs = ["external_cpu_backend_context.cc"], hdrs = ["external_cpu_backend_context.h"], + build_for_embedded = True, copts = TFLITE_DEFAULT_COPTS, deps = [ "//tensorflow/lite/c:common", diff --git a/tensorflow/lite/core/subgraph.cc b/tensorflow/lite/core/subgraph.cc index 4dfea10cf2d..f3c1be6fd11 100644 --- a/tensorflow/lite/core/subgraph.cc +++ b/tensorflow/lite/core/subgraph.cc @@ -762,6 +762,36 @@ TfLiteStatus Subgraph::ResizeInputTensor(int tensor_index, return ResizeTensorImpl(tensor, ConvertVectorToTfLiteIntArray(dims)); } +TfLiteStatus Subgraph::ResizeInputTensorStrict(int tensor_index, + const std::vector& dims) { + TF_LITE_ENSURE(&context_, + tensor_index < context_.tensors_size && tensor_index >= 0); + TfLiteTensor* tensor = &context_.tensors[tensor_index]; + + // Ensure that only unknown dimensions can be resized. + TF_LITE_ENSURE_EQ(&context_, tensor->dims->size, dims.size()); + for (size_t idx = 0; idx < dims.size(); idx++) { + // `dims_signature` is not defined when no unknown dimensions are present. + int dim_signature; + if (tensor->dims_signature && tensor->dims_signature->size) { + dim_signature = tensor->dims_signature->data[idx]; + } else { + dim_signature = tensor->dims->data[idx]; + } + + if (dim_signature != -1 && dim_signature != dims[idx]) { + ReportError( + "Attempting to resize dimension %d of tensor %d with value %d to %d. " + "ResizeInputTensorStrict only allows mutating unknown dimensions " + "identified by -1.", + idx, tensor_index, dim_signature, dims[idx]); + return kTfLiteError; + } + } + + return ResizeInputTensor(tensor_index, dims); +} + TfLiteStatus Subgraph::ReleaseNonPersistentMemory() { if (memory_planner_) { TF_LITE_ENSURE_STATUS(memory_planner_->ReleaseNonPersistentMemory()); @@ -802,7 +832,7 @@ TfLiteStatus Subgraph::PrepareOpsStartingAt( const TfLiteRegistration& registration = nodes_and_registration_[node_index].second; EnsureTensorsVectorCapacity(); - if (OpPrepare(registration, &node) == kTfLiteError) { + if (OpPrepare(registration, &node) != kTfLiteOk) { return ReportOpError(&context_, node, registration, node_index, "failed to prepare"); } @@ -909,7 +939,7 @@ TfLiteStatus Subgraph::Invoke() { EnsureTensorsVectorCapacity(); tensor_resized_since_op_invoke_ = false; - if (OpInvoke(registration, &node) == kTfLiteError) { + if (OpInvoke(registration, &node) != kTfLiteOk) { return ReportOpError(&context_, node, registration, node_index, "failed to invoke"); } diff --git a/tensorflow/lite/core/subgraph.h b/tensorflow/lite/core/subgraph.h index 70ca4d24b61..a85eeab5696 100644 --- a/tensorflow/lite/core/subgraph.h +++ b/tensorflow/lite/core/subgraph.h @@ -221,6 +221,15 @@ class Subgraph { TfLiteStatus ResizeInputTensor(int tensor_index, const std::vector& dims); + // WARNING: Experimental interface, subject to change + // Change the dimensionality of a given tensor. This is only acceptable for + // tensor indices that are inputs or variables. Only unknown dimensions can be + // resized with this function. Unknown dimensions are indicated as `-1` in the + // `dims_signature` attribute of a `TfLiteTensor`. Returns status of failure + // or success. + TfLiteStatus ResizeInputTensorStrict(int tensor_index, + const std::vector& dims); + // This releases memory held by non-persistent tensors. It does NOT re-perform // memory planning. // AllocateTensors needs to be called before next invocation. diff --git a/tensorflow/lite/delegates/gpu/BUILD b/tensorflow/lite/delegates/gpu/BUILD index d6875476dec..099f653a1b8 100644 --- a/tensorflow/lite/delegates/gpu/BUILD +++ b/tensorflow/lite/delegates/gpu/BUILD @@ -237,6 +237,7 @@ cc_library( "//tensorflow/lite/delegates/gpu/common:model_transformer", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/gl:api2", + "//tensorflow/lite/kernels/internal:optimized_base", "@com_google_absl//absl/memory", "@com_google_absl//absl/types:span", ], diff --git a/tensorflow/lite/delegates/gpu/common/BUILD b/tensorflow/lite/delegates/gpu/common/BUILD index 3d3f685f66a..d80a2fb0a4a 100644 --- a/tensorflow/lite/delegates/gpu/common/BUILD +++ b/tensorflow/lite/delegates/gpu/common/BUILD @@ -91,6 +91,7 @@ cc_library( ":status", ":tensor", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", "@com_google_absl//absl/types:any", "@com_google_absl//absl/types:optional", ], @@ -139,6 +140,8 @@ cc_library( "//tensorflow/lite/c:common", "//tensorflow/lite/delegates/gpu/common/transformations:general_transformations", "//tensorflow/lite/kernels:kernel_util", + "//tensorflow/lite/kernels/internal:reference_base", + "//tensorflow/lite/kernels/internal:tensor", "//tensorflow/lite/schema:schema_fbs", "@FP16", "@com_google_absl//absl/memory", diff --git a/tensorflow/lite/delegates/gpu/common/model.h b/tensorflow/lite/delegates/gpu/common/model.h index 1a68b6975dd..7b499b3ef2b 100644 --- a/tensorflow/lite/delegates/gpu/common/model.h +++ b/tensorflow/lite/delegates/gpu/common/model.h @@ -24,6 +24,7 @@ limitations under the License. #include #include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" #include "absl/types/any.h" #include "absl/types/optional.h" #include "tensorflow/lite/delegates/gpu/common/data_type.h" @@ -305,8 +306,8 @@ class Model : public Graph { // check if this value has the same producer already if (node_ptr == v->producer) { - return absl::InvalidArgumentError( - "Node is already a producer of the value"); + return absl::AlreadyExistsError(absl::StrCat( + "Node ", producer, " is already a producer of the value ", value)); } // Check if the node is a consumer of this value. @@ -389,8 +390,8 @@ class Model : public Graph { // check if this value has the same consumer already if (IsInput(consumer, value)) { - return absl::InvalidArgumentError( - "Node is already a consumer of the value"); + return absl::AlreadyExistsError(absl::StrCat( + "Node ", consumer, " is already a consumer of the value ", value)); } n->inputs.push_back(value_ptr); diff --git a/tensorflow/lite/delegates/gpu/common/model_builder.cc b/tensorflow/lite/delegates/gpu/common/model_builder.cc index e2cc431e79b..513920a7dd1 100644 --- a/tensorflow/lite/delegates/gpu/common/model_builder.cc +++ b/tensorflow/lite/delegates/gpu/common/model_builder.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -48,6 +49,8 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/status.h" #include "tensorflow/lite/delegates/gpu/common/tensor.h" #include "tensorflow/lite/delegates/gpu/common/transformations/general_transformations.h" +#include "tensorflow/lite/kernels/internal/reference/dequantize.h" +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" #include "tensorflow/lite/kernels/kernel_util.h" #include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/util.h" @@ -96,6 +99,30 @@ void ConvertFloat16ToFloat32(size_t num_elements, const uint16_t* src, } } +template +inline void DequantizeConstantTensor(const TfLiteTensor& tensor, + const T* source_data, + float* dequantized_data) { + TfLiteAffineQuantization* quant_params = + reinterpret_cast(tensor.quantization.params); + if (quant_params->scale->size > 1) { + // Tensor is per-channel quantized. + PerChannelDequantizationParams op_params; + op_params.zero_point = quant_params->zero_point->data; + op_params.scale = quant_params->scale->data; + op_params.quantized_dimension = quant_params->quantized_dimension; + reference_ops::PerChannelDequantize(op_params, GetTensorShape(&tensor), + source_data, GetTensorShape(&tensor), + dequantized_data); + } else { + DequantizationParams op_params; + op_params.zero_point = tensor.params.zero_point; + op_params.scale = tensor.params.scale; + reference_ops::Dequantize(op_params, GetTensorShape(&tensor), source_data, + GetTensorShape(&tensor), dequantized_data); + } +} + template <> absl::Status CreateVectorCopyData(const TfLiteTensor& tensor, float* tensor_data) { @@ -108,6 +135,15 @@ absl::Status CreateVectorCopyData(const TfLiteTensor& tensor, NumElements(&tensor), reinterpret_cast(tensor.data.f16), tensor_data); break; + case kTfLiteInt8: + DequantizeConstantTensor(tensor, tensor.data.int8, tensor_data); + break; + case kTfLiteUInt8: + DequantizeConstantTensor(tensor, tensor.data.uint8, tensor_data); + break; + case kTfLiteInt32: + DequantizeConstantTensor(tensor, tensor.data.i32, tensor_data); + break; default: return absl::InvalidArgumentError( "Unsupported data type for float32 tensor"); @@ -210,6 +246,8 @@ DataType ToDataType(TfLiteType type) { return DataType::INT32; case kTfLiteInt64: return DataType::INT64; + case kTfLiteInt8: + return DataType::INT8; case kTfLiteUInt8: return DataType::UINT8; default: @@ -292,17 +330,62 @@ absl::Status CheckInputsConstsOutputs(const TfLiteContext* context, return CheckInputsOutputs(context, tflite_node, runtime_inputs, outputs); } +// Populates quantization parameters for non-constant UInt8/Int8 tensors. +// This helps the delegate emulate quantized inference with +// QuantizeAndDequantize. +absl::Status PopulateQuantParams(const TfLiteTensor& tensor, + QuantizationParams* quant_params) { + const TfLiteQuantization& quant = tensor.quantization; + if (quant.type != TfLiteQuantizationType::kTfLiteAffineQuantization) { + return absl::InvalidArgumentError( + absl::StrCat("Tensor not quantized: ", std::string(tensor.name))); + } + const TfLiteAffineQuantization* params = + static_cast(quant.params); + if (params->scale->size > 1) { + return absl::InvalidArgumentError( + absl::StrCat("Non-constant per-channel quantized tensor: ", + std::string(tensor.name))); + } + const float scale = params->scale->data[0]; + const float zero_point = static_cast(params->zero_point->data[0]); + + float qmin_value = 0; + float qmax_value = 0; + if (tensor.type == kTfLiteUInt8) { + qmin_value = static_cast(std::numeric_limits::min()); + qmax_value = static_cast(std::numeric_limits::max()); + } else if (tensor.type == kTfLiteInt8) { + qmin_value = static_cast(std::numeric_limits::min()); + qmax_value = static_cast(std::numeric_limits::max()); + } else { + return absl::InvalidArgumentError(absl::StrCat( + "Type invalid for quantized tensor: ", std::string(tensor.name))); + } + quant_params->min = scale * (static_cast(qmin_value) - zero_point); + quant_params->max = scale * (static_cast(qmax_value) - zero_point); + quant_params->scale = scale; + + return absl::OkStatus(); +} + +// If quantized tensors exist in the graph & quant_conversion_map is non-null, +// the mapping between the original tensors (fixed-point) & GPU values (fp) is +// stored in quant_conversion_map. class ObjectReader { public: - ObjectReader(GraphFloat32* graph, TfLiteContext* context, - const TfLiteNode* tflite_node, - std::vector>*>* tensor_to_value) + ObjectReader( + GraphFloat32* graph, TfLiteContext* context, + const TfLiteNode* tflite_node, + std::unordered_map>*>* tensor_to_value, + std::unordered_map* quant_conversion_map = nullptr) : graph_(graph), context_(context), tflite_node_(tflite_node), - tensor_to_value_(tensor_to_value) {} + tensor_to_value_(tensor_to_value), + quant_conversion_map_(quant_conversion_map) {} - absl::Status ReadValue(uint32_t idx, Value>** value) const { + absl::Status ReadValue(uint32_t idx, Value>** value) { if (idx >= tflite_node_->inputs->size) { return absl::OutOfRangeError( absl::StrCat("ReadValue: input tensor index: ", idx)); @@ -368,23 +451,60 @@ class ObjectReader { } absl::Status ReadValueByTensorIdx(uint32_t tensor_idx, - Value>** value) const { - if (tensor_idx >= tensor_to_value_->size()) { + Value>** value) { + if (tensor_idx >= context_->tensors_size) { return absl::OutOfRangeError( absl::StrCat("ReadValue: input tensor index: ", tensor_idx)); } - if ((*tensor_to_value_)[tensor_idx] == nullptr) { + + if (tensor_to_value_->find(tensor_idx) == tensor_to_value_->end()) { const TfLiteTensor& tflite_tensor = context_->tensors[tensor_idx]; if (tflite::IsConstantTensor(&tflite_tensor)) { return absl::NotFoundError(absl::StrCat( "ReadValue: value is a constant tensor: ", tensor_idx)); } - Value>* value = graph_->NewValue(); - RETURN_IF_ERROR( - ConvertTfLiteTensorToTensorRef(tflite_tensor, &value->tensor)); - value->tensor.ref = tensor_idx; - (*tensor_to_value_)[tensor_idx] = value; + + if ((tflite_tensor.type == kTfLiteInt8 || + tflite_tensor.type == kTfLiteUInt8) && + quant_conversion_map_) { + // Quantized case + if (quant_conversion_map_->find(tensor_idx) == + quant_conversion_map_->end()) { + // Since the original tensor is fixed-point, add a new float tensor to + // the TFLite graph to represent the dequantized data. + int fp_tensor_index = 0; + TfLiteTensor* fp_tflite_tensor; + if (delegates::CreateNewTensorWithDifferentType( + context_, tensor_idx, kTfLiteFloat32, &fp_tflite_tensor, + &fp_tensor_index) != kTfLiteOk) { + return absl::InternalError("Could not add new tensor to graph"); + } + // Remember this tensor for later. + (*quant_conversion_map_)[fp_tensor_index] = tensor_idx; + (*quant_conversion_map_)[tensor_idx] = fp_tensor_index; + // Add a new GPU Value for the new dequantized floating-point tensor. + Value>* value = graph_->NewValue(); + RETURN_IF_ERROR(ConvertTfLiteTensorToTensorRef(*fp_tflite_tensor, + &value->tensor)); + value->tensor.ref = fp_tensor_index; + value->quant_params.emplace(); + RETURN_IF_ERROR( + PopulateQuantParams(tflite_tensor, &value->quant_params.value())); + (*tensor_to_value_)[fp_tensor_index] = value; + } + // We do not use the original tensor index as reference for the GPU + // Value, instead pointing at the corresponding float version. + tensor_idx = quant_conversion_map_->at(tensor_idx); + } else { + // Floating-point case. + Value>* value = graph_->NewValue(); + RETURN_IF_ERROR( + ConvertTfLiteTensorToTensorRef(tflite_tensor, &value->tensor)); + value->tensor.ref = tensor_idx; + (*tensor_to_value_)[tensor_idx] = value; + } } + *value = (*tensor_to_value_)[tensor_idx]; return absl::OkStatus(); } @@ -410,9 +530,10 @@ class ObjectReader { private: GraphFloat32* graph_ = nullptr; - const TfLiteContext* context_ = nullptr; + TfLiteContext* context_ = nullptr; const TfLiteNode* tflite_node_ = nullptr; - std::vector>*>* tensor_to_value_; + std::unordered_map>*>* tensor_to_value_; + std::unordered_map* quant_conversion_map_; }; // A parser responsible for parsing TFLite operation and adding it to a graph. @@ -1077,6 +1198,43 @@ class DepthwiseConvolutionOperationParser : public TFLiteOperationParser { } }; +class DequantizeOperationParser : public TFLiteOperationParser { + public: + absl::Status IsSupported(const TfLiteContext* context, + const TfLiteNode* tflite_node, + const TfLiteRegistration* registration) final { + RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2)); + RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, + /*runtime_inputs=*/1, /*outputs=*/1)); + return absl::OkStatus(); + } + + absl::Status Parse(const TfLiteNode* tflite_node, + const TfLiteRegistration* registration, + GraphFloat32* graph, ObjectReader* reader) final { + // 'Dequantize' is rewritten as QuantizeAndDequantize since we are dealing + // with floating-point versions of the original tensors. + Node* node = graph->NewNode(); + node->operation.type = ToString(OperationType::QUANTIZE_AND_DEQUANTIZE); + RETURN_IF_ERROR(reader->AddInput(node, 0)); + RETURN_IF_ERROR(reader->AddOutputs(node)); + + // Quantization attributes should already be present in the input tensor. + auto input_value = graph->FindInputs(node->id)[0]; + if (!input_value->quant_params) { + return absl::InvalidArgumentError( + "Encountered Dequantize input with no quant params"); + } + QuantizeAndDequantizeAttributes attr; + attr.min = input_value->quant_params.value().min; + attr.max = input_value->quant_params.value().max; + attr.scale = input_value->quant_params.value().scale; + + node->operation.attributes = attr; + return absl::OkStatus(); + } +}; + class ElementwiseOperationParser : public TFLiteOperationParser { public: explicit ElementwiseOperationParser(OperationType operation_type) @@ -1736,6 +1894,43 @@ class Pooling2DOperationParser : public TFLiteOperationParser { const PoolingType type_; }; +class QuantizeOperationParser : public TFLiteOperationParser { + public: + absl::Status IsSupported(const TfLiteContext* context, + const TfLiteNode* tflite_node, + const TfLiteRegistration* registration) final { + RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2)); + RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, + /*runtime_inputs=*/1, /*outputs=*/1)); + return absl::OkStatus(); + } + + absl::Status Parse(const TfLiteNode* tflite_node, + const TfLiteRegistration* registration, + GraphFloat32* graph, ObjectReader* reader) final { + // 'Quantize' is rewritten as QuantizeAndDequantize since we are dealing + // with floating-point versions of the original tensors. + Node* node = graph->NewNode(); + node->operation.type = ToString(OperationType::QUANTIZE_AND_DEQUANTIZE); + RETURN_IF_ERROR(reader->AddInput(node, 0)); + RETURN_IF_ERROR(reader->AddOutputs(node)); + + // Quantization attributes should already be present in the output tensor. + auto output_value = graph->FindOutputs(node->id)[0]; + if (!output_value->quant_params) { + return absl::InvalidArgumentError( + "Encountered Quantize output with no quant params"); + } + QuantizeAndDequantizeAttributes attr; + attr.min = output_value->quant_params.value().min; + attr.max = output_value->quant_params.value().max; + attr.scale = output_value->quant_params.value().scale; + + node->operation.attributes = attr; + return absl::OkStatus(); + } +}; + class ReLUOperationParser : public TFLiteOperationParser { public: explicit ReLUOperationParser(int clip) : clip_(clip) {} @@ -2696,7 +2891,7 @@ class UnsupportedOperationParser : public TFLiteOperationParser { }; std::unique_ptr NewOperationParser( - const TfLiteRegistration* registration) { + const TfLiteRegistration* registration, bool allow_quant_ops = false) { const auto builtin_code = registration->builtin_code; switch (builtin_code) { case kTfLiteBuiltinAbs: @@ -2713,6 +2908,11 @@ std::unique_ptr NewOperationParser( return absl::make_unique(OperationType::COS); case kTfLiteBuiltinDepthwiseConv2d: return absl::make_unique(); + case kTfLiteBuiltinDequantize: + if (allow_quant_ops) { + return absl::make_unique(); + } + break; case kTfLiteBuiltinDiv: return absl::make_unique(OperationType::DIV); case kTfLiteBuiltinFullyConnected: @@ -2744,6 +2944,11 @@ std::unique_ptr NewOperationParser( return absl::make_unique(/*mirror_pad=*/false); case kTfLiteBuiltinPow: return absl::make_unique(OperationType::POW); + case kTfLiteBuiltinQuantize: + if (allow_quant_ops) { + return absl::make_unique(); + } + break; case kTfLiteBuiltinRelu: return absl::make_unique(0); case kTfLiteBuiltinRelu6: @@ -2826,17 +3031,23 @@ std::unique_ptr NewOperationParser( } absl::Status IsSupported(const TfLiteContext* context, TfLiteNode* node, - const TfLiteRegistration* registration) { - return NewOperationParser(registration) + const TfLiteRegistration* registration, + bool allow_quant_ops = false) { + return NewOperationParser(registration, allow_quant_ops) ->IsSupported(context, node, registration); } -bool IsAllFloatTensors(const TfLiteContext* context, - const TfLiteIntArray* array) { +bool IsAllAllowedTensors(TfLiteContext* context, const TfLiteIntArray* array, + bool allow_quant_ops = false) { for (int i = 0; i < array->size; ++i) { const TfLiteTensor* t = context->tensors + array->data[i]; - bool const type_supported = + bool type_supported = (t->type == kTfLiteFloat32 || t->type == kTfLiteFloat16); + if (allow_quant_ops) { + // Since we only check non-constant tensors, type cannot be Int32. + type_supported = + type_supported || t->type == kTfLiteInt8 || t->type == kTfLiteUInt8; + } if (t->allocation_type == kTfLiteArenaRw && !type_supported) { return false; } @@ -2853,12 +3064,13 @@ absl::Status ConvertTfLiteTensorToTensorRef(const TfLiteTensor& tflite_tensor, // TODO(impjdi): Check number of input/output tensors and their dimensions. // TODO(impjdi): Check ops' parameters. -TfLiteIntArray* GetOpsToReplace(TfLiteContext* context) { +TfLiteIntArray* GetOpsToReplace(TfLiteContext* context, bool allow_quant_ops) { IsNodeSupportedFn node_supported_fn = [=](TfLiteContext* context, TfLiteNode* node, TfLiteRegistration* registration, std::string* unsupported_details) -> bool { - const auto status = IsSupported(context, node, registration); + const auto status = + IsSupported(context, node, registration, allow_quant_ops); if (!status.ok()) { if (unsupported_details) { *unsupported_details = std::string(status.message()); @@ -2866,8 +3078,8 @@ TfLiteIntArray* GetOpsToReplace(TfLiteContext* context) { return false; } - if (!IsAllFloatTensors(context, node->inputs) || - !IsAllFloatTensors(context, node->outputs)) { + if (!IsAllAllowedTensors(context, node->inputs, allow_quant_ops) || + !IsAllAllowedTensors(context, node->outputs, allow_quant_ops)) { if (unsupported_details) { *unsupported_details = "OP is supported, but tensor type isn't matched!"; @@ -2914,7 +3126,8 @@ TfLiteIntArray* GetOpsToReplace(TfLiteContext* context) { absl::Status BuildModel(TfLiteContext* context, const TfLiteDelegateParams* delegate_params, - GraphFloat32* graph) { + GraphFloat32* graph, + std::unordered_map* quant_conversion_map) { std::vector> operations; std::vector tflite_nodes; for (int i = 0; i < delegate_params->nodes_to_replace->size; ++i) { @@ -2923,11 +3136,14 @@ absl::Status BuildModel(TfLiteContext* context, RETURN_IF_ERROR(GetNodeAndRegistration( context, delegate_params->nodes_to_replace->data[i], &tflite_node, ®istration)); - if (registration->builtin_code == kTfLiteBuiltinDequantize) { - // Ignore Dequantize nodes. + if (registration->builtin_code == kTfLiteBuiltinDequantize && + context->tensors[tflite_node->inputs->data[0]].type == + TfLiteType::kTfLiteFloat16) { + // Ignore Fp16 Dequantize nodes. continue; } - auto op_parser = NewOperationParser(registration); + auto op_parser = NewOperationParser( + registration, /*allow_quant_ops=*/quant_conversion_map != nullptr); if (!op_parser) { return absl::UnimplementedError( absl::StrCat("Operation ", registration->builtin_code, "(", @@ -2937,15 +3153,15 @@ absl::Status BuildModel(TfLiteContext* context, operations.push_back(std::move(op_parser)); tflite_nodes.push_back(i); } - std::vector>*> tensor_to_value(context->tensors_size, - nullptr); + std::unordered_map>*> tensor_to_value; for (int i = 0; i < operations.size(); ++i) { TfLiteNode* tflite_node; TfLiteRegistration* registration; RETURN_IF_ERROR(GetNodeAndRegistration( context, delegate_params->nodes_to_replace->data[tflite_nodes[i]], &tflite_node, ®istration)); - ObjectReader reader(graph, context, tflite_node, &tensor_to_value); + ObjectReader reader(graph, context, tflite_node, &tensor_to_value, + quant_conversion_map); const auto status = operations[i]->Parse(tflite_node, registration, graph, &reader); if (!status.ok()) { @@ -2956,10 +3172,11 @@ absl::Status BuildModel(TfLiteContext* context, return absl::OkStatus(); } -absl::Status BuildFinalModel(TfLiteContext* context, - const TfLiteDelegateParams* delegate_params, - GraphFloat32* graph) { - RETURN_IF_ERROR(BuildModel(context, delegate_params, graph)); +absl::Status BuildFinalModel( + TfLiteContext* context, const TfLiteDelegateParams* delegate_params, + GraphFloat32* graph, std::unordered_map* quant_conversion_map) { + RETURN_IF_ERROR( + BuildModel(context, delegate_params, graph, quant_conversion_map)); // Apply general transformations on the graph. NullTransformationReporter reporter; diff --git a/tensorflow/lite/delegates/gpu/common/model_builder.h b/tensorflow/lite/delegates/gpu/common/model_builder.h index b8fcab0c5c8..4b2a2f51db3 100644 --- a/tensorflow/lite/delegates/gpu/common/model_builder.h +++ b/tensorflow/lite/delegates/gpu/common/model_builder.h @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include #include "tensorflow/lite/context.h" #include "tensorflow/lite/delegates/gpu/common/model.h" @@ -28,19 +29,36 @@ namespace gpu { // Validates which operations are supported and returns array of operations to // replace with GPU kernels. The caller must free the pointer on TfLiteIntArray. -TfLiteIntArray* GetOpsToReplace(TfLiteContext* context); +TfLiteIntArray* GetOpsToReplace(TfLiteContext* context, + bool allow_quant_ops = false); // Extracts TFLite delegate execution plan from the input TFLite context and // converts it into generic graph format. -absl::Status BuildModel(TfLiteContext* context, - const TfLiteDelegateParams* delegate_params, - GraphFloat32* graph); +// +// If model is quantized, quant_conversion_map maps the dequantized tensor +// (floating-point) to the original tensor (fixed-point) & vice-versa. +// NOTE: Not all of these new tensors will have any data and need memory +// allocated for them. We need to do that only for the overall GPU graph inputs +// & outputs. This should be done by the delegate, by setting the appropriate +// TfLiteNode->temporaries. +absl::Status BuildModel( + TfLiteContext* context, const TfLiteDelegateParams* delegate_params, + GraphFloat32* graph, + std::unordered_map* quant_conversion_map = nullptr); // Same as above but also apply all transformations on the final graph. // Prefer using this method instead of BuildModel. -absl::Status BuildFinalModel(TfLiteContext* context, - const TfLiteDelegateParams* delegate_params, - GraphFloat32* graph); +// +// If model is quantized, quant_conversion_map maps the dequantized tensor +// (floating-point) to the original TFLite tensor (fixed-point) & vice-versa. +// NOTE: Not all of these new tensors will have any data and need memory +// allocated for them. We need to do that only for the overall GPU graph inputs +// & outputs. This should be done by the delegate, by setting the appropriate +// TfLiteNode->temporaries. +absl::Status BuildFinalModel( + TfLiteContext* context, const TfLiteDelegateParams* delegate_params, + GraphFloat32* graph, + std::unordered_map* quant_conversion_map = nullptr); // Module-internal converter, exposed for unit testing purpose only. absl::Status ConvertTfLiteTensorToTensorRef(const TfLiteTensor& tflite_tensor, diff --git a/tensorflow/lite/delegates/gpu/common/model_builder_helper.h b/tensorflow/lite/delegates/gpu/common/model_builder_helper.h index cf1367079c7..4dc76eb22ae 100644 --- a/tensorflow/lite/delegates/gpu/common/model_builder_helper.h +++ b/tensorflow/lite/delegates/gpu/common/model_builder_helper.h @@ -158,7 +158,9 @@ class GraphWithDequantPartitionHelper TfLiteRegistration* registration = nullptr; GetNodeAndRegistration(context_, node_id, &node, ®istration) .IgnoreError(); - if (registration->builtin_code != kTfLiteBuiltinDequantize) { + if (registration->builtin_code != kTfLiteBuiltinDequantize || + context_->tensors[node->inputs->data[0]].type != + TfLiteType::kTfLiteFloat16) { ++it; continue; } diff --git a/tensorflow/lite/delegates/gpu/common/model_builder_test.cc b/tensorflow/lite/delegates/gpu/common/model_builder_test.cc index 214d02599d5..7b12f46453d 100644 --- a/tensorflow/lite/delegates/gpu/common/model_builder_test.cc +++ b/tensorflow/lite/delegates/gpu/common/model_builder_test.cc @@ -828,6 +828,187 @@ TEST(ModelBuilderTest, GetOpsToReplaceSelectsCorrectDequantsGreaterFirst) { TfLiteIntArrayFree(ops_to_replace); } +// Adds the pattern: +// +// float -> QUANTIZE -> ADD -> DEQUANTIZE -> float +// float -> QUANTIZE ----^ +// +// The tensors between the QUANTIZE & DEQUANTIZE nodes are int8. +class InterpreterQuantized : public DelegatedInterpreter { + public: + InterpreterQuantized() : DelegatedInterpreter(4) { + void* builtin_data = malloc(sizeof(int)); + EXPECT_EQ(interpreter_.AddTensors(6), kTfLiteOk); + EXPECT_EQ(interpreter_.SetInputs({0, 3}), kTfLiteOk); + EXPECT_EQ(interpreter_.SetOutputs({5}), kTfLiteOk); + + // QUANTIZE 1 + const TfLiteRegistration reg_quant0 = {/*init=*/nullptr, + /*free=*/nullptr, + /*prepare=*/nullptr, + /*invoke=*/nullptr, + /*profiling_string=*/nullptr, + kTfLiteBuiltinQuantize}; + EXPECT_EQ(interpreter_.AddNodeWithParameters( + /*inputs=*/{0}, /*outputs=*/{1}, /*init_data=*/nullptr, + /*init_data_size=*/0, /*builtin_data=*/nullptr, + /*registration=*/®_quant0), + kTfLiteOk); + + // QUANTIZE 2 + const TfLiteRegistration reg_quant1 = {/*init=*/nullptr, + /*free=*/nullptr, + /*prepare=*/nullptr, + /*invoke=*/nullptr, + /*profiling_string=*/nullptr, + kTfLiteBuiltinQuantize}; + EXPECT_EQ(interpreter_.AddNodeWithParameters( + /*inputs=*/{3}, /*outputs=*/{2}, /*init_data=*/nullptr, + /*init_data_size=*/0, /*builtin_data=*/nullptr, + /*registration=*/®_quant1), + kTfLiteOk); + + // ADD + const TfLiteRegistration reg_add0 = { + [](TfLiteContext* context, const char* buffer, size_t length) { + return reinterpret_cast(new int(1)); + }, + [](TfLiteContext* context, void* buffer) { + delete reinterpret_cast(buffer); + }, + nullptr, + nullptr, + nullptr, + kTfLiteBuiltinAdd}; + EXPECT_EQ(interpreter_.AddNodeWithParameters( + /*inputs=*/{1, 2}, /*outputs=*/{4}, /*init_data=*/nullptr, + /*init_data_size=*/0, + /*builtin_data=*/builtin_data, + /*registration=*/®_add0), + kTfLiteOk); + + // DEQUANTIZE + const TfLiteRegistration reg_dequant0 = {/*init=*/nullptr, + /*free=*/nullptr, + /*prepare=*/nullptr, + /*invoke=*/nullptr, + /*profiling_string=*/nullptr, + kTfLiteBuiltinDequantize}; + EXPECT_EQ(interpreter_.AddNodeWithParameters( + /*inputs=*/{4}, /*outputs=*/{5}, /*init_data=*/nullptr, + /*init_data_size=*/0, /*builtin_data=*/nullptr, + /*registration=*/®_dequant0), + kTfLiteOk); + + const std::vector dims = {1, 3, 3, 2}; + + // Input & output tensors are floating-point. + TfLiteQuantization no_quantization; + no_quantization.type = kTfLiteNoQuantization; + EXPECT_EQ( + interpreter_.SetTensorParametersReadWrite( + 0, TfLiteType::kTfLiteFloat32, "t0", dims, no_quantization, false), + kTfLiteOk); + EXPECT_EQ( + interpreter_.SetTensorParametersReadWrite( + 3, TfLiteType::kTfLiteFloat32, "t3", dims, no_quantization, false), + kTfLiteOk); + EXPECT_EQ( + interpreter_.SetTensorParametersReadWrite( + 5, TfLiteType::kTfLiteFloat32, "t5", dims, no_quantization, false), + kTfLiteOk); + // Other tensors are int8. + float scale = 0.5f; + int32_t zero_point = 12; + TfLiteQuantization rw_quantization; + rw_quantization.type = kTfLiteAffineQuantization; + auto* rw_affine_quantization = static_cast( + malloc(sizeof(TfLiteAffineQuantization))); + rw_affine_quantization->scale = TfLiteFloatArrayCreate(1); + rw_affine_quantization->zero_point = TfLiteIntArrayCreate(1); + rw_affine_quantization->scale->data[0] = scale; + rw_affine_quantization->zero_point->data[0] = zero_point; + rw_quantization.params = rw_affine_quantization; + EXPECT_EQ( + interpreter_.SetTensorParametersReadWrite( + 1, TfLiteType::kTfLiteInt8, "t1", dims, rw_quantization, false), + kTfLiteOk); + EXPECT_EQ( + interpreter_.SetTensorParametersReadWrite( + 2, TfLiteType::kTfLiteInt8, "t2", dims, rw_quantization, false), + kTfLiteOk); + EXPECT_EQ( + interpreter_.SetTensorParametersReadWrite( + 4, TfLiteType::kTfLiteInt8, "t4", dims, rw_quantization, false), + kTfLiteOk); + + exec_plan()->data[0] = 0; + exec_plan()->data[1] = 1; + exec_plan()->data[2] = 2; + exec_plan()->data[3] = 3; + } +}; + +InterpreterQuantized* interpreter_quant = new InterpreterQuantized(); +TEST(ModelBuilderTest, GetOpsToReplace_AllowQuantOps) { + TfLiteContext* context = interpreter_quant->context(); + + // These functions are meant to be called inside delegates. Swap out + // for similar functions to permit direct calling of GetOpsToReplace. + context->GetExecutionPlan = [](struct TfLiteContext* context, + TfLiteIntArray** execution_plan) { + *execution_plan = interpreter_quant->exec_plan(); + return kTfLiteOk; + }; + context->GetNodeAndRegistration = [](struct TfLiteContext*, int node_index, + TfLiteNode** node, + TfLiteRegistration** registration) { + auto& node_and_reg = + interpreter_quant->nodes_and_registration()[node_index]; + *node = &node_and_reg.first; + *registration = &node_and_reg.second; + return kTfLiteOk; + }; + context->PreviewDelegatePartitioning = + [](struct TfLiteContext* context, const TfLiteIntArray* nodes_to_replace, + TfLiteDelegateParams** partition_params_array, int* num_partitions) { + if (nodes_to_replace->size == 0) { + *num_partitions = 0; + return kTfLiteOk; + } + auto params = interpreter_quant->add_delegate_params(); + params->nodes_to_replace = TfLiteIntArrayCreate(3); + params->nodes_to_replace->data[0] = 0; + params->nodes_to_replace->data[1] = 1; + params->nodes_to_replace->data[2] = 2; + params->input_tensors = TfLiteIntArrayCreate(2); + params->input_tensors->data[0] = 0; + params->input_tensors->data[1] = 3; + params->output_tensors = TfLiteIntArrayCreate(1); + params->output_tensors->data[0] = 4; + + *partition_params_array = interpreter_quant->delegate_params(); + *num_partitions = interpreter_quant->num_delegate_params(); + return kTfLiteOk; + }; + + TfLiteIntArray* ops_to_replace = + GetOpsToReplace(context, /**allow_quant_ops=*/true); + // If we allow quant ops, two QUANTIZE & one ADD node should be accepted. + EXPECT_EQ(ops_to_replace->size, 3); + EXPECT_EQ(0, ops_to_replace->data[0]); + EXPECT_EQ(1, ops_to_replace->data[1]); + EXPECT_EQ(2, ops_to_replace->data[2]); + + TfLiteIntArray* ops_to_replace_without_quant = + GetOpsToReplace(context, /**allow_quant_ops=*/false); + // No ops should be accepted. + EXPECT_EQ(ops_to_replace_without_quant->size, 0); + + TfLiteIntArrayFree(ops_to_replace); + TfLiteIntArrayFree(ops_to_replace_without_quant); +} + } // namespace } // namespace gpu } // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/common/transformations/BUILD b/tensorflow/lite/delegates/gpu/common/transformations/BUILD index 232aa8f5161..4c76e4a81d3 100644 --- a/tensorflow/lite/delegates/gpu/common/transformations/BUILD +++ b/tensorflow/lite/delegates/gpu/common/transformations/BUILD @@ -105,6 +105,7 @@ cc_library( srcs = ["general_transformations.cc"], hdrs = ["general_transformations.h"], deps = [ + ":add_quant_adjustments", ":fuse_add_to_conv", ":fuse_mul_to_conv", ":make_fully_connected", diff --git a/tensorflow/lite/delegates/gpu/common/transformations/general_transformations.cc b/tensorflow/lite/delegates/gpu/common/transformations/general_transformations.cc index 0e1273f8d4a..354fbcd040b 100644 --- a/tensorflow/lite/delegates/gpu/common/transformations/general_transformations.cc +++ b/tensorflow/lite/delegates/gpu/common/transformations/general_transformations.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/transformations/general_transformations.h" +#include "tensorflow/lite/delegates/gpu/common/transformations/add_quant_adjustments.h" #include "tensorflow/lite/delegates/gpu/common/transformations/fuse_add_to_conv.h" #include "tensorflow/lite/delegates/gpu/common/transformations/fuse_mul_to_conv.h" #include "tensorflow/lite/delegates/gpu/common/transformations/make_fully_connected.h" @@ -28,7 +29,9 @@ namespace gpu { bool ApplyGeneralTransformations(ModelTransformer* transformer) { // whenever any of these transforms return false, that means that a graph // is in the broken state and processing should not continue. - return transformer->Apply("remove_degenerate_upsampling", + return transformer->Apply("add_quant_adjustments", + NewAddQuantAdjustments().get()) && + transformer->Apply("remove_degenerate_upsampling", NewRemoveDegenerateUpsampling().get()) && transformer->Apply("remove_single_input_add", NewRemoveSingleInputAdd().get()) && diff --git a/tensorflow/lite/delegates/gpu/delegate.cc b/tensorflow/lite/delegates/gpu/delegate.cc index 7bfc977f7af..540c8ba8c18 100644 --- a/tensorflow/lite/delegates/gpu/delegate.cc +++ b/tensorflow/lite/delegates/gpu/delegate.cc @@ -18,11 +18,13 @@ limitations under the License. #include #include #include // NOLINT(build/c++11) +#include #include #include "absl/memory/memory.h" #include "absl/types/span.h" #include "tensorflow/lite/builtin_ops.h" +#include "tensorflow/lite/c/common.h" #include "tensorflow/lite/delegates/gpu/api.h" #include "tensorflow/lite/delegates/gpu/cl/api.h" #include "tensorflow/lite/delegates/gpu/cl/opencl_wrapper.h" @@ -32,6 +34,7 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/model_transformer.h" #include "tensorflow/lite/delegates/gpu/common/status.h" #include "tensorflow/lite/delegates/gpu/gl/api2.h" +#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" #include "tensorflow/lite/minimal_logging.h" namespace tflite { @@ -74,6 +77,11 @@ class Delegate { TfLiteDelegate* tflite_delegate() { return &delegate_; } const TfLiteGpuDelegateOptionsV2& options() const { return options_; } + bool IsQuantOpsAllowed() { + return options_.experimental_flags & + TFLITE_GPU_EXPERIMENTAL_FLAGS_ENABLE_QUANT; + } + private: TfLiteDelegate delegate_ = { .data_ = reinterpret_cast(this), @@ -100,24 +108,10 @@ class DelegateKernel { // Extract TFLite delegate execution plan from the context and convert it // into GraphFloat32. GraphFloat32 graph; - RETURN_IF_ERROR(BuildFinalModel(context, delegate_params, &graph)); - std::vector input_refs; - { - const auto& inputs = graph.inputs(); - input_refs.reserve(inputs.size()); - for (auto input : inputs) { - input_refs.push_back(input->tensor.ref); - } - } std::vector output_refs; - { - const auto& outputs = graph.outputs(); - output_refs.reserve(outputs.size()); - for (auto output : outputs) { - output_refs.push_back(output->tensor.ref); - } - } + RETURN_IF_ERROR(InitializeGraph(context, delegate_params, &graph, + &input_refs, &output_refs)); std::unique_ptr builder; bool graph_is_destroyed; @@ -130,7 +124,8 @@ class DelegateKernel { // Graph need to be re-created because it is moved above. GraphFloat32 graph2; if (graph_is_destroyed) { - RETURN_IF_ERROR(BuildFinalModel(context, delegate_params, &graph2)); + RETURN_IF_ERROR(InitializeGraph(context, delegate_params, &graph2, + &input_refs, &output_refs)); } RETURN_IF_ERROR( InitializeOpenGlApi(graph_is_destroyed ? &graph2 : &graph, &builder)); @@ -156,6 +151,30 @@ class DelegateKernel { return builder->Build(&runner_); } + // This directs the runtime to allocate memory for input/output temporary + // tensors that require dequantization/quantization. + absl::Status GetRequiredTemporaries(TfLiteContext* context, TfLiteNode* node, + TfLiteIntArray** temporaries_array_ptr) { + if (quant_conversion_map_.empty()) return absl::OkStatus(); + + std::vector temporary_tensors; + for (auto index : input_indices_) { + if (quant_conversion_map_.find(index) != quant_conversion_map_.end()) { + temporary_tensors.push_back(index); + } + } + for (auto index : output_indices_) { + if (quant_conversion_map_.find(index) != quant_conversion_map_.end()) { + temporary_tensors.push_back(index); + } + } + *temporaries_array_ptr = TfLiteIntArrayCreate(temporary_tensors.size()); + for (int i = 0; i < temporary_tensors.size(); ++i) { + (*temporaries_array_ptr)->data[i] = temporary_tensors[i]; + } + return absl::OkStatus(); + } + absl::Status Invoke(TfLiteContext* context) { if (thread_id_prepare_ != std::this_thread::get_id()) { TFLITE_LOG(tflite::TFLITE_LOG_WARNING, @@ -167,8 +186,16 @@ class DelegateKernel { } } + const bool is_dequant_required = !quant_conversion_map_.empty(); + if (is_dequant_required) { + RETURN_IF_ERROR(DequantizeInputs(context)); + } RETURN_IF_ERROR(SetInputsAndOutputs(context)); - return runner_->Run(); + RETURN_IF_ERROR(runner_->Run()); + if (is_dequant_required) { + RETURN_IF_ERROR(QuantizeOutputs(context)); + } + return absl::OkStatus(); } private: @@ -198,6 +225,101 @@ class DelegateKernel { return MakeCpuMemory(absl::MakeSpan(tensor.data.raw, tensor.bytes)); } + private: + absl::Status InitializeGraph(TfLiteContext* context, + const TfLiteDelegateParams* delegate_params, + GraphFloat32* graph, + std::vector* input_refs, + std::vector* output_refs) { + quant_conversion_map_.clear(); + if (options_.experimental_flags & + TFLITE_GPU_EXPERIMENTAL_FLAGS_ENABLE_QUANT) { + RETURN_IF_ERROR(BuildFinalModel(context, delegate_params, graph, + &quant_conversion_map_)); + } else { + RETURN_IF_ERROR(BuildFinalModel(context, delegate_params, graph)); + } + + input_refs->clear(); + output_refs->clear(); + const auto& inputs = graph->inputs(); + input_refs->reserve(inputs.size()); + for (const auto& input : inputs) { + input_refs->push_back(input->tensor.ref); + } + const auto& outputs = graph->outputs(); + output_refs->reserve(outputs.size()); + for (const auto& output : outputs) { + output_refs->push_back(output->tensor.ref); + } + + return absl::OkStatus(); + } + + // TODO(b/150798231): Refactor these two into common utils when generalizing + // to other backends. + + // Dequantizes input tensors pre-inference, leaving float tensors intact. + absl::Status DequantizeInputs(TfLiteContext* context) { + for (auto index : input_indices_) { + if (quant_conversion_map_.find(index) == quant_conversion_map_.end()) { + continue; + } + int original_tensor_idx = quant_conversion_map_[index]; + const TfLiteTensor& dequantized_tflite_tensor = context->tensors[index]; + const TfLiteTensor& original_tflite_tensor = + context->tensors[original_tensor_idx]; + DequantizationParams op_params; + op_params.zero_point = original_tflite_tensor.params.zero_point; + op_params.scale = original_tflite_tensor.params.scale; + if (original_tflite_tensor.type == kTfLiteInt8) { + optimized_ops::Dequantize(op_params, + GetTensorShape(&original_tflite_tensor), + original_tflite_tensor.data.int8, + GetTensorShape(&original_tflite_tensor), + dequantized_tflite_tensor.data.f); + } else if (original_tflite_tensor.type == kTfLiteUInt8) { + optimized_ops::Dequantize(op_params, + GetTensorShape(&original_tflite_tensor), + original_tflite_tensor.data.uint8, + GetTensorShape(&original_tflite_tensor), + dequantized_tflite_tensor.data.f); + } + } + return absl::OkStatus(); + } + + // Quantizes output tensors post-inference, leaving float tensors intact. + absl::Status QuantizeOutputs(TfLiteContext* context) { + for (auto index : output_indices_) { + if (quant_conversion_map_.find(index) == quant_conversion_map_.end()) { + continue; + } + int original_tensor_idx = quant_conversion_map_[index]; + const TfLiteTensor& dequantized_tflite_tensor = context->tensors[index]; + const TfLiteTensor& original_tflite_tensor = + context->tensors[original_tensor_idx]; + tflite::QuantizationParams op_params; + op_params.zero_point = original_tflite_tensor.params.zero_point; + op_params.scale = original_tflite_tensor.params.scale; + if (original_tflite_tensor.type == kTfLiteInt8) { + optimized_ops::AffineQuantize(op_params, + GetTensorShape(&original_tflite_tensor), + dequantized_tflite_tensor.data.f, + GetTensorShape(&original_tflite_tensor), + original_tflite_tensor.data.int8); + } else if (original_tflite_tensor.type == kTfLiteUInt8) { + optimized_ops::AffineQuantize(op_params, + GetTensorShape(&original_tflite_tensor), + dequantized_tflite_tensor.data.f, + GetTensorShape(&original_tflite_tensor), + original_tflite_tensor.data.uint8); + } + } + + return absl::OkStatus(); + } + absl::Status InitializeOpenClApi(GraphFloat32* graph, std::unique_ptr* builder, bool* graph_is_destroyed) { @@ -257,6 +379,10 @@ class DelegateKernel { std::unique_ptr runner_; std::vector input_indices_; std::vector output_indices_; + // Whenever quantized inference is enabled, this maps the tensor index of each + // originally quantized (8-bit) tensor to its float version added in + // model_builder - and vice versa. + std::unordered_map quant_conversion_map_; std::thread::id thread_id_prepare_; // thread id used for Prapare() bool enforce_same_thread_ = false; // flag to enforce same thread for Invoke }; @@ -300,6 +426,14 @@ TfLiteStatus DelegatePrepare(TfLiteContext* context, TfLiteDelegate* delegate) { "TfLiteGpuDelegate Prepare: delegate is not initialized"); return kTfLiteError; } + auto* gpu_delegate_kernel = GetDelegateKernel(node); + const auto status = gpu_delegate_kernel->GetRequiredTemporaries( + context, node, &node->temporaries); + if (!status.ok()) { + TF_LITE_KERNEL_LOG(context, "TfLiteGpuDelegate Prepare: %s", + std::string(status.message()).c_str()); + return kTfLiteError; + } // TODO(akulik): tflite tensors are not allocated here either. It would // be good to set inputs and outputs only once here instead of setting // them every time in .invoke. @@ -320,7 +454,8 @@ TfLiteStatus DelegatePrepare(TfLiteContext* context, TfLiteDelegate* delegate) { "TfLiteGpuDelegateV2", // .custom_name 1, // .version }; - TfLiteIntArray* ops_to_replace = GetOpsToReplace(context); + TfLiteIntArray* ops_to_replace = GetOpsToReplace( + context, /*allow_quant_ops=*/GetDelegate(delegate)->IsQuantOpsAllowed()); const auto status = context->ReplaceNodeSubsetsWithDelegateKernels( context, kRegistration, ops_to_replace, delegate); TfLiteIntArrayFree(ops_to_replace); @@ -340,6 +475,7 @@ TfLiteGpuDelegateOptionsV2 TfLiteGpuDelegateOptionsV2Default() { options.inference_priority1 = TFLITE_GPU_INFERENCE_PRIORITY_MAX_PRECISION; options.inference_priority2 = TFLITE_GPU_INFERENCE_PRIORITY_AUTO; options.inference_priority3 = TFLITE_GPU_INFERENCE_PRIORITY_AUTO; + options.experimental_flags = TFLITE_GPU_EXPERIMENTAL_FLAGS_NONE; return options; } diff --git a/tensorflow/lite/delegates/gpu/delegate.h b/tensorflow/lite/delegates/gpu/delegate.h index 29bececf39b..a60ebec84fe 100644 --- a/tensorflow/lite/delegates/gpu/delegate.h +++ b/tensorflow/lite/delegates/gpu/delegate.h @@ -60,6 +60,14 @@ enum TfLiteGpuInferencePriority { TFLITE_GPU_INFERENCE_PRIORITY_MIN_MEMORY_USAGE = 3, }; +// Used to toggle experimental flags used in the delegate. Note that this is a +// bitmask, so the values should be 1, 2, 4, 8, ...etc. +enum TfLiteGpuExperimentalFlags { + TFLITE_GPU_EXPERIMENTAL_FLAGS_NONE = 0, + // Enables inference on quantized models with the delegate. + TFLITE_GPU_EXPERIMENTAL_FLAGS_ENABLE_QUANT = 1 << 0 +}; + // IMPORTANT: Always use TfLiteGpuDelegateOptionsV2Default() method to create // new instance of TfLiteGpuDelegateOptionsV2, otherwise every new added option // may break inference. @@ -95,6 +103,9 @@ typedef struct { int32_t inference_priority1; int32_t inference_priority2; int32_t inference_priority3; + + // Bitmask flags. See the comments in TfLiteGpuExperimentalFlags. + int64_t experimental_flags; } TfLiteGpuDelegateOptionsV2; // Populates TfLiteGpuDelegateOptionsV2 as follows: diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/quantize_and_dequantize.cc b/tensorflow/lite/delegates/gpu/gl/kernels/quantize_and_dequantize.cc index 1d45e07aeee..7d24b3d1798 100644 --- a/tensorflow/lite/delegates/gpu/gl/kernels/quantize_and_dequantize.cc +++ b/tensorflow/lite/delegates/gpu/gl/kernels/quantize_and_dequantize.cc @@ -33,17 +33,12 @@ class QuantizeAndDequantize : public NodeShader { public: absl::Status GenerateCode(const GenerationContext& ctx, GeneratedCode* generated_code) const final { - std::string code; - // Constants - code += "vec4 scale = vec4($quant_scale$);"; - code += "vec4 min_bound = vec4($quant_min$);"; - code += "vec4 max_bound = vec4($quant_max$);"; - // Quantize - code += "value_0 = clamp(value_0, min_bound, max_bound);"; - code += "value_0 = (value_0 - min_bound) / scale;"; - code += "value_0 = floor(value_0 + vec4(0.5));"; - // Dequantize - code += "value_0 = value_0 * scale + min_bound;"; + std::string code = R"( +value_0 = clamp(value_0, vec4($quant_min$), vec4($quant_max$)); +value_0 = (value_0 - vec4($quant_min$)) / vec4($quant_scale$); +value_0 = floor(value_0 + vec4(0.5)); +value_0 = value_0 * vec4($quant_scale$) + vec4($quant_min$); +)"; auto attr = absl::any_cast( ctx.node->operation.attributes); diff --git a/tensorflow/lite/delegates/nnapi/acceleration_test_list.cc b/tensorflow/lite/delegates/nnapi/acceleration_test_list.cc index a2df868f866..c22cbe86175 100644 --- a/tensorflow/lite/delegates/nnapi/acceleration_test_list.cc +++ b/tensorflow/lite/delegates/nnapi/acceleration_test_list.cc @@ -263,6 +263,8 @@ FloatPoolingOpTest/L2PoolActivationRelu.*,29 FloatPoolingOpTest/.+ # Image is too big -QuantizedPoolingOpTest/AveragePoolImageSize17 +# Int16 unsupported +-QuantizedPoolingOpTest/SymmetricAveragePool16 QuantizedPoolingOpTest/.+ QuantizedUInt8PoolingOpTest/.+ diff --git a/tensorflow/lite/experimental/delegates/coreml/builders/fully_connected_op_builder.cc b/tensorflow/lite/experimental/delegates/coreml/builders/fully_connected_op_builder.cc index 8ca2c442ed1..2efc767d703 100644 --- a/tensorflow/lite/experimental/delegates/coreml/builders/fully_connected_op_builder.cc +++ b/tensorflow/lite/experimental/delegates/coreml/builders/fully_connected_op_builder.cc @@ -53,7 +53,7 @@ void FullyConnectedOpBuilder::FillCoreMLWeights() { layer_->mutable_innerproduct()->set_outputchannels(weights_->dims->data[0]); const float* weights_data = GetTensorData(weights_); std::copy(weights_data, weights_data + NumElements(weights_), - proto2::RepeatedFieldBackInserter(layer_->mutable_innerproduct() + google::protobuf::RepeatedFieldBackInserter(layer_->mutable_innerproduct() ->mutable_weights() ->mutable_floatvalue())); } @@ -63,7 +63,7 @@ void FullyConnectedOpBuilder::FillCoreMLBias() { layer_->mutable_innerproduct()->set_hasbias(true); const float* bias_data = GetTensorData(bias_); std::copy(bias_data, bias_data + NumElements(bias_), - proto2::RepeatedFieldBackInserter(layer_->mutable_innerproduct() + google::protobuf::RepeatedFieldBackInserter(layer_->mutable_innerproduct() ->mutable_bias() ->mutable_floatvalue())); } diff --git a/tensorflow/lite/experimental/kernels/BUILD b/tensorflow/lite/experimental/kernels/BUILD index 671d7f65851..e5d789690d3 100644 --- a/tensorflow/lite/experimental/kernels/BUILD +++ b/tensorflow/lite/experimental/kernels/BUILD @@ -1,5 +1,4 @@ load("//tensorflow/lite:build_def.bzl", "tflite_copts") -load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc") package( default_visibility = [ @@ -173,14 +172,3 @@ cc_test( "@flatbuffers", ], ) - -tf_py_wrap_cc( - name = "hashtable_ops_py_wrapper", - srcs = [ - "hashtable_ops.i", - ], - deps = [ - ":hashtable_op_kernels", - "//third_party/python_runtime:headers", - ], -) diff --git a/tensorflow/lite/experimental/support/codegen/python/codegen_lib.cc b/tensorflow/lite/experimental/support/codegen/python/codegen_lib.cc index e3db29b1959..979ef7fe72a 100644 --- a/tensorflow/lite/experimental/support/codegen/python/codegen_lib.cc +++ b/tensorflow/lite/experimental/support/codegen/python/codegen_lib.cc @@ -13,10 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "include/pybind11/detail/common.h" -#include "include/pybind11/pybind11.h" -#include "include/pybind11/pytypes.h" -#include "include/pybind11/stl.h" +#include "pybind11/detail/common.h" +#include "pybind11/pybind11.h" +#include "pybind11/pytypes.h" +#include "pybind11/stl.h" #include "tensorflow/lite/experimental/support/codegen/android_java_generator.h" #include "tensorflow/lite/experimental/support/codegen/code_generator.h" diff --git a/tensorflow/lite/experimental/support/metadata/flatbuffers_lib/flatbuffers_lib.cc b/tensorflow/lite/experimental/support/metadata/flatbuffers_lib/flatbuffers_lib.cc index 3f5d4e221c7..6185722504f 100644 --- a/tensorflow/lite/experimental/support/metadata/flatbuffers_lib/flatbuffers_lib.cc +++ b/tensorflow/lite/experimental/support/metadata/flatbuffers_lib/flatbuffers_lib.cc @@ -15,9 +15,9 @@ limitations under the License. #include "flatbuffers/flatbuffers.h" // from @flatbuffers #include "flatbuffers/idl.h" // from @flatbuffers -#include "include/pybind11/pybind11.h" -#include "include/pybind11/pytypes.h" -#include "include/pybind11/stl.h" +#include "pybind11/pybind11.h" +#include "pybind11/pytypes.h" +#include "pybind11/stl.h" namespace tflite { namespace support { diff --git a/tensorflow/lite/experimental/support/metadata/java/BUILD b/tensorflow/lite/experimental/support/metadata/java/BUILD index 3b51b4a8a27..f1cd6173b9e 100644 --- a/tensorflow/lite/experimental/support/metadata/java/BUILD +++ b/tensorflow/lite/experimental/support/metadata/java/BUILD @@ -16,7 +16,7 @@ android_library( deps = [ "//tensorflow/lite/experimental/support/metadata:metadata_schema_fbs_android", "//tensorflow/lite/experimental/support/metadata:schema_fbs_android", - "//tensorflow/lite/java:tensorflowlite", + "//tensorflow/lite/java:tensorflowlite_java", "@org_checkerframework_qual", ], ) @@ -28,7 +28,7 @@ java_library( deps = [ "//tensorflow/lite/experimental/support/metadata:metadata_schema_java", "//tensorflow/lite/experimental/support/metadata:schema_fbs_java", - "//tensorflow/lite/java:tensorflowlitelib", + "//tensorflow/lite/java:tensorflowlite_javalib", "@org_checkerframework_qual", ], ) diff --git a/tensorflow/lite/g3doc/_book.yaml b/tensorflow/lite/g3doc/_book.yaml index 1a769615eef..ab0958bea2f 100644 --- a/tensorflow/lite/g3doc/_book.yaml +++ b/tensorflow/lite/g3doc/_book.yaml @@ -136,7 +136,7 @@ upper_tabs: path: /lite/performance/model_optimization - title: "Post-training quantization" path: /lite/performance/post_training_quantization - - title: "Post-training weight quantization" + - title: "Post-training dynamic range quantization" path: /lite/performance/post_training_quant - title: "Post-training integer quantization" path: /lite/performance/post_training_integer_quant diff --git a/tensorflow/lite/g3doc/performance/model_optimization.md b/tensorflow/lite/g3doc/performance/model_optimization.md index 5a5772b4a1f..feb6cfecea6 100644 --- a/tensorflow/lite/g3doc/performance/model_optimization.md +++ b/tensorflow/lite/g3doc/performance/model_optimization.md @@ -86,12 +86,12 @@ a smaller model size and faster computation. The following types of quantization are available in TensorFlow Lite: -Technique | Data requirements | Size reduction | Accuracy | Supported hardware --------------------------------------------------------------------------------------------------------------- | -------------------------------- | -------------- | --------------------------- | ------------------ -[Post-training float16 quantization](post_training_float16_quant.ipynb) | No data | Up to 50% | Insignificant accuracy loss | CPU, GPU -[Post-training dynamic range quantization](post_training_quant.ipynb) | No data | Up to 75% | Accuracy loss | CPU -[Post-training integer quantization](post_training_integer_quant.ipynb) | Unlabelled representative sample | Up to 75% | Smaller accuracy loss | CPU, EdgeTPU, Hexagon DSP -[Quantization-aware training](https://github.com/tensorflow/tensorflow/tree/r1.13/tensorflow/contrib/quantize) | Labelled training data | Up to 75% | Smallest accuracy loss | CPU, EdgeTPU, Hexagon DSP +Technique | Data requirements | Size reduction | Accuracy | Supported hardware +------------------------------------------------------------------------------------------------------- | -------------------------------- | -------------- | --------------------------- | ------------------ +[Post-training float16 quantization](post_training_float16_quant.ipynb) | No data | Up to 50% | Insignificant accuracy loss | CPU, GPU +[Post-training dynamic range quantization](post_training_quant.ipynb) | No data | Up to 75% | Accuracy loss | CPU +[Post-training integer quantization](post_training_integer_quant.ipynb) | Unlabelled representative sample | Up to 75% | Smaller accuracy loss | CPU, EdgeTPU, Hexagon DSP +[Quantization-aware training](http://www.tensorflow.org/model_optimization/guide/quantization/training) | Labelled training data | Up to 75% | Smallest accuracy loss | CPU, EdgeTPU, Hexagon DSP Below are the latency and accuracy results for post-training quantization and quantization-aware training on a few models. All latency numbers are measured on @@ -144,11 +144,9 @@ broadly applicable and does not require training data. For cases where the accuracy and latency targets are not met, or hardware accelerator support is important, -[quantization-aware training](https://github.com/tensorflow/tensorflow/tree/r1.13/tensorflow/contrib/quantize){:.external} +[quantization-aware training](https://www.tensorflow.org/model_optimization/guide/quantization/training){:.external} is the better option. See additional optimization techniques under the [Tensorflow Model Optimization Toolkit](https://www.tensorflow.org/model_optimization). -Note: Quantization-aware training supports a subset of convolutional neural network architectures. - If you want to further reduce your model size, you can try [pruning](#pruning) prior to quantizing your models. diff --git a/tensorflow/lite/g3doc/performance/post_training_quant.ipynb b/tensorflow/lite/g3doc/performance/post_training_quant.ipynb index 1d566cadc84..d6edb656d0e 100644 --- a/tensorflow/lite/g3doc/performance/post_training_quant.ipynb +++ b/tensorflow/lite/g3doc/performance/post_training_quant.ipynb @@ -41,7 +41,7 @@ "id": "6Y8E0lw5eYWm" }, "source": [ - "# Post-training weight quantization" + "# Post-training dynamic range quantization" ] }, { @@ -75,9 +75,7 @@ "\n", "[TensorFlow Lite](https://www.tensorflow.org/lite/) now supports\n", "converting weights to 8 bit precision as part of model conversion from\n", - "tensorflow graphdefs to TensorFlow Lite's flat buffer format. Weight quantization\n", - "achieves a 4x reduction in the model size. In addition, TFLite supports on the\n", - "fly quantization and dequantization of activations to allow for:\n", + "tensorflow graphdefs to TensorFlow Lite's flat buffer format. Dynamic range quantization achieves a 4x reduction in the model size. In addition, TFLite supports on the fly quantization and dequantization of activations to allow for:\n", "\n", "1. Using quantized kernels for faster implementation when available.\n", "2. Mixing of floating-point kernels with quantized kernels for different parts\n", @@ -99,7 +97,7 @@ "\n", "This tutorial trains an MNIST model from scratch, checks its accuracy in\n", "TensorFlow, and then converts the model into a Tensorflow Lite flatbuffer\n", - "with weight quantization. Finally, it checks the\n", + "with dynamic range quantization. Finally, it checks the\n", "accuracy of the converted model and compare it to the original float model." ] }, @@ -295,7 +293,7 @@ }, "outputs": [], "source": [ - "converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE]\n", + "converter.optimizations = [tf.lite.Optimize.DEFAULT]\n", "tflite_quant_model = converter.convert()\n", "tflite_model_quant_file = tflite_models_dir/\"mnist_model_quant.tflite\"\n", "tflite_model_quant_file.write_bytes(tflite_quant_model)" @@ -497,7 +495,7 @@ "id": "Km3cY9ry8ZlG" }, "source": [ - "Repeat the evaluation on the weight quantized model to obtain:\n" + "Repeat the evaluation on the dynamic range quantized model to obtain:\n" ] }, { @@ -586,7 +584,7 @@ "outputs": [], "source": [ "# Convert to TF Lite with quantization\n", - "converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE]\n", + "converter.optimizations = [tf.lite.Optimize.DEFAULT]\n", "resnet_quantized_tflite_file = tflite_models_dir/\"resnet_v2_101_quantized.tflite\"\n", "resnet_quantized_tflite_file.write_bytes(converter.convert())" ] diff --git a/tensorflow/lite/interpreter.cc b/tensorflow/lite/interpreter.cc index af72ec258d9..1aabdf6409b 100644 --- a/tensorflow/lite/interpreter.cc +++ b/tensorflow/lite/interpreter.cc @@ -197,6 +197,11 @@ TfLiteStatus Interpreter::ResizeInputTensor(int tensor_index, return primary_subgraph().ResizeInputTensor(tensor_index, dims); } +TfLiteStatus Interpreter::ResizeInputTensorStrict( + int tensor_index, const std::vector& dims) { + return primary_subgraph().ResizeInputTensorStrict(tensor_index, dims); +} + TfLiteStatus Interpreter::ReleaseNonPersistentMemory() { // TODO(b/138790287): We could do this for all subgraphs whose tensors have // been allocated. However, AllocateTensors() relies on Control Flow ops to @@ -256,10 +261,12 @@ TfLiteStatus Interpreter::SetTensorParametersReadOnly( TfLiteStatus Interpreter::SetTensorParametersReadWrite( int tensor_index, TfLiteType type, const char* name, const size_t rank, - const int* dims, TfLiteQuantizationParams quantization, bool is_variable) { + const int* dims, TfLiteQuantizationParams quantization, bool is_variable, + const size_t rank_dims_signature, const int* dims_signature) { TfLiteQuantization new_quantization = GetQuantizationFromLegacy(quantization); return primary_subgraph().SetTensorParametersReadWrite( - tensor_index, type, name, rank, dims, new_quantization, is_variable); + tensor_index, type, name, rank, dims, new_quantization, is_variable, + rank_dims_signature, dims_signature); } TfLiteStatus Interpreter::SetExecutionPlan(const std::vector& new_plan) { diff --git a/tensorflow/lite/interpreter.h b/tensorflow/lite/interpreter.h index a869c1368d2..c6a86572682 100644 --- a/tensorflow/lite/interpreter.h +++ b/tensorflow/lite/interpreter.h @@ -166,14 +166,23 @@ class Interpreter { inline TfLiteStatus SetTensorParametersReadWrite( int tensor_index, TfLiteType type, const char* name, const std::vector& dims, TfLiteQuantizationParams quantization, - bool is_variable = false) { - return SetTensorParametersReadWrite(tensor_index, type, name, dims.size(), - dims.data(), quantization, is_variable); + bool is_variable = false, + const std::vector* dims_signature = nullptr) { + size_t rank_dims_signature = 0; + const int* dims_signature_pointer = nullptr; + if (dims_signature) { + rank_dims_signature = dims_signature->size(); + dims_signature_pointer = dims_signature->data(); + } + return SetTensorParametersReadWrite( + tensor_index, type, name, dims.size(), dims.data(), quantization, + is_variable, rank_dims_signature, dims_signature_pointer); } TfLiteStatus SetTensorParametersReadWrite( int tensor_index, TfLiteType type, const char* name, const size_t rank, const int* dims, TfLiteQuantizationParams quantization, - bool is_variable = false); + bool is_variable = false, const size_t rank_dims_signature = 0, + const int* dims_signature = nullptr); #endif // DOXYGEN_SKIP // Functions to access tensor data @@ -319,6 +328,15 @@ class Interpreter { TfLiteStatus ResizeInputTensor(int tensor_index, const std::vector& dims); + // WARNING: Experimental interface, subject to change + // Change the dimensionality of a given tensor. This is only acceptable for + // tensor indices that are inputs or variables. Only unknown dimensions can be + // resized with this function. Unknown dimensions are indicated as `-1` in the + // `dims_signature` attribute of a `TfLiteTensor`. Returns status of failure + // or success. + TfLiteStatus ResizeInputTensorStrict(int tensor_index, + const std::vector& dims); + // This releases memory held by non-persistent tensors. It does NOT re-perform // memory planning. // AllocateTensors needs to be called before next invocation. diff --git a/tensorflow/lite/interpreter_test.cc b/tensorflow/lite/interpreter_test.cc index ad0e19f4f89..38f9cd26f40 100644 --- a/tensorflow/lite/interpreter_test.cc +++ b/tensorflow/lite/interpreter_test.cc @@ -551,6 +551,65 @@ TEST(BasicInterpreter, NoopResizingTensors) { ASSERT_EQ(tensor->data.f[5], 0.123f); } +TEST(BasicInterpreter, ResizingTensorsStrictInvalid) { + // Tests ResizeInputTensorStrict where `dims_signature` is not specified. + Interpreter interpreter; + ASSERT_EQ(interpreter.AddTensors(1), kTfLiteOk); + ASSERT_EQ(interpreter.SetInputs({0}), kTfLiteOk); + ASSERT_EQ(interpreter.SetOutputs({0}), kTfLiteOk); + + ASSERT_EQ(interpreter.SetTensorParametersReadWrite( + 0, kTfLiteFloat32, "", {1, 1, 3}, TfLiteQuantizationParams()), + kTfLiteOk); + + int t = interpreter.inputs()[0]; + TfLiteTensor* tensor = interpreter.tensor(t); + + ASSERT_EQ(interpreter.ResizeInputTensorStrict(t, {1, 1, 3}), kTfLiteOk); + EXPECT_EQ(tensor->bytes, 3 * sizeof(float)); + ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); + + // Invalid becuase `dims_signature` is not specified. + ASSERT_EQ(interpreter.ResizeInputTensorStrict(t, {1, 2, 3}), kTfLiteError); + EXPECT_EQ(tensor->bytes, 3 * sizeof(float)); + ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); + + // Assert that ResizeInputTensor works for this value. + ASSERT_EQ(interpreter.ResizeInputTensor(t, {1, 2, 3}), kTfLiteOk); + EXPECT_EQ(tensor->bytes, 6 * sizeof(float)); + ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); +} + +TEST(BasicInterpreter, ResizingTensorsStrict) { + // Tests ResizeInputTensorStrict where `dims_signature` is specified. + Interpreter interpreter; + ASSERT_EQ(interpreter.AddTensors(1), kTfLiteOk); + ASSERT_EQ(interpreter.SetInputs({0}), kTfLiteOk); + ASSERT_EQ(interpreter.SetOutputs({0}), kTfLiteOk); + + std::vector dims_signature = {-1, -1, 3}; + ASSERT_EQ(interpreter.SetTensorParametersReadWrite( + 0, kTfLiteFloat32, "", {1, 1, 3}, TfLiteQuantizationParams(), + false, &dims_signature), + kTfLiteOk); + + int t = interpreter.inputs()[0]; + TfLiteTensor* tensor = interpreter.tensor(t); + + ASSERT_EQ(interpreter.ResizeInputTensorStrict(t, {1, 2, 3}), kTfLiteOk); + EXPECT_EQ(tensor->bytes, 6 * sizeof(float)); + ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); + + ASSERT_EQ(interpreter.ResizeInputTensorStrict(t, {1, 2, 4}), kTfLiteError); + EXPECT_EQ(tensor->bytes, 6 * sizeof(float)); + ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); + + // Assert that ResizeInputTensor works for this value. + ASSERT_EQ(interpreter.ResizeInputTensor(t, {1, 2, 4}), kTfLiteOk); + EXPECT_EQ(tensor->bytes, 8 * sizeof(float)); + ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); +} + // Simple op that does input = output. TfLiteRegistration GetPassthroughOpRegistration() { TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr}; diff --git a/tensorflow/lite/java/BUILD b/tensorflow/lite/java/BUILD index bfd610b759a..857974ecce2 100644 --- a/tensorflow/lite/java/BUILD +++ b/tensorflow/lite/java/BUILD @@ -105,6 +105,15 @@ android_library( ], ) +java_library( + name = "tensorflowlite_javalib", + srcs = JAVA_SRCS, + javacopts = JAVACOPTS, + deps = [ + "@org_checkerframework_qual", + ], +) + java_library( name = "tensorflowlitelib", srcs = JAVA_SRCS, diff --git a/tensorflow/lite/kernels/BUILD b/tensorflow/lite/kernels/BUILD index a4d188f34da..5bd03b0d14a 100644 --- a/tensorflow/lite/kernels/BUILD +++ b/tensorflow/lite/kernels/BUILD @@ -230,14 +230,15 @@ cc_test( cc_library( name = "tflite_with_ruy_enabled", + build_for_embedded = True, defines = ["TFLITE_WITH_RUY"], visibility = ["//visibility:private"], ) cc_library( name = "tflite_with_ruy_default", - visibility = ["//visibility:private"], - deps = select({ + build_for_embedded = True, + select_deps = { ":chromiumos_arm64": [":tflite_with_ruy_enabled"], ":cpu_aarch64": [":tflite_with_ruy_enabled"], ":cpu_arm64": [":tflite_with_ruy_enabled"], @@ -247,16 +248,18 @@ cc_library( ":cpu_arm64_v8a": [":tflite_with_ruy_enabled"], "//tensorflow:android_arm": ["tflite_with_ruy_enabled"], "//conditions:default": [], - }), + }, + visibility = ["//visibility:private"], ) cc_library( name = "tflite_with_ruy", - deps = select({ + build_for_embedded = True, + select_deps = { ":tflite_with_ruy_explicit_true": [":tflite_with_ruy_enabled"], ":tflite_with_ruy_explicit_false": [], "//conditions:default": [":tflite_with_ruy_default"], - }), + }, ) cc_library( @@ -267,6 +270,7 @@ cc_library( hdrs = [ "cpu_backend_context.h", ], + build_for_embedded = True, copts = tflite_copts(), deps = [ ":tflite_with_ruy", @@ -376,6 +380,7 @@ cc_library( hdrs = [ "kernel_util.h", ], + build_for_embedded = True, copts = tflite_copts() + micro_copts(), deps = [ "//tensorflow/lite/c:common", @@ -411,6 +416,7 @@ cc_library( name = "padding", srcs = [], hdrs = ["padding.h"], + build_for_embedded = True, copts = tflite_copts(), deps = [ "//tensorflow/lite/c:common", diff --git a/tensorflow/lite/kernels/fill.cc b/tensorflow/lite/kernels/fill.cc index b5e0c0cc9ff..19ff1de4939 100644 --- a/tensorflow/lite/kernels/fill.cc +++ b/tensorflow/lite/kernels/fill.cc @@ -94,6 +94,20 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { return kTfLiteOk; } +TfLiteStatus FillString(const TfLiteTensor* value, TfLiteTensor* output) { + DynamicBuffer buffer; + const auto string_ref = GetString(value, 0); + int n = 1; + for (int i = 0; i < output->dims->size; ++i) { + n *= output->dims->data[i]; + } + for (int i = 0; i < n; ++i) { + buffer.AddString(string_ref.str, string_ref.len); + } + buffer.WriteToTensor(output, /*new_shape=*/nullptr); + return kTfLiteOk; +} + TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { const TfLiteTensor* value = GetInput(context, node, kValueTensor); @@ -117,11 +131,17 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { case kTfLiteFloat32: TF_LITE_FILL(float); break; + case kTfLiteBool: + TF_LITE_FILL(bool); + break; + case kTfLiteString: + FillString(value, output); + break; default: context->ReportError( context, - "Fill only currently supports int32, int64, float32 for input 1," - "got %d.", + "Fill only currently supports int32, int64, float32, bool, string " + "for input 1, got %d.", value->type); return kTfLiteError; } diff --git a/tensorflow/lite/kernels/fill_test.cc b/tensorflow/lite/kernels/fill_test.cc index 5e359a86efb..4ab013bb357 100644 --- a/tensorflow/lite/kernels/fill_test.cc +++ b/tensorflow/lite/kernels/fill_test.cc @@ -84,5 +84,27 @@ TEST(FillOpModel, FillOutputScalar) { EXPECT_THAT(m.GetTensorShape(m.output()), IsEmpty()); } +TEST(FillOpModel, FillBool) { + FillOpModel m({TensorType_INT64, {3}}, {TensorType_BOOL}); + m.PopulateTensor(m.input1(), {2, 2, 2}); + m.PopulateTensor(m.input2(), {true}); + m.Invoke(); + EXPECT_THAT( + m.ExtractVector(m.output()), + ElementsAreArray({true, true, true, true, true, true, true, true})); + EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({2, 2, 2})); +} + +TEST(FillOpModel, FillString) { + FillOpModel m({TensorType_INT64, {3}}, {TensorType_STRING}); + m.PopulateTensor(m.input1(), {2, 2, 2}); + m.PopulateTensor(m.input2(), {"AB"}); + m.Invoke(); + EXPECT_THAT( + m.ExtractVector(m.output()), + ElementsAreArray({"AB", "AB", "AB", "AB", "AB", "AB", "AB", "AB"})); + EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({2, 2, 2})); +} + } // namespace } // namespace tflite diff --git a/tensorflow/lite/kernels/hashtable_lookup.cc b/tensorflow/lite/kernels/hashtable_lookup.cc index 62a15c68e29..a432dcb8e22 100644 --- a/tensorflow/lite/kernels/hashtable_lookup.cc +++ b/tensorflow/lite/kernels/hashtable_lookup.cc @@ -92,7 +92,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { } status = context->ResizeTensor(context, output, outputSize); } - if (context->ResizeTensor(context, hits, hitSize) == kTfLiteError) { + if (context->ResizeTensor(context, hits, hitSize) != kTfLiteOk) { status = kTfLiteError; } return status; diff --git a/tensorflow/lite/kernels/internal/BUILD b/tensorflow/lite/kernels/internal/BUILD index 373fffd8c24..f5d5b6da31f 100644 --- a/tensorflow/lite/kernels/internal/BUILD +++ b/tensorflow/lite/kernels/internal/BUILD @@ -62,6 +62,7 @@ cc_library( cc_library( name = "legacy_types", hdrs = ["legacy_types.h"], + build_for_embedded = True, copts = tflite_copts(), deps = [ ":types", @@ -201,6 +202,7 @@ cc_library( name = "common", srcs = [], hdrs = ["common.h"], + build_for_embedded = True, copts = tflite_copts(), deps = [ ":cpu_check", @@ -357,6 +359,7 @@ cc_library( name = "quantization_util", srcs = ["quantization_util.cc"], hdrs = ["quantization_util.h"], + build_for_embedded = True, copts = tflite_copts() + micro_copts(), deps = [ ":compatibility", @@ -384,6 +387,7 @@ cc_library( hdrs = [ "transpose_utils.h", ], + build_for_embedded = True, copts = tflite_copts(), deps = [ ":types", @@ -405,6 +409,7 @@ cc_library( hdrs = [ "strided_slice_logic.h", ], + build_for_embedded = True, copts = tflite_copts(), deps = [ ":compatibility", @@ -463,23 +468,9 @@ cc_library( "reference/sub.h", "reference/svdf.h", ], + build_for_embedded = True, copts = tflite_copts(), - deps = [ - ":common", - ":compatibility", - ":quantization_util", - ":cppmath", - ":strided_slice_logic", - ":tensor", - ":tensor_utils", - ":types", - "@gemmlowp//:fixedpoint", - "//third_party/eigen3", - "//tensorflow/lite/c:common", - "//tensorflow/lite/kernels:op_macros", - "@ruy//ruy/profiler:instrumentation", - "//tensorflow/lite/tools/optimize/sparsity:format_converter", - ] + select({ + select_deps = { ":haswell": tflite_deps_intel, ":ios_x86_64": tflite_deps_intel, ":k8": tflite_deps_intel, @@ -490,7 +481,23 @@ cc_library( ":freebsd": tflite_deps_intel, ":windows": tflite_deps_intel, "//conditions:default": [], - }), + }, + deps = [ + ":common", + ":compatibility", + ":cppmath", + ":quantization_util", + ":strided_slice_logic", + ":tensor", + ":tensor_utils", + ":types", + "//tensorflow/lite/c:common", + "//tensorflow/lite/kernels:op_macros", + "//tensorflow/lite/tools/optimize/sparsity:format_converter", + "//third_party/eigen3", + "@gemmlowp//:fixedpoint", + "@ruy//ruy/profiler:instrumentation", + ], ) cc_library( @@ -658,6 +665,7 @@ cc_library( name = "kernel_utils", srcs = ["kernel_utils.cc"], hdrs = ["kernel_utils.h"], + build_for_embedded = True, copts = tflite_copts() + micro_copts(), deps = [ ":tensor_utils", @@ -695,13 +703,9 @@ cc_library( hdrs = [ "tensor_utils.h", ], + build_for_embedded = True, copts = tflite_copts() + NEON_FLAGS_IF_APPLICABLE, - deps = [ - ":cpu_check", - "//third_party/eigen3", - "//tensorflow/lite/c:common", - "//tensorflow/lite/kernels:cpu_backend_context", - ] + select({ + select_deps = { ":aarch64": [ ":neon_tensor_utils", ], @@ -757,7 +761,13 @@ cc_library( "//conditions:default": [ ":portable_tensor_utils", ], - }), + }, + deps = [ + ":cpu_check", + "//tensorflow/lite/c:common", + "//tensorflow/lite/kernels:cpu_backend_context", + "//third_party/eigen3", + ], ) cc_library( @@ -995,10 +1005,9 @@ cc_library( "optimized/neon_check.h", "optimized/sse_check.h", ], + build_for_embedded = True, copts = tflite_copts(), - deps = [ - "//tensorflow/lite/kernels:cpu_backend_context", - ] + select({ + select_deps = { ":haswell": tflite_deps_intel, ":ios_x86_64": tflite_deps_intel, ":k8": tflite_deps_intel, @@ -1009,7 +1018,10 @@ cc_library( ":freebsd": tflite_deps_intel, ":windows": tflite_deps_intel, "//conditions:default": [], - }), + }, + deps = [ + "//tensorflow/lite/kernels:cpu_backend_context", + ], ) cc_test( diff --git a/tensorflow/lite/kernels/internal/reference/integer_ops/pooling.h b/tensorflow/lite/kernels/internal/reference/integer_ops/pooling.h index 2762bec8e6c..6b49d2b150b 100644 --- a/tensorflow/lite/kernels/internal/reference/integer_ops/pooling.h +++ b/tensorflow/lite/kernels/internal/reference/integer_ops/pooling.h @@ -135,6 +135,121 @@ inline void MaxPool(const PoolParams& params, const RuntimeShape& input_shape, } } +inline void AveragePool(const PoolParams& params, + const RuntimeShape& input_shape, + const int16* input_data, + const RuntimeShape& output_shape, int16* output_data) { + TFLITE_DCHECK_LE(params.quantized_activation_min, + params.quantized_activation_max); + TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4); + const int batches = MatchingDim(input_shape, 0, output_shape, 0); + const int depth = MatchingDim(input_shape, 3, output_shape, 3); + const int input_height = input_shape.Dims(1); + const int input_width = input_shape.Dims(2); + const int output_height = output_shape.Dims(1); + const int output_width = output_shape.Dims(2); + const int stride_height = params.stride_height; + const int stride_width = params.stride_width; + for (int batch = 0; batch < batches; ++batch) { + for (int out_y = 0; out_y < output_height; ++out_y) { + for (int out_x = 0; out_x < output_width; ++out_x) { + for (int channel = 0; channel < depth; ++channel) { + const int in_x_origin = + (out_x * stride_width) - params.padding_values.width; + const int in_y_origin = + (out_y * stride_height) - params.padding_values.height; + // Compute the boundaries of the filter region clamped so as to + // ensure that the filter window fits in the input array. + const int filter_x_start = std::max(0, -in_x_origin); + const int filter_x_end = + std::min(params.filter_width, input_width - in_x_origin); + const int filter_y_start = std::max(0, -in_y_origin); + const int filter_y_end = + std::min(params.filter_height, input_height - in_y_origin); + int32 acc = 0; + int filter_count = 0; + for (int filter_y = filter_y_start; filter_y < filter_y_end; + ++filter_y) { + for (int filter_x = filter_x_start; filter_x < filter_x_end; + ++filter_x) { + const int in_x = in_x_origin + filter_x; + const int in_y = in_y_origin + filter_y; + acc += + input_data[Offset(input_shape, batch, in_y, in_x, channel)]; + filter_count++; + } + } + // Round to the closest integer value. + acc = acc > 0 ? (acc + filter_count / 2) / filter_count + : (acc - filter_count / 2) / filter_count; + acc = std::max(acc, params.quantized_activation_min); + acc = std::min(acc, params.quantized_activation_max); + output_data[Offset(output_shape, batch, out_y, out_x, channel)] = + static_cast(acc); + } + } + } + } +} + +inline void MaxPool(const PoolParams& params, const RuntimeShape& input_shape, + const int16* input_data, const RuntimeShape& output_shape, + int16* output_data) { + TFLITE_DCHECK_LE(params.quantized_activation_min, + params.quantized_activation_max); + TFLITE_DCHECK_GE(params.quantized_activation_min, + std::numeric_limits::min()); + TFLITE_DCHECK_LE(params.quantized_activation_max, + std::numeric_limits::max()); + TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4); + const int batches = MatchingDim(input_shape, 0, output_shape, 0); + const int depth = MatchingDim(input_shape, 3, output_shape, 3); + const int input_height = input_shape.Dims(1); + const int input_width = input_shape.Dims(2); + const int output_height = output_shape.Dims(1); + const int output_width = output_shape.Dims(2); + const int stride_height = params.stride_height; + const int stride_width = params.stride_width; + for (int batch = 0; batch < batches; ++batch) { + for (int out_y = 0; out_y < output_height; ++out_y) { + for (int out_x = 0; out_x < output_width; ++out_x) { + for (int channel = 0; channel < depth; ++channel) { + const int in_x_origin = + (out_x * stride_width) - params.padding_values.width; + const int in_y_origin = + (out_y * stride_height) - params.padding_values.height; + // Compute the boundaries of the filter region clamped so as to + // ensure that the filter window fits in the input array. + const int filter_x_start = std::max(0, -in_x_origin); + const int filter_x_end = + std::min(params.filter_width, input_width - in_x_origin); + const int filter_y_start = std::max(0, -in_y_origin); + const int filter_y_end = + std::min(params.filter_height, input_height - in_y_origin); + int16_t max = std::numeric_limits::lowest(); + for (int filter_y = filter_y_start; filter_y < filter_y_end; + ++filter_y) { + for (int filter_x = filter_x_start; filter_x < filter_x_end; + ++filter_x) { + const int in_x = in_x_origin + filter_x; + const int in_y = in_y_origin + filter_y; + max = std::max( + max, + input_data[Offset(input_shape, batch, in_y, in_x, channel)]); + } + } + max = std::max(max, params.quantized_activation_min); + max = std::min(max, params.quantized_activation_max); + output_data[Offset(output_shape, batch, out_y, out_x, channel)] = + static_cast(max); + } + } + } + } +} + } // namespace reference_integer_ops } // namespace tflite diff --git a/tensorflow/lite/kernels/pooling.cc b/tensorflow/lite/kernels/pooling.cc index 63c6eb1239f..0dcb667e901 100644 --- a/tensorflow/lite/kernels/pooling.cc +++ b/tensorflow/lite/kernels/pooling.cc @@ -197,6 +197,32 @@ void AverageEvalQuantizedInt8(TfLiteContext* context, TfLiteNode* node, #undef TF_LITE_AVERAGE_POOL } +template +void AverageEvalQuantizedInt16(TfLiteContext* context, TfLiteNode* node, + TfLitePoolParams* params, OpData* data, + const TfLiteTensor* input, + TfLiteTensor* output) { + int32_t activation_min; + int32_t activation_max; + CalculateActivationRangeQuantized(context, params->activation, output, + &activation_min, &activation_max); +#define TF_LITE_AVERAGE_POOL(type) \ + tflite::PoolParams op_params; \ + op_params.stride_height = params->stride_height; \ + op_params.stride_width = params->stride_width; \ + op_params.filter_height = params->filter_height; \ + op_params.filter_width = params->filter_width; \ + op_params.padding_values.height = data->padding.height; \ + op_params.padding_values.width = data->padding.width; \ + op_params.quantized_activation_min = activation_min; \ + op_params.quantized_activation_max = activation_max; \ + type::AveragePool(op_params, GetTensorShape(input), \ + GetTensorData(input), GetTensorShape(output), \ + GetTensorData(output)) + TF_LITE_AVERAGE_POOL(reference_integer_ops); +#undef TF_LITE_AVERAGE_POOL +} + template void MaxEvalFloat(TfLiteContext* context, TfLiteNode* node, TfLitePoolParams* params, OpData* data, @@ -282,6 +308,31 @@ void MaxEvalQuantizedInt8(TfLiteContext* context, TfLiteNode* node, #undef TF_LITE_MAX_POOL } +template +void MaxEvalQuantizedInt16(TfLiteContext* context, TfLiteNode* node, + TfLitePoolParams* params, OpData* data, + const TfLiteTensor* input, TfLiteTensor* output) { + int32_t activation_min; + int32_t activation_max; + CalculateActivationRangeQuantized(context, params->activation, output, + &activation_min, &activation_max); +#define TF_LITE_MAX_POOL(type) \ + tflite::PoolParams op_params; \ + op_params.stride_height = params->stride_height; \ + op_params.stride_width = params->stride_width; \ + op_params.filter_height = params->filter_height; \ + op_params.filter_width = params->filter_width; \ + op_params.padding_values.height = data->padding.height; \ + op_params.padding_values.width = data->padding.width; \ + op_params.quantized_activation_min = activation_min; \ + op_params.quantized_activation_max = activation_max; \ + type::MaxPool(op_params, GetTensorShape(input), \ + GetTensorData(input), GetTensorShape(output), \ + GetTensorData(output)) + TF_LITE_MAX_POOL(reference_integer_ops); +#undef TF_LITE_MAX_POOL +} + template void L2EvalFloat(TfLiteContext* context, TfLiteNode* node, TfLitePoolParams* params, OpData* data, @@ -330,6 +381,10 @@ TfLiteStatus AverageEval(TfLiteContext* context, TfLiteNode* node) { AverageEvalQuantizedInt8(context, node, params, data, input, output); break; + case kTfLiteInt16: + AverageEvalQuantizedInt16(context, node, params, data, input, + output); + break; default: context->ReportError(context, "Type %d not currently supported.", input->type); @@ -357,6 +412,10 @@ TfLiteStatus MaxEval(TfLiteContext* context, TfLiteNode* node) { MaxEvalQuantizedInt8(context, node, params, data, input, output); break; + case kTfLiteInt16: + MaxEvalQuantizedInt16(context, node, params, data, input, + output); + break; default: context->ReportError(context, "Type %d not currently supported.", input->type); diff --git a/tensorflow/lite/kernels/pooling_test.cc b/tensorflow/lite/kernels/pooling_test.cc index 1b371361a4d..e609f04e21d 100644 --- a/tensorflow/lite/kernels/pooling_test.cc +++ b/tensorflow/lite/kernels/pooling_test.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include + #include #include "tensorflow/lite/interpreter.h" #include "tensorflow/lite/kernels/register.h" @@ -96,6 +97,25 @@ class SymmetricQuantizedPoolingOpModel : public BasePoolingOpModel { } }; +class SymmetricQuantizedPoolingOpModel16 : public BasePoolingOpModel { + public: + using BasePoolingOpModel::BasePoolingOpModel; + + void SetInput(std::initializer_list data) { + QuantizeAndPopulate(input_, data); + } + + void SetInput(const std::vector& data) { + QuantizeAndPopulate(input_, data); + } + + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetDequantizedOutput() { + return Dequantize(ExtractVector(output_), + GetScale(output_), GetZeroPoint(output_)); + } +}; + // Replicate each entry in a vector n times along depth (innermost dimension). // The values are incremented by delta, creating ramps offset by each input // value. This is used to create simple and predicatable variation. @@ -398,6 +418,29 @@ TEST(QuantizedPoolingOpTest, AveragePoolLargeDepth) { ReplicateDepthRamp(output_image_plane, depth, 1.f / 512.f), 1. / 32.f))); } + +// Test quantized AveragePool with int16 input and output. The input is the same +// as the uint8 test QuantizedPoolingOpTest.AveragePool but with a scale of +// 1/4096 rather than 1/16. +TEST(QuantizedPoolingOpTest, SymmetricAveragePool16) { + const float ulp = 1.f / 4096.f; + SymmetricQuantizedPoolingOpModel16 m( + BuiltinOperator_AVERAGE_POOL_2D, + /*input=*/{TensorType_INT16, {1, 2, 4, 1}, 0, 16 - ulp}, + /*filter_width=*/2, /*filter_height=*/2, + /*output=*/{TensorType_INT16, {}, 0, 16 - ulp}); + m.SetInput({ + 0, 6, 2, 4, // + 3, 2, 10, 7, // + }); + m.Invoke(); + + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear({2.75, 5.75}))); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({(44 - 128) * 256, (92 - 128) * 256})); +} + // Test quantized AveragePool with int8 input and output. The input is the same // as the uint8 test QuantizedPoolingOpTest.AveragePool. The float output is // identical to uint8 test and quantized output is identical to uint8 test with @@ -858,6 +901,28 @@ TEST(QuantizedInt8PoolingOpTest, MaxPool) { EXPECT_THAT(m.GetOutput(), ElementsAreArray({96 - 128, 160 - 128})); } +TEST(QuantizedInt8PoolingOpTest16, MaxPool) { + // Choose the input ranges carefully so that the dequantized output matches + // the results of the float model above. + // Input Range[0, 16-(1/4096)] --> [Scale{(1/4096)}, zero_point{-32768}] + const float ulp = 1.f / 4096.f; + SymmetricQuantizedPoolingOpModel16 m( + BuiltinOperator_MAX_POOL_2D, + /*input=*/{TensorType_INT16, {1, 2, 4, 1}, 0, 16 - ulp}, + /*filter_width=*/2, /*filter_height=*/2, + /*output=*/{TensorType_INT16, {}, 0, 16 - ulp}); + m.SetInput({ + 0, 6, 2, 4, // + 3, 2, 10, 7, // + }); + m.Invoke(); + + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear({6, 10}))); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({(96 - 128) * 256, (160 - 128) * 256})); +} + TEST(QuantizedInt8PoolingOpTest, MaxPoolActivationRelu) { // Choose the input ranges carefully so that the dequantized output matches // the results of the float model above. diff --git a/tensorflow/lite/kernels/register.cc b/tensorflow/lite/kernels/register.cc index faa4d818e0a..619f2a650c9 100644 --- a/tensorflow/lite/kernels/register.cc +++ b/tensorflow/lite/kernels/register.cc @@ -44,11 +44,11 @@ BuiltinOpResolver::BuiltinOpResolver() { /* min_version = */ 1, /* max_version = */ 3); AddBuiltin(BuiltinOperator_AVERAGE_POOL_2D, Register_AVERAGE_POOL_2D(), - /* min_version = */ 1, - /* max_version = */ 2); + /* min_version */ 1, + /* max_version */ 3); AddBuiltin(BuiltinOperator_MAX_POOL_2D, Register_MAX_POOL_2D(), - /* min_version = */ 1, - /* max_version = */ 2); + /* min_version */ 1, + /* max_version */ 3); AddBuiltin(BuiltinOperator_L2_POOL_2D, Register_L2_POOL_2D()); AddBuiltin(BuiltinOperator_CONV_2D, Register_CONV_2D(), /* min_version = */ 1, @@ -255,7 +255,9 @@ BuiltinOpResolver::BuiltinOpResolver() { /* min_version = */ 1, /* max_version = */ 2); AddBuiltin(BuiltinOperator_SQUARED_DIFFERENCE, Register_SQUARED_DIFFERENCE()); - AddBuiltin(BuiltinOperator_FILL, Register_FILL()); + AddBuiltin(BuiltinOperator_FILL, Register_FILL(), + /* min_version = */ 1, + /* max_version = */ 2); AddBuiltin(BuiltinOperator_MIRROR_PAD, Register_MIRROR_PAD()); AddBuiltin(BuiltinOperator_UNIQUE, Register_UNIQUE()); AddBuiltin(BuiltinOperator_REVERSE_V2, Register_REVERSE_V2(), diff --git a/tensorflow/lite/micro/benchmarks/conv_benchmark.cc b/tensorflow/lite/micro/benchmarks/conv_benchmark.cc index cd258312822..9153b3015c1 100644 --- a/tensorflow/lite/micro/benchmarks/conv_benchmark.cc +++ b/tensorflow/lite/micro/benchmarks/conv_benchmark.cc @@ -230,7 +230,7 @@ int main() { TfLiteStatus status = tflite::testing::ValidateConvGoldens( tensors, num_tensors, &conv_params, kQuantizationTolerance, output_dims_count, golden_quantized); - if (status == kTfLiteError) { + if (status != kTfLiteOk) { printf("Model invoke failed\n"); } return 0; diff --git a/tensorflow/lite/micro/benchmarks/depthwise_conv_benchmark.cc b/tensorflow/lite/micro/benchmarks/depthwise_conv_benchmark.cc index d0170b588ef..e4e96f8e029 100644 --- a/tensorflow/lite/micro/benchmarks/depthwise_conv_benchmark.cc +++ b/tensorflow/lite/micro/benchmarks/depthwise_conv_benchmark.cc @@ -241,7 +241,7 @@ int main() { TfLiteStatus status = tflite::testing::ValidateDepthwiseConvGoldens( tensors, kTensorsSize, kTfLiteActNone, kQuantizationTolerance, output_elements, golden_quantized); - if (status == kTfLiteError) { + if (status != kTfLiteOk) { printf("Model invoke failed\n"); } return 0; diff --git a/tensorflow/lite/micro/build_def.bzl b/tensorflow/lite/micro/build_def.bzl index f6e36255c22..ef37c92d9cd 100644 --- a/tensorflow/lite/micro/build_def.bzl +++ b/tensorflow/lite/micro/build_def.bzl @@ -12,6 +12,12 @@ def micro_copts(): def cc_library(**kwargs): kwargs.pop("build_for_embedded", False) + if "select_deps" in kwargs.keys(): + select_deps = kwargs.pop("select_deps", {}) + if "deps" in kwargs.keys(): + kwargs["deps"] += select(select_deps) + else: + kwargs["deps"] = select(select_deps) _cc_library(**kwargs) def flatbuffer_cc_library(**kwargs): diff --git a/tensorflow/lite/micro/examples/hello_world/BUILD b/tensorflow/lite/micro/examples/hello_world/BUILD index 25cf97bdd82..c03069e4ecc 100644 --- a/tensorflow/lite/micro/examples/hello_world/BUILD +++ b/tensorflow/lite/micro/examples/hello_world/BUILD @@ -7,6 +7,7 @@ load( ) load( "//tensorflow/lite/micro:build_def.bzl", + "cc_library", "micro_copts", ) @@ -22,6 +23,7 @@ cc_library( hdrs = [ "sine_model_data.h", ], + build_for_embedded = True, copts = micro_copts(), ) diff --git a/tensorflow/lite/micro/examples/micro_speech/train_speech_model.ipynb b/tensorflow/lite/micro/examples/micro_speech/train_speech_model.ipynb index c528ea16098..0baaeac0482 100644 --- a/tensorflow/lite/micro/examples/micro_speech/train_speech_model.ipynb +++ b/tensorflow/lite/micro/examples/micro_speech/train_speech_model.ipynb @@ -120,10 +120,10 @@ "colab": {} }, "source": [ - "# Replace Colab's default TensorFlow install with a more recent\n", + "# Replace Colab's default TensorFlow install with an older\n", "# build that contains the operations that are needed for training\n", "!pip uninstall -y tensorflow tensorflow_estimator tensorboard\n", - "!pip install -q tf-estimator-nightly==1.14.0.dev2019072901 tf-nightly-gpu==1.15.0.dev20190729" + "!pip install -q tensorflow==1.15" ], "execution_count": 0, "outputs": [] @@ -147,10 +147,7 @@ }, "source": [ "# Clone the repository from GitHub\n", - "!git clone -q https://github.com/tensorflow/tensorflow\n", - "# Check out a commit that has been tested to work\n", - "# with the build of TensorFlow we're using\n", - "!git -c advice.detachedHead=false -C tensorflow checkout 17ce384df70" + "!git clone -q https://github.com/tensorflow/tensorflow\n" ], "execution_count": 0, "outputs": [] @@ -209,7 +206,7 @@ "--wanted_words=${WANTED_WORDS} --silence_percentage=25 --unknown_percentage=25 \\\n", "--quantize=1 --verbosity=WARN --how_many_training_steps=${TRAINING_STEPS} \\\n", "--learning_rate=${LEARNING_RATE} --summaries_dir=/content/retrain_logs \\\n", - "--data_dir=/content/speech_dataset --train_dir=/content/speech_commands_train \\\n" + "--data_dir=/content/speech_dataset --train_dir=/content/speech_commands_train\n" ], "execution_count": 0, "outputs": [] diff --git a/tensorflow/lite/micro/kernels/BUILD b/tensorflow/lite/micro/kernels/BUILD index a0ffa342008..37709f94284 100644 --- a/tensorflow/lite/micro/kernels/BUILD +++ b/tensorflow/lite/micro/kernels/BUILD @@ -4,6 +4,7 @@ load( ) load( "//tensorflow/lite/micro:build_def.bzl", + "cc_library", "micro_copts", ) @@ -14,6 +15,11 @@ package( licenses = ["notice"], # Apache 2.0 ) +config_setting( + name = "xtensa_hifimini", + define_values = {"tflm_build": "xtensa_hifimini"}, +) + # LINT.IfChange(micro_ops) cc_library( name = "micro_ops", @@ -25,12 +31,9 @@ cc_library( "circular_buffer.cc", "comparisons.cc", "concatenation.cc", - "conv.cc", - "depthwise_conv.cc", "dequantize.cc", "elementwise.cc", "floor.cc", - "fully_connected.cc", "logical.cc", "logistic.cc", "maximum_minimum.cc", @@ -40,18 +43,36 @@ cc_library( "pad.cc", "pooling.cc", "prelu.cc", - "quantize.cc", "reduce.cc", "reshape.cc", "round.cc", - "softmax.cc", "split.cc", "strided_slice.cc", "sub.cc", - "svdf.cc", "unpack.cc", - ], + ] + select({ + "//conditions:default": [ + "conv.cc", + "depthwise_conv.cc", + "fully_connected.cc", + "quantize.cc", + "softmax.cc", + "svdf.cc", + ], + ":xtensa_hifimini": [ + "xtensa_hifimini/conv.cc", + "xtensa_hifimini/depthwise_conv.cc", + "xtensa_hifimini/fixedpoint_utils.h", + "xtensa_hifimini/fully_connected.cc", + "xtensa_hifimini/quantize.cc", + "xtensa_hifimini/softmax.cc", + "xtensa_hifimini/svdf.cc", + "xtensa_hifimini/utils.h", + ], + }), hdrs = ["micro_ops.h"], + # TODO(b/153609488): enable embedded build once we can properly support it. + #build_for_embedded = True, copts = micro_copts(), deps = [ ":activation_utils", @@ -67,7 +88,12 @@ cc_library( "//tensorflow/lite/kernels/internal:tensor", "//tensorflow/lite/kernels/internal:types", "//tensorflow/lite/micro:micro_utils", - ], + ] + select({ + "//conditions:default": [], + ":xtensa_hifimini": [ + #"//third_party/xtensa:hifi_mini_cstub64", + ], + }), ) # LINT.ThenChange(//tensorflow/lite/micro/kernels/BUILD:portable_optimized_micro_ops) @@ -79,6 +105,8 @@ cc_library( hdrs = [ "all_ops_resolver.h", ], + # TODO(b/153609488): enable embedded build once we can properly support it. + #build_for_embedded = True, copts = micro_copts(), deps = [ ":micro_ops", @@ -506,6 +534,7 @@ tflite_micro_cc_test( cc_library( name = "activation_utils", hdrs = ["activation_utils.h"], + build_for_embedded = True, deps = [ "//tensorflow/lite/c:common", "//tensorflow/lite/kernels/internal:cppmath", @@ -543,6 +572,7 @@ tflite_micro_cc_test( cc_library( name = "micro_utils", hdrs = ["micro_utils.h"], + build_for_embedded = True, ) tflite_micro_cc_test( diff --git a/tensorflow/lite/micro/kernels/fully_connected.cc b/tensorflow/lite/micro/kernels/fully_connected.cc index 73f200d454b..54c923cd314 100644 --- a/tensorflow/lite/micro/kernels/fully_connected.cc +++ b/tensorflow/lite/micro/kernels/fully_connected.cc @@ -71,33 +71,16 @@ TfLiteStatus CalculateOpData(TfLiteContext* context, } // namespace void* Init(TfLiteContext* context, const char* buffer, size_t length) { - OpData* data = nullptr; - TfLiteStatus status = context->AllocatePersistentBuffer( - context, sizeof(OpData), reinterpret_cast(&data)); - if (status != kTfLiteOk || data == nullptr) { - return nullptr; - } - return data; + return nullptr; } TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { - OpData* data = reinterpret_cast(node->user_data); - auto* params = - reinterpret_cast(node->builtin_data); - const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* filter = GetInput(context, node, kWeightsTensor); - const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); - TF_LITE_ENSURE_EQ(context, input->type, output->type); TF_LITE_ENSURE_MSG(context, input->type == filter->type, "Hybrid models are not supported on TFLite Micro."); - - TfLiteType data_type = input->type; - TF_LITE_ENSURE_STATUS(CalculateOpData(context, params, data_type, input, - filter, bias, output, data)); - return kTfLiteOk; } @@ -192,7 +175,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); - OpData* data = reinterpret_cast(node->user_data); + TfLiteType data_type = input->type; + OpData local_data_object; + OpData* data = &local_data_object; + TF_LITE_ENSURE_STATUS(CalculateOpData(context, params, data_type, input, + filter, bias, output, data)); // Checks in Prepare ensure input, output and filter types are all the same. switch (input->type) { diff --git a/tensorflow/lite/micro/kernels/fully_connected_test.cc b/tensorflow/lite/micro/kernels/fully_connected_test.cc index 539c7ecc3a4..32a1b67b88e 100644 --- a/tensorflow/lite/micro/kernels/fully_connected_test.cc +++ b/tensorflow/lite/micro/kernels/fully_connected_test.cc @@ -49,6 +49,7 @@ void TestFullyConnectedFloat( TfLiteContext context; PopulateContext(tensors, tensors_size, micro_test::reporter, &context); + ::tflite::ops::micro::AllOpsResolver resolver; const TfLiteRegistration* registration = resolver.FindOp(tflite::BuiltinOperator_FULLY_CONNECTED, 1); diff --git a/tensorflow/lite/micro/kernels/reshape_test.cc b/tensorflow/lite/micro/kernels/reshape_test.cc index 983b3da35d5..b783c3c3e0f 100644 --- a/tensorflow/lite/micro/kernels/reshape_test.cc +++ b/tensorflow/lite/micro/kernels/reshape_test.cc @@ -79,7 +79,7 @@ void TestReshapeImpl(TfLiteTensor* input_tensor, TfLiteTensor* shape_tensor, if (registration->prepare) { // Error can happen either in Prepare or eval stage. auto status = registration->prepare(&context, &node); - if (status == kTfLiteError && expect_failure) { + if (status != kTfLiteOk && expect_failure) { return; } else { TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, status); diff --git a/tensorflow/lite/micro/kernels/xtensa_hifimini/conv.cc b/tensorflow/lite/micro/kernels/xtensa_hifimini/conv.cc index 15ba0c33ee2..a1de5ef22e0 100644 --- a/tensorflow/lite/micro/kernels/xtensa_hifimini/conv.cc +++ b/tensorflow/lite/micro/kernels/xtensa_hifimini/conv.cc @@ -204,7 +204,7 @@ struct OpData { // These constants represent constants specific to the music detect model. // They exist until (b/132070898) is fixed. static const int kMaxOpDataSize = 6; -static int kStaticOpDataCounter = 0; +static int op_data_counter = 0; static OpData kStaticOpData[kMaxOpDataSize]; TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node, @@ -245,6 +245,8 @@ TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node, return kTfLiteOk; } +void Free(TfLiteContext* context, void* buffer) { op_data_counter = 0; } + TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { auto* params = reinterpret_cast(node->builtin_data); @@ -254,7 +256,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // TODO(b/132070898): Use statically slotted OpData structures until a // scratch memory API is ready. - OpData* op_data = &kStaticOpData[kStaticOpDataCounter++]; + OpData* op_data = &kStaticOpData[op_data_counter++]; node->user_data = op_data; int input_width = input->dims->data[2]; @@ -344,7 +346,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteRegistration* Register_CONV_2D() { static TfLiteRegistration r = {/*init=*/nullptr, - /*free=*/nullptr, + /*free=*/conv::Free, /*prepare=*/conv::Prepare, /*invoke=*/conv::Eval, /*profiling_string=*/nullptr, diff --git a/tensorflow/lite/micro/kernels/xtensa_hifimini/depthwise_conv.cc b/tensorflow/lite/micro/kernels/xtensa_hifimini/depthwise_conv.cc index 554ec1d177c..3760dd71838 100644 --- a/tensorflow/lite/micro/kernels/xtensa_hifimini/depthwise_conv.cc +++ b/tensorflow/lite/micro/kernels/xtensa_hifimini/depthwise_conv.cc @@ -207,7 +207,7 @@ struct OpData { // These constants represent constants specific to the music detect model. // They exist until (b/132070898) is fixed. static const int kMaxOpDataSize = 6; -static int kStaticOpDataCounter = 0; +static int op_data_counter = 0; static OpData kStaticOpData[kMaxOpDataSize]; TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node, @@ -249,6 +249,8 @@ TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node, } // namespace +void Free(TfLiteContext* context, void* buffer) { op_data_counter = 0; } + TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { auto* params = reinterpret_cast(node->builtin_data); @@ -258,7 +260,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // TODO(b/132070898): Use statically slotted OpData structures until a // scratch memory API is ready. - OpData* op_data = &kStaticOpData[kStaticOpDataCounter++]; + OpData* op_data = &kStaticOpData[op_data_counter++]; node->user_data = op_data; const TfLiteType data_type = input->type; @@ -352,7 +354,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteRegistration* Register_DEPTHWISE_CONV_2D() { static TfLiteRegistration r = {/*init=*/nullptr, - /*free=*/nullptr, + /*free=*/depthwise_conv::Free, /*prepare=*/depthwise_conv::Prepare, /*invoke=*/depthwise_conv::Eval, /*profiling_string=*/nullptr, diff --git a/tensorflow/lite/micro/kernels/xtensa_hifimini/fixedpoint_utils.h b/tensorflow/lite/micro/kernels/xtensa_hifimini/fixedpoint_utils.h index f35ffaa741e..4ffb3653f50 100644 --- a/tensorflow/lite/micro/kernels/xtensa_hifimini/fixedpoint_utils.h +++ b/tensorflow/lite/micro/kernels/xtensa_hifimini/fixedpoint_utils.h @@ -31,23 +31,6 @@ namespace micro { namespace xtensa { namespace hifimini { -// -// Product of two fixed-point 24bit integers with right shift. -// -// Two 24bit integers from the HH side of a PR register entry are MAC into a QR -// register. That value will be right shifted if |shift_length| is greater than -// 0. -// -inline ae_q56s SaturatingMultiply(ae_p24x2s a_56, ae_p24x2s b_56, - int shift_length) { - ae_q56s result_56 = AE_ZEROQ56(); - AE_MULAS56P24S_HH(result_56, a_56, b_56); - if (shift_length > 0) { - return AE_Q56S_SRA(result_56, shift_length); - } - return result_56; -} - // // Multiply 32bit value by a quantized multiplier (w/ shift) and returns a 48bit // aligned value in the QR register. @@ -58,7 +41,7 @@ inline ae_q56s MultiplyByQuantizedMultiplier(int32_t x, // These boolean factors will carry an additional 2^8 (e.g 256) factor // throughout the equation to cover the missing 8 bits of precision when a // 32bit integer is outside the bounds of INT24. The additional scaling factor - // will be adjusted on the final SaturatingMultiply() call in this method. + // will be adjusted after the final multiplication in this method. // // The Q-notation comments in this method describe the calculations that take // place when both |x| and the shifted value of |1| overflow the INT24 limits. @@ -78,10 +61,8 @@ inline ae_q56s MultiplyByQuantizedMultiplier(int32_t x, // Q31.0 -> Q23.0 / 2^8 ae_p24x2s shifted_24x2 = AE_CONVERT_INT32_24x2(shifted); - // Multiply/accumulate sum and multiplier: // (Q23.0 / 2^8) * (Q23.0 / 2^8) = Q47.0 / 2^16 - ae_q56s sum_56 = AE_ZEROQ56(); - AE_MULAS56P24S_HH(sum_56, x_24x2, shifted_24x2); + ae_q56s sum_56 = AE_MULP24S_HH(x_24x2, shifted_24x2); // Shift left into 24bit space: // ((Q47.0 / 2^16) << 24) = Q23.24 / 2^16 @@ -113,8 +94,10 @@ inline ae_q56s MultiplyByQuantizedMultiplier(int32_t x, // to 48bit aligned. // (Q23.0 / 2^16) * Q23.0 = Q47.0 / 2^16 // (Q47.0 / 2^16) >> 7 = Q47.0 - ae_q56s result_56 = - SaturatingMultiply(x_24x2, quantized_multiplier_24x2, shift_amount); + ae_q56s result_56 = AE_MULP24S_HH(x_24x2, quantized_multiplier_24x2); + if (shift_amount > 0) { + result_56 = AE_Q56S_SRA(result_56, shift_amount); + } if (shift < 0) { // Handle any negative shift directly on the 48 bit value. @@ -137,8 +120,7 @@ inline ae_q56s MultiplyByQuantizedMultiplier(ae_p24x2s x_24x2, // the limits of INT24, which requires |AE_CONVERT_INT32_24x2()| to load the // left-most 24 bits of a 32bit integer. When this occurs, all Q values here // carry an additional division of 2^8 to account for this loss in precision. - // This division will be applied to the final shift of the result in - // |SaturatingMultiply()|. + // This division will be applied to the final shift after multiplication. // // The Q-notation comments in this method describe the calculations that take // place when both |x| and the shifted value of |1| overflow the INT24 limits. @@ -154,11 +136,8 @@ inline ae_q56s MultiplyByQuantizedMultiplier(ae_p24x2s x_24x2, // Q31.0 -> Q23.0 / 2^8 ae_p24x2s shifted_24x2 = AE_CONVERT_INT32_24x2(shifted); - // Multiply/accumulate sum and multiplier: - ae_q56s sum_56 = AE_ZEROQ56(); - // Multiply/accumulate sum and multiplier: // Q23.0 * (Q23.0 / 2^8) = Q47.0 / 2^8 - AE_MULAS56P24S_HH(sum_56, x_24x2, shifted_24x2); + ae_q56s sum_56 = AE_MULP24S_HH(x_24x2, shifted_24x2); // Shift left into 24bit space: // ((Q47.0 / 2^8) << 24) = Q23.24 / 2^8 @@ -182,8 +161,8 @@ inline ae_q56s MultiplyByQuantizedMultiplier(ae_p24x2s x_24x2, // function: // (Q23.0 / 2^8) * Q23.0 = Q47.0 / 2^8 // (Q47.0 / 2^8) >> 7 = Q47.0 - ae_q56s result = SaturatingMultiply(x_shifted_24x2, quantized_multiplier_24x2, - shift_exceeds_24bits ? 15 : 23); + ae_q56s result = AE_MULP24S_HH(x_shifted_24x2, quantized_multiplier_24x2); + result = AE_Q56S_SRA(result, shift_exceeds_24bits ? 15 : 23); if (shift < 0) { // Handle any negative shift directly on the 48 bit value. diff --git a/tensorflow/lite/micro/kernels/xtensa_hifimini/fully_connected.cc b/tensorflow/lite/micro/kernels/xtensa_hifimini/fully_connected.cc index 024ce06273c..7a535120216 100644 --- a/tensorflow/lite/micro/kernels/xtensa_hifimini/fully_connected.cc +++ b/tensorflow/lite/micro/kernels/xtensa_hifimini/fully_connected.cc @@ -55,8 +55,8 @@ inline void FullyConnected( const int accum_depth = filter_shape.Dims(filter_dim_count - 1); const int accum_depth_iters = accum_depth / 2; - ae_p24x2s offsets_input_24x2 = AE_MOVPA24X2(input_offset, input_offset); - ae_p24x2s offsets_filter_24x2 = AE_MOVPA24X2(filter_offset, filter_offset); + ae_p24x2s offsets_input_24x2 = AE_MOVPA24(input_offset); + ae_p24x2s offsets_filter_24x2 = AE_MOVPA24(filter_offset); ae_q56s output_offset_56 = AE_CVTQ48A32S(output_offset); ae_q56s output_activation_max_56 = AE_CVTQ48A32S(output_activation_max); ae_q56s output_activation_min_56 = AE_CVTQ48A32S(output_activation_min); @@ -148,7 +148,7 @@ constexpr int kOutputTensor = 0; // This size will work for both the hotword (5) and ambient music (2): constexpr int kMaxOpDataSize = 7; -static int kStaticOpDataCounter = 0; +static int op_data_counter = 0; static OpData kStaticOpData[kMaxOpDataSize]; TfLiteStatus CalculateOpData(TfLiteContext* context, @@ -175,6 +175,8 @@ TfLiteStatus CalculateOpData(TfLiteContext* context, } // namespace +void Free(TfLiteContext* context, void* buffer) { op_data_counter = 0; } + TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { auto* params = reinterpret_cast(node->builtin_data); @@ -187,7 +189,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // TODO(b/132070898): Use statically slotted OpData structures until a // scratch memory API is ready. - OpData* op_data = &kStaticOpData[kStaticOpDataCounter++]; + OpData* op_data = &kStaticOpData[op_data_counter++]; node->user_data = op_data; TF_LITE_ENSURE_STATUS(CalculateOpData(context, params, data_type, input, @@ -246,7 +248,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteRegistration* Register_FULLY_CONNECTED() { static TfLiteRegistration r = {/*init=*/nullptr, - /*free=*/nullptr, + /*free=*/fully_connected::Free, /*prepare=*/fully_connected::Prepare, /*invoke=*/fully_connected::Eval, /*profiling_string=*/nullptr, diff --git a/tensorflow/lite/micro/kernels/xtensa_hifimini/quantize.cc b/tensorflow/lite/micro/kernels/xtensa_hifimini/quantize.cc index 9a119cea528..0ac5ab821df 100644 --- a/tensorflow/lite/micro/kernels/xtensa_hifimini/quantize.cc +++ b/tensorflow/lite/micro/kernels/xtensa_hifimini/quantize.cc @@ -55,10 +55,8 @@ void AffineQuantize(int scale_multiplier, inputs_24x2 = AE_P24X2S_SRAI(inputs_24x2, 8); // Q0.23 * Q16.0 == Q16.23 - ae_q56s sum_56 = AE_ZEROQ56(); - { - AE_MULAS56P24S_HH(sum_56, scale_multiplier_24x2, inputs_24x2); + ae_q56s sum_56 = AE_MULP24S_HH(scale_multiplier_24x2, inputs_24x2); // Q16.23 -> Q16.0 // Shift right only 7 bits (23 - 16). This truncated shift aligns the @@ -78,10 +76,8 @@ void AffineQuantize(int scale_multiplier, output_data[i * 2] = static_cast(AE_TRUNCA32Q48(sum_56)); } - - sum_56 = AE_ZEROQ56(); { - AE_MULAS56P24S_LL(sum_56, scale_multiplier_24x2, inputs_24x2); + ae_q56s sum_56 = AE_MULP24S_LL(scale_multiplier_24x2, inputs_24x2); // Q16.23 -> Q16.0 // Shift right only 7 bits (23 - 16). This truncated shift aligns the diff --git a/tensorflow/lite/micro/kernels/xtensa_hifimini/softmax.cc b/tensorflow/lite/micro/kernels/xtensa_hifimini/softmax.cc index 5ddc36eb75c..c77e9d1173c 100644 --- a/tensorflow/lite/micro/kernels/xtensa_hifimini/softmax.cc +++ b/tensorflow/lite/micro/kernels/xtensa_hifimini/softmax.cc @@ -71,21 +71,29 @@ TfLiteStatus CalculateSoftmaxOpData(TfLiteContext* context, return kTfLiteOk; } -} // namespace - -void SoftmaxQuantized(const TfLiteTensor* input, TfLiteTensor* output, - const SoftmaxParams& op_params) { - if (output->type == kTfLiteInt16) { - tflite::reference_ops::Softmax( - op_params, GetTensorShape(input), GetTensorData(input), - GetTensorShape(output), GetTensorData(output)); - } else { - tflite::reference_ops::Softmax( - op_params, GetTensorShape(input), GetTensorData(input), - GetTensorShape(output), GetTensorData(output)); +TfLiteStatus SoftmaxQuantized(TfLiteContext* context, const TfLiteTensor* input, + TfLiteTensor* output, + const SoftmaxParams& op_params) { + switch (output->type) { + case kTfLiteInt16: + tflite::reference_ops::Softmax( + op_params, GetTensorShape(input), GetTensorData(input), + GetTensorShape(output), GetTensorData(output)); + return kTfLiteOk; + case kTfLiteInt8: + tflite::reference_ops::Softmax( + op_params, GetTensorShape(input), GetTensorData(input), + GetTensorShape(output), GetTensorData(output)); + return kTfLiteOk; + default: + TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.", + TfLiteTypeGetName(output->type), output->type); + return kTfLiteError; } } +} // namespace + TfLiteStatus SoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) { auto* params = static_cast(node->builtin_data); @@ -113,10 +121,8 @@ TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) { TfLiteTensor* output = GetOutput(context, node, 0); switch (input->type) { - case kTfLiteInt8: { - SoftmaxQuantized(input, output, *op_params); - return kTfLiteOk; - } + case kTfLiteInt8: + return SoftmaxQuantized(context, input, output, *op_params); default: TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.", TfLiteTypeGetName(input->type), input->type); diff --git a/tensorflow/lite/micro/kernels/xtensa_hifimini/svdf.cc b/tensorflow/lite/micro/kernels/xtensa_hifimini/svdf.cc index 53ee9d70b64..1847a4e88e8 100644 --- a/tensorflow/lite/micro/kernels/xtensa_hifimini/svdf.cc +++ b/tensorflow/lite/micro/kernels/xtensa_hifimini/svdf.cc @@ -48,7 +48,7 @@ struct OpData { int effective_scale_2_b; }; -static int kStaticOpDataCounter = 0; +static int op_data_counter = 0; static OpData kStaticOpData[kMaxOpDataSize]; /** @@ -245,8 +245,6 @@ void EvalIntegerSVDF( } } -} // namespace - // Input tensors. constexpr int kInputTensor = 0; constexpr int kWeightsFeatureTensor = 1; @@ -257,6 +255,9 @@ constexpr int kInputActivationStateTensor = 4; // Output tensor. constexpr int kOutputTensor = 0; +} // namespace + +void Free(TfLiteContext* context, void* buffer) { op_data_counter = 0; } TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { const auto* params = reinterpret_cast(node->builtin_data); @@ -353,7 +354,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // TODO(b/132070898): Use statically slotted OpData structures until a // scratch memory API is ready. - OpData* op_data = &kStaticOpData[kStaticOpDataCounter++]; + OpData* op_data = &kStaticOpData[op_data_counter++]; node->user_data = op_data; // Calculate effective scales. @@ -410,7 +411,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteRegistration* Register_SVDF() { static TfLiteRegistration r = {/*init=*/nullptr, - /*free=*/nullptr, + /*free=*/svdf::Free, /*prepare=*/svdf::Prepare, /*invoke=*/svdf::Eval, /*profiling_string=*/nullptr, diff --git a/tensorflow/lite/micro/micro_allocator.cc b/tensorflow/lite/micro/micro_allocator.cc index c3044a0351f..8585c8fa5b8 100644 --- a/tensorflow/lite/micro/micro_allocator.cc +++ b/tensorflow/lite/micro/micro_allocator.cc @@ -407,7 +407,7 @@ TfLiteStatus MicroAllocator::Init() { TfLiteStatus status = internal::InitializeRuntimeTensor( memory_allocator_, *tensors_->Get(i), model_->buffers(), error_reporter_, &context_->tensors[i]); - if (status == kTfLiteError) { + if (status != kTfLiteOk) { TF_LITE_REPORT_ERROR(error_reporter_, "Failed to initialize tensor %d", i); return kTfLiteError; diff --git a/tensorflow/lite/micro/test_helpers.h b/tensorflow/lite/micro/test_helpers.h index f4e7fa8dfba..010e1f9e336 100644 --- a/tensorflow/lite/micro/test_helpers.h +++ b/tensorflow/lite/micro/test_helpers.h @@ -58,6 +58,9 @@ CreateFlatbufferBuffers(); // Performs a simple string comparison without requiring standard C library. int TestStrcmp(const char* a, const char* b); +// Wrapper to forward kernel errors to the interpreter's error reporter. +void ReportOpError(struct TfLiteContext* context, const char* format, ...); + void PopulateContext(TfLiteTensor* tensors, int tensors_size, TfLiteContext* context); diff --git a/tensorflow/lite/micro/testing/BUILD b/tensorflow/lite/micro/testing/BUILD index 1d8c0745e4a..245e919bb05 100644 --- a/tensorflow/lite/micro/testing/BUILD +++ b/tensorflow/lite/micro/testing/BUILD @@ -22,7 +22,6 @@ cc_library( deps = [ "//tensorflow/lite/c:common", "//tensorflow/lite/core/api", - "//tensorflow/lite/kernels/internal:compatibility", "//tensorflow/lite/micro:micro_framework", "//tensorflow/lite/micro:micro_utils", ], diff --git a/tensorflow/lite/micro/testing/test_utils.cc b/tensorflow/lite/micro/testing/test_utils.cc index 5fd0161d621..9f7803fcf62 100644 --- a/tensorflow/lite/micro/testing/test_utils.cc +++ b/tensorflow/lite/micro/testing/test_utils.cc @@ -15,107 +15,24 @@ limitations under the License. #include "tensorflow/lite/micro/testing/test_utils.h" -#include "tensorflow/lite/kernels/internal/compatibility.h" - namespace tflite { namespace testing { -TfLiteStatus FakeAllocator::AllocatePersistentBuffer(size_t bytes, void** ptr) { - uint8_t* addr = memory_allocator_->AllocateFromTail(bytes, kBufferAlignment); - *ptr = addr; - return kTfLiteOk; -} - -TfLiteStatus FakeAllocator::RequestScratchBufferInArena(int node_idx, - size_t bytes, - int* buffer_idx) { - if (scratch_buffers_count_ >= max_scratch_buffers_count_) { - return kTfLiteError; - } - uint8_t* ptr = memory_allocator_->AllocateFromTail(bytes, kBufferAlignment); - scratch_buffers_[scratch_buffers_count_] = ptr; - *buffer_idx = scratch_buffers_count_; - scratch_buffers_count_++; - return kTfLiteOk; -} - -void FakeAllocator::Reset() { - // Get A fresh memory allocator. - memory_allocator_ = CreateInPlaceSimpleMemoryAllocator(arena_, arena_size_); - TFLITE_DCHECK_NE(memory_allocator_, nullptr); - - // Allocate enough space holding pointers to the scrtach buffers. - scratch_buffers_ = - reinterpret_cast(memory_allocator_->AllocateFromTail( - sizeof(uint8_t*) * max_scratch_buffers_count_, alignof(uint8_t*))); - TFLITE_DCHECK_NE(scratch_buffers_, nullptr); - - scratch_buffers_count_ = 0; -} - -void* FakeAllocator::GetScratchBuffer(int buffer_idx) { - if (buffer_idx < 0 || buffer_idx >= scratch_buffers_count_) { - return nullptr; - } - return scratch_buffers_[buffer_idx]; -} - -TfLiteStatus FakeContextHelper::AllocatePersistentBuffer(TfLiteContext* ctx, - size_t bytes, - void** ptr) { - return reinterpret_cast(ctx->impl_) - ->allocator_->AllocatePersistentBuffer(bytes, ptr); -} - -TfLiteStatus FakeContextHelper::RequestScratchBufferInArena(TfLiteContext* ctx, - size_t bytes, - int* buffer_idx) { - FakeContextHelper* helper = reinterpret_cast(ctx->impl_); - // FakeAllocator doesn't do memory reusing so it doesn't need node_idx to - // calculate the lifetime of the scratch buffer. - int node_idx = -1; - return helper->allocator_->RequestScratchBufferInArena(node_idx, bytes, - buffer_idx); -} - -void* FakeContextHelper::GetScratchBuffer(TfLiteContext* ctx, int buffer_idx) { - return reinterpret_cast(ctx->impl_) - ->allocator_->GetScratchBuffer(buffer_idx); -} - -void FakeContextHelper::ReportOpError(struct TfLiteContext* context, - const char* format, ...) { - FakeContextHelper* helper = static_cast(context->impl_); - va_list args; - va_start(args, format); - TF_LITE_REPORT_ERROR(helper->error_reporter_, format, args); - va_end(args); -} - -namespace { -constexpr size_t kArenaSize = 10000; -constexpr int kMaxScratchBufferCount = 32; -uint8_t arena[kArenaSize]; -} // namespace - // TODO(b/141330728): Move this method elsewhere as part clean up. void PopulateContext(TfLiteTensor* tensors, int tensors_size, ErrorReporter* error_reporter, TfLiteContext* context) { - // This should be a large enough arena for each test cases. - static FakeAllocator allocator(arena, kArenaSize, kMaxScratchBufferCount); - static FakeContextHelper helper(error_reporter, &allocator); - // Reset the allocator so that it's ready for another test. - allocator.Reset(); - - *context = {}; - context->recommended_num_threads = 1; context->tensors_size = tensors_size; context->tensors = tensors; - context->impl_ = static_cast(&helper); - context->AllocatePersistentBuffer = helper.AllocatePersistentBuffer; - context->RequestScratchBufferInArena = helper.RequestScratchBufferInArena; - context->GetScratchBuffer = helper.GetScratchBuffer; - context->ReportError = helper.ReportOpError; + context->impl_ = static_cast(error_reporter); + context->GetExecutionPlan = nullptr; + context->ResizeTensor = nullptr; + context->ReportError = ReportOpError; + context->AddTensors = nullptr; + context->GetNodeAndRegistration = nullptr; + context->ReplaceNodeSubsetsWithDelegateKernels = nullptr; + context->recommended_num_threads = 1; + context->GetExternalContext = nullptr; + context->SetExternalContext = nullptr; for (int i = 0; i < tensors_size; ++i) { if (context->tensors[i].is_variable) { diff --git a/tensorflow/lite/micro/testing/test_utils.h b/tensorflow/lite/micro/testing/test_utils.h index f7f5dff6bb1..7aa1e9d488f 100644 --- a/tensorflow/lite/micro/testing/test_utils.h +++ b/tensorflow/lite/micro/testing/test_utils.h @@ -23,7 +23,6 @@ limitations under the License. #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/core/api/tensor_utils.h" #include "tensorflow/lite/micro/micro_utils.h" -#include "tensorflow/lite/micro/simple_memory_allocator.h" #include "tensorflow/lite/micro/test_helpers.h" #include "tensorflow/lite/micro/testing/micro_test.h" @@ -96,67 +95,7 @@ inline int32_t F2Q32(const float value, const float scale) { return static_cast(quantized); } -// A fake version of MemoryAllocator that allocates everything from the tail -// without static memory planning or reusing. -// TODO(b/150260678): Consider splitting this into its own file and inherit from -// the same public interface as MicroAllocator. -class FakeAllocator { - public: - FakeAllocator(uint8_t* arena, size_t arena_size, - size_t max_scratch_buffers_count) - : arena_(arena), - arena_size_(arena_size), - max_scratch_buffers_count_(max_scratch_buffers_count) { - Reset(); - } - - TfLiteStatus AllocatePersistentBuffer(size_t bytes, void** ptr); - TfLiteStatus RequestScratchBufferInArena(int node_idx, size_t bytes, - int* buffer_idx); - void* GetScratchBuffer(int buffer_idx); - - // Reset the allocator to the intial state. - void Reset(); - - private: - uint8_t* arena_; - size_t arena_size_; - size_t max_scratch_buffers_count_; - - SimpleMemoryAllocator* memory_allocator_; - // An array of buffer pointers. - uint8_t** scratch_buffers_; - size_t scratch_buffers_count_ = 0; - static constexpr size_t kBufferAlignment = 16; -}; - -// A fake implementation of ContextHelper. Instead of forwarding requests to -// MicroAllocator, it calls into FakeAllocator. -// PopulateContext will point context->impl_ to an instance of this class. -// TODO(b/150260678): Consider moving this into the same file as FakeAllocator. -class FakeContextHelper { - public: - explicit FakeContextHelper(ErrorReporter* error_reporter, - FakeAllocator* allocator) - : allocator_(allocator), error_reporter_(error_reporter) {} - - static TfLiteStatus AllocatePersistentBuffer(TfLiteContext* ctx, size_t bytes, - void** ptr); - - static TfLiteStatus RequestScratchBufferInArena(TfLiteContext* ctx, - size_t bytes, - int* buffer_idx); - - static void* GetScratchBuffer(TfLiteContext* ctx, int buffer_idx); - - static void ReportOpError(struct TfLiteContext* context, const char* format, - ...); - - private: - FakeAllocator* allocator_; - ErrorReporter* error_reporter_; -}; - +// TODO(b/141330728): Move this method elsewhere as part clean up. void PopulateContext(TfLiteTensor* tensors, int tensors_size, ErrorReporter* error_reporter, TfLiteContext* context); diff --git a/tensorflow/lite/micro/tools/make/targets/bluepill_makefile.inc b/tensorflow/lite/micro/tools/make/targets/bluepill_makefile.inc index ac066408d9a..0b7e63cae5b 100644 --- a/tensorflow/lite/micro/tools/make/targets/bluepill_makefile.inc +++ b/tensorflow/lite/micro/tools/make/targets/bluepill_makefile.inc @@ -38,7 +38,6 @@ ifeq ($(TARGET), bluepill) -Wno-unused-parameter \ -Wno-write-strings \ -fno-delete-null-pointer-checks \ - -fno-threadsafe-statics \ -fomit-frame-pointer \ -fpermissive \ -fno-use-cxa-atexit \ diff --git a/tensorflow/lite/python/convert.py b/tensorflow/lite/python/convert.py index 0b140ec3826..bf9bee02971 100644 --- a/tensorflow/lite/python/convert.py +++ b/tensorflow/lite/python/convert.py @@ -400,7 +400,10 @@ def build_toco_convert_protos(input_tensors, model.change_concat_input_ranges = change_concat_input_ranges for idx, input_tensor in enumerate(input_tensors): input_array = model.input_arrays.add() - input_array.name = util.get_tensor_name(input_tensor) + if saved_model_dir: + input_array.name = input_tensor.name + else: + input_array.name = util.get_tensor_name(input_tensor) input_array.data_type = util.convert_dtype_to_tflite_type( input_tensor.dtype) @@ -423,7 +426,10 @@ def build_toco_convert_protos(input_tensors, input_array.shape.dims.extend(dims) for output_tensor in output_tensors: - model.output_arrays.append(util.get_tensor_name(output_tensor)) + if saved_model_dir: + model.output_arrays.append(output_tensor.name) + else: + model.output_arrays.append(util.get_tensor_name(output_tensor)) model.allow_nonexistent_arrays = allow_nonexistent_arrays diff --git a/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper_pybind11.cc b/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper_pybind11.cc index 84264ad803a..27159365f69 100644 --- a/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper_pybind11.cc +++ b/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper_pybind11.cc @@ -13,9 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "include/pybind11/pybind11.h" -#include "include/pybind11/pytypes.h" -#include "include/pybind11/stl.h" +#include "pybind11/pybind11.h" +#include "pybind11/pytypes.h" +#include "pybind11/stl.h" #include "tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.h" #include "tensorflow/python/lib/core/pybind11_lib.h" diff --git a/tensorflow/lite/python/lite.py b/tensorflow/lite/python/lite.py index 83b0e2b734c..96f3428efe3 100644 --- a/tensorflow/lite/python/lite.py +++ b/tensorflow/lite/python/lite.py @@ -72,6 +72,7 @@ from tensorflow.python.framework.errors_impl import NotFoundError as _NotFoundEr from tensorflow.python.framework.importer import import_graph_def as _import_graph_def from tensorflow.python.keras.saving import saving_utils as _saving_utils from tensorflow.python.lib.io import file_io as _file_io +from tensorflow.python.saved_model import loader_impl as _loader_impl from tensorflow.python.saved_model import signature_constants as _signature_constants from tensorflow.python.saved_model import tag_constants as _tag_constants from tensorflow.python.saved_model.load import load as _load @@ -382,6 +383,9 @@ class TFLiteConverterBase(object): def _parse_saved_model_args(self): """Parses SavedModel arguments from the given Keras/RNN SavedModel.""" + if not self.experimental_new_converter: + self._saved_model_dir = None + return if self._saved_model_dir: try: saved_model_proto, _ = ( @@ -594,28 +598,47 @@ class TFLiteConverterV2(TFLiteConverterBase): self._parse_saved_model_args() # graph_def is used here to preserve the node bug information - frozen_func, graph_def = ( - _convert_to_constants.convert_variables_to_constants_v2_as_graph( - self._funcs[0], lower_control_flow=False)) - self._graph_def = graph_def - input_tensors = [ - tensor for tensor in frozen_func.inputs - if tensor.dtype != _dtypes.resource - ] - output_tensors = frozen_func.outputs + if self._saved_model_dir: + graph = _ops.Graph() + saved_model = _loader_impl.SavedModelLoader(self._saved_model_dir) + saved_model.load_graph(graph, tags=self._saved_model_tags) + meta_graph = saved_model.get_meta_graph_def_from_tags( + self._saved_model_tags) + signature_def = meta_graph.signature_def[ + _signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] + input_tensors = [ + graph.get_tensor_by_name(signature_def.inputs[key].name) + for key in signature_def.inputs + ] + output_tensors = [ + graph.get_tensor_by_name(signature_def.outputs[key].name) + for key in signature_def.outputs + ] + self._graph_def = graph_def = meta_graph.graph_def + else: + frozen_func, graph_def = ( + _convert_to_constants.convert_variables_to_constants_v2_as_graph( + self._funcs[0], lower_control_flow=False)) + self._graph_def = graph_def - # Run a Grappler pass. - grappler_config = self._grappler_config() - # Skip running grappler when there are no optimizers to run. If not, - # grappler will run with the default optimizer set and it will lead to - # causing an unexpected behavior. - if grappler_config.graph_options.rewrite_options.optimizers: - graph_def = _run_graph_optimizations( - graph_def, - input_tensors, - output_tensors, - config=grappler_config, - graph=frozen_func.graph) + input_tensors = [ + tensor for tensor in frozen_func.inputs + if tensor.dtype != _dtypes.resource + ] + output_tensors = frozen_func.outputs + + # Run a Grappler pass. + grappler_config = self._grappler_config() + # Skip running grappler when there are no optimizers to run. If not, + # grappler will run with the default optimizer set and it will lead to + # causing an unexpected behavior. + if grappler_config.graph_options.rewrite_options.optimizers: + graph_def = _run_graph_optimizations( + graph_def, + input_tensors, + output_tensors, + config=grappler_config, + graph=frozen_func.graph) quant_mode = QuantizationMode(self.optimizations, self.target_spec, self.representative_dataset, graph_def) @@ -1228,28 +1251,29 @@ class TFLiteConverter(TFLiteConverterBase): "are not enabled.") optimized_graph = self._graph_def - # if it is not uint8 or int8 with post-training quantization, it is not - # quantization aware training, then graph optimization is applied. - # Graph optimization is disabled for quantization aware training. - if (self.inference_type != constants.QUANTIZED_UINT8 or - (self.inference_type == constants.INT8 and - (post_training_optimize or weight_only_quantize))): - try: - # TODO(b/150163103): Merge `disabling lower using switch merge' calls. - # Grappler will also try to lower while loop into switch merge - # representation which is undesired for Ophints, so we simply remove - # those attributes to prevent Grappler from doing so. - graph_def = _convert_to_constants.disable_lower_using_switch_merge( - optimized_graph) - # Run function inlining optimization to ensure any models generated - # through the from_frozen_graph path have been inlined. - optimized_graph = _run_graph_optimizations( - graph_def, - self._input_tensors, - self._output_tensors, - config=self._grappler_config(["function"])) - except Exception: - optimized_graph = self._graph_def + if not self._saved_model_dir: + # if it is not uint8 or int8 with post-training quantization, it is not + # quantization aware training, then graph optimization is applied. + # Graph optimization is disabled for quantization aware training. + if (self.inference_type != constants.QUANTIZED_UINT8 or + (self.inference_type == constants.INT8 and + (post_training_optimize or weight_only_quantize))): + try: + # TODO(b/150163103): Merge `disabling lower using switch merge' calls. + # Grappler will also try to lower while loop into switch merge + # representation which is undesired for Ophints, so we simply remove + # those attributes to prevent Grappler from doing so. + graph_def = _convert_to_constants.disable_lower_using_switch_merge( + optimized_graph) + # Run function inlining optimization to ensure any models generated + # through the from_frozen_graph path have been inlined. + optimized_graph = _run_graph_optimizations( + graph_def, + self._input_tensors, + self._output_tensors, + config=self._grappler_config(["function"])) + except Exception: + optimized_graph = self._graph_def self._debug_info = _get_debug_info(self._debug_info_func, optimized_graph) diff --git a/tensorflow/lite/python/lite_test.py b/tensorflow/lite/python/lite_test.py index 6ec70cf0f20..445a8b4cfed 100644 --- a/tensorflow/lite/python/lite_test.py +++ b/tensorflow/lite/python/lite_test.py @@ -529,7 +529,9 @@ class FromSessionTest(TestModels, parameterized.TestCase): shape=[1, 16, 16, 3], dtype=dtypes.float32) var = variable_scope.get_variable( 'weights', shape=[1, 16, 16, 3], dtype=dtypes.float32) - out_tensor = in_tensor + var + # Get the second output to ensure freezing properly processes tensor names + # like 'X:1'. + out_tensor = nn_ops.top_k(in_tensor + var, name='top_k')[1] sess = session.Session() sess.run(_global_variables_initializer()) @@ -552,9 +554,9 @@ class FromSessionTest(TestModels, parameterized.TestCase): output_details = interpreter.get_output_details() self.assertEqual(1, len(output_details)) - self.assertEqual('add', output_details[0]['name']) - self.assertEqual(np.float32, output_details[0]['dtype']) - self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all()) + self.assertEqual('top_k:1', output_details[0]['name']) + self.assertEqual(np.int32, output_details[0]['dtype']) + self.assertTrue(([1, 16, 16, 1] == output_details[0]['shape']).all()) self.assertEqual((0., 0.), output_details[0]['quantization']) def testGraphviz(self): @@ -1632,19 +1634,19 @@ class FromSavedModelTest(TestModels): input_details = interpreter.get_input_details() self.assertEqual(2, len(input_details)) - self.assertEqual('inputA', input_details[0]['name']) + self.assertStartsWith(input_details[0]['name'], 'inputA') self.assertEqual(np.float32, input_details[0]['dtype']) self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all()) self.assertEqual((0., 0.), input_details[0]['quantization']) - self.assertEqual('inputB', input_details[1]['name']) + self.assertStartsWith(input_details[1]['name'], 'inputB') self.assertEqual(np.float32, input_details[1]['dtype']) self.assertTrue(([1, 16, 16, 3] == input_details[1]['shape']).all()) self.assertEqual((0., 0.), input_details[1]['quantization']) output_details = interpreter.get_output_details() self.assertEqual(1, len(output_details)) - self.assertEqual('add', output_details[0]['name']) + self.assertStartsWith(output_details[0]['name'], 'add') self.assertEqual(np.float32, output_details[0]['dtype']) self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all()) self.assertEqual((0., 0.), output_details[0]['quantization']) @@ -1694,19 +1696,19 @@ class FromSavedModelTest(TestModels): input_details = interpreter.get_input_details() self.assertEqual(2, len(input_details)) - self.assertEqual('inputA', input_details[0]['name']) + self.assertStartsWith(input_details[0]['name'], 'inputA') self.assertEqual(np.float32, input_details[0]['dtype']) self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all()) self.assertEqual((0., 0.), input_details[0]['quantization']) - self.assertEqual('inputB', input_details[1]['name']) + self.assertStartsWith(input_details[1]['name'], 'inputB') self.assertEqual(np.float32, input_details[1]['dtype']) self.assertTrue(([1, 16, 16, 3] == input_details[1]['shape']).all()) self.assertEqual((0., 0.), input_details[1]['quantization']) output_details = interpreter.get_output_details() self.assertEqual(1, len(output_details)) - self.assertEqual('add', output_details[0]['name']) + self.assertStartsWith(output_details[0]['name'], 'add') self.assertEqual(np.float32, output_details[0]['dtype']) self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all()) self.assertEqual((0., 0.), output_details[0]['quantization']) @@ -1726,19 +1728,19 @@ class FromSavedModelTest(TestModels): input_details = interpreter.get_input_details() self.assertEqual(2, len(input_details)) - self.assertEqual('inputA', input_details[0]['name']) + self.assertStartsWith(input_details[0]['name'], 'inputA') self.assertEqual(np.float32, input_details[0]['dtype']) self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all()) self.assertEqual((0., 0.), input_details[0]['quantization']) - self.assertEqual('inputB', input_details[1]['name']) + self.assertStartsWith(input_details[1]['name'], 'inputB') self.assertEqual(np.float32, input_details[1]['dtype']) self.assertTrue(([1, 16, 16, 3] == input_details[1]['shape']).all()) self.assertEqual((0., 0.), input_details[1]['quantization']) output_details = interpreter.get_output_details() self.assertEqual(1, len(output_details)) - self.assertEqual('add', output_details[0]['name']) + self.assertStartsWith(output_details[0]['name'], 'add') self.assertEqual(np.float32, output_details[0]['dtype']) self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all()) self.assertEqual((0., 0.), output_details[0]['quantization']) diff --git a/tensorflow/lite/python/lite_v2_test.py b/tensorflow/lite/python/lite_v2_test.py index 9512bdca70d..d04117c1a32 100644 --- a/tensorflow/lite/python/lite_v2_test.py +++ b/tensorflow/lite/python/lite_v2_test.py @@ -373,19 +373,22 @@ class FromSavedModelTest(lite_v2_test_util.ModelTest): input_details = interpreter.get_input_details() self.assertLen(input_details, 2) - self.assertEqual('inputA', input_details[0]['name']) + self.assertStartsWith(input_details[0]['name'], 'inputA') self.assertEqual(np.float32, input_details[0]['dtype']) self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all()) self.assertEqual((0., 0.), input_details[0]['quantization']) - self.assertEqual('inputB', input_details[1]['name']) + self.assertStartsWith( + input_details[1]['name'], + 'inputB', + ) self.assertEqual(np.float32, input_details[1]['dtype']) self.assertTrue(([1, 16, 16, 3] == input_details[1]['shape']).all()) self.assertEqual((0., 0.), input_details[1]['quantization']) output_details = interpreter.get_output_details() self.assertLen(output_details, 1) - self.assertEqual('add', output_details[0]['name']) + self.assertStartsWith(output_details[0]['name'], 'add') self.assertEqual(np.float32, output_details[0]['dtype']) self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all()) self.assertEqual((0., 0.), output_details[0]['quantization']) diff --git a/tensorflow/lite/python/optimize/calibration_wrapper_pybind11.cc b/tensorflow/lite/python/optimize/calibration_wrapper_pybind11.cc index 3d75eccd505..06a617463aa 100644 --- a/tensorflow/lite/python/optimize/calibration_wrapper_pybind11.cc +++ b/tensorflow/lite/python/optimize/calibration_wrapper_pybind11.cc @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "include/pybind11/pybind11.h" -#include "include/pybind11/pytypes.h" +#include "pybind11/pybind11.h" +#include "pybind11/pytypes.h" #include "tensorflow/lite/python/optimize/calibration_wrapper.h" #include "tensorflow/python/lib/core/pybind11_lib.h" diff --git a/tensorflow/lite/python/optimize/sparsification_wrapper_pybind11.cc b/tensorflow/lite/python/optimize/sparsification_wrapper_pybind11.cc index 8a090d5b50a..6c63c83f45e 100644 --- a/tensorflow/lite/python/optimize/sparsification_wrapper_pybind11.cc +++ b/tensorflow/lite/python/optimize/sparsification_wrapper_pybind11.cc @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "include/pybind11/pybind11.h" -#include "include/pybind11/pytypes.h" +#include "pybind11/pybind11.h" +#include "pybind11/pytypes.h" #include "tensorflow/lite/python/optimize/sparsification_wrapper.h" #include "tensorflow/python/lib/core/pybind11_lib.h" diff --git a/tensorflow/lite/python/testdata/test_registerer_wrapper.cc b/tensorflow/lite/python/testdata/test_registerer_wrapper.cc index c50dee4346c..834f2112d14 100644 --- a/tensorflow/lite/python/testdata/test_registerer_wrapper.cc +++ b/tensorflow/lite/python/testdata/test_registerer_wrapper.cc @@ -9,8 +9,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "include/pybind11/pybind11.h" -#include "include/pybind11/pytypes.h" +#include "pybind11/pybind11.h" +#include "pybind11/pytypes.h" #include "tensorflow/lite/python/testdata/test_registerer.h" PYBIND11_MODULE(_pywrap_test_registerer, m) { diff --git a/tensorflow/lite/python/util.py b/tensorflow/lite/python/util.py index faf9ba611ed..32a2d596629 100644 --- a/tensorflow/lite/python/util.py +++ b/tensorflow/lite/python/util.py @@ -263,9 +263,9 @@ def freeze_graph(sess, input_tensors, output_tensors): hinted_outputs_nodes) if not is_frozen_graph(sess): - output_arrays = [get_tensor_name(tensor) for tensor in output_tensors] + output_node_names = [tensor.name.split(":")[0] for tensor in output_tensors] return tf_graph_util.convert_variables_to_constants(sess, graph_def, - output_arrays) + output_node_names) else: return sess.graph_def diff --git a/tensorflow/lite/testing/op_tests/fill.py b/tensorflow/lite/testing/op_tests/fill.py index 541651a5445..d5ef39854d0 100644 --- a/tensorflow/lite/testing/op_tests/fill.py +++ b/tensorflow/lite/testing/op_tests/fill.py @@ -31,7 +31,7 @@ def make_fill_tests(options): test_parameters = [{ "dims_dtype": [tf.int32, tf.int64], "dims_shape": [[], [1], [3], [3, 3]], - "value_dtype": [tf.int32, tf.int64, tf.float32], + "value_dtype": [tf.int32, tf.int64, tf.float32, tf.bool, tf.string], }] def build_graph(parameters): @@ -57,4 +57,4 @@ def make_fill_tests(options): test_parameters, build_graph, build_inputs, - expected_tf_failures=12) + expected_tf_failures=20) diff --git a/tensorflow/lite/testing/op_tests/hardswish.py b/tensorflow/lite/testing/op_tests/hardswish.py index 2816fe5bd64..97dad804f3b 100644 --- a/tensorflow/lite/testing/op_tests/hardswish.py +++ b/tensorflow/lite/testing/op_tests/hardswish.py @@ -48,10 +48,17 @@ def make_hardswish_tests(options): """Make a set of tests to do hardswish.""" # Chose a set of parameters - test_parameters = [{ - "input_shape": [[], [1], [2, 3], [1, 1, 1, 1], [1, 3, 4, 3], - [3, 15, 14, 3], [3, 1, 2, 4, 6], [2, 2, 3, 4, 5, 6]], - }] + if options.run_with_flex: + # Only Flex is able to execute on the data bigger than four dimension. + test_parameters = [{ + "input_shape": [[], [1], [2, 3], [1, 1, 1, 1], [1, 3, 4, 3], + [3, 15, 14, 3], [3, 1, 2, 4, 6], [2, 2, 3, 4, 5, 6]], + }] + else: + test_parameters = [{ + "input_shape": [[], [1], [2, 3], [1, 1, 1, 1], [1, 3, 4, 3], + [3, 15, 14, 3]], + }] def build_graph(parameters): inp = tf.compat.v1.placeholder( diff --git a/tensorflow/lite/testing/string_util_wrapper.cc b/tensorflow/lite/testing/string_util_wrapper.cc index f5b490ab617..8d7d4588c3b 100644 --- a/tensorflow/lite/testing/string_util_wrapper.cc +++ b/tensorflow/lite/testing/string_util_wrapper.cc @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "include/pybind11/pybind11.h" -#include "include/pybind11/pytypes.h" +#include "pybind11/pybind11.h" +#include "pybind11/pytypes.h" #include "tensorflow/lite/testing/string_util.h" #include "tensorflow/python/lib/core/pybind11_lib.h" diff --git a/tensorflow/lite/testing/zip_test_utils.py b/tensorflow/lite/testing/zip_test_utils.py index c5de72ef9f7..8fc00775633 100644 --- a/tensorflow/lite/testing/zip_test_utils.py +++ b/tensorflow/lite/testing/zip_test_utils.py @@ -133,6 +133,11 @@ def create_scalar_data(dtype, min_value=-100, max_value=100): value = (max_value - min_value) * np.random.random() + min_value elif dtype in (tf.int32, tf.uint8, tf.int64, tf.int16): value = np.random.randint(min_value, max_value + 1) + elif dtype == tf.bool: + value = np.random.choice([True, False]) + elif dtype == np.string_: + l = np.random.randint(1, 6) + value = "".join(np.random.choice(list(string.ascii_uppercase), size=l)) return np.array(value, dtype=dtype) diff --git a/tensorflow/lite/toco/tflite/op_version.cc b/tensorflow/lite/toco/tflite/op_version.cc index ac9d575ab64..01a750b5d69 100644 --- a/tensorflow/lite/toco/tflite/op_version.cc +++ b/tensorflow/lite/toco/tflite/op_version.cc @@ -41,6 +41,7 @@ string GetMinimumRuntimeVersionForModel(const Model& model) { new std::map, string>({ {{OperatorType::kAveragePool, 1}, "1.5.0"}, {{OperatorType::kAveragePool, 2}, "1.14.0"}, + {{OperatorType::kAveragePool, 3}, kPendingReleaseOpVersion}, {{OperatorType::kConv, 1}, "1.5.0"}, {{OperatorType::kConv, 2}, "1.14.0"}, {{OperatorType::kConv, 3}, "1.14.0"}, @@ -83,6 +84,7 @@ string GetMinimumRuntimeVersionForModel(const Model& model) { {{OperatorType::kLocalResponseNormalization, 1}, "1.5.0"}, {{OperatorType::kMaxPool, 1}, "1.5.0"}, {{OperatorType::kMaxPool, 2}, "1.14.0"}, + {{OperatorType::kMaxPool, 3}, kPendingReleaseOpVersion}, {{OperatorType::kMaximum, 1}, "1.14.0"}, {{OperatorType::kMaximum, 2}, "1.14.0"}, {{OperatorType::kMaximum, 3}, kPendingReleaseOpVersion}, @@ -224,7 +226,7 @@ string GetMinimumRuntimeVersionForModel(const Model& model) { {{OperatorType::kZerosLike, 1}, "1.12.0"}, {{OperatorType::kAbs, 1}, "1.13.0"}, {{OperatorType::kHardSwish, 1}, "1.15.0"}, - {{OperatorType::kFill, 1}, "1.13.0"}, + {{OperatorType::kFill, 2}, kPendingReleaseOpVersion}, {{OperatorType::kReverseV2, 1}, "1.14.0"}, {{OperatorType::kReverseV2, 2}, "2.2.0"}, {{OperatorType::kRank, 1}, "1.14.0"}, diff --git a/tensorflow/lite/tools/BUILD b/tensorflow/lite/tools/BUILD index 41ccb3df36e..3b8f501e5d8 100644 --- a/tensorflow/lite/tools/BUILD +++ b/tensorflow/lite/tools/BUILD @@ -205,12 +205,37 @@ cc_test( ], ) +cc_library( + name = "logging", + hdrs = ["logging.h"], + copts = common_copts, +) + +cc_library( + name = "tool_params", + srcs = ["tool_params.cc"], + hdrs = ["tool_params.h"], + copts = tflite_copts(), + deps = [":logging"], +) + +cc_test( + name = "tool_params_test", + srcs = ["tool_params_test.cc"], + copts = tflite_copts(), + visibility = ["//visibility:private"], + deps = [ + ":tool_params", + "@com_google_googletest//:gtest_main", + ], +) + cc_library( name = "command_line_flags", srcs = ["command_line_flags.cc"], hdrs = ["command_line_flags.h"], copts = tflite_copts(), - deps = ["//tensorflow/lite:minimal_logging"], + deps = ["//tensorflow/lite/tools:logging"], ) cc_test( @@ -220,8 +245,7 @@ cc_test( visibility = ["//visibility:private"], deps = [ ":command_line_flags", - "//tensorflow/lite/testing:util", - "@com_google_googletest//:gtest", + "@com_google_googletest//:gtest_main", ], ) diff --git a/tensorflow/lite/tools/benchmark/BUILD b/tensorflow/lite/tools/benchmark/BUILD index a979a8a55ef..16bd2644fd8 100644 --- a/tensorflow/lite/tools/benchmark/BUILD +++ b/tensorflow/lite/tools/benchmark/BUILD @@ -17,6 +17,9 @@ cc_library( name = "logging", hdrs = ["logging.h"], copts = common_copts, + deps = [ + "//tensorflow/lite/tools:logging", + ], ) cc_binary( @@ -107,6 +110,7 @@ cc_test( ":benchmark_performance_options", ":benchmark_tflite_model_lib", ":delegate_provider_hdr", + ":logging", "//tensorflow/lite:framework", "//tensorflow/lite:string_util", "//tensorflow/lite/testing:util", @@ -128,6 +132,7 @@ cc_library( "//tensorflow/lite/profiling:profile_summarizer", "//tensorflow/lite/profiling:profile_summary_formatter", "//tensorflow/lite/profiling:profiler", + "//tensorflow/lite/tools:logging", ], ) @@ -142,31 +147,23 @@ cc_library( "//conditions:default": [], }), deps = [ - ":profiling_listener", ":benchmark_model_lib", ":benchmark_utils", ":delegate_provider_hdr", - ":gpu_delegate_provider", - ":hexagon_delegate_provider", - ":external_delegate_provider", ":logging", - ":nnapi_delegate_provider", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/strings", + ":profiling_listener", + ":tflite_execution_providers", "//tensorflow/lite:framework", "//tensorflow/lite:string_util", - "@ruy//ruy/profiler", "//tensorflow/lite/kernels:builtin_ops", "//tensorflow/lite/profiling:platform_profiler", - "//tensorflow/lite/profiling:profiler", "//tensorflow/lite/profiling:profile_summary_formatter", + "//tensorflow/lite/profiling:profiler", "//tensorflow/lite/tools/evaluation:utils", - ] + select({ - "//tensorflow:fuchsia": [], - "//conditions:default": [ - ":xnnpack_delegate_provider", - ], - }), + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/strings", + "@ruy//ruy/profiler", + ], ) cc_library( @@ -197,12 +194,9 @@ cc_library( cc_library( name = "benchmark_params", - srcs = [ - "benchmark_params.cc", - ], hdrs = ["benchmark_params.h"], copts = common_copts, - deps = [":logging"], + deps = ["//tensorflow/lite/tools:tool_params"], ) cc_library( @@ -235,9 +229,41 @@ cc_library( ":benchmark_params", "//tensorflow/lite/c:common", "//tensorflow/lite/tools:command_line_flags", + "//tensorflow/lite/tools/benchmark:logging", ], ) +# A convenient library for all inference execution providers. +cc_library( + name = "tflite_execution_providers", + copts = tflite_copts(), + deps = [ + ":default_execution_provider", + ":external_delegate_provider", + ":gpu_delegate_provider", + ":hexagon_delegate_provider", + ":nnapi_delegate_provider", + ] + select({ + "//tensorflow:fuchsia": [], + "//conditions:default": [ + ":xnnpack_delegate_provider", + ], + }), + alwayslink = 1, +) + +cc_library( + name = "default_execution_provider", + srcs = ["default_execution_provider.cc"], + copts = tflite_copts(), + linkstatic = True, + visibility = ["//visibility:public"], + deps = [ + ":delegate_provider_hdr", + ], + alwayslink = 1, +) + cc_library( name = "gpu_delegate_provider", srcs = ["gpu_delegate_provider.cc"], @@ -248,10 +274,7 @@ cc_library( "//conditions:default": [], }), deps = [ - ":benchmark_model_lib", - ":benchmark_params", ":delegate_provider_hdr", - ":logging", "//tensorflow/lite/tools/evaluation:utils", ] + select({ "//tensorflow:android": [ @@ -270,10 +293,7 @@ cc_library( srcs = ["nnapi_delegate_provider.cc"], copts = common_copts, deps = [ - ":benchmark_model_lib", - ":benchmark_params", ":delegate_provider_hdr", - ":logging", "//tensorflow/lite/tools/evaluation:utils", ], alwayslink = 1, @@ -284,10 +304,7 @@ cc_library( srcs = ["hexagon_delegate_provider.cc"], copts = common_copts, deps = [ - ":benchmark_model_lib", - ":benchmark_params", ":delegate_provider_hdr", - ":logging", "//tensorflow/lite/tools/evaluation:utils", ], alwayslink = 1, @@ -300,9 +317,7 @@ cc_library( linkstatic = True, visibility = ["//visibility:public"], deps = [ - ":benchmark_model_lib", ":delegate_provider_hdr", - ":logging", "//tensorflow/lite/tools/evaluation:utils", ], alwayslink = 1, @@ -315,9 +330,7 @@ cc_library( linkstatic = True, visibility = ["//visibility:public"], deps = [ - ":benchmark_model_lib", ":delegate_provider_hdr", - ":logging", ], alwayslink = 1, ) diff --git a/tensorflow/lite/tools/benchmark/benchmark_params.cc b/tensorflow/lite/tools/benchmark/benchmark_params.cc deleted file mode 100644 index 1dd6a8d519a..00000000000 --- a/tensorflow/lite/tools/benchmark/benchmark_params.cc +++ /dev/null @@ -1,76 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/lite/tools/benchmark/benchmark_params.h" - -#include -#include -#include - -#include "tensorflow/lite/tools/benchmark/logging.h" - -namespace tflite { -namespace benchmark { - -void BenchmarkParam::AssertHasSameType(BenchmarkParam::ParamType a, - BenchmarkParam::ParamType b) { - TFLITE_BENCHMARK_CHECK(a == b) << "Type mismatch while accessing parameter."; -} - -template <> -BenchmarkParam::ParamType BenchmarkParam::GetValueType() { - return BenchmarkParam::ParamType::TYPE_INT32; -} - -template <> -BenchmarkParam::ParamType BenchmarkParam::GetValueType() { - return BenchmarkParam::ParamType::TYPE_BOOL; -} - -template <> -BenchmarkParam::ParamType BenchmarkParam::GetValueType() { - return BenchmarkParam::ParamType::TYPE_FLOAT; -} - -template <> -BenchmarkParam::ParamType BenchmarkParam::GetValueType() { - return BenchmarkParam::ParamType::TYPE_STRING; -} - -void BenchmarkParams::AssertParamExists(const std::string& name) const { - TFLITE_BENCHMARK_CHECK(HasParam(name)) << name << " was not found."; -} - -void BenchmarkParams::Set(const BenchmarkParams& other) { - for (const auto& param : params_) { - const BenchmarkParam* other_param = other.GetParam(param.first); - if (other_param == nullptr) continue; - param.second->Set(*other_param); - } -} - -void BenchmarkParams::Merge(const BenchmarkParams& other, bool overwrite) { - for (const auto& one : other.params_) { - auto it = params_.find(one.first); - if (it == params_.end()) { - AddParam(one.first, one.second->Clone()); - } else if (overwrite) { - it->second->Set(*one.second); - } - } -} - -} // namespace benchmark -} // namespace tflite diff --git a/tensorflow/lite/tools/benchmark/benchmark_params.h b/tensorflow/lite/tools/benchmark/benchmark_params.h index 1b3dabf3f7b..9037c35869f 100644 --- a/tensorflow/lite/tools/benchmark/benchmark_params.h +++ b/tensorflow/lite/tools/benchmark/benchmark_params.h @@ -15,123 +15,12 @@ limitations under the License. #ifndef TENSORFLOW_LITE_TOOLS_BENCHMARK_BENCHMARK_PARAMS_H_ #define TENSORFLOW_LITE_TOOLS_BENCHMARK_BENCHMARK_PARAMS_H_ -#include -#include -#include -#include -#include - -#include "tensorflow/lite/tools/benchmark/logging.h" +#include "tensorflow/lite/tools/tool_params.h" namespace tflite { namespace benchmark { - -template -class TypedBenchmarkParam; - -class BenchmarkParam { - protected: - enum class ParamType { TYPE_INT32, TYPE_FLOAT, TYPE_BOOL, TYPE_STRING }; - template - static ParamType GetValueType(); - - public: - template - static std::unique_ptr Create(const T& default_value) { - return std::unique_ptr( - new TypedBenchmarkParam(default_value)); - } - - template - TypedBenchmarkParam* AsTyped() { - AssertHasSameType(GetValueType(), type_); - return static_cast*>(this); - } - - template - const TypedBenchmarkParam* AsConstTyped() const { - AssertHasSameType(GetValueType(), type_); - return static_cast*>(this); - } - - virtual ~BenchmarkParam() {} - explicit BenchmarkParam(ParamType type) : type_(type) {} - - virtual void Set(const BenchmarkParam&) {} - - virtual std::unique_ptr Clone() const = 0; - - private: - static void AssertHasSameType(ParamType a, ParamType b); - - const ParamType type_; -}; - -template -class TypedBenchmarkParam : public BenchmarkParam { - public: - explicit TypedBenchmarkParam(const T& value) - : BenchmarkParam(GetValueType()), value_(value) {} - - void Set(const T& value) { value_ = value; } - - T Get() const { return value_; } - - void Set(const BenchmarkParam& other) override { - Set(other.AsConstTyped()->Get()); - } - - std::unique_ptr Clone() const override { - return std::unique_ptr(new TypedBenchmarkParam(value_)); - } - - private: - T value_; -}; - -class BenchmarkParams { - public: - void AddParam(const std::string& name, - std::unique_ptr value) { - params_[name] = std::move(value); - } - - bool HasParam(const std::string& name) const { - return params_.find(name) != params_.end(); - } - - bool Empty() const { return params_.empty(); } - - const BenchmarkParam* GetParam(const std::string& name) const { - const auto& entry = params_.find(name); - if (entry == params_.end()) return nullptr; - return entry->second.get(); - } - - template - void Set(const std::string& name, const T& value) { - AssertParamExists(name); - params_.at(name)->AsTyped()->Set(value); - } - - template - T Get(const std::string& name) const { - AssertParamExists(name); - return params_.at(name)->AsTyped()->Get(); - } - - // Set the value of all same parameters from 'other'. - void Set(const BenchmarkParams& other); - - // Merge the value of all parameters from 'other'. 'overwrite' indicates - // whether the value of the same paratmeter is overwrite or not. - void Merge(const BenchmarkParams& other, bool overwrite = false); - - private: - void AssertParamExists(const std::string& name) const; - std::unordered_map> params_; -}; - +using BenchmarkParam = tflite::tools::ToolParam; +using BenchmarkParams = tflite::tools::ToolParams; } // namespace benchmark } // namespace tflite #endif // TENSORFLOW_LITE_TOOLS_BENCHMARK_BENCHMARK_PARAMS_H_ diff --git a/tensorflow/lite/tools/benchmark/benchmark_test.cc b/tensorflow/lite/tools/benchmark/benchmark_test.cc index da4082926a2..3560f866eff 100644 --- a/tensorflow/lite/tools/benchmark/benchmark_test.cc +++ b/tensorflow/lite/tools/benchmark/benchmark_test.cc @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/lite/tools/benchmark/benchmark_performance_options.h" #include "tensorflow/lite/tools/benchmark/benchmark_tflite_model.h" #include "tensorflow/lite/tools/benchmark/delegate_provider.h" +#include "tensorflow/lite/tools/benchmark/logging.h" #include "tensorflow/lite/tools/command_line_flags.h" namespace { @@ -80,14 +81,13 @@ BenchmarkParams CreateParams(int32_t num_runs, float min_secs, float max_secs, params.AddParam("enable_op_profiling", BenchmarkParam::Create(false)); params.AddParam("max_profiling_buffer_entries", BenchmarkParam::Create(1024)); - params.AddParam("max_delegated_partitions", BenchmarkParam::Create(0)); params.AddParam("profiling_output_csv_file", BenchmarkParam::Create("")); params.AddParam("enable_platform_tracing", BenchmarkParam::Create(false)); for (const auto& delegate_provider : GetRegisteredDelegateProviders()) { - delegate_provider->AddParams(¶ms); + params.Merge(delegate_provider->DefaultParams()); } return params; } diff --git a/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc b/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc index a451eab5448..261825923fd 100644 --- a/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc +++ b/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc @@ -39,7 +39,6 @@ limitations under the License. #include "tensorflow/lite/tools/benchmark/delegate_provider.h" #include "tensorflow/lite/tools/benchmark/logging.h" #include "tensorflow/lite/tools/benchmark/profiling_listener.h" -#include "tensorflow/lite/tools/evaluation/utils.h" void RegisterSelectedOps(::tflite::MutableOpResolver* resolver); @@ -270,13 +269,11 @@ BenchmarkParams BenchmarkTfLiteModel::DefaultParams() { BenchmarkParam::Create(1024)); default_params.AddParam("profiling_output_csv_file", BenchmarkParam::Create("")); - default_params.AddParam("max_delegated_partitions", - BenchmarkParam::Create(0)); default_params.AddParam("enable_platform_tracing", BenchmarkParam::Create(false)); for (const auto& delegate_util : GetRegisteredDelegateProviders()) { - delegate_util->AddParams(&default_params); + default_params.Merge(delegate_util->DefaultParams()); } return default_params; @@ -296,7 +293,7 @@ void BenchmarkTfLiteModel::CleanUp() { BenchmarkTfLiteModel::~BenchmarkTfLiteModel() { CleanUp(); } std::vector BenchmarkTfLiteModel::GetFlags() { - std::vector flags = BenchmarkTfLiteModel::BenchmarkModel::GetFlags(); + std::vector flags = BenchmarkModel::GetFlags(); std::vector specific_flags = { CreateFlag("graph", ¶ms_, "graph file name"), CreateFlag("input_layer", ¶ms_, "input layer names"), @@ -329,8 +326,6 @@ std::vector BenchmarkTfLiteModel::GetFlags() { "profiling_output_csv_file", ¶ms_, "File path to export profile data as CSV, if not set " "prints to stdout."), - CreateFlag("max_delegated_partitions", ¶ms_, - "Max partitions to be delegated."), CreateFlag("enable_platform_tracing", ¶ms_, "enable platform-wide tracing, only meaningful when " "--enable_op_profiling is set to true.")}; @@ -374,8 +369,6 @@ void BenchmarkTfLiteModel::LogParams() { TFLITE_LOG(INFO) << "CSV File to export profiling data to: [" << params_.Get("profiling_output_csv_file") << "]"; - TFLITE_LOG(INFO) << "Max number of delegated partitions : [" - << params_.Get("max_delegated_partitions") << "]"; TFLITE_LOG(INFO) << "Enable platform-wide tracing: [" << params_.Get("enable_platform_tracing") << "]"; diff --git a/tensorflow/lite/tools/benchmark/default_execution_provider.cc b/tensorflow/lite/tools/benchmark/default_execution_provider.cc new file mode 100644 index 00000000000..f7204cba954 --- /dev/null +++ b/tensorflow/lite/tools/benchmark/default_execution_provider.cc @@ -0,0 +1,64 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include + +#include "tensorflow/lite/tools/benchmark/delegate_provider.h" + +namespace tflite { +namespace benchmark { + +// This class actually doesn't provide any TFLite delegate instances, it simply +// provides common params and flags that are common to all actual delegate +// providers. +class DefaultExecutionProvider : public DelegateProvider { + public: + DefaultExecutionProvider() { + default_params_.AddParam("num_threads", BenchmarkParam::Create(1)); + default_params_.AddParam("max_delegated_partitions", + BenchmarkParam::Create(0)); + } + + std::vector CreateFlags(BenchmarkParams* params) const final; + void LogParams(const BenchmarkParams& params) const final; + TfLiteDelegatePtr CreateTfLiteDelegate( + const BenchmarkParams& params) const final; + std::string GetName() const final { return "Default-NoDelegate"; } +}; +REGISTER_DELEGATE_PROVIDER(DefaultExecutionProvider); + +std::vector DefaultExecutionProvider::CreateFlags( + BenchmarkParams* params) const { + std::vector flags = { + CreateFlag("num_threads", params, + "number of threads used for inference on CPU."), + CreateFlag("max_delegated_partitions", params, + "Max number of partitions to be delegated.")}; + return flags; +} + +void DefaultExecutionProvider::LogParams(const BenchmarkParams& params) const { + TFLITE_LOG(INFO) << "#threads used for CPU inference: [" + << params.Get("num_threads") << "]"; + TFLITE_LOG(INFO) << "Max number of delegated partitions : [" + << params.Get("max_delegated_partitions") << "]"; +} + +TfLiteDelegatePtr DefaultExecutionProvider::CreateTfLiteDelegate( + const BenchmarkParams& params) const { + return TfLiteDelegatePtr(nullptr, [](TfLiteDelegate*) {}); +} + +} // namespace benchmark +} // namespace tflite diff --git a/tensorflow/lite/tools/benchmark/delegate_provider.h b/tensorflow/lite/tools/benchmark/delegate_provider.h index f9a742c997e..a1531de4cad 100644 --- a/tensorflow/lite/tools/benchmark/delegate_provider.h +++ b/tensorflow/lite/tools/benchmark/delegate_provider.h @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/tools/benchmark/benchmark_params.h" +#include "tensorflow/lite/tools/benchmark/logging.h" #include "tensorflow/lite/tools/command_line_flags.h" namespace tflite { @@ -40,9 +41,6 @@ class DelegateProvider { // value. virtual std::vector CreateFlags(BenchmarkParams* params) const = 0; - // Add delegate-specific benchmark pararms to 'params' - virtual void AddParams(BenchmarkParams* params) const = 0; - // Log benchmark params. virtual void LogParams(const BenchmarkParams& params) const = 0; @@ -51,6 +49,18 @@ class DelegateProvider { const BenchmarkParams& params) const = 0; virtual std::string GetName() const = 0; + + const BenchmarkParams& DefaultParams() const { return default_params_; } + + protected: + template + Flag CreateFlag(const char* name, BenchmarkParams* params, + const std::string& usage) const { + return Flag( + name, [params, name](const T& val) { params->Set(name, val); }, + default_params_.Get(name), usage, Flag::OPTIONAL); + } + BenchmarkParams default_params_; }; using DelegateProviderPtr = std::unique_ptr; diff --git a/tensorflow/lite/tools/benchmark/external_delegate_provider.cc b/tensorflow/lite/tools/benchmark/external_delegate_provider.cc index 9174b4a1f95..a5d8a941697 100644 --- a/tensorflow/lite/tools/benchmark/external_delegate_provider.cc +++ b/tensorflow/lite/tools/benchmark/external_delegate_provider.cc @@ -12,9 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/lite/tools/benchmark/benchmark_model.h" #include "tensorflow/lite/tools/benchmark/delegate_provider.h" -#include "tensorflow/lite/tools/benchmark/logging.h" #if defined(_WIN32) #include @@ -97,9 +95,14 @@ struct ExternalLib { // the generated delegates. class ExternalDelegateProvider : public DelegateProvider { public: - std::vector CreateFlags(BenchmarkParams* params) const final; + ExternalDelegateProvider() { + default_params_.AddParam("external_delegate_path", + BenchmarkParam::Create("")); + default_params_.AddParam("external_delegate_options", + BenchmarkParam::Create("")); + } - void AddParams(BenchmarkParams* params) const final; + std::vector CreateFlags(BenchmarkParams* params) const final; void LogParams(const BenchmarkParams& params) const final; @@ -121,13 +124,6 @@ std::vector ExternalDelegateProvider::CreateFlags( return flags; } -void ExternalDelegateProvider::AddParams(BenchmarkParams* params) const { - params->AddParam("external_delegate_path", - BenchmarkParam::Create("")); - params->AddParam("external_delegate_options", - BenchmarkParam::Create("")); -} - void ExternalDelegateProvider::LogParams(const BenchmarkParams& params) const { TFLITE_LOG(INFO) << "External delegate path : [" << params.Get("external_delegate_path") << "]"; diff --git a/tensorflow/lite/tools/benchmark/gpu_delegate_provider.cc b/tensorflow/lite/tools/benchmark/gpu_delegate_provider.cc index 2deb111652f..72195e183e6 100644 --- a/tensorflow/lite/tools/benchmark/gpu_delegate_provider.cc +++ b/tensorflow/lite/tools/benchmark/gpu_delegate_provider.cc @@ -14,9 +14,7 @@ limitations under the License. ==============================================================================*/ #include -#include "tensorflow/lite/tools/benchmark/benchmark_model.h" #include "tensorflow/lite/tools/benchmark/delegate_provider.h" -#include "tensorflow/lite/tools/benchmark/logging.h" #include "tensorflow/lite/tools/evaluation/utils.h" #if defined(__ANDROID__) #include "tensorflow/lite/delegates/gpu/delegate.h" @@ -34,9 +32,19 @@ namespace benchmark { class GpuDelegateProvider : public DelegateProvider { public: - std::vector CreateFlags(BenchmarkParams* params) const final; + GpuDelegateProvider() { + default_params_.AddParam("use_gpu", BenchmarkParam::Create(false)); +#if defined(__ANDROID__) || defined(REAL_IPHONE_DEVICE) + default_params_.AddParam("gpu_precision_loss_allowed", + BenchmarkParam::Create(true)); +#endif +#if defined(REAL_IPHONE_DEVICE) + default_params_.AddParam("gpu_wait_type", + BenchmarkParam::Create("")); +#endif + } - void AddParams(BenchmarkParams* params) const final; + std::vector CreateFlags(BenchmarkParams* params) const final; void LogParams(const BenchmarkParams& params) const final; @@ -66,17 +74,6 @@ std::vector GpuDelegateProvider::CreateFlags( return flags; } -void GpuDelegateProvider::AddParams(BenchmarkParams* params) const { - params->AddParam("use_gpu", BenchmarkParam::Create(false)); -#if defined(__ANDROID__) || defined(REAL_IPHONE_DEVICE) - params->AddParam("gpu_precision_loss_allowed", - BenchmarkParam::Create(true)); -#endif -#if defined(REAL_IPHONE_DEVICE) - params->AddParam("gpu_wait_type", BenchmarkParam::Create("")); -#endif -} - void GpuDelegateProvider::LogParams(const BenchmarkParams& params) const { TFLITE_LOG(INFO) << "Use gpu : [" << params.Get("use_gpu") << "]"; #if defined(__ANDROID__) || defined(REAL_IPHONE_DEVICE) diff --git a/tensorflow/lite/tools/benchmark/hexagon_delegate_provider.cc b/tensorflow/lite/tools/benchmark/hexagon_delegate_provider.cc index 4b341a1d6c3..b06d4f4c94a 100644 --- a/tensorflow/lite/tools/benchmark/hexagon_delegate_provider.cc +++ b/tensorflow/lite/tools/benchmark/hexagon_delegate_provider.cc @@ -14,9 +14,7 @@ limitations under the License. ==============================================================================*/ #include -#include "tensorflow/lite/tools/benchmark/benchmark_model.h" #include "tensorflow/lite/tools/benchmark/delegate_provider.h" -#include "tensorflow/lite/tools/benchmark/logging.h" #include "tensorflow/lite/tools/evaluation/utils.h" #if (defined(ANDROID) || defined(__ANDROID__)) && \ @@ -29,9 +27,19 @@ namespace benchmark { class HexagonDelegateProvider : public DelegateProvider { public: - std::vector CreateFlags(BenchmarkParams* params) const final; + HexagonDelegateProvider() { +#if defined(TFLITE_ENABLE_HEXAGON) + default_params_.AddParam("use_hexagon", + BenchmarkParam::Create(false)); + default_params_.AddParam( + "hexagon_lib_path", + BenchmarkParam::Create("/data/local/tmp")); + default_params_.AddParam("hexagon_profiling", + BenchmarkParam::Create(false)); +#endif + } - void AddParams(BenchmarkParams* params) const final; + std::vector CreateFlags(BenchmarkParams* params) const final; void LogParams(const BenchmarkParams& params) const final; @@ -58,15 +66,6 @@ std::vector HexagonDelegateProvider::CreateFlags( #endif } -void HexagonDelegateProvider::AddParams(BenchmarkParams* params) const { -#if defined(TFLITE_ENABLE_HEXAGON) - params->AddParam("use_hexagon", BenchmarkParam::Create(false)); - params->AddParam("hexagon_lib_path", - BenchmarkParam::Create("/data/local/tmp")); - params->AddParam("hexagon_profiling", BenchmarkParam::Create(false)); -#endif -} - void HexagonDelegateProvider::LogParams(const BenchmarkParams& params) const { #if defined(TFLITE_ENABLE_HEXAGON) TFLITE_LOG(INFO) << "Use Hexagon : [" << params.Get("use_hexagon") diff --git a/tensorflow/lite/tools/benchmark/logging.h b/tensorflow/lite/tools/benchmark/logging.h index 808090bf21f..ff3287026b2 100644 --- a/tensorflow/lite/tools/benchmark/logging.h +++ b/tensorflow/lite/tools/benchmark/logging.h @@ -16,74 +16,10 @@ limitations under the License. #ifndef TENSORFLOW_LITE_TOOLS_BENCHMARK_LOGGING_H_ #define TENSORFLOW_LITE_TOOLS_BENCHMARK_LOGGING_H_ -// LOG and CHECK macros for benchmarks. +// TODO(b/149482807): completely remove this file from the code base. +#include "tensorflow/lite/tools/logging.h" -#include -#include -#include - -#ifdef _WIN32 -#undef ERROR -#endif - -namespace tflite { -namespace logging { -// A wrapper that logs to stderr. -// -// Used for TFLITE_LOG and TFLITE_BENCHMARK_CHECK macros. -class LoggingWrapper { - public: - enum class LogSeverity : int { - INFO = 0, - WARN = 1, - ERROR = 2, - FATAL = 3, - }; - LoggingWrapper(LogSeverity severity) - : severity_(severity), should_log_(true) {} - LoggingWrapper(LogSeverity severity, bool log) - : severity_(severity), should_log_(log) {} - std::stringstream& Stream() { return stream_; } - ~LoggingWrapper() { - if (should_log_) { - switch (severity_) { - case LogSeverity::INFO: - case LogSeverity::WARN: - std::cout << stream_.str() << std::endl; - break; - case LogSeverity::ERROR: - std::cerr << stream_.str() << std::endl; - break; - case LogSeverity::FATAL: - std::cerr << stream_.str() << std::endl; - std::flush(std::cerr); - std::abort(); - break; - } - } - } - - private: - std::stringstream stream_; - LogSeverity severity_; - bool should_log_; -}; - -} // namespace logging - -} // namespace tflite - -#define TFLITE_LOG(severity) \ - tflite::logging::LoggingWrapper( \ - tflite::logging::LoggingWrapper::LogSeverity::severity) \ - .Stream() - -#define TFLITE_BENCHMARK_CHECK(condition) \ - tflite::logging::LoggingWrapper( \ - tflite::logging::LoggingWrapper::LogSeverity::FATAL, \ - (condition) ? false : true) \ - .Stream() - -#define TFLITE_BENCHMARK_CHECK_EQ(a, b) TFLITE_BENCHMARK_CHECK(a == b) +#define TFLITE_BENCHMARK_CHECK(condition) TFLITE_TOOLS_CHECK(condition) +#define TFLITE_BENCHMARK_CHECK_EQ(a, b) TFLITE_TOOLS_CHECK(a == b) #endif // TENSORFLOW_LITE_TOOLS_BENCHMARK_LOGGING_H_ diff --git a/tensorflow/lite/tools/benchmark/nnapi_delegate_provider.cc b/tensorflow/lite/tools/benchmark/nnapi_delegate_provider.cc index 3f87de863e7..04aa318b789 100644 --- a/tensorflow/lite/tools/benchmark/nnapi_delegate_provider.cc +++ b/tensorflow/lite/tools/benchmark/nnapi_delegate_provider.cc @@ -14,9 +14,7 @@ limitations under the License. ==============================================================================*/ #include -#include "tensorflow/lite/tools/benchmark/benchmark_model.h" #include "tensorflow/lite/tools/benchmark/delegate_provider.h" -#include "tensorflow/lite/tools/benchmark/logging.h" #include "tensorflow/lite/tools/evaluation/utils.h" #if defined(__ANDROID__) #include "tensorflow/lite/nnapi/nnapi_util.h" @@ -27,9 +25,17 @@ namespace benchmark { class NnapiDelegateProvider : public DelegateProvider { public: - std::vector CreateFlags(BenchmarkParams* params) const final; + NnapiDelegateProvider() { + default_params_.AddParam("use_nnapi", BenchmarkParam::Create(false)); + default_params_.AddParam("nnapi_execution_preference", + BenchmarkParam::Create("")); + default_params_.AddParam("nnapi_accelerator_name", + BenchmarkParam::Create("")); + default_params_.AddParam("disable_nnapi_cpu", + BenchmarkParam::Create(false)); + } - void AddParams(BenchmarkParams* params) const final; + std::vector CreateFlags(BenchmarkParams* params) const final; void LogParams(const BenchmarkParams& params) const final; @@ -57,15 +63,6 @@ std::vector NnapiDelegateProvider::CreateFlags( return flags; } -void NnapiDelegateProvider::AddParams(BenchmarkParams* params) const { - params->AddParam("use_nnapi", BenchmarkParam::Create(false)); - params->AddParam("nnapi_execution_preference", - BenchmarkParam::Create("")); - params->AddParam("nnapi_accelerator_name", - BenchmarkParam::Create("")); - params->AddParam("disable_nnapi_cpu", BenchmarkParam::Create(false)); -} - void NnapiDelegateProvider::LogParams(const BenchmarkParams& params) const { #if defined(__ANDROID__) TFLITE_LOG(INFO) << "Use nnapi : [" << params.Get("use_nnapi") << "]"; diff --git a/tensorflow/lite/tools/benchmark/profiling_listener.cc b/tensorflow/lite/tools/benchmark/profiling_listener.cc index 50df69c4b7c..ddd653e757d 100644 --- a/tensorflow/lite/tools/benchmark/profiling_listener.cc +++ b/tensorflow/lite/tools/benchmark/profiling_listener.cc @@ -17,6 +17,8 @@ limitations under the License. #include +#include "tensorflow/lite/tools/logging.h" + namespace tflite { namespace benchmark { @@ -29,7 +31,7 @@ ProfilingListener::ProfilingListener( csv_file_path_(csv_file_path), interpreter_(interpreter), profiler_(max_num_entries) { - TFLITE_BENCHMARK_CHECK(interpreter); + TFLITE_TOOLS_CHECK(interpreter); interpreter_->SetProfiler(&profiler_); // We start profiling here in order to catch events that are recorded during diff --git a/tensorflow/lite/tools/benchmark/xnnpack_delegate_provider.cc b/tensorflow/lite/tools/benchmark/xnnpack_delegate_provider.cc index 8fa9e7de69a..72226396949 100644 --- a/tensorflow/lite/tools/benchmark/xnnpack_delegate_provider.cc +++ b/tensorflow/lite/tools/benchmark/xnnpack_delegate_provider.cc @@ -14,9 +14,7 @@ limitations under the License. ==============================================================================*/ #include -#include "tensorflow/lite/tools/benchmark/benchmark_model.h" #include "tensorflow/lite/tools/benchmark/delegate_provider.h" -#include "tensorflow/lite/tools/benchmark/logging.h" #include "tensorflow/lite/tools/evaluation/utils.h" namespace tflite { @@ -24,9 +22,12 @@ namespace benchmark { class XnnpackDelegateProvider : public DelegateProvider { public: - std::vector CreateFlags(BenchmarkParams* params) const final; + XnnpackDelegateProvider() { + default_params_.AddParam("use_xnnpack", + BenchmarkParam::Create(false)); + } - void AddParams(BenchmarkParams* params) const final; + std::vector CreateFlags(BenchmarkParams* params) const final; void LogParams(const BenchmarkParams& params) const final; @@ -44,10 +45,6 @@ std::vector XnnpackDelegateProvider::CreateFlags( return flags; } -void XnnpackDelegateProvider::AddParams(BenchmarkParams* params) const { - params->AddParam("use_xnnpack", BenchmarkParam::Create(false)); -} - void XnnpackDelegateProvider::LogParams(const BenchmarkParams& params) const { TFLITE_LOG(INFO) << "Use xnnpack : [" << params.Get("use_xnnpack") << "]"; diff --git a/tensorflow/lite/tools/command_line_flags.cc b/tensorflow/lite/tools/command_line_flags.cc index 841424421e0..c565a5f1484 100644 --- a/tensorflow/lite/tools/command_line_flags.cc +++ b/tensorflow/lite/tools/command_line_flags.cc @@ -18,10 +18,11 @@ limitations under the License. #include #include #include +#include #include #include -#include "tensorflow/lite/minimal_logging.h" +#include "tensorflow/lite/tools/logging.h" namespace tflite { namespace { @@ -165,7 +166,12 @@ std::string Flag::GetTypeName() const { /*static*/ bool Flags::Parse(int* argc, const char** argv, const std::vector& flag_list) { bool result = true; - std::vector unknown_flags(*argc, true); + std::vector unknown_argvs(*argc, true); + // Record the list of flags that have been processed. key is the flag's name + // and the value is the corresponding argv index if there's one, or -1 when + // the argv list doesn't contain this flag. + std::unordered_map processed_flags; + // Stores indexes of flag_list in a sorted order. std::vector sorted_idx(flag_list.size()); std::iota(std::begin(sorted_idx), std::end(sorted_idx), 0); @@ -174,45 +180,69 @@ std::string Flag::GetTypeName() const { }); int positional_count = 0; - for (int i = 0; i < sorted_idx.size(); ++i) { - const Flag& flag = flag_list[sorted_idx[i]]; + for (int idx = 0; idx < sorted_idx.size(); ++idx) { + const Flag& flag = flag_list[sorted_idx[idx]]; + + const auto it = processed_flags.find(flag.name_); + if (it != processed_flags.end()) { + TFLITE_LOG(WARN) << "Duplicate flags: " << flag.name_; + if (it->second != -1) { + bool value_parsing_ok; + flag.Parse(argv[it->second], &value_parsing_ok); + if (!value_parsing_ok) { + TFLITE_LOG(ERROR) << "Failed to parse flag '" << flag.name_ + << "' against argv '" << argv[it->second] << "'"; + result = false; + } + continue; + } else if (flag.flag_type_ == Flag::REQUIRED) { + // Check if required flag not found. + TFLITE_LOG(ERROR) << "Required flag not provided: " << flag.name_; + result = false; + break; + } + } + // Parses positional flags. if (flag.flag_type_ == Flag::POSITIONAL) { if (++positional_count >= *argc) { - TFLITE_LOG(TFLITE_LOG_ERROR, "Too few command line arguments"); + TFLITE_LOG(ERROR) << "Too few command line arguments."; return false; } bool value_parsing_ok; flag.Parse(argv[positional_count], &value_parsing_ok); if (!value_parsing_ok) { - TFLITE_LOG(TFLITE_LOG_ERROR, "Failed to parse positional flag: %s", - flag.name_.c_str()); + TFLITE_LOG(ERROR) << "Failed to parse positional flag: " << flag.name_; return false; } - unknown_flags[positional_count] = false; + unknown_argvs[positional_count] = false; + processed_flags[flag.name_] = positional_count; continue; } // Parse other flags. bool was_found = false; for (int i = positional_count + 1; i < *argc; ++i) { - if (!unknown_flags[i]) continue; + if (!unknown_argvs[i]) continue; bool value_parsing_ok; was_found = flag.Parse(argv[i], &value_parsing_ok); if (!value_parsing_ok) { - TFLITE_LOG(TFLITE_LOG_ERROR, "Failed to parse flag: %s", - flag.name_.c_str()); + TFLITE_LOG(ERROR) << "Failed to parse flag '" << flag.name_ + << "' against argv '" << argv[i] << "'"; result = false; } if (was_found) { - unknown_flags[i] = false; + unknown_argvs[i] = false; + processed_flags[flag.name_] = i; break; } } + if (!was_found) { + processed_flags[flag.name_] = -1; + } // Check if required flag not found. if (flag.flag_type_ == Flag::REQUIRED && !was_found) { - TFLITE_LOG(TFLITE_LOG_ERROR, "Required flag not provided: %s", - flag.name_.c_str()); + TFLITE_LOG(ERROR) << "Required flag not provided: " << flag.name_; result = false; break; } @@ -220,7 +250,7 @@ std::string Flag::GetTypeName() const { int dst = 1; // Skip argv[0] for (int i = 1; i < *argc; ++i) { - if (unknown_flags[i]) { + if (unknown_argvs[i]) { argv[dst++] = argv[i]; } } diff --git a/tensorflow/lite/tools/command_line_flags.h b/tensorflow/lite/tools/command_line_flags.h index 2808a12a489..941a1b8b59a 100644 --- a/tensorflow/lite/tools/command_line_flags.h +++ b/tensorflow/lite/tools/command_line_flags.h @@ -125,6 +125,14 @@ class Flags { // with matching flags, and remove the matching arguments from (*argc, argv). // Return true iff all recognized flag values were parsed correctly, and the // first remaining argument is not "--help". + // Note: + // 1. when there are duplicate args in argv for the same flag, the flag value + // and the parse result will be based on the 1st arg. + // 2. when there are duplicate flags in flag_list (i.e. two flags having the + // same name), all of them will be checked against the arg list and the parse + // result will be false if any of the parsing fails. + // See *Duplicate* unit tests in command_line_flags_test.cc for the + // illustration of such behaviors. static bool Parse(int* argc, const char** argv, const std::vector& flag_list); diff --git a/tensorflow/lite/tools/command_line_flags_test.cc b/tensorflow/lite/tools/command_line_flags_test.cc index 1354c6d503b..a9a351d315f 100644 --- a/tensorflow/lite/tools/command_line_flags_test.cc +++ b/tensorflow/lite/tools/command_line_flags_test.cc @@ -17,7 +17,6 @@ limitations under the License. #include #include -#include "tensorflow/lite/testing/util.h" namespace tflite { namespace { @@ -60,12 +59,12 @@ TEST(CommandLineFlagsTest, BasicUsage) { Flag::CreateFlag("float_1", &float_1, "some float", Flag::POSITIONAL), }); - EXPECT_EQ(true, parsed_ok); + EXPECT_TRUE(parsed_ok); EXPECT_EQ(20, some_int32); EXPECT_EQ(8, some_int1); EXPECT_EQ(5, some_int2); EXPECT_EQ(214748364700, some_int64); - EXPECT_EQ(true, some_switch); + EXPECT_TRUE(some_switch); EXPECT_EQ("somethingelse", some_name); EXPECT_NEAR(42.0f, some_float, 1e-5f); EXPECT_NEAR(12.2f, float_1, 1e-5f); @@ -82,7 +81,7 @@ TEST(CommandLineFlagsTest, EmptyStringFlag) { &argc, reinterpret_cast(argv_strings), {Flag::CreateFlag("some_string", &some_string, "some string")}); - EXPECT_EQ(true, parsed_ok); + EXPECT_TRUE(parsed_ok); EXPECT_EQ(some_string, ""); EXPECT_EQ(argc, 1); } @@ -95,7 +94,7 @@ TEST(CommandLineFlagsTest, BadIntValue) { Flags::Parse(&argc, reinterpret_cast(argv_strings), {Flag::CreateFlag("some_int", &some_int, "some int")}); - EXPECT_EQ(false, parsed_ok); + EXPECT_FALSE(parsed_ok); EXPECT_EQ(10, some_int); EXPECT_EQ(argc, 1); } @@ -108,8 +107,8 @@ TEST(CommandLineFlagsTest, BadBoolValue) { &argc, reinterpret_cast(argv_strings), {Flag::CreateFlag("some_switch", &some_switch, "some switch")}); - EXPECT_EQ(false, parsed_ok); - EXPECT_EQ(false, some_switch); + EXPECT_FALSE(parsed_ok); + EXPECT_FALSE(some_switch); EXPECT_EQ(argc, 1); } @@ -121,7 +120,7 @@ TEST(CommandLineFlagsTest, BadFloatValue) { Flags::Parse(&argc, reinterpret_cast(argv_strings), {Flag::CreateFlag("some_float", &some_float, "some float")}); - EXPECT_EQ(false, parsed_ok); + EXPECT_FALSE(parsed_ok); EXPECT_NEAR(-23.23f, some_float, 1e-5f); EXPECT_EQ(argc, 1); } @@ -134,7 +133,7 @@ TEST(CommandLineFlagsTest, RequiredFlagNotFound) { &argc, reinterpret_cast(argv_strings), {Flag::CreateFlag("some_flag", &some_float, "", Flag::REQUIRED)}); - EXPECT_EQ(false, parsed_ok); + EXPECT_FALSE(parsed_ok); EXPECT_NEAR(-23.23f, some_float, 1e-5f); EXPECT_EQ(argc, 2); } @@ -147,7 +146,7 @@ TEST(CommandLineFlagsTest, NoArguments) { &argc, reinterpret_cast(argv_strings), {Flag::CreateFlag("some_flag", &some_float, "", Flag::REQUIRED)}); - EXPECT_EQ(false, parsed_ok); + EXPECT_FALSE(parsed_ok); EXPECT_NEAR(-23.23f, some_float, 1e-5f); EXPECT_EQ(argc, 1); } @@ -160,7 +159,7 @@ TEST(CommandLineFlagsTest, NotEnoughArguments) { &argc, reinterpret_cast(argv_strings), {Flag::CreateFlag("some_flag", &some_float, "", Flag::POSITIONAL)}); - EXPECT_EQ(false, parsed_ok); + EXPECT_FALSE(parsed_ok); EXPECT_NEAR(-23.23f, some_float, 1e-5f); EXPECT_EQ(argc, 1); } @@ -173,7 +172,7 @@ TEST(CommandLineFlagsTest, PositionalFlagFailed) { &argc, reinterpret_cast(argv_strings), {Flag::CreateFlag("some_flag", &some_float, "", Flag::POSITIONAL)}); - EXPECT_EQ(false, parsed_ok); + EXPECT_FALSE(parsed_ok); EXPECT_NEAR(-23.23f, some_float, 1e-5f); EXPECT_EQ(argc, 2); } @@ -235,11 +234,125 @@ TEST(CommandLineFlagsTest, UsageString) { << usage; } +// When there are duplicate args, the flag value and the parsing result will be +// based on the 1st arg. +TEST(CommandLineFlagsTest, DuplicateArgsParsableValues) { + int some_int = -23; + int argc = 3; + const char* argv_strings[] = {"program_name", "--some_int=1", "--some_int=2"}; + bool parsed_ok = + Flags::Parse(&argc, reinterpret_cast(argv_strings), + {Flag::CreateFlag("some_int", &some_int, "some int")}); + + EXPECT_TRUE(parsed_ok); + EXPECT_EQ(1, some_int); + EXPECT_EQ(argc, 2); + EXPECT_EQ("--some_int=2", argv_strings[1]); +} + +TEST(CommandLineFlagsTest, DuplicateArgsBadValueAppearFirst) { + int some_int = -23; + int argc = 3; + const char* argv_strings[] = {"program_name", "--some_int=value", + "--some_int=1"}; + bool parsed_ok = + Flags::Parse(&argc, reinterpret_cast(argv_strings), + {Flag::CreateFlag("some_int", &some_int, "some int")}); + + EXPECT_FALSE(parsed_ok); + EXPECT_EQ(-23, some_int); + EXPECT_EQ(argc, 2); + EXPECT_EQ("--some_int=1", argv_strings[1]); +} + +TEST(CommandLineFlagsTest, DuplicateArgsBadValueAppearSecondly) { + int some_int = -23; + int argc = 3; + // Although the 2nd arg has non-parsable int value, the flag 'some_int' value + // could still be set and the parsing result is ok. + const char* argv_strings[] = {"program_name", "--some_int=1", + "--some_int=value"}; + bool parsed_ok = + Flags::Parse(&argc, reinterpret_cast(argv_strings), + {Flag::CreateFlag("some_int", &some_int, "some int")}); + + EXPECT_TRUE(parsed_ok); + EXPECT_EQ(1, some_int); + EXPECT_EQ(argc, 2); + EXPECT_EQ("--some_int=value", argv_strings[1]); +} + +// When there are duplicate flags, all of them will be checked against the arg +// list. +TEST(CommandLineFlagsTest, DuplicateFlags) { + int some_int1 = -23; + int some_int2 = -23; + int argc = 2; + const char* argv_strings[] = {"program_name", "--some_int=1"}; + bool parsed_ok = + Flags::Parse(&argc, reinterpret_cast(argv_strings), + {Flag::CreateFlag("some_int", &some_int1, "some int1"), + Flag::CreateFlag("some_int", &some_int2, "some int2")}); + + EXPECT_TRUE(parsed_ok); + EXPECT_EQ(1, some_int1); + EXPECT_EQ(1, some_int2); + EXPECT_EQ(argc, 1); +} + +TEST(CommandLineFlagsTest, DuplicateFlagsNotFound) { + int some_int1 = -23; + int some_int2 = -23; + int argc = 2; + const char* argv_strings[] = {"program_name", "--some_float=1.0"}; + bool parsed_ok = Flags::Parse( + &argc, reinterpret_cast(argv_strings), + {Flag::CreateFlag("some_int", &some_int1, "some int1", Flag::OPTIONAL), + Flag::CreateFlag("some_int", &some_int2, "some int2", Flag::REQUIRED)}); + + EXPECT_FALSE(parsed_ok); + EXPECT_EQ(-23, some_int1); + EXPECT_EQ(-23, some_int2); + EXPECT_EQ(argc, 2); +} + +TEST(CommandLineFlagsTest, DuplicateFlagNamesButDifferentTypes) { + int some_int = -23; + bool some_bool = true; + int argc = 2; + const char* argv_strings[] = {"program_name", "--some_val=20"}; + // In this case, the 2nd 'some_val' flag of bool type will cause a no-ok + // parsing result. + bool parsed_ok = + Flags::Parse(&argc, reinterpret_cast(argv_strings), + {Flag::CreateFlag("some_val", &some_int, "some val-int"), + Flag::CreateFlag("some_val", &some_bool, "some val-bool")}); + + EXPECT_FALSE(parsed_ok); + EXPECT_EQ(20, some_int); + EXPECT_TRUE(some_bool); + EXPECT_EQ(argc, 1); +} + +TEST(CommandLineFlagsTest, DuplicateFlagsAndArgs) { + int some_int1 = -23; + int some_int2 = -23; + int argc = 3; + const char* argv_strings[] = {"program_name", "--some_int=1 --some_int=2"}; + bool parsed_ok = Flags::Parse( + &argc, reinterpret_cast(argv_strings), + {Flag::CreateFlag("some_int", &some_int1, "flag1: bind with some_int1"), + Flag::CreateFlag("some_int", &some_int2, "flag2: bind with some_int2")}); + + // Note, when there're duplicate args, the flag value and the parsing result + // will be based on the 1st arg (i.e. --some_int=1). And both duplicate flags + // (i.e. flag1 and flag2) are checked, thus resulting their associated values + // (some_int1 and some_int2) being set to 1. + EXPECT_TRUE(parsed_ok); + EXPECT_EQ(1, some_int1); + EXPECT_EQ(1, some_int2); + EXPECT_EQ(argc, 2); +} + } // namespace } // namespace tflite - -int main(int argc, char** argv) { - ::tflite::LogToStderr(); - ::testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); -} diff --git a/tensorflow/lite/tools/evaluation/BUILD b/tensorflow/lite/tools/evaluation/BUILD index caa4a637766..bf21e553b1f 100644 --- a/tensorflow/lite/tools/evaluation/BUILD +++ b/tensorflow/lite/tools/evaluation/BUILD @@ -70,6 +70,10 @@ cc_library( copts = tflite_copts(), deps = [ ":utils", + "//tensorflow/lite/tools:command_line_flags", + "//tensorflow/lite/tools:tool_params", + "//tensorflow/lite/tools/benchmark:delegate_provider_hdr", + "//tensorflow/lite/tools/benchmark:tflite_execution_providers", "//tensorflow/lite/tools/evaluation/proto:evaluation_stages_cc_proto", ], ) diff --git a/tensorflow/lite/tools/evaluation/evaluation_delegate_provider.cc b/tensorflow/lite/tools/evaluation/evaluation_delegate_provider.cc index 925cae8d140..91a0aea4711 100644 --- a/tensorflow/lite/tools/evaluation/evaluation_delegate_provider.cc +++ b/tensorflow/lite/tools/evaluation/evaluation_delegate_provider.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/lite/tools/evaluation/evaluation_delegate_provider.h" +#include "tensorflow/lite/tools/command_line_flags.h" + namespace tflite { namespace evaluation { namespace { @@ -36,7 +38,6 @@ TfliteInferenceParams::Delegate ParseStringToDelegateType( TfLiteDelegatePtr CreateTfLiteDelegate(const TfliteInferenceParams& params, std::string* error_msg) { const auto type = params.delegate(); - switch (type) { case TfliteInferenceParams::NNAPI: { auto p = CreateNNAPIDelegate(); @@ -76,5 +77,86 @@ TfLiteDelegatePtr CreateTfLiteDelegate(const TfliteInferenceParams& params, } } +DelegateProviders::DelegateProviders() + : delegates_list_(benchmark::GetRegisteredDelegateProviders()), + delegates_map_([=]() -> std::unordered_map { + std::unordered_map delegates_map; + for (int i = 0; i < delegates_list_.size(); ++i) { + delegates_map[delegates_list_[i]->GetName()] = i; + } + return delegates_map; + }()) { + for (const auto& one : delegates_list_) { + params_.Merge(one->DefaultParams()); + } +} + +bool DelegateProviders::InitFromCmdlineArgs(int* argc, const char** argv) { + std::vector flags; + for (const auto& one : delegates_list_) { + auto one_flags = one->CreateFlags(¶ms_); + flags.insert(flags.end(), one_flags.begin(), one_flags.end()); + } + return Flags::Parse(argc, argv, flags); +} + +TfLiteDelegatePtr DelegateProviders::CreateDelegate( + const std::string& name) const { + const auto it = delegates_map_.find(name); + if (it == delegates_map_.end()) { + return TfLiteDelegatePtr(nullptr, [](TfLiteDelegate*) {}); + } + return delegates_list_[it->second]->CreateTfLiteDelegate(params_); +} + +std::vector DelegateProviders::CreateAllDelegates( + const tools::ToolParams& params) const { + std::vector delegates; + for (const auto& one : delegates_list_) { + auto ptr = one->CreateTfLiteDelegate(params); + // It's possible that a delegate of certain type won't be created as + // user-specified benchmark params tells not to. + if (ptr == nullptr) continue; + delegates.emplace_back(std::move(ptr)); + } + return delegates; +} + +std::vector DelegateProviders::CreateAllDelegates( + const TfliteInferenceParams& params) const { + tools::ToolParams merged_params; + merged_params.Merge(params_); + + const auto type = params.delegate(); + switch (type) { + case TfliteInferenceParams::NNAPI: + if (merged_params.HasParam("use_nnapi")) { + merged_params.Set("use_nnapi", true); + } + break; + case TfliteInferenceParams::GPU: + if (merged_params.HasParam("use_gpu")) { + merged_params.Set("use_gpu", true); + } + break; + case TfliteInferenceParams::HEXAGON: + if (merged_params.HasParam("use_hexagon")) { + merged_params.Set("use_hexagon", true); + } + break; + case TfliteInferenceParams::XNNPACK: + if (merged_params.HasParam("use_xnnpack")) { + merged_params.Set("use_xnnpack", true); + } + if (params.has_num_threads()) { + merged_params.Set("num_threads", params.num_threads()); + } + break; + default: + break; + } + return CreateAllDelegates(merged_params); +} + } // namespace evaluation } // namespace tflite diff --git a/tensorflow/lite/tools/evaluation/evaluation_delegate_provider.h b/tensorflow/lite/tools/evaluation/evaluation_delegate_provider.h index 7f093295be2..5c5c4bb1021 100644 --- a/tensorflow/lite/tools/evaluation/evaluation_delegate_provider.h +++ b/tensorflow/lite/tools/evaluation/evaluation_delegate_provider.h @@ -16,12 +16,60 @@ limitations under the License. #ifndef TENSORFLOW_LITE_TOOLS_EVALUATION_EVALUATION_DELEGATE_PROVIDER_H_ #define TENSORFLOW_LITE_TOOLS_EVALUATION_EVALUATION_DELEGATE_PROVIDER_H_ +#include +#include +#include + +#include "tensorflow/lite/tools/benchmark/delegate_provider.h" #include "tensorflow/lite/tools/evaluation/proto/evaluation_stages.pb.h" #include "tensorflow/lite/tools/evaluation/utils.h" +#include "tensorflow/lite/tools/tool_params.h" namespace tflite { namespace evaluation { +class DelegateProviders { + public: + DelegateProviders(); + + // Initialize delegate-related parameters from commandline arguments and + // returns true if sucessful. + bool InitFromCmdlineArgs(int* argc, const char** argv); + + // Get all parameters from all registered delegate providers. + const tools::ToolParams& GetAllParams() const { return params_; } + + // Create the a TfLite delegate instance based on the provided delegate + // 'name'. If the specified one isn't found, an empty TfLiteDelegatePtr is + // returned. + TfLiteDelegatePtr CreateDelegate(const std::string& name) const; + + // Create a list of TfLite delegates based on what have been initialized (i.e. + // 'params_'). + std::vector CreateAllDelegates() const { + return CreateAllDelegates(params_); + } + + // Create a list of TfLite delegates based on the given TfliteInferenceParams + // 'params' but considering what have been initialized (i.e. 'params_'). + std::vector CreateAllDelegates( + const TfliteInferenceParams& params) const; + + private: + // Create a list of TfLite delegates based on the provided 'params'. + std::vector CreateAllDelegates( + const tools::ToolParams& params) const; + + // Contain delegate-related parameters that are initialized from command-line + // flags. + tools::ToolParams params_; + + const benchmark::DelegateProviderList& delegates_list_; + // Key is the delegate name, and the value is the index to the + // 'delegates_list_'. + const std::unordered_map delegates_map_; +}; + // Parse a string 'val' to the corresponding delegate type defined by // TfliteInferenceParams::Delegate. TfliteInferenceParams::Delegate ParseStringToDelegateType( diff --git a/tensorflow/lite/tools/evaluation/evaluation_delegate_provider_test.cc b/tensorflow/lite/tools/evaluation/evaluation_delegate_provider_test.cc index 1b984206eb6..1d7870eaed0 100644 --- a/tensorflow/lite/tools/evaluation/evaluation_delegate_provider_test.cc +++ b/tensorflow/lite/tools/evaluation/evaluation_delegate_provider_test.cc @@ -39,6 +39,21 @@ TEST(EvaluationDelegateProviderTest, CreateTfLiteDelegate) { EXPECT_TRUE(!CreateTfLiteDelegate(params)); } +TEST(EvaluationDelegateProviderTest, DelegateProvidersParams) { + DelegateProviders providers; + const auto& params = providers.GetAllParams(); + EXPECT_TRUE(params.HasParam("use_nnapi")); + EXPECT_TRUE(params.HasParam("use_gpu")); + + int argc = 3; + const char* argv[] = {"program_name", "--use_gpu=true", + "--other_undefined_flag=1"}; + EXPECT_TRUE(providers.InitFromCmdlineArgs(&argc, argv)); + EXPECT_TRUE(params.Get("use_gpu")); + EXPECT_EQ(2, argc); + EXPECT_EQ("--other_undefined_flag=1", argv[1]); +} + } // namespace } // namespace evaluation } // namespace tflite diff --git a/tensorflow/lite/tools/evaluation/stages/BUILD b/tensorflow/lite/tools/evaluation/stages/BUILD index ea3341f4e75..1650151bfa7 100644 --- a/tensorflow/lite/tools/evaluation/stages/BUILD +++ b/tensorflow/lite/tools/evaluation/stages/BUILD @@ -147,6 +147,7 @@ cc_library( ":tflite_inference_stage", ":topk_accuracy_eval_stage", "//tensorflow/core:tflite_portable_logging", + "//tensorflow/lite/tools/evaluation:evaluation_delegate_provider", "//tensorflow/lite/tools/evaluation:evaluation_stage", "//tensorflow/lite/tools/evaluation:utils", "//tensorflow/lite/tools/evaluation/proto:evaluation_config_cc_proto", @@ -163,6 +164,7 @@ cc_library( ":tflite_inference_stage", "//tensorflow/core:tflite_portable_logging", "//tensorflow/core/util:stats_calculator_portable", + "//tensorflow/lite/tools/evaluation:evaluation_delegate_provider", "//tensorflow/lite/tools/evaluation:evaluation_stage", "//tensorflow/lite/tools/evaluation/proto:evaluation_config_cc_proto", "//tensorflow/lite/tools/evaluation/proto:evaluation_stages_cc_proto", @@ -224,6 +226,7 @@ cc_library( ":tflite_inference_stage", "//tensorflow/core:tflite_portable_logging", "//tensorflow/lite/c:common", + "//tensorflow/lite/tools/evaluation:evaluation_delegate_provider", "//tensorflow/lite/tools/evaluation:evaluation_stage", "//tensorflow/lite/tools/evaluation:utils", "//tensorflow/lite/tools/evaluation/proto:evaluation_config_cc_proto", diff --git a/tensorflow/lite/tools/evaluation/stages/image_classification_stage.cc b/tensorflow/lite/tools/evaluation/stages/image_classification_stage.cc index c9f8f832441..212e148cbc7 100644 --- a/tensorflow/lite/tools/evaluation/stages/image_classification_stage.cc +++ b/tensorflow/lite/tools/evaluation/stages/image_classification_stage.cc @@ -29,7 +29,8 @@ namespace { const float kCroppingFraction = 0.875; } // namespace -TfLiteStatus ImageClassificationStage::Init() { +TfLiteStatus ImageClassificationStage::Init( + const DelegateProviders* delegate_providers) { // Ensure inference params are provided. if (!config_.specification().has_image_classification_params()) { LOG(ERROR) << "ImageClassificationParams not provided"; @@ -47,7 +48,8 @@ TfLiteStatus ImageClassificationStage::Init() { *tflite_inference_config.mutable_specification() ->mutable_tflite_inference_params() = params.inference_params(); inference_stage_.reset(new TfliteInferenceStage(tflite_inference_config)); - if (inference_stage_->Init() != kTfLiteOk) return kTfLiteError; + if (inference_stage_->Init(delegate_providers) != kTfLiteOk) + return kTfLiteError; // Validate model inputs. const TfLiteModelInfo* model_info = inference_stage_->GetModelInfo(); diff --git a/tensorflow/lite/tools/evaluation/stages/image_classification_stage.h b/tensorflow/lite/tools/evaluation/stages/image_classification_stage.h index a74a5979f35..c3f8eb8f900 100644 --- a/tensorflow/lite/tools/evaluation/stages/image_classification_stage.h +++ b/tensorflow/lite/tools/evaluation/stages/image_classification_stage.h @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "tensorflow/lite/tools/evaluation/evaluation_delegate_provider.h" #include "tensorflow/lite/tools/evaluation/evaluation_stage.h" #include "tensorflow/lite/tools/evaluation/proto/evaluation_config.pb.h" #include "tensorflow/lite/tools/evaluation/stages/image_preprocessing_stage.h" @@ -36,7 +37,8 @@ class ImageClassificationStage : public EvaluationStage { explicit ImageClassificationStage(const EvaluationStageConfig& config) : EvaluationStage(config) {} - TfLiteStatus Init() override; + TfLiteStatus Init() override { return Init(nullptr); } + TfLiteStatus Init(const DelegateProviders* delegate_providers); TfLiteStatus Run() override; diff --git a/tensorflow/lite/tools/evaluation/stages/inference_profiler_stage.cc b/tensorflow/lite/tools/evaluation/stages/inference_profiler_stage.cc index cfafc1e9214..8a3759a17c2 100644 --- a/tensorflow/lite/tools/evaluation/stages/inference_profiler_stage.cc +++ b/tensorflow/lite/tools/evaluation/stages/inference_profiler_stage.cc @@ -68,11 +68,12 @@ float CalculateAverageError(T* reference, T* test, int64_t num_elements) { } // namespace -TfLiteStatus InferenceProfilerStage::Init() { +TfLiteStatus InferenceProfilerStage::Init( + const DelegateProviders* delegate_providers) { // Initialize TfliteInferenceStage with the user-provided // TfliteInferenceParams. test_stage_.reset(new TfliteInferenceStage(config_)); - if (test_stage_->Init() != kTfLiteOk) return kTfLiteError; + if (test_stage_->Init(delegate_providers) != kTfLiteOk) return kTfLiteError; LOG(INFO) << "Test interpreter has been initialized."; // Initialize a reference TfliteInferenceStage that uses the given model & diff --git a/tensorflow/lite/tools/evaluation/stages/inference_profiler_stage.h b/tensorflow/lite/tools/evaluation/stages/inference_profiler_stage.h index e5fd37943e9..d10c7beb088 100644 --- a/tensorflow/lite/tools/evaluation/stages/inference_profiler_stage.h +++ b/tensorflow/lite/tools/evaluation/stages/inference_profiler_stage.h @@ -22,6 +22,7 @@ limitations under the License. #include #include "tensorflow/core/util/stats_calculator.h" +#include "tensorflow/lite/tools/evaluation/evaluation_delegate_provider.h" #include "tensorflow/lite/tools/evaluation/evaluation_stage.h" #include "tensorflow/lite/tools/evaluation/proto/evaluation_config.pb.h" #include "tensorflow/lite/tools/evaluation/stages/tflite_inference_stage.h" @@ -39,7 +40,8 @@ class InferenceProfilerStage : public EvaluationStage { explicit InferenceProfilerStage(const EvaluationStageConfig& config) : EvaluationStage(config) {} - TfLiteStatus Init() override; + TfLiteStatus Init() override { return Init(nullptr); } + TfLiteStatus Init(const DelegateProviders* delegate_providers); // New Gaussian random data is used as input for each Run. TfLiteStatus Run() override; diff --git a/tensorflow/lite/tools/evaluation/stages/object_detection_stage.cc b/tensorflow/lite/tools/evaluation/stages/object_detection_stage.cc index f7821d81894..1ed8db2076c 100644 --- a/tensorflow/lite/tools/evaluation/stages/object_detection_stage.cc +++ b/tensorflow/lite/tools/evaluation/stages/object_detection_stage.cc @@ -26,7 +26,8 @@ limitations under the License. namespace tflite { namespace evaluation { -TfLiteStatus ObjectDetectionStage::Init() { +TfLiteStatus ObjectDetectionStage::Init( + const DelegateProviders* delegate_providers) { // Ensure inference params are provided. if (!config_.specification().has_object_detection_params()) { LOG(ERROR) << "ObjectDetectionParams not provided"; @@ -48,7 +49,7 @@ TfLiteStatus ObjectDetectionStage::Init() { *tflite_inference_config.mutable_specification() ->mutable_tflite_inference_params() = params.inference_params(); inference_stage_.reset(new TfliteInferenceStage(tflite_inference_config)); - TF_LITE_ENSURE_STATUS(inference_stage_->Init()); + TF_LITE_ENSURE_STATUS(inference_stage_->Init(delegate_providers)); // Validate model inputs. const TfLiteModelInfo* model_info = inference_stage_->GetModelInfo(); diff --git a/tensorflow/lite/tools/evaluation/stages/object_detection_stage.h b/tensorflow/lite/tools/evaluation/stages/object_detection_stage.h index cc0c935bba9..1489d853c34 100644 --- a/tensorflow/lite/tools/evaluation/stages/object_detection_stage.h +++ b/tensorflow/lite/tools/evaluation/stages/object_detection_stage.h @@ -20,6 +20,7 @@ limitations under the License. #include #include "absl/container/flat_hash_map.h" +#include "tensorflow/lite/tools/evaluation/evaluation_delegate_provider.h" #include "tensorflow/lite/tools/evaluation/evaluation_stage.h" #include "tensorflow/lite/tools/evaluation/proto/evaluation_config.pb.h" #include "tensorflow/lite/tools/evaluation/proto/evaluation_stages.pb.h" @@ -43,7 +44,8 @@ class ObjectDetectionStage : public EvaluationStage { explicit ObjectDetectionStage(const EvaluationStageConfig& config) : EvaluationStage(config) {} - TfLiteStatus Init() override; + TfLiteStatus Init() override { return Init(nullptr); } + TfLiteStatus Init(const DelegateProviders* delegate_providers); TfLiteStatus Run() override; diff --git a/tensorflow/lite/tools/evaluation/stages/tflite_inference_stage.cc b/tensorflow/lite/tools/evaluation/stages/tflite_inference_stage.cc index cbf41de0e03..bb01ca2cc4d 100644 --- a/tensorflow/lite/tools/evaluation/stages/tflite_inference_stage.cc +++ b/tensorflow/lite/tools/evaluation/stages/tflite_inference_stage.cc @@ -20,7 +20,6 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/profiling/time.h" -#include "tensorflow/lite/tools/evaluation/evaluation_delegate_provider.h" #include "tensorflow/lite/tools/evaluation/proto/evaluation_stages.pb.h" #include "tensorflow/lite/tools/evaluation/utils.h" @@ -71,7 +70,8 @@ TfLiteStatus TfliteInferenceStage::ApplyCustomDelegate( return kTfLiteOk; } -TfLiteStatus TfliteInferenceStage::Init() { +TfLiteStatus TfliteInferenceStage::Init( + const DelegateProviders* delegate_providers) { if (!config_.specification().has_tflite_inference_params()) { LOG(ERROR) << "TfliteInferenceParams not provided"; return kTfLiteError; @@ -96,14 +96,19 @@ TfLiteStatus TfliteInferenceStage::Init() { } interpreter_->SetNumThreads(params.num_threads()); - std::string error_message; - auto delegate = CreateTfLiteDelegate(params, &error_message); - if (delegate) { - delegates_.push_back(std::move(delegate)); - LOG(INFO) << "Successfully created " - << params.Delegate_Name(params.delegate()) << " delegate."; + if (!delegate_providers) { + std::string error_message; + auto delegate = CreateTfLiteDelegate(params, &error_message); + if (delegate) { + delegates_.push_back(std::move(delegate)); + LOG(INFO) << "Successfully created " + << params.Delegate_Name(params.delegate()) << " delegate."; + } else { + LOG(WARNING) << error_message; + } } else { - LOG(WARNING) << error_message; + auto delegates = delegate_providers->CreateAllDelegates(params); + for (auto& one : delegates) delegates_.push_back(std::move(one)); } for (int i = 0; i < delegates_.size(); ++i) { diff --git a/tensorflow/lite/tools/evaluation/stages/tflite_inference_stage.h b/tensorflow/lite/tools/evaluation/stages/tflite_inference_stage.h index c27462fdcd6..a8a319fcd16 100644 --- a/tensorflow/lite/tools/evaluation/stages/tflite_inference_stage.h +++ b/tensorflow/lite/tools/evaluation/stages/tflite_inference_stage.h @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/lite/interpreter.h" #include "tensorflow/lite/kernels/register.h" #include "tensorflow/lite/model.h" +#include "tensorflow/lite/tools/evaluation/evaluation_delegate_provider.h" #include "tensorflow/lite/tools/evaluation/evaluation_stage.h" #include "tensorflow/lite/tools/evaluation/proto/evaluation_config.pb.h" @@ -41,14 +42,15 @@ class TfliteInferenceStage : public EvaluationStage { explicit TfliteInferenceStage(const EvaluationStageConfig& config) : EvaluationStage(config) {} - TfLiteStatus Init() override; + TfLiteStatus Init() override { return Init(nullptr); } + TfLiteStatus Init(const DelegateProviders* delegate_providers); TfLiteStatus Run() override; // EvaluationStageMetrics.num_runs denotes the number of inferences run. EvaluationStageMetrics LatestMetrics() override; - ~TfliteInferenceStage() {} + ~TfliteInferenceStage() override {} // Call before Run(). // This class does not take ownership of raw_input_ptrs. diff --git a/tensorflow/lite/tools/evaluation/tasks/coco_object_detection/run_eval.cc b/tensorflow/lite/tools/evaluation/tasks/coco_object_detection/run_eval.cc index 39b5082accb..3b5fc08ab84 100644 --- a/tensorflow/lite/tools/evaluation/tasks/coco_object_detection/run_eval.cc +++ b/tensorflow/lite/tools/evaluation/tasks/coco_object_detection/run_eval.cc @@ -50,7 +50,8 @@ bool EvaluateModel(const std::string& model_file_path, const std::vector& image_paths, const std::string& ground_truth_proto_file, std::string delegate, std::string output_file_path, - int num_interpreter_threads, bool debug_mode) { + int num_interpreter_threads, bool debug_mode, + const DelegateProviders& delegate_providers) { EvaluationStageConfig eval_config; eval_config.set_name("object_detection"); auto* detection_params = @@ -74,7 +75,7 @@ bool EvaluateModel(const std::string& model_file_path, ObjectDetectionStage eval(eval_config); eval.SetAllLabels(model_labels); - if (eval.Init() != kTfLiteOk) return false; + if (eval.Init(&delegate_providers) != kTfLiteOk) return false; // Open output file for writing. std::ofstream ofile; @@ -156,6 +157,8 @@ int Main(int argc, char* argv[]) { "Must be one of {'nnapi', 'gpu'}"), }; tflite::Flags::Parse(&argc, const_cast(argv), flag_list); + DelegateProviders delegate_providers; + delegate_providers.InitFromCmdlineArgs(&argc, const_cast(argv)); // Process images in filename-sorted order. std::vector image_paths; @@ -170,7 +173,7 @@ int Main(int argc, char* argv[]) { if (!EvaluateModel(model_file_path, model_labels, image_paths, ground_truth_proto_file, delegate, output_file_path, - num_interpreter_threads, debug_mode)) { + num_interpreter_threads, debug_mode, delegate_providers)) { LOG(ERROR) << "Could not evaluate model"; return EXIT_FAILURE; } diff --git a/tensorflow/lite/tools/evaluation/tasks/imagenet_image_classification/run_eval.cc b/tensorflow/lite/tools/evaluation/tasks/imagenet_image_classification/run_eval.cc index 5268039c500..cd6c6cfb3c4 100644 --- a/tensorflow/lite/tools/evaluation/tasks/imagenet_image_classification/run_eval.cc +++ b/tensorflow/lite/tools/evaluation/tasks/imagenet_image_classification/run_eval.cc @@ -50,7 +50,8 @@ bool EvaluateModel(const std::string& model_file_path, const std::vector& image_labels, const std::vector& model_labels, std::string delegate, std::string output_file_path, - int num_interpreter_threads) { + int num_interpreter_threads, + const DelegateProviders& delegate_providers) { EvaluationStageConfig eval_config; eval_config.set_name("image_classification"); auto* classification_params = eval_config.mutable_specification() @@ -69,7 +70,7 @@ bool EvaluateModel(const std::string& model_file_path, ImageClassificationStage eval(eval_config); eval.SetAllLabels(model_labels); - if (eval.Init() != kTfLiteOk) return false; + if (eval.Init(&delegate_providers) != kTfLiteOk) return false; const int step = image_labels.size() / 100; for (int i = 0; i < image_labels.size(); ++i) { @@ -135,6 +136,8 @@ int Main(int argc, char* argv[]) { "Must be one of {'nnapi', 'gpu'}"), }; tflite::Flags::Parse(&argc, const_cast(argv), flag_list); + DelegateProviders delegate_providers; + delegate_providers.InitFromCmdlineArgs(&argc, const_cast(argv)); // Process images in filename-sorted order. std::vector image_files, ground_truth_image_labels; @@ -168,7 +171,8 @@ int Main(int argc, char* argv[]) { } if (!EvaluateModel(model_file_path, image_labels, model_labels, delegate, - output_file_path, num_interpreter_threads)) { + output_file_path, num_interpreter_threads, + delegate_providers)) { LOG(ERROR) << "Could not evaluate model"; return EXIT_FAILURE; } diff --git a/tensorflow/lite/tools/evaluation/tasks/inference_diff/run_eval.cc b/tensorflow/lite/tools/evaluation/tasks/inference_diff/run_eval.cc index cdd83d52d6f..6a7c6e8fc42 100644 --- a/tensorflow/lite/tools/evaluation/tasks/inference_diff/run_eval.cc +++ b/tensorflow/lite/tools/evaluation/tasks/inference_diff/run_eval.cc @@ -36,7 +36,8 @@ constexpr char kDelegateFlag[] = "delegate"; bool EvaluateModel(const std::string& model_file_path, const std::string& delegate, int num_runs, const std::string& output_file_path, - int num_interpreter_threads) { + int num_interpreter_threads, + const DelegateProviders& delegate_providers) { // Initialize evaluation stage. EvaluationStageConfig eval_config; eval_config.set_name("inference_profiling"); @@ -54,7 +55,7 @@ bool EvaluateModel(const std::string& model_file_path, return false; } InferenceProfilerStage eval(eval_config); - if (eval.Init() != kTfLiteOk) return false; + if (eval.Init(&delegate_providers) != kTfLiteOk) return false; // Run inference & check diff for specified number of runs. for (int i = 0; i < num_runs; ++i) { @@ -94,8 +95,10 @@ int Main(int argc, char* argv[]) { }; tflite::Flags::Parse(&argc, const_cast(argv), flag_list); + DelegateProviders delegate_providers; + delegate_providers.InitFromCmdlineArgs(&argc, const_cast(argv)); if (!EvaluateModel(model_file_path, delegate, num_runs, output_file_path, - num_interpreter_threads)) { + num_interpreter_threads, delegate_providers)) { LOG(ERROR) << "Could not evaluate model!"; return EXIT_FAILURE; } diff --git a/tensorflow/lite/tools/logging.h b/tensorflow/lite/tools/logging.h new file mode 100644 index 00000000000..b832e387993 --- /dev/null +++ b/tensorflow/lite/tools/logging.h @@ -0,0 +1,87 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_TOOLS_LOGGING_H_ +#define TENSORFLOW_LITE_TOOLS_LOGGING_H_ + +// LOG and CHECK macros for tflite tooling. + +#include +#include +#include + +#ifdef _WIN32 +#undef ERROR +#endif + +namespace tflite { +namespace logging { +// A wrapper that logs to stderr. +// +// Used for TFLITE_LOG and TFLITE_BENCHMARK_CHECK macros. +class LoggingWrapper { + public: + enum class LogSeverity : int { + INFO = 0, + WARN = 1, + ERROR = 2, + FATAL = 3, + }; + LoggingWrapper(LogSeverity severity) + : severity_(severity), should_log_(true) {} + LoggingWrapper(LogSeverity severity, bool log) + : severity_(severity), should_log_(log) {} + std::stringstream& Stream() { return stream_; } + ~LoggingWrapper() { + if (should_log_) { + switch (severity_) { + case LogSeverity::INFO: + case LogSeverity::WARN: + std::cout << stream_.str() << std::endl; + break; + case LogSeverity::ERROR: + std::cerr << stream_.str() << std::endl; + break; + case LogSeverity::FATAL: + std::cerr << stream_.str() << std::endl; + std::flush(std::cerr); + std::abort(); + break; + } + } + } + + private: + std::stringstream stream_; + LogSeverity severity_; + bool should_log_; +}; +} // namespace logging +} // namespace tflite + +#define TFLITE_LOG(severity) \ + tflite::logging::LoggingWrapper( \ + tflite::logging::LoggingWrapper::LogSeverity::severity) \ + .Stream() + +#define TFLITE_TOOLS_CHECK(condition) \ + tflite::logging::LoggingWrapper( \ + tflite::logging::LoggingWrapper::LogSeverity::FATAL, \ + (condition) ? false : true) \ + .Stream() + +#define TFLITE_TOOLS_CHECK_EQ(a, b) TFLITE_TOOLS_CHECK((a) == (b)) + +#endif // TENSORFLOW_LITE_TOOLS_LOGGING_H_ diff --git a/tensorflow/lite/tools/make/Makefile b/tensorflow/lite/tools/make/Makefile index 7a77cf2b3f5..426ed63b482 100644 --- a/tensorflow/lite/tools/make/Makefile +++ b/tensorflow/lite/tools/make/Makefile @@ -111,7 +111,8 @@ PROFILE_SUMMARIZER_SRCS := \ tensorflow/core/util/stats_calculator.cc CMD_LINE_TOOLS_SRCS := \ - tensorflow/lite/tools/command_line_flags.cc + tensorflow/lite/tools/command_line_flags.cc \ + tensorflow/lite/tools/tool_params.cc CORE_CC_ALL_SRCS := \ $(wildcard tensorflow/lite/*.cc) \ @@ -211,6 +212,7 @@ TF_LITE_CC_SRCS := $(filter-out $(CORE_CC_EXCLUDE_SRCS), $(CORE_CC_ALL_SRCS)) # Benchmark sources BENCHMARK_SRCS_DIR := tensorflow/lite/tools/benchmark +DELEGATE_PROVIDER_SRCS_DIR := tensorflow/lite/tools/benchmark EVALUATION_UTILS_SRCS := \ tensorflow/lite/tools/evaluation/utils.cc BENCHMARK_ALL_SRCS := \ @@ -227,11 +229,12 @@ BENCHMARK_LIB_SRCS := $(filter-out \ $(BENCHMARK_MAIN_SRC) \ $(BENCHMARK_PERF_OPTIONS_SRC) \ $(BENCHMARK_SRCS_DIR)/benchmark_plus_flex_main.cc \ - $(BENCHMARK_SRCS_DIR)/external_delegate_provider.cc \ - $(BENCHMARK_SRCS_DIR)/gpu_delegate_provider.cc \ - $(BENCHMARK_SRCS_DIR)/hexagon_delegate_provider.cc \ - $(BENCHMARK_SRCS_DIR)/nnapi_delegate_provider.cc \ - $(BENCHMARK_SRCS_DIR)/xnnpack_delegate_provider.cc, \ + $(DELEGATE_PROVIDER_SRCS_DIR)/default_execution_provider.cc \ + $(DELEGATE_PROVIDER_SRCS_DIR)/external_delegate_provider.cc \ + $(DELEGATE_PROVIDER_SRCS_DIR)/gpu_delegate_provider.cc \ + $(DELEGATE_PROVIDER_SRCS_DIR)/hexagon_delegate_provider.cc \ + $(DELEGATE_PROVIDER_SRCS_DIR)/nnapi_delegate_provider.cc \ + $(DELEGATE_PROVIDER_SRCS_DIR)/xnnpack_delegate_provider.cc, \ $(BENCHMARK_ALL_SRCS)) # These target-specific makefiles should modify or replace options like diff --git a/tensorflow/lite/tools/optimize/quantize_model.cc b/tensorflow/lite/tools/optimize/quantize_model.cc index 5ebc513f5bc..6c6ba10d60c 100644 --- a/tensorflow/lite/tools/optimize/quantize_model.cc +++ b/tensorflow/lite/tools/optimize/quantize_model.cc @@ -445,7 +445,7 @@ TfLiteStatus QuantizeOpInput( } if (utils::QuantizeWeight(model, tensor, tensor_property.per_axis, tensor_property.per_axis_index, - error_reporter) == kTfLiteError) { + error_reporter) != kTfLiteOk) { TF_LITE_REPORT_ERROR( error_reporter, "Unable to quantize buffer or min/max value for input %d " @@ -1001,9 +1001,20 @@ TfLiteStatus FillQuantizationParams( // Dynamic tensor. } else if (!utils::HasMinMax(tensor) && !utils::HasBuffer(model, subgraph, tensor_idx)) { - TF_LITE_REPORT_ERROR(error_reporter, - "Max and min for dynamic tensors should be" - " recorded during calibration"); + TF_LITE_REPORT_ERROR( + error_reporter, + "Max and min for dynamic tensors should be" + " recorded during calibration: Failed for tensor %s\n", + tensor->name.c_str()); + if (tensor->quantization == nullptr) { + TF_LITE_REPORT_ERROR(error_reporter, + "No quantization params for tensor %s", + tensor->name.c_str()); + } else if (tensor->quantization->min.empty() || + tensor->quantization->max.empty()) { + TF_LITE_REPORT_ERROR(error_reporter, "Empty min/max for tensor %s", + tensor->name.c_str()); + } return kTfLiteError; } diff --git a/tensorflow/lite/tools/optimize/sparsity/BUILD b/tensorflow/lite/tools/optimize/sparsity/BUILD index b68094849c1..4ea901f77f9 100644 --- a/tensorflow/lite/tools/optimize/sparsity/BUILD +++ b/tensorflow/lite/tools/optimize/sparsity/BUILD @@ -1,4 +1,5 @@ load("//tensorflow/lite:build_def.bzl", "tflite_copts") +load("//tensorflow/lite/micro:build_def.bzl", "cc_library") package( default_visibility = [ @@ -11,6 +12,7 @@ cc_library( name = "format_converter", srcs = ["format_converter.cc"], hdrs = ["format_converter.h"], + build_for_embedded = True, copts = tflite_copts(), deps = [ "//tensorflow/lite/c:common", diff --git a/tensorflow/lite/tools/optimize/sparsity/python/format_converter_extension.cc b/tensorflow/lite/tools/optimize/sparsity/python/format_converter_extension.cc index 59cd7b46fa0..5a0091e9e89 100644 --- a/tensorflow/lite/tools/optimize/sparsity/python/format_converter_extension.cc +++ b/tensorflow/lite/tools/optimize/sparsity/python/format_converter_extension.cc @@ -15,8 +15,8 @@ limitations under the License. #include -#include "include/pybind11/pybind11.h" -#include "include/pybind11/stl.h" +#include "pybind11/pybind11.h" +#include "pybind11/stl.h" #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/tools/optimize/sparsity/format_converter.h" diff --git a/tensorflow/lite/tools/randomize_weights.py b/tensorflow/lite/tools/randomize_weights.py index 84bbe3955a7..b68bdbb180b 100644 --- a/tensorflow/lite/tools/randomize_weights.py +++ b/tensorflow/lite/tools/randomize_weights.py @@ -12,10 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Randomize all weights in a tflite file. +r"""Randomize all weights in a tflite file. Example usage: -python randomize_weights.py foo.tflite foo_randomized.tflite +python randomize_weights.py \ + --input_tflite_file=foo.tflite \ + --output_tflite_file=foo_randomized.tflite """ from __future__ import absolute_import diff --git a/tensorflow/lite/tools/strip_strings.py b/tensorflow/lite/tools/strip_strings.py index cc88562caf1..e24d2b737c5 100644 --- a/tensorflow/lite/tools/strip_strings.py +++ b/tensorflow/lite/tools/strip_strings.py @@ -12,10 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Strips all nonessential strings from a tflite file. +r"""Strips all nonessential strings from a tflite file. Example usage: -python strip_strings.py foo.tflite foo_stripped.tflite +python strip_strings.py \ + --input_tflite_file=foo.tflite \ + --output_tflite_file=foo_stripped.tflite """ from __future__ import absolute_import diff --git a/tensorflow/lite/tools/tool_params.cc b/tensorflow/lite/tools/tool_params.cc new file mode 100644 index 00000000000..678b81b7784 --- /dev/null +++ b/tensorflow/lite/tools/tool_params.cc @@ -0,0 +1,76 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/tools/tool_params.h" + +#include +#include +#include + +#include "tensorflow/lite/tools/logging.h" + +namespace tflite { +namespace tools { + +void ToolParam::AssertHasSameType(ToolParam::ParamType a, + ToolParam::ParamType b) { + TFLITE_TOOLS_CHECK(a == b) << "Type mismatch while accessing parameter."; +} + +template <> +ToolParam::ParamType ToolParam::GetValueType() { + return ToolParam::ParamType::TYPE_INT32; +} + +template <> +ToolParam::ParamType ToolParam::GetValueType() { + return ToolParam::ParamType::TYPE_BOOL; +} + +template <> +ToolParam::ParamType ToolParam::GetValueType() { + return ToolParam::ParamType::TYPE_FLOAT; +} + +template <> +ToolParam::ParamType ToolParam::GetValueType() { + return ToolParam::ParamType::TYPE_STRING; +} + +void ToolParams::AssertParamExists(const std::string& name) const { + TFLITE_TOOLS_CHECK(HasParam(name)) << name << " was not found."; +} + +void ToolParams::Set(const ToolParams& other) { + for (const auto& param : params_) { + const ToolParam* other_param = other.GetParam(param.first); + if (other_param == nullptr) continue; + param.second->Set(*other_param); + } +} + +void ToolParams::Merge(const ToolParams& other, bool overwrite) { + for (const auto& one : other.params_) { + auto it = params_.find(one.first); + if (it == params_.end()) { + AddParam(one.first, one.second->Clone()); + } else if (overwrite) { + it->second->Set(*one.second); + } + } +} + +} // namespace tools +} // namespace tflite diff --git a/tensorflow/lite/tools/tool_params.h b/tensorflow/lite/tools/tool_params.h new file mode 100644 index 00000000000..30961473c8c --- /dev/null +++ b/tensorflow/lite/tools/tool_params.h @@ -0,0 +1,134 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_TOOLS_TOOL_PARAMS_H_ +#define TENSORFLOW_LITE_TOOLS_TOOL_PARAMS_H_ +#include +#include +#include +#include +#include + +namespace tflite { +namespace tools { + +template +class TypedToolParam; + +class ToolParam { + protected: + enum class ParamType { TYPE_INT32, TYPE_FLOAT, TYPE_BOOL, TYPE_STRING }; + template + static ParamType GetValueType(); + + public: + template + static std::unique_ptr Create(const T& default_value) { + return std::unique_ptr(new TypedToolParam(default_value)); + } + + template + TypedToolParam* AsTyped() { + AssertHasSameType(GetValueType(), type_); + return static_cast*>(this); + } + + template + const TypedToolParam* AsConstTyped() const { + AssertHasSameType(GetValueType(), type_); + return static_cast*>(this); + } + + virtual ~ToolParam() {} + explicit ToolParam(ParamType type) : type_(type) {} + + virtual void Set(const ToolParam&) {} + + virtual std::unique_ptr Clone() const = 0; + + private: + static void AssertHasSameType(ParamType a, ParamType b); + + const ParamType type_; +}; + +template +class TypedToolParam : public ToolParam { + public: + explicit TypedToolParam(const T& value) + : ToolParam(GetValueType()), value_(value) {} + + void Set(const T& value) { value_ = value; } + + T Get() const { return value_; } + + void Set(const ToolParam& other) override { + Set(other.AsConstTyped()->Get()); + } + + std::unique_ptr Clone() const override { + return std::unique_ptr(new TypedToolParam(value_)); + } + + private: + T value_; +}; + +// A map-like container for holding values of different types. +class ToolParams { + public: + void AddParam(const std::string& name, std::unique_ptr value) { + params_[name] = std::move(value); + } + + bool HasParam(const std::string& name) const { + return params_.find(name) != params_.end(); + } + + bool Empty() const { return params_.empty(); } + + const ToolParam* GetParam(const std::string& name) const { + const auto& entry = params_.find(name); + if (entry == params_.end()) return nullptr; + return entry->second.get(); + } + + template + void Set(const std::string& name, const T& value) { + AssertParamExists(name); + params_.at(name)->AsTyped()->Set(value); + } + + template + T Get(const std::string& name) const { + AssertParamExists(name); + return params_.at(name)->AsTyped()->Get(); + } + + // Set the value of all same parameters from 'other'. + void Set(const ToolParams& other); + + // Merge the value of all parameters from 'other'. 'overwrite' indicates + // whether the value of the same paratmeter is overwritten or not. + void Merge(const ToolParams& other, bool overwrite = false); + + private: + void AssertParamExists(const std::string& name) const; + std::unordered_map> params_; +}; + +} // namespace tools +} // namespace tflite +#endif // TENSORFLOW_LITE_TOOLS_TOOL_PARAMS_H_ diff --git a/tensorflow/lite/tools/tool_params_test.cc b/tensorflow/lite/tools/tool_params_test.cc new file mode 100644 index 00000000000..3bdead3e7c8 --- /dev/null +++ b/tensorflow/lite/tools/tool_params_test.cc @@ -0,0 +1,71 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/tools/tool_params.h" + +#include +#include + +namespace tflite { +namespace tools { +namespace { + +TEST(ToolParams, SetTest) { + ToolParams params; + params.AddParam("some-int1", ToolParam::Create(13)); + params.AddParam("some-int2", ToolParam::Create(17)); + + ToolParams others; + others.AddParam("some-int1", ToolParam::Create(19)); + others.AddParam("some-bool", ToolParam::Create(true)); + + params.Set(others); + EXPECT_EQ(19, params.Get("some-int1")); + EXPECT_EQ(17, params.Get("some-int2")); + EXPECT_FALSE(params.HasParam("some-bool")); +} + +TEST(ToolParams, MergeTestOverwriteTrue) { + ToolParams params; + params.AddParam("some-int1", ToolParam::Create(13)); + params.AddParam("some-int2", ToolParam::Create(17)); + + ToolParams others; + others.AddParam("some-int1", ToolParam::Create(19)); + others.AddParam("some-bool", ToolParam::Create(true)); + + params.Merge(others, true /* overwrite */); + EXPECT_EQ(19, params.Get("some-int1")); + EXPECT_EQ(17, params.Get("some-int2")); + EXPECT_TRUE(params.Get("some-bool")); +} + +TEST(ToolParams, MergeTestOverwriteFalse) { + ToolParams params; + params.AddParam("some-int1", ToolParam::Create(13)); + params.AddParam("some-int2", ToolParam::Create(17)); + + ToolParams others; + others.AddParam("some-int1", ToolParam::Create(19)); + others.AddParam("some-bool", ToolParam::Create(true)); + + params.Merge(others); // default overwrite is false + EXPECT_EQ(13, params.Get("some-int1")); + EXPECT_EQ(17, params.Get("some-int2")); + EXPECT_TRUE(params.Get("some-bool")); +} +} // namespace +} // namespace tools +} // namespace tflite diff --git a/tensorflow/lite/tools/versioning/op_version.cc b/tensorflow/lite/tools/versioning/op_version.cc index 7a10dc5637f..ae84bdc9695 100644 --- a/tensorflow/lite/tools/versioning/op_version.cc +++ b/tensorflow/lite/tools/versioning/op_version.cc @@ -172,6 +172,18 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) { } return 1; + case BuiltinOperator_MAX_POOL_2D: + case BuiltinOperator_AVERAGE_POOL_2D: + if (op_sig.input_types.at(0) == TensorType_INT16 && + op_sig.output_types.at(0) == TensorType_INT16) { + return 3; + } + + if (op_sig.input_types.at(0) == TensorType_INT8) { + return 2; + } + return 1; + case BuiltinOperator_TRANSPOSE: if (op_sig.options.single_input_op.num_dims > 4) { return 4; @@ -399,10 +411,16 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) { } return 1; - case BuiltinOperator_AVERAGE_POOL_2D: + case BuiltinOperator_FILL: + if (op_sig.input_types.size() >= 2 && + (op_sig.input_types.at(1) == TensorType_BOOL || + op_sig.input_types.at(1) == TensorType_STRING)) { + return 2; + } + return 1; + case BuiltinOperator_ADD: case BuiltinOperator_CONCATENATION: - case BuiltinOperator_MAX_POOL_2D: case BuiltinOperator_PAD: case BuiltinOperator_PADV2: case BuiltinOperator_SOFTMAX: diff --git a/tensorflow/lite/tools/versioning/op_version_test.cc b/tensorflow/lite/tools/versioning/op_version_test.cc index ae4efce2544..5dde260241e 100644 --- a/tensorflow/lite/tools/versioning/op_version_test.cc +++ b/tensorflow/lite/tools/versioning/op_version_test.cc @@ -554,4 +554,18 @@ TEST(OpVersionTest, VersioningDivTest) { fake_op_sig.options.broadcast.num_dims = 4; EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1); } +TEST(OpVersionTEst, VersioningFillTest) { + OpSignature fake_op_sig = {.op = BuiltinOperator_FILL, + .input_types = std::vector{ + TensorType_INT32, TensorType_BOOL}}; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); + fake_op_sig = {.op = BuiltinOperator_FILL, + .input_types = std::vector{TensorType_INT32, + TensorType_STRING}}; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); + fake_op_sig = {.op = BuiltinOperator_FILL, + .input_types = std::vector{TensorType_INT32, + TensorType_INT32}}; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1); +} } // namespace tflite diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 1039d10a4f1..f3449f80986 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -13,7 +13,7 @@ load("//tensorflow:tensorflow.bzl", "tf_python_pybind_extension") load("//tensorflow:tensorflow.bzl", "pybind_extension") # buildifier: disable=same-origin-load -load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc") +load("//tensorflow:tensorflow.bzl", "pywrap_tensorflow_macro") # buildifier: disable=same-origin-load load("//tensorflow:tensorflow.bzl", "cuda_py_test") @@ -5832,16 +5832,14 @@ py_library( ["pywrap_dlopen_global_flags.py"], # Import will fail, indicating no global dlopen flags otherwise = [], - ), + ), # b/153585257 srcs_version = "PY2AND3", deps = [":pywrap_tensorflow_internal"], ) -# TODO(b/137885063): This macro should be cleaned up to be just -# a py_binary/library shared object. We have removed the source -# SWIG file to prevent people from adding SWIG code. -tf_py_wrap_cc( +pywrap_tensorflow_macro( name = "pywrap_tensorflow_internal", + srcs = ["pywrap_tensorflow_internal.cc"], # add win_def_file for pywrap_tensorflow win_def_file = select({ "//tensorflow:windows": ":pywrap_tensorflow_filtered_def_file", diff --git a/tensorflow/python/autograph/impl/testing/pybind_for_testing.cc b/tensorflow/python/autograph/impl/testing/pybind_for_testing.cc index cdeabde8390..09a86afc594 100644 --- a/tensorflow/python/autograph/impl/testing/pybind_for_testing.cc +++ b/tensorflow/python/autograph/impl/testing/pybind_for_testing.cc @@ -13,9 +13,9 @@ // limitations under the License. // ============================================================================== -#include "include/pybind11/pybind11.h" -#include "include/pybind11/pytypes.h" -#include "include/pybind11/stl.h" +#include "pybind11/pybind11.h" +#include "pybind11/pytypes.h" +#include "pybind11/stl.h" namespace autograph { diff --git a/tensorflow/python/client/debug_events_writer_wrapper.cc b/tensorflow/python/client/debug_events_writer_wrapper.cc index 75abf70d749..a786c6f2db6 100644 --- a/tensorflow/python/client/debug_events_writer_wrapper.cc +++ b/tensorflow/python/client/debug_events_writer_wrapper.cc @@ -14,8 +14,8 @@ limitations under the License. ==============================================================================*/ #include "absl/strings/string_view.h" -#include "include/pybind11/pybind11.h" -#include "include/pybind11/pytypes.h" +#include "pybind11/pybind11.h" +#include "pybind11/pytypes.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/util/debug_events_writer.h" diff --git a/tensorflow/python/client/device_lib_wrapper.cc b/tensorflow/python/client/device_lib_wrapper.cc index 62ce47669d8..70f6edb1903 100644 --- a/tensorflow/python/client/device_lib_wrapper.cc +++ b/tensorflow/python/client/device_lib_wrapper.cc @@ -15,7 +15,7 @@ limitations under the License. #include -#include "include/pybind11/pybind11.h" +#include "pybind11/pybind11.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/framework/device_attributes.pb.h" diff --git a/tensorflow/python/client/events_writer_wrapper.cc b/tensorflow/python/client/events_writer_wrapper.cc index 22b3811c93d..d7dbd3f3b4d 100644 --- a/tensorflow/python/client/events_writer_wrapper.cc +++ b/tensorflow/python/client/events_writer_wrapper.cc @@ -14,8 +14,8 @@ limitations under the License. ==============================================================================*/ #include "absl/strings/string_view.h" -#include "include/pybind11/pybind11.h" -#include "include/pybind11/pytypes.h" +#include "pybind11/pybind11.h" +#include "pybind11/pytypes.h" #include "tensorflow/core/util/events_writer.h" #include "tensorflow/python/lib/core/pybind11_absl.h" #include "tensorflow/python/lib/core/pybind11_proto.h" diff --git a/tensorflow/python/client/tf_session_wrapper.cc b/tensorflow/python/client/tf_session_wrapper.cc index 46d48b12dfa..14fda32f956 100644 --- a/tensorflow/python/client/tf_session_wrapper.cc +++ b/tensorflow/python/client/tf_session_wrapper.cc @@ -15,11 +15,11 @@ limitations under the License. #include "Python.h" #include "absl/types/optional.h" -#include "include/pybind11/chrono.h" -#include "include/pybind11/complex.h" -#include "include/pybind11/functional.h" -#include "include/pybind11/pybind11.h" -#include "include/pybind11/stl.h" +#include "pybind11/chrono.h" +#include "pybind11/complex.h" +#include "pybind11/functional.h" +#include "pybind11/pybind11.h" +#include "pybind11/stl.h" #include "tensorflow/c/c_api.h" #include "tensorflow/c/c_api_experimental.h" #include "tensorflow/c/c_api_internal.h" diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index b7a5d32fa32..3a16062ad87 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -31,7 +31,7 @@ from tensorflow.python.util.tf_export import tf_export # This value changes every day with an automatic CL. It can be modified in code # via `forward_compatibility_horizon()` or with the environment variable # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2020, 4, 8) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2020, 4, 10) _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DATE_NUMBER = None diff --git a/tensorflow/python/data/experimental/kernel_tests/replicate_test.py b/tensorflow/python/data/experimental/kernel_tests/replicate_test.py index 5be6cca9332..521b38bf5d3 100644 --- a/tensorflow/python/data/experimental/kernel_tests/replicate_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/replicate_test.py @@ -270,21 +270,19 @@ class RemoteReplicateTest(test_base.DatasetTestBase, parameterized.TestCase): "counter", (), dtypes.int32, use_resource=True) dataset0 = dataset_ops.Dataset.range(100).map( lambda _: counter_var.assign_add(1)) - with self.assertRaises(errors.InvalidArgumentError): - replicated_ds = distribute.replicate(dataset0, - [self._device1, self._device2]) - dataset1 = replicated_ds[self._device1] - dataset2 = replicated_ds[self._device2] - with ops.device(self._device0): - get_next0 = self.getNext(dataset0) - with ops.device(self._device1): - get_next1 = self.getNext(dataset1) - with ops.device(self._device2): - get_next2 = self.getNext(dataset2) - for _ in range(100): - self.evaluate(get_next0()) - self.evaluate(get_next1()) - self.evaluate(get_next2()) + replicated_ds = distribute.replicate(dataset0, + [self._device1, self._device2]) + dataset1 = replicated_ds[self._device1] + dataset2 = replicated_ds[self._device2] + with ops.device(self._device0): + self.assertDatasetProduces( + dataset0, range(1, 101), requires_initialization=True) + with ops.device(self._device1): + self.assertDatasetProduces( + dataset1, range(101, 201), requires_initialization=True) + with ops.device(self._device2): + self.assertDatasetProduces( + dataset2, range(201, 301), requires_initialization=True) if __name__ == "__main__": diff --git a/tensorflow/python/data/kernel_tests/from_generator_test.py b/tensorflow/python/data/kernel_tests/from_generator_test.py index ecc022b58a5..aaa66f353e5 100644 --- a/tensorflow/python/data/kernel_tests/from_generator_test.py +++ b/tensorflow/python/data/kernel_tests/from_generator_test.py @@ -445,6 +445,30 @@ class FromGeneratorTest(test_base.DatasetTestBase, parameterized.TestCase): self.assertAllEqual([20], self.evaluate(get_next())) + @combinations.generate(test_base.default_test_combinations()) + def testTypeIsListError(self): + + def generator(): + for _ in range(10): + yield [20] + + with self.assertRaisesRegexp( + TypeError, r"Cannot convert value \[tf.int64\] to a TensorFlow DType"): + dataset_ops.Dataset.from_generator( + generator, output_types=[dtypes.int64]) + + @combinations.generate(test_base.default_test_combinations()) + def testDimensionIsListError(self): + + def generator(): + for _ in range(10): + yield [20] + + with self.assertRaisesRegexp( + TypeError, r"Failed to convert '\[\[1\]\]' to a shape"): + dataset_ops.Dataset.from_generator( + generator, output_types=(dtypes.int64), output_shapes=[[1]]) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/data/kernel_tests/from_tensor_slices_test.py b/tensorflow/python/data/kernel_tests/from_tensor_slices_test.py index 9ed730e7bca..a16518d1111 100644 --- a/tensorflow/python/data/kernel_tests/from_tensor_slices_test.py +++ b/tensorflow/python/data/kernel_tests/from_tensor_slices_test.py @@ -17,6 +17,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections + from absl.testing import parameterized import numpy as np @@ -63,6 +65,17 @@ class FromTensorSlicesTest(test_base.DatasetTestBase, parameterized.TestCase): ds = ds.flat_map(lambda x: x) self.assertDatasetProduces(ds, expected_output=list(range(10)) * 10) + @combinations.generate(test_base.default_test_combinations()) + def testFromTensorSlicesDatasetOfOrderedDict(self): + dss = [dataset_ops.Dataset.range(10).map( + lambda x: collections.OrderedDict([("x", x)])) for _ in range(10)] + ds = dataset_ops.Dataset.from_tensor_slices(dss) + ds = ds.flat_map(lambda x: x) + self.assertDatasetProduces( + ds, + expected_output=[collections.OrderedDict([("x", x)]) + for x in list(range(10)) * 10]) + @combinations.generate(test_base.default_test_combinations()) def testFromTensorSlicesDatasetInFunction(self): dss = [dataset_ops.Dataset.range(10) for _ in range(10)] diff --git a/tensorflow/python/data/kernel_tests/interleave_test.py b/tensorflow/python/data/kernel_tests/interleave_test.py index 716cddcab79..344ce50d00f 100644 --- a/tensorflow/python/data/kernel_tests/interleave_test.py +++ b/tensorflow/python/data/kernel_tests/interleave_test.py @@ -35,13 +35,14 @@ from tensorflow.python.ops import sparse_ops from tensorflow.python.platform import test -def _interleave(lists, cycle_length, block_length): +def _interleave(lists, cycle_length, block_length, num_parallel_calls=None): """Reference implementation of interleave used for testing. Args: lists: a list of lists to interleave cycle_length: the length of the interleave cycle block_length: the length of the interleave block + num_parallel_calls: the number of parallel calls Yields: Elements of `lists` interleaved in the order determined by `cycle_length` @@ -55,8 +56,15 @@ def _interleave(lists, cycle_length, block_length): # `open_iterators` are the iterators whose elements are currently being # interleaved. open_iterators = [] - if cycle_length == dataset_ops.AUTOTUNE: - cycle_length = multiprocessing.cpu_count() + if cycle_length is None: + # The logic here needs to match interleave C++ kernels. + if num_parallel_calls is None: + cycle_length = multiprocessing.cpu_count() + elif num_parallel_calls == dataset_ops.AUTOTUNE: + cycle_length = (multiprocessing.cpu_count() + 2) // 3 + else: + cycle_length = min(num_parallel_calls, multiprocessing.cpu_count()) + for i in range(cycle_length): if all_iterators: open_iterators.append(all_iterators.pop(0)) @@ -162,7 +170,7 @@ class InterleaveTest(test_base.DatasetTestBase, parameterized.TestCase): num_parallel_calls=[None, 1, 3, 5, 7]) + combinations.combine( input_values=[np.int64([4, 5, 6, 7])], - cycle_length=dataset_ops.AUTOTUNE, + cycle_length=None, block_length=3, num_parallel_calls=[None, 1]) + combinations.combine( input_values=[np.int64([]), np.int64([0, 0, 0])], @@ -182,7 +190,8 @@ class InterleaveTest(test_base.DatasetTestBase, parameterized.TestCase): cycle_length, block_length, num_parallel_calls) expected_output = [ element for element in _interleave( - _repeat(input_values, count), cycle_length, block_length) + _repeat(input_values, count), cycle_length, block_length, + num_parallel_calls) ] self.assertDatasetProduces(dataset, expected_output) @@ -259,7 +268,7 @@ class InterleaveTest(test_base.DatasetTestBase, parameterized.TestCase): block_length=2, num_parallel_calls=[1, 3, 5, 7]) + combinations.combine( input_values=[np.int64([4, 5, 6, 7])], - cycle_length=dataset_ops.AUTOTUNE, + cycle_length=None, block_length=3, num_parallel_calls=1) + combinations.combine( input_values=[np.int64([4, 0, 6])], @@ -278,7 +287,8 @@ class InterleaveTest(test_base.DatasetTestBase, parameterized.TestCase): dataset = dataset.with_options(options) expected_output = [ element for element in _interleave( - _repeat(input_values, count), cycle_length, block_length) + _repeat(input_values, count), cycle_length, block_length, + num_parallel_calls) ] get_next = self.getNext(dataset) actual_output = [] diff --git a/tensorflow/python/data/kernel_tests/options_test.py b/tensorflow/python/data/kernel_tests/options_test.py index b38d008b833..dea217367dc 100644 --- a/tensorflow/python/data/kernel_tests/options_test.py +++ b/tensorflow/python/data/kernel_tests/options_test.py @@ -100,6 +100,18 @@ class OptionsTest(test_base.DatasetTestBase, parameterized.TestCase): self.assertEqual(options1.experimental_threading, threading_options.ThreadingOptions()) + @combinations.generate(test_base.eager_only_combinations()) + def testNestedDataset(self): + ds = dataset_ops.Dataset.from_tensors(0) + result = ds + + for _ in range(999): + result = result.concatenate(ds) + options = dataset_ops.Options() + options.experimental_optimization.autotune = False + result = result.with_options(options) + self.assertDatasetProduces(result, [0]*1000) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py index b23df3672c9..f20d9f3c1ef 100644 --- a/tensorflow/python/data/ops/dataset_ops.py +++ b/tensorflow/python/data/ops/dataset_ops.py @@ -152,14 +152,18 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor): Elements may be nested structures containing multiple components. For example, the element `(1, (3, "apple"))` has one tuple nested in another tuple. The components are `1`, `3`, and `"apple"`. + **Component**: The leaf in the nested structure of an element. Supported types: Elements can be nested structures of tuples, named tuples, and dictionaries. - Element components can be of any type representable by `tf.TypeSpec`, - including `tf.Tensor`, `tf.data.Dataset`, `tf.sparse.SparseTensor`, - `tf.RaggedTensor`, and `tf.TensorArray`. + Note that Python lists are *not* treated as nested structures of components. + Instead, lists are converted to tensors and treated as components. For + example, the element `(1, [1, 2, 3])` has only two components; the tensor `1` + and the tensor `[1, 2, 3]`. Element components can be of any type + representable by `tf.TypeSpec`, including `tf.Tensor`, `tf.data.Dataset`, + `tf.sparse.SparseTensor`, `tf.RaggedTensor`, and `tf.TensorArray`. >>> a = 1 # Integer element >>> b = 2.0 # Float element @@ -194,6 +198,13 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor): name="_variant_tracker") self._graph_attr = ops.get_default_graph() + # Initialize the options for this dataset and its inputs. + self._options_attr = Options() + for input_dataset in self._inputs(): + input_options = input_dataset.options() + if input_options is not None: + self._options_attr = self._options_attr.merge(input_options) + @property def _variant_tensor(self): return self._variant_tensor_attr @@ -332,12 +343,7 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor): Returns: A `tf.data.Options` object representing the dataset options. """ - options = Options() - for input_dataset in self._inputs(): - input_options = input_dataset.options() - if input_options is not None: - options = options.merge(input_options) - return options + return self._options_attr def _apply_options(self): """Apply options, such as optimization configuration, to the dataset.""" @@ -1669,8 +1675,8 @@ name=None)) def interleave(self, map_func, - cycle_length=AUTOTUNE, - block_length=1, + cycle_length=None, + block_length=None, num_parallel_calls=None, deterministic=None): """Maps `map_func` across this dataset, and interleaves the results. @@ -1739,12 +1745,13 @@ name=None)) Args: map_func: A function mapping a dataset element to a dataset. cycle_length: (Optional.) The number of input elements that will be - processed concurrently. If not specified, the value will be derived from - the number of available CPU cores. If the `num_parallel_calls` argument - is set to `tf.data.experimental.AUTOTUNE`, the `cycle_length` argument - also identifies the maximum degree of parallelism. + processed concurrently. If not set, the tf.data runtime decides what it + should be based on available CPU. If `num_parallel_calls` is set to + `tf.data.experimental.AUTOTUNE`, the `cycle_length` argument identifies + the maximum degree of parallelism. block_length: (Optional.) The number of consecutive elements to produce - from each input element before cycling to another input element. + from each input element before cycling to another input element. If not + set, defaults to 1. num_parallel_calls: (Optional.) If specified, the implementation creates a threadpool, which is used to fetch inputs from cycle elements asynchronously and in parallel. The default behavior is to fetch inputs @@ -1761,6 +1768,12 @@ name=None)) Returns: Dataset: A `Dataset`. """ + if block_length is None: + block_length = 1 + + if cycle_length is None: + cycle_length = AUTOTUNE + if num_parallel_calls is None: return InterleaveDataset(self, map_func, cycle_length, block_length) else: @@ -4197,6 +4210,7 @@ class InterleaveDataset(UnaryDataset): def __init__(self, input_dataset, map_func, cycle_length, block_length): """See `Dataset.interleave()` for details.""" + self._input_dataset = input_dataset self._map_func = StructuredFunctionWrapper( map_func, self._transformation_name(), dataset=input_dataset) @@ -4413,16 +4427,16 @@ class _OptionsDataset(UnaryUnchangedStructureDataset): def __init__(self, input_dataset, options): self._input_dataset = input_dataset - self._options = input_dataset.options() - if self._options: - self._options = self._options.merge(options) - else: - self._options = options variant_tensor = input_dataset._variant_tensor # pylint: disable=protected-access super(_OptionsDataset, self).__init__(input_dataset, variant_tensor) + if self._options_attr: + self._options_attr = self._options_attr.merge(options) + else: + self._options_attr = options + def options(self): - return self._options + return self._options_attr class _ModelDataset(UnaryUnchangedStructureDataset): diff --git a/tensorflow/python/data/service/server_lib_wrapper.cc b/tensorflow/python/data/service/server_lib_wrapper.cc index e273eb5b6a9..16a12eef873 100644 --- a/tensorflow/python/data/service/server_lib_wrapper.cc +++ b/tensorflow/python/data/service/server_lib_wrapper.cc @@ -14,12 +14,12 @@ limitations under the License. ==============================================================================*/ #include "Python.h" -#include "include/pybind11/chrono.h" -#include "include/pybind11/complex.h" -#include "include/pybind11/functional.h" -#include "include/pybind11/pybind11.h" -#include "include/pybind11/pytypes.h" -#include "include/pybind11/stl.h" +#include "pybind11/chrono.h" +#include "pybind11/complex.h" +#include "pybind11/functional.h" +#include "pybind11/pybind11.h" +#include "pybind11/pytypes.h" +#include "pybind11/stl.h" #include "tensorflow/core/data/service/server_lib.h" #include "tensorflow/python/lib/core/pybind11_lib.h" #include "tensorflow/python/lib/core/pybind11_status.h" diff --git a/tensorflow/python/debug/BUILD b/tensorflow/python/debug/BUILD index a6178fce24b..70e478dec21 100644 --- a/tensorflow/python/debug/BUILD +++ b/tensorflow/python/debug/BUILD @@ -783,7 +783,10 @@ cuda_py_test( python_version = "PY3", shard_count = 8, tags = [ + "manual", + "no_oss", "no_windows", # TODO(b/142475891): Enable this test on Windows. + "notap", # TODO(b/153671240) ], xla_enable_strict_auto_jit = False, # Node names are different with autojit deps = [ @@ -1463,7 +1466,7 @@ sh_test( ":debug_errors", ":debug_fibonacci", ":debug_keras", - ":debug_mnist", + ":debug_mnist_v1", ":debug_tflearn_iris", ":offline_analyzer", ], diff --git a/tensorflow/python/debug/examples/v1/examples_v1_test.sh b/tensorflow/python/debug/examples/v1/examples_v1_test.sh index afd1f2d86bc..6b52f57ba8a 100755 --- a/tensorflow/python/debug/examples/v1/examples_v1_test.sh +++ b/tensorflow/python/debug/examples/v1/examples_v1_test.sh @@ -46,14 +46,14 @@ done if [[ -z "${PYTHON_BIN_PATH}" ]]; then DEBUG_FIBONACCI_BIN="$TEST_SRCDIR/org_tensorflow/tensorflow/python/debug/debug_fibonacci" DEBUG_ERRORS_BIN="$TEST_SRCDIR/org_tensorflow/tensorflow/python/debug/debug_errors" - DEBUG_MNIST_BIN="$TEST_SRCDIR/org_tensorflow/tensorflow/python/debug/debug_mnist" + DEBUG_MNIST_BIN="$TEST_SRCDIR/org_tensorflow/tensorflow/python/debug/debug_mnist_v1" DEBUG_TFLEARN_IRIS_BIN="$TEST_SRCDIR/org_tensorflow/tensorflow/python/debug/debug_tflearn_iris" DEBUG_KERAS_BIN="$TEST_SRCDIR/org_tensorflow/tensorflow/python/debug/debug_keras" OFFLINE_ANALYZER_BIN="$TEST_SRCDIR/org_tensorflow/tensorflow/python/debug/offline_analyzer" else DEBUG_FIBONACCI_BIN="${PYTHON_BIN_PATH} -m tensorflow.python.debug.examples.v1.debug_fibonacci" DEBUG_ERRORS_BIN="${PYTHON_BIN_PATH} -m tensorflow.python.debug.examples.v1.debug_errors" - DEBUG_MNIST_BIN="${PYTHON_BIN_PATH} -m tensorflow.python.debug.examples.v1.debug_mnist" + DEBUG_MNIST_BIN="${PYTHON_BIN_PATH} -m tensorflow.python.debug.examples.v1.debug_mnist_v1" DEBUG_TFLEARN_IRIS_BIN="${PYTHON_BIN_PATH} -m tensorflow.python.debug.examples.v1.debug_tflearn_iris" DEBUG_KERAS_BIN="${PYTHON_BIN_PATH} -m tensorflow.python.debug.examples.v1.debug_keras" OFFLINE_ANALYZER_BIN="${PYTHON_BIN_PATH} -m tensorflow.python.debug.cli.offline_analyzer" diff --git a/tensorflow/python/debug/lib/check_numerics_callback_test.py b/tensorflow/python/debug/lib/check_numerics_callback_test.py index 81e761386a6..ea5d70f0d08 100644 --- a/tensorflow/python/debug/lib/check_numerics_callback_test.py +++ b/tensorflow/python/debug/lib/check_numerics_callback_test.py @@ -31,17 +31,14 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import test_util -from tensorflow.python.keras import layers -from tensorflow.python.keras import models -from tensorflow.python.keras import optimizer_v2 -from tensorflow.python.keras.applications import mobilenet_v2 +from tensorflow.python.ops import array_grad # pylint: disable=unused-import from tensorflow.python.ops import custom_gradient from tensorflow.python.ops import gen_nn_ops from tensorflow.python.ops import gradient_checker_v2 +from tensorflow.python.ops import math_grad # pylint: disable=unused-import from tensorflow.python.ops import math_ops from tensorflow.python.ops import variables from tensorflow.python.platform import googletest -from tensorflow.python.platform import test as test_lib class LimitStringLengthTest(test_util.TensorFlowTestCase): @@ -85,57 +82,6 @@ class CheckNumericsCallbackTest(test_util.TensorFlowTestCase): y = constant_op.constant([1.0, 0.0]) self.assertAllClose((x + y) * (x - y), [3.0, 9.0]) - @test_util.run_in_graph_and_eager_modes - def testKerasModelHealthyPredictAndFitCalls(self): - """Test a simple healthy keras model runs fine under the callback.""" - check_numerics_callback.enable_check_numerics() - - model = models.Sequential() - model.add(layers.Dense( - units=100, - input_shape=(5,), - use_bias=False, - activation="relu", - kernel_initializer="ones")) - model.add(layers.BatchNormalization()) - model.add(layers.Dropout(0.5)) - model.add(layers.Dense( - units=1, - activation="linear", - kernel_initializer="ones")) - - model.compile( - loss="mse", optimizer=optimizer_v2.gradient_descent.SGD(1e-3)) - - batch_size = 16 - xs = np.zeros([batch_size, 5]) - ys = np.ones([batch_size, 1]) - - outputs = model.predict(xs) - self.assertEqual(outputs.shape, (batch_size, 1)) - - epochs = 100 - history = model.fit(xs, ys, epochs=epochs, verbose=0) - self.assertEqual(len(history.history["loss"]), epochs) - - @test_util.run_in_graph_and_eager_modes - def testKerasModelWithRNNHealthyPredictAndFitCalls(self): - """Test a simple healthy keras recurrent model works under the callback.""" - check_numerics_callback.enable_check_numerics() - - model = models.Sequential() - model.add(layers.LSTM(1, input_shape=(2, 4))) - model.compile(loss="mse", optimizer="rmsprop") - - xs = np.zeros([8, 2, 4], dtype=np.float32) - ys = np.zeros([8, 1], dtype=np.float32) - - model.predict(xs) - - epochs = 3 - history = model.fit(xs, ys, epochs=epochs, verbose=0) - self.assertEqual(len(history.history["loss"]), epochs) - @test_util.run_in_graph_and_eager_modes def testDatasetMapHealthyResults(self): check_numerics_callback.enable_check_numerics() @@ -153,26 +99,6 @@ class CheckNumericsCallbackTest(test_util.TensorFlowTestCase): self.assertAllClose(self.evaluate(iterator.get_next()), np.log([1.25, 2])) self.assertAllClose(self.evaluate(iterator.get_next()), np.log([3.25, 5])) - @test_util.run_in_graph_and_eager_modes - def testMobileNetV2Fit(self): - """Test training Keras MobileNetV2 application works w/ check numerics.""" - - if test_lib.is_built_with_rocm(): - # This test passes with MIOpen Find Mode (which is the default) - # This bug is being tracked via MLOpen Issue #2379, re-enable this - # test once the fix for that issue is available in a ROCm release - self.skipTest("MIOpen bug results in test failure") - - check_numerics_callback.enable_check_numerics() - model = mobilenet_v2.MobileNetV2(alpha=0.1, weights=None) - - xs = np.zeros([2] + list(model.input_shape[1:])) - ys = np.zeros([2] + list(model.output_shape[1:])) - model.compile(optimizer="sgd", loss="categorical_crossentropy") - epochs = 1 - history = model.fit(xs, ys, epochs=epochs, verbose=0) - self.assertEqual(len(history.history["loss"]), epochs) - class CheckNumericsCallbackUnhealthyTest(test_util.TensorFlowTestCase): """Test for cases in which enable_check_numerics() catches infs or nans.""" @@ -242,54 +168,6 @@ class CheckNumericsCallbackUnhealthyTest(test_util.TensorFlowTestCase): self.assertIn("# of -Inf elements: 1\n", message) self.assertTrue(re.search(r"Input tensor.*0\.", message)) - @test_util.run_in_graph_and_eager_modes - def testKerasModelUnhealthyPredictAndFitCallsWithLargeLearningRate(self): - """Test keras model training crashes with Infinity is caught by callback.""" - check_numerics_callback.enable_check_numerics() - - model = models.Sequential() - # Use weight initializers for deterministic behavior during test. - model.add(layers.Dense( - units=100, - input_shape=(5,), - activation="relu", - kernel_initializer="ones")) - model.add(layers.Dense( - units=1, - activation="linear", - kernel_initializer="ones")) - - lr = 1e3 # Intentionally huge learning rate. - model.compile(loss="mse", optimizer=optimizer_v2.gradient_descent.SGD(lr)) - - batch_size = 16 - xs = np.zeros([batch_size, 5]) - ys = np.ones([batch_size, 1]) - - outputs = model.predict(xs) - self.assertEqual(outputs.shape, (batch_size, 1)) - - epochs = 100 - message = self._assertRaisesInvalidArgumentErrorAndGetMessage( - lambda: model.fit(xs, ys, epochs=epochs, verbose=0)) - - # Check the content of the error message. - # Let's not hardcode the op name for future-proof. - self.assertTrue(re.search(r"graph op.*\".*\"", message)) - self.assertTrue(re.search(r"dtype:.*float32", message)) - self.assertTrue(re.search(r"shape:.*\(.*\)", message)) - # Check that the correct input op is printed. - self.assertTrue(re.search(r"Input tensor.*", message)) - # Check that the correct line for op creation is printed. - self.assertTrue(re.search(r"Stack trace of op's creation", message)) - # The stacks are different between when eager execution is enabled and - # when it's not (i.e., v1 graph). TODO(cais): Investigate if we can improve - # this. - if context.executing_eagerly(): - self.assertIn("lambda: model.fit(xs, ys,", message) - else: - self.assertIn("model.compile(", message) - @test_util.run_in_graph_and_eager_modes def testCatchFunctionOpInfFloat64(self): """Test catching infinites generated in a FuncGraph.""" @@ -389,102 +267,6 @@ class CheckNumericsCallbackUnhealthyTest(test_util.TensorFlowTestCase): self.assertTrue(re.search(r"Stack trace of op's creation", message)) self.assertIn("accum.assign(accum * 2.0)", message) - @test_util.run_in_graph_and_eager_modes - def testInfInCustomKerasLayerWithTfFunctionPredictCall(self): - """Test catching Infinity in a custom layer, w/ tf.function.""" - check_numerics_callback.enable_check_numerics() - - class DivByXLayer(layers.Layer): - - @def_function.function - def call(self, x): - """The computation performed by the for-test custom layer. - - Generates Infinity by intention. - - Args: - x: Input tensor of scalar shape. - - Returns: - A scalar tensor. - """ - one_over_x = 1.0 / x - return one_over_x - - model = models.Sequential() - model.add(DivByXLayer(input_shape=[5])) - - # TODO(b/140245224): Currently the model must be compiled prior to - # predict() being called(). Or keras will fall back to V1 behavior. - # Remove this after the bug is fixed. - model.compile(loss="mse", optimizer="sgd") - - xs = np.ones([1, 5]) - # Calling the model with non-zero inputs should be fine. - self.assertAllClose(model.predict(xs), [[1.0, 1.0, 1.0, 1.0, 1.0]]) - - xs = np.zeros([1, 5]) - message = self._assertRaisesInvalidArgumentErrorAndGetMessage( - lambda: model.predict(xs)) - - # Check the content of the error message. - self.assertTrue(re.search(r"graph op.*\"RealDiv\"", message)) - self.assertTrue(re.search(r"dtype.*float32", message)) - self.assertTrue(re.search(r"shape: \(.*, 5\)", message)) - # # Check that the correct input op is printed. - self.assertIn("Input tensors (2):", message) - # # # Check that the correct line for op creation is printed. - self.assertTrue(re.search(r"Stack trace of op's creation", message)) - self.assertIn("one_over_x = 1.0 / x", message) - - @test_util.run_in_graph_and_eager_modes - def testInfInCustomKerasLayerWithoutTfFunctionPredictCall(self): - """Test catching Infinity in a custom layer, w/o tf.function.""" - check_numerics_callback.enable_check_numerics() - - class DivByXLayer(layers.Layer): - - # Not using the tf.function decorator here. - def call(self, x): - """The computation performed by the for-test custom layer. - - Generates Infinity by intention. - - Args: - x: Input tensor of scalar shape. - - Returns: - A scalar tensor. - """ - one_over_x = 1.0 / x - return one_over_x - - model = models.Sequential() - model.add(DivByXLayer(input_shape=[5])) - - # TODO(b/140245224): Currently the model must be compiled prior to - # predict() being called(). Or keras will fall back to V1 behavior. - # Remove this after the bug is fixed. - model.compile(loss="mse", optimizer="sgd") - - xs = np.ones([1, 5]) - # Calling the model with non-zero inputs should be fine. - self.assertAllClose(model.predict(xs), [[1.0, 1.0, 1.0, 1.0, 1.0]]) - - xs = np.zeros([1, 5]) - message = self._assertRaisesInvalidArgumentErrorAndGetMessage( - lambda: model.predict(xs)) - - # Check the content of the error message. - self.assertTrue(re.search(r"graph op.*\"RealDiv\"", message)) - self.assertTrue(re.search(r"dtype.*float32", message)) - self.assertTrue(re.search(r"shape: \(.*, 5\)", message)) - # Check that the correct input op is printed. - self.assertIn("Input tensors (2):", message) - # Check that the correct line for op creation is printed. - self.assertTrue(re.search(r"Stack trace of op's creation", message)) - self.assertIn("one_over_x = 1.0 / x", message) - @test_util.run_in_graph_and_eager_modes def testCatchInfinityInDatasetMapFunction(self): """Test that callback catches NaN in a tf.dataset map function.""" diff --git a/tensorflow/python/debug/lib/debug_events_reader.py b/tensorflow/python/debug/lib/debug_events_reader.py index f861f9f278d..10fd0d61222 100644 --- a/tensorflow/python/debug/lib/debug_events_reader.py +++ b/tensorflow/python/debug/lib/debug_events_reader.py @@ -39,6 +39,13 @@ DebugEventWithOffset = collections.namedtuple( class DebugEventsReader(object): """Reader class for a tfdbg v2 DebugEvents directory.""" + # Number of digests after which a read lock is released and re-acquired during + # serial reading of digests for SourceFiles, Execution, and + # GraphExecutionTrace. This allows us to avoid releasing and re-acquiring the + # lock too often (i.e., after each digest) and to minimize performance + # penalty. + _READER_RELEASE_PER = 100 + def __init__(self, dump_root): if not file_io.is_directory(dump_root): raise ValueError("Specified dump_root is not a directory: %s" % dump_root) @@ -64,7 +71,10 @@ class DebugEventsReader(object): self._readers = dict() # A map from file path to reader. # A map from file path to current reading offset. self._reader_offsets = dict() + # Lock for reader creation. self._readers_lock = threading.Lock() + # Locks for read operation on individual readers. + self._reader_read_locks = dict() self._offsets = dict() @@ -88,19 +98,32 @@ class DebugEventsReader(object): Yields: A tuple of (offset, debug_event_proto) on each `next()` call. """ + yield_count = 0 reader = self._get_reader(file_path) - while True: - current_offset = self._reader_offsets[file_path] - try: - record, self._reader_offsets[file_path] = reader.read(current_offset) - except (errors.DataLossError, IndexError): - # We ignore partial read exceptions, because a record may be truncated. - # The PyRandomRecordReader throws an `IndexError` when offset goes out - # of bound. - break - yield DebugEventWithOffset( - debug_event=debug_event_pb2.DebugEvent.FromString(record), - offset=current_offset) + read_lock = self._reader_read_locks[file_path] + read_lock.acquire() + try: + while True: + current_offset = self._reader_offsets[file_path] + try: + record, self._reader_offsets[file_path] = reader.read(current_offset) + except (errors.DataLossError, IndexError): + # We ignore partial read exceptions, because a record may be + # truncated. The PyRandomRecordReader throws an `IndexError` when + # offset goes out of bound. + break + yield DebugEventWithOffset( + debug_event=debug_event_pb2.DebugEvent.FromString(record), + offset=current_offset) + yield_count += 1 + # The read lock must be periodically released to allow for concurrent + # random reads. But we do so at a number of reads, instead of after + # every single read, in order to minimize the performance penalty. + if yield_count % self._READER_RELEASE_PER == 0: + read_lock.release() + read_lock.acquire() + finally: + read_lock.release() def _get_reader(self, file_path): """Get a random-access reader for TFRecords file at file_path.""" @@ -112,6 +135,7 @@ class DebugEventsReader(object): if file_path not in self._readers: # 2nd check, with lock. self._readers[file_path] = tf_record.tf_record_random_reader( file_path) + self._reader_read_locks[file_path] = threading.Lock() self._reader_offsets[file_path] = 0 return self._readers[file_path] @@ -129,8 +153,9 @@ class DebugEventsReader(object): def read_source_files_event(self, offset): """Read a DebugEvent proto at given offset from the .source_files file.""" - return debug_event_pb2.DebugEvent.FromString( - self._get_reader(self._source_files_path).read(offset)[0]) + with self._reader_read_locks[self._source_files_path]: + proto_string = self._get_reader(self._source_files_path).read(offset)[0] + return debug_event_pb2.DebugEvent.FromString(proto_string) def read_graphs_event(self, offset): """Read a DebugEvent proto at a given offset from the .graphs file. @@ -151,7 +176,7 @@ class DebugEventsReader(object): def execution_iterator(self): return self._generic_iterator(self._execution_path) - def read_execution_debug_event(self, offset): + def read_execution_event(self, offset): """Read a DebugEvent proto at a given offset from the .execution file. Args: @@ -164,8 +189,9 @@ class DebugEventsReader(object): `errors.DataLossError` if offset is at a wrong location. `IndexError` if offset is out of range of the file. """ - return debug_event_pb2.DebugEvent.FromString( - self._get_reader(self._execution_path).read(offset)[0]) + with self._reader_read_locks[self._execution_path]: + proto_string = self._get_reader(self._execution_path).read(offset)[0] + return debug_event_pb2.DebugEvent.FromString(proto_string) def graph_execution_traces_iterator(self): return self._generic_iterator(self._graph_execution_traces_path) @@ -183,8 +209,10 @@ class DebugEventsReader(object): `errors.DataLossError` if offset is at a wrong location. `IndexError` if offset is out of range of the file. """ - return debug_event_pb2.DebugEvent.FromString( - self._get_reader(self._graph_execution_traces_path).read(offset)[0]) + with self._reader_read_locks[self._graph_execution_traces_path]: + proto_string = self._get_reader( + self._graph_execution_traces_path).read(offset)[0] + return debug_event_pb2.DebugEvent.FromString(proto_string) def close(self): with self._readers_lock: @@ -1051,8 +1079,7 @@ class DebugDataReader(object): def read_execution(self, execution_digest): """Read a detailed Execution object.""" - debug_event = self._reader.read_execution_debug_event( - execution_digest.offset) + debug_event = self._reader.read_execution_event(execution_digest.offset) return _execution_from_debug_event_proto( debug_event, execution_digest.offset) @@ -1118,7 +1145,7 @@ class DebugDataReader(object): A list of numpy arrays representing the output tensor values of the execution event. """ - debug_event = self._reader.read_execution_debug_event(execution.offset) + debug_event = self._reader.read_execution_event(execution.offset) return [_parse_tensor_value(tensor_proto) for tensor_proto in debug_event.execution.tensor_protos] @@ -1132,7 +1159,8 @@ class DebugDataReader(object): A numpy array representing the output tensor value of the intra-graph tensor execution event. """ - debug_event = self._reader.read_graph_execution_traces_event(trace.offset) + debug_event = self._reader.read_graph_execution_traces_event( + trace.offset) return _parse_tensor_value(debug_event.graph_execution_trace.tensor_proto) def symbolic_tensor_id(self, graph_id, op_name, output_slot): diff --git a/tensorflow/python/debug/lib/debug_events_writer_test.py b/tensorflow/python/debug/lib/debug_events_writer_test.py index 82bb4992d0b..45b7f16b2a4 100644 --- a/tensorflow/python/debug/lib/debug_events_writer_test.py +++ b/tensorflow/python/debug/lib/debug_events_writer_test.py @@ -37,13 +37,13 @@ from tensorflow.python.platform import googletest class DebugEventsWriterTest(dumping_callback_test_lib.DumpingCallbackTestBase): def testMultiThreadedConstructorCallWorks(self): - def InitWriter(): + def init_writer(): debug_events_writer.DebugEventsWriter(self.dump_root) num_threads = 4 threads = [] for _ in range(num_threads): - thread = threading.Thread(target=InitWriter) + thread = threading.Thread(target=init_writer) thread.start() threads.append(thread) for thread in threads: @@ -123,7 +123,7 @@ class DebugEventsWriterTest(dumping_callback_test_lib.DumpingCallbackTestBase): source_file_state = {"counter": 0, "lock": threading.Lock()} - def WriteSourceFile(): + def writer_source_file(): source_file = debug_event_pb2.SourceFile() with source_file_state["lock"]: source_file.file_path = "/home/tf2user/file_%d.py" % source_file_state[ @@ -136,7 +136,7 @@ class DebugEventsWriterTest(dumping_callback_test_lib.DumpingCallbackTestBase): stack_frame_state = {"counter": 0, "lock": threading.Lock()} - def WriteStackFrame(): + def write_stack_frame(): stack_frame = debug_event_pb2.StackFrameWithId() with stack_frame_state["lock"]: stack_frame.id = "stack_frame_%d" % stack_frame_state["counter"] @@ -148,7 +148,7 @@ class DebugEventsWriterTest(dumping_callback_test_lib.DumpingCallbackTestBase): graph_op_state = {"counter": 0, "lock": threading.Lock()} - def WriteGraphOpCreation(): + def write_graph_op_creation(): graph_op_creation = debug_event_pb2.GraphOpCreation() with graph_op_state["lock"]: graph_op_creation.op_name = "Op%d" % graph_op_state["counter"] @@ -162,11 +162,11 @@ class DebugEventsWriterTest(dumping_callback_test_lib.DumpingCallbackTestBase): threads = [] for i in range(num_threads): if i % 3 == 0: - target = WriteSourceFile + target = writer_source_file elif i % 3 == 1: - target = WriteStackFrame + target = write_stack_frame else: - target = WriteGraphOpCreation + target = write_graph_op_creation thread = threading.Thread(target=target) thread.start() threads.append(thread) @@ -294,7 +294,7 @@ class DebugEventsWriterTest(dumping_callback_test_lib.DumpingCallbackTestBase): execution_state = {"counter": 0, "lock": threading.Lock()} - def WriteExecution(): + def write_execution(): execution = debug_event_pb2.Execution() with execution_state["lock"]: execution.op_type = "OpType%d" % execution_state["counter"] @@ -303,7 +303,7 @@ class DebugEventsWriterTest(dumping_callback_test_lib.DumpingCallbackTestBase): graph_execution_trace_state = {"counter": 0, "lock": threading.Lock()} - def WriteGraphExecutionTrace(): + def write_graph_execution_trace(): with graph_execution_trace_state["lock"]: op_name = "Op%d" % graph_execution_trace_state["counter"] graph_op_creation = debug_event_pb2.GraphOpCreation( @@ -317,9 +317,9 @@ class DebugEventsWriterTest(dumping_callback_test_lib.DumpingCallbackTestBase): threads = [] for i in range(circular_buffer_size * 4): if i % 2 == 0: - target = WriteExecution + target = write_execution else: - target = WriteGraphExecutionTrace + target = write_graph_execution_trace thread = threading.Thread(target=target) thread.start() threads.append(thread) @@ -341,6 +341,185 @@ class DebugEventsWriterTest(dumping_callback_test_lib.DumpingCallbackTestBase): self.assertLen(op_names, circular_buffer_size) self.assertLen(op_names, len(set(op_names))) + def testConcurrentSourceFileRandomReads(self): + writer = debug_events_writer.DebugEventsWriter(self.dump_root) + + for i in range(100): + source_file = debug_event_pb2.SourceFile( + host_name="localhost", file_path="/tmp/file_%d.py" % i) + source_file.lines.append("# File %d" % i) + writer.WriteSourceFile(source_file) + writer.FlushNonExecutionFiles() + + reader = debug_events_reader.DebugDataReader(self.dump_root) + reader.update() + lines = [None] * 100 + def read_job_1(): + # Read in the reverse order to enhance randomness of the read access. + for i in range(49, -1, -1): + lines[i] = reader.source_lines("localhost", "/tmp/file_%d.py" % i) + def read_job_2(): + for i in range(99, 49, -1): + lines[i] = reader.source_lines("localhost", "/tmp/file_%d.py" % i) + thread_1 = threading.Thread(target=read_job_1) + thread_2 = threading.Thread(target=read_job_2) + thread_1.start() + thread_2.start() + thread_1.join() + thread_2.join() + for i in range(100): + self.assertEqual(lines[i], ["# File %d" % i]) + + def testConcurrentExecutionUpdateAndRandomRead(self): + circular_buffer_size = -1 + writer = debug_events_writer.DebugEventsWriter(self.dump_root, + circular_buffer_size) + + writer_state = {"counter": 0, "done": False} + + with debug_events_reader.DebugDataReader(self.dump_root) as reader: + def write_and_update_job(): + while True: + if writer_state["done"]: + break + execution = debug_event_pb2.Execution() + execution.op_type = "OpType%d" % writer_state["counter"] + writer_state["counter"] += 1 + writer.WriteExecution(execution) + writer.FlushExecutionFiles() + reader.update() + # On the sub-thread, keep writing and reading new Execution protos. + write_and_update_thread = threading.Thread(target=write_and_update_job) + write_and_update_thread.start() + # On the main thread, do concurrent random read. + while True: + exec_digests = reader.executions(digest=True) + if exec_digests: + exec_0 = reader.read_execution(exec_digests[0]) + self.assertEqual(exec_0.op_type, "OpType0") + writer_state["done"] = True + break + else: + time.sleep(0.1) + continue + write_and_update_thread.join() + + def testConcurrentExecutionRandomReads(self): + circular_buffer_size = -1 + writer = debug_events_writer.DebugEventsWriter(self.dump_root, + circular_buffer_size) + + for i in range(100): + execution = debug_event_pb2.Execution() + execution.op_type = "OpType%d" % i + writer.WriteExecution(execution) + writer.FlushNonExecutionFiles() + writer.FlushExecutionFiles() + + reader = debug_events_reader.DebugDataReader(self.dump_root) + reader.update() + executions = [None] * 100 + def read_job_1(): + execution_digests = reader.executions(digest=True) + # Read in the reverse order to enhance randomness of the read access. + for i in range(49, -1, -1): + execution = reader.read_execution(execution_digests[i]) + executions[i] = execution + def read_job_2(): + execution_digests = reader.executions(digest=True) + for i in range(99, 49, -1): + execution = reader.read_execution(execution_digests[i]) + executions[i] = execution + thread_1 = threading.Thread(target=read_job_1) + thread_2 = threading.Thread(target=read_job_2) + thread_1.start() + thread_2.start() + thread_1.join() + thread_2.join() + for i in range(100): + self.assertEqual(executions[i].op_type, "OpType%d" % i) + + def testConcurrentGraphExecutionTraceUpdateAndRandomRead(self): + circular_buffer_size = -1 + writer = debug_events_writer.DebugEventsWriter(self.dump_root, + circular_buffer_size) + debugged_graph = debug_event_pb2.DebuggedGraph(graph_id="graph1", + graph_name="graph1") + writer.WriteDebuggedGraph(debugged_graph) + + writer_state = {"counter": 0, "done": False} + + with debug_events_reader.DebugDataReader(self.dump_root) as reader: + def write_and_update_job(): + while True: + if writer_state["done"]: + break + op_name = "Op%d" % writer_state["counter"] + graph_op_creation = debug_event_pb2.GraphOpCreation( + op_type="FooOp", op_name=op_name, graph_id="graph1") + writer.WriteGraphOpCreation(graph_op_creation) + trace = debug_event_pb2.GraphExecutionTrace( + op_name=op_name, tfdbg_context_id="graph1") + writer.WriteGraphExecutionTrace(trace) + writer_state["counter"] += 1 + writer.FlushNonExecutionFiles() + writer.FlushExecutionFiles() + reader.update() + # On the sub-thread, keep writing and reading new GraphExecutionTraces. + write_and_update_thread = threading.Thread(target=write_and_update_job) + write_and_update_thread.start() + # On the main thread, do concurrent random read. + while True: + digests = reader.graph_execution_traces(digest=True) + if digests: + trace_0 = reader.read_graph_execution_trace(digests[0]) + self.assertEqual(trace_0.op_name, "Op0") + writer_state["done"] = True + break + else: + time.sleep(0.1) + continue + write_and_update_thread.join() + + def testConcurrentGraphExecutionTraceRandomReads(self): + circular_buffer_size = -1 + writer = debug_events_writer.DebugEventsWriter(self.dump_root, + circular_buffer_size) + debugged_graph = debug_event_pb2.DebuggedGraph(graph_id="graph1", + graph_name="graph1") + writer.WriteDebuggedGraph(debugged_graph) + + for i in range(100): + op_name = "Op%d" % i + graph_op_creation = debug_event_pb2.GraphOpCreation( + op_type="FooOp", op_name=op_name, graph_id="graph1") + writer.WriteGraphOpCreation(graph_op_creation) + trace = debug_event_pb2.GraphExecutionTrace( + op_name=op_name, tfdbg_context_id="graph1") + writer.WriteGraphExecutionTrace(trace) + writer.FlushNonExecutionFiles() + writer.FlushExecutionFiles() + + reader = debug_events_reader.DebugDataReader(self.dump_root) + reader.update() + traces = [None] * 100 + def read_job_1(): + digests = reader.graph_execution_traces(digest=True) + for i in range(49, -1, -1): + traces[i] = reader.read_graph_execution_trace(digests[i]) + def read_job_2(): + digests = reader.graph_execution_traces(digest=True) + for i in range(99, 49, -1): + traces[i] = reader.read_graph_execution_trace(digests[i]) + thread_1 = threading.Thread(target=read_job_1) + thread_2 = threading.Thread(target=read_job_2) + thread_1.start() + thread_2.start() + thread_1.join() + thread_2.join() + for i in range(100): + self.assertEqual(traces[i].op_name, "Op%d" % i) + class DataObjectsTest(test_util.TensorFlowTestCase): diff --git a/tensorflow/python/distribute/cluster_resolver/README.slurm b/tensorflow/python/distribute/cluster_resolver/README.slurm deleted file mode 100644 index 3a7675f250d..00000000000 --- a/tensorflow/python/distribute/cluster_resolver/README.slurm +++ /dev/null @@ -1,50 +0,0 @@ -# Slurm Cluster Resolver - -The Slurm Cluster Resolver resolves cluster specification for distribution TensorFlow work launched on HPC system running on Slurm. This implementation is able to handle homogeneous task allocation on computing nodes with default task distribution plane. The resolution is done by determining job configuration through a number of Slurm output variables and user input. The resolver requires the specification of total number of tasks launched, process ID/rank of the running process, number of tasks launched per node, number of GPUs present on each node and the number of GPUs to allocate for each task. - -The process ID/rank is extracted from environment variable ```SLURM_PROCID``` and the total number of tasks launched is extract from ```SLURM_NTASKS```. The number of tasks per node is extracted from ```SLURM_NTASKS_PER_NODE```, unless a value is specified by user. The number of GPUs present on each node and number of GPUs for each task have to be specified by the user. A base port can be specified by user and in case there are more than one task launched per node the port number will be incremented for each additional tasks on that node. The hostnames are resolved by running command ```scontrol show hostname``` through a subprocess and a list of hostnames will be returned. The distribution of rank/process ID by default follows that order. By default allocated GPUs will be automatically exposed to processes according to specification by setting ```CUDA_VISIBLE_DEVICE```. - -## Example -- Slurm allocation in shell ```salloc --nodes=2 -t 01:30:00 -A --ntasks-per-node=2 --gres=gpu:k80:2 --exclusive``` -- Creating cluster in Python -``` -cluster_resolver = tf.contrib.cluster_resolver.SlurmClusterResolver( - {'ps': 1, 'worker': 3}, - port_base=8888, - tasks_per_node=2, - gpus_per_node=2, - gpus_per_task=1, - auto_set_gpu=True) - -cluster = cluster_resolver.cluster_spec() -job_name, task_index = cluster_resolver.get_task_info() -``` -The above example resolves a cluster specification for a Slurm job allocation with two computing nodes each having two GPUs and two tasks will be launched on each node. The jobs are specified in form of a dictionary where the key is a string representing job name and value is an integer that specifies the number of tasks in that job. ```cluster_resolver.cluster_spec()``` will return a cluster specificaiton object and the cluster specification will have the following specification as protobuf. - -``` -job { - name: "ps" - tasks { - value: "t02n13:8888" - } -} -job { - name: "worker" - tasks { - value: "t02n13:8889" - } - tasks { - key: 1 - value: "t02n41:8888" - } - tasks { - key: 2 - value: "t02n41:8889" - } -} -``` - -After calling ```cluster_resolver.cluster_spec()``` internal data structions of the resolver will be populated. By looking at the process ID/rank and comparing with cluster specification the task can 'realize' which task it belongs to. This can be retrieved by calling ```cluster_resolver.get_task_info()``` and a string specifying job name and an integer specifying the task index will be returned. - -GPUs will be automatically allocated to the processes. For example in the above example ``` -t02n41:8888``` will see GPU 0 and ```t02n41:8889``` will see GPU 1. diff --git a/tensorflow/python/distribute/cluster_resolver/README_Slurm.md b/tensorflow/python/distribute/cluster_resolver/README_Slurm.md new file mode 100644 index 00000000000..af89a1ce124 --- /dev/null +++ b/tensorflow/python/distribute/cluster_resolver/README_Slurm.md @@ -0,0 +1,97 @@ +# Slurm Cluster Resolver + +The Slurm Cluster Resolver resolves cluster specification for distributing +TensorFlow work launched on HPC systems running on Slurm. This implementation is +able to handle homogeneous and heterogeneous tasks as long as the number of GPUs +per node and task are the same. This means on nodes with 4 GPUs each it will be +possible to allocate 4 processes on node A and only 2 on node B. The resolution +is done by determining job configuration through a number of Slurm variables and +can be overwritten by user input. By default everything is determined from the +Slurm environment, hence for most uses case no manual setting of parameters is +required. + +## How it works + +`SlurmClusterResolver` reads the environment variables that are set inside a job +step launched by Slurm. This means it will only work correctly for applications +launched via `srun`. + +The process ID/rank is extracted from environment variable `SLURM_PROCID` and +the total number of tasks launched is extracted from `SLURM_STEP_NUM_TASKS`. The +hostnames are resolved by inspection `SLURM_STEP_NODELIST`. The number of tasks +per node is extracted from `SLURM_STEP_TASKS_PER_NODE`, unless a value is +specified by user. By using this variable heterogeneous task distributions are +possible. The user can set `tasks_per_node` to a single integer for homogeneous +tasks or a dictionary mapping node names to number of tasks for heterogeneous +distributions. However setting this is **NOT** recommended as there is a chance +it makes `SLURM_PROCID` be wrong. + +A base port can be specified by user and in case there are more than one task +launched per node the port number will be incremented for each additional tasks +on that node. However a reasonable default is used. + +The number of GPUs present on each node and number of GPUs for each tasks are +automatically detected. This is done by checking for `CUDA_VISIBLE_DEVICES` +first (which is set by Slurm to a list of GPUs for the current node) and has a +fallback to using `nvidia-smi`. If this doesn't work or non-NVIDIA GPUs are used +those 2 values have to be specified by the user. By default allocated GPUs will +be automatically exposed to processes according to specification by setting +`CUDA_VISIBLE_DEVICES`. + +## Basic example + +- Slurm allocation in shell `salloc --nodes=2 -t 01:30:00 --ntasks-per-node=2 + --gres=gpu:k80:4 --exclusive` +- Run the example `srun python tf_example.py` +- Creating cluster in Python `import tensorflow as tf cluster_resolver = + tf.distribute.cluster_resolver.SlurmClusterResolver() strategy = + tf.distribute.experimental.MultiWorkerMirroredStrategy(cluster_resolver=cluster_resolver) + with strategy.scope(): # Load and compile model and data` + +The above example will allocate 4 jobs on 2 nodes with each node having 2 jobs +and 4 GPUs. `cluster_resolver.cluster_spec()` will return a cluster +specification object in protobuf format with the following value (host names may +vary): `job { name: "worker" tasks { key: 0 value: "t02n13:8888" } tasks { key: +1 value: "t02n13:8889" } tasks { key: 2 value: "t02n41:8888" } tasks { key: 3 +value: "t02n41:8889" } }` + +The `job_name` will be `worker` for all nodes and `task_index` will be `0` to +`3`. Also GPUs will be allocated automatically, so the first job on each node +will see GPU 0 and 1, and the second GPU 2 and 3. + +## Advanced example + +- Assuming the same job parameters (`salloc` & `srun`) as above +- Creating cluster in Python ``` cluster_resolver = + tf.contrib.cluster_resolver.SlurmClusterResolver( {'ps': 1, 'worker': 3}, + port_base=1337, tasks_per_node=2, gpus_per_node=2, gpus_per_task=1, + auto_set_gpu=False) + +cluster = cluster_resolver.cluster_spec() job_name, task_index = +cluster_resolver.get_task_info() ``` + +In this case 1 parameter server job and 3 worker jobs are used. The resulting +protobuf specification will look similar to this: `job { name: "ps" tasks { key: +0 value: "t02n13:1337" } } job { name: "worker" tasks { key: 0 value: +"t02n13:1338" } tasks { key: 1 value: "t02n41:1337" } tasks { key: 2 value: +"t02n41:1338" } }` + +The value of `job_name` will be `ps` for `t02n13:1337` and `worker` for all +others. There will be no GPU allocation done by the cluster resolver, so this +has to be done manually which is useful if e.g. GPUs 0 should go to the first +process and GPU 3 to the second process on each node. Also note that only 1 GPU +will be used per task. + +## Extension points + +The class `SlurmClusterResolver` provides some methods that are meant to be +overwritten by deriving classes: + +- `_resolve_own_rank` +- `_resolve_num_tasks` +- `_resolve_hostlist` +- `_resolve_task_configuration` + + Those can be used to implement a cluster resolver that gets information from + a different source, e.g. via MPI, a file or other environment variables. See + the documentation of these methods on what to return. diff --git a/tensorflow/python/distribute/cluster_resolver/slurm_cluster_resolver.py b/tensorflow/python/distribute/cluster_resolver/slurm_cluster_resolver.py index 62759dd0853..3b9f8a259dd 100644 --- a/tensorflow/python/distribute/cluster_resolver/slurm_cluster_resolver.py +++ b/tensorflow/python/distribute/cluster_resolver/slurm_cluster_resolver.py @@ -201,7 +201,7 @@ class SlurmClusterResolver(ClusterResolver): - SLURM_PROCID - (opt) SLURM_STEP_NUM_TASKS - (opt) SLURM_STEP_NODELIST - - (opt) SLURM_TASKS_PER_NODE + - (opt) SLURM_STEP_TASKS_PER_NODE Args: jobs: Dictionary with job names as key and number of tasks in the job as diff --git a/tensorflow/python/distribute/combinations.py b/tensorflow/python/distribute/combinations.py index e9a45f0cc10..ffa03ee5329 100644 --- a/tensorflow/python/distribute/combinations.py +++ b/tensorflow/python/distribute/combinations.py @@ -181,14 +181,13 @@ class NamedTPUCombination(combinations_lib.TestCombination): if not number_of_required_tpus and TPUCombination.TPU_TEST: return (False, "Test that doesn't require TPUs.") - elif number_of_required_tpus and not TPUCombination.TPU_TEST: + if number_of_required_tpus and not TPUCombination.TPU_TEST: return (False, "Test requires a TPU, but it's not available.") - elif use_cloud_tpu and not tpu: + if use_cloud_tpu and not tpu: return (False, "Test requires a Cloud TPU, but none specified.") - elif not use_cloud_tpu and tpu: + if not use_cloud_tpu and tpu: return (False, "Test requires local TPU, but Cloud TPU specified.") - else: - return (True, None) + return (True, None) def parameter_modifiers(self): return [ diff --git a/tensorflow/python/distribute/custom_training_loop_models_test.py b/tensorflow/python/distribute/custom_training_loop_models_test.py index 2e9a8db5bc8..3c748bd7364 100644 --- a/tensorflow/python/distribute/custom_training_loop_models_test.py +++ b/tensorflow/python/distribute/custom_training_loop_models_test.py @@ -214,6 +214,42 @@ class KerasModelsTest(test.TestCase, parameterized.TestCase): train_step(input_iterator) + @combinations.generate( + combinations.combine( + distribution=strategy_combinations.all_strategies, + mode=["eager"] + )) + def test_batch_norm_with_dynamic_batch(self, distribution): + inputs = np.zeros((10, 3, 3, 3), dtype=np.float32) + targets = np.zeros((10, 4), dtype=np.float32) + dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) + dataset = dataset.repeat() + dataset = dataset.batch(10, drop_remainder=False) + input_iterator = iter(distribution.experimental_distribute_dataset(dataset)) + + with distribution.scope(): + x = keras.layers.Input(shape=(3, 3, 3), name="input") + y = keras.layers.BatchNormalization(fused=True, name="bn")(x) + y = keras.layers.Flatten()(y) + y = keras.layers.Dense(4, name="dense")(y) + model = keras.Model(x, y) + optimizer = keras.optimizer_v2.rmsprop.RMSprop() + + @def_function.function + def train_step(iterator): + def step_fn(inputs): + images, targets = inputs + with backprop.GradientTape() as tape: + outputs = model(images, training=True) + loss = math_ops.reduce_sum(outputs - targets) + grads = tape.gradient(loss, model.variables) + optimizer.apply_gradients(zip(grads, model.variables)) + return loss + + distribution.run(step_fn, args=(next(iterator),)) + + train_step(input_iterator) + @combinations.generate( combinations.combine( distribution=strategy_combinations.all_strategies, diff --git a/tensorflow/python/distribute/multi_process_runner.py b/tensorflow/python/distribute/multi_process_runner.py index 6d4d9b38063..e091f0da5e2 100644 --- a/tensorflow/python/distribute/multi_process_runner.py +++ b/tensorflow/python/distribute/multi_process_runner.py @@ -36,6 +36,13 @@ from tensorflow.python.distribute import multi_process_lib from tensorflow.python.eager import context from tensorflow.python.platform import test +# pylint: disable=g-import-not-at-top +try: + # `faulthandler` is not available in py2. + import faulthandler +except ImportError: + faulthandler = None + # _ProcessStatusInfo contains process status information. When is_successful # attribute is True, the subprocess has ended successfully, or if False, the # exception stack trace info is stored in exc_info to pass on to parent process @@ -366,7 +373,7 @@ class MultiProcessRunner(object): break return list_to_return - def join(self, timeout=None): + def join(self, timeout=250): """Joins all the processes with timeout. Args: @@ -397,6 +404,9 @@ class MultiProcessRunner(object): if self._all_forced_terminated: break if time.time() - start_time > timeout: + # Send SIGTERM signal to subprocesses to dump their current + # stack trace. + self.terminate_all(sig=signal.SIGTERM) # If none of those did, report timeout to user. raise RuntimeError('One or more subprocesses timed out. ' 'Number of outstanding subprocesses ' @@ -435,8 +445,11 @@ class MultiProcessRunner(object): _resource(PARENT_TO_SUB_QUEUE).put('terminate {} {}'.format( task_type, task_id)) - def terminate_all(self): + def terminate_all(self, sig=None): """Terminates all subprocesses.""" + # Use SIGKILL as default. In systems where that's unavailable such as + # windows, use SIGTERM. + sig = sig or getattr(signal, 'SIGKILL', signal.SIGTERM) subprocess_infos = [] while True: @@ -449,7 +462,7 @@ class MultiProcessRunner(object): for subprocess_info in subprocess_infos: logging.info('Parent process is now killing PID: %d', subprocess_info.pid) try: - os.kill(subprocess_info.pid, signal.SIGKILL) + os.kill(subprocess_info.pid, sig) except ProcessLookupError: # TODO(rchao): Remove subprocess info from the queue once a subprocess # is terminated. @@ -510,11 +523,14 @@ class _Subprocess(object): *arg, **kwargs): """The wrapper function that actually gets run in child process(es).""" + if faulthandler is not None: + faulthandler.enable() + faulthandler.register(signal.SIGTERM, chain=True) + pid = os.getpid() logging.info('Subprocess with PID %d (%s, %d) is now being started.', pid, task_type, task_id) _resource(SUBPROCESS_INFO_QUEUE).put(_SubprocessInfo(pid=pid)) - # Assign sys.stdout and sys.stderr as duplicates of `pipe_w` so print() and # logging.*() write directly to `pipe_w`. Unfortunately since we cannot # prepend task_type and task_id information to the streamed logs we will @@ -605,7 +621,7 @@ def run(proc_func, grpc_fail_fast=None, stream_stdout=True, list_stdout=False, - timeout=None, + timeout=250, args=None, kwargs=None): # pylint: disable=g-doc-args """Runs functions in local child processes. diff --git a/tensorflow/python/distribute/zero_batch_test.py b/tensorflow/python/distribute/zero_batch_test.py index e590d815459..b41611a91e0 100644 --- a/tensorflow/python/distribute/zero_batch_test.py +++ b/tensorflow/python/distribute/zero_batch_test.py @@ -21,6 +21,7 @@ from __future__ import print_function from absl.testing import parameterized import numpy as np +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.distribute import combinations from tensorflow.python.distribute import strategy_combinations from tensorflow.python.eager import backprop @@ -158,5 +159,53 @@ class NormalizationTest(test.TestCase, parameterized.TestCase): self.assertAllEqual(np.zeros(shape=(0, 4, 4, 3), dtype=np.float32), test_step().numpy()) + @combinations.generate( + combinations.combine( + distribution=[ + strategy_combinations.one_device_strategy, + ], + mode=["eager"], + fused=[True, False])) + def testBNWithDynamicBatchInputEager(self, distribution, fused): + distribution.extended.experimental_enable_get_next_as_optional = True + with distribution.scope(): + # Explicitly create dataset with drop_remainder=False. + # This would make batch size unknown. + inputs = np.random.random((11, 4, 4, 3)).astype(np.float32) + 100 + targets = np.random.random((11, 4, 4, 3)).astype(np.float32) + dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)).batch( + 10, drop_remainder=False).repeat() + dataset_iterator = iter( + distribution.experimental_distribute_dataset(dataset)) + + bn = normalization.BatchNormalization( + axis=-1, epsilon=1e-3, momentum=0.9, fused=fused) + optimizer = gradient_descent.GradientDescentOptimizer(0.01) + + @def_function.function + def train_step(iterator): + + def step_fn(inputs): + features, targets = inputs + with backprop.GradientTape() as tape: + outputs = bn(features, training=True) + loss = losses.mean_squared_error(targets, outputs) + + grads = tape.gradient(loss, bn.variables) + optimizer.apply_gradients(zip(grads, bn.variables)) + return loss + + return distribution.run(step_fn, args=(next(iterator),)) + + for _ in range(100): + train_step(dataset_iterator).numpy() + + # Verify that the statistics and weights are updated. + self.assertNotAllEqual(np.ndarray([0, 0, 0]), bn.moving_mean.numpy()) + self.assertNotAllEqual(np.ndarray([1, 1, 1]), bn.moving_variance.numpy()) + self.assertNotAllEqual(np.ndarray([1, 1, 1]), bn.gamma.numpy()) + self.assertNotAllEqual(np.ndarray([0, 0, 0]), bn.beta.numpy()) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD index 9c229974b05..e5d7a9cfe50 100644 --- a/tensorflow/python/eager/BUILD +++ b/tensorflow/python/eager/BUILD @@ -1,5 +1,8 @@ load("//tensorflow:tensorflow.bzl", "tf_py_test") load("//tensorflow:tensorflow.bzl", "cuda_py_test") + +# buildifier: disable=same-origin-load +load("//tensorflow:tensorflow.bzl", "tf_python_pybind_extension") load("//tensorflow/python/tpu:tpu.bzl", "tpu_py_test") load( "//tensorflow/tools/test:performance.bzl", @@ -163,6 +166,47 @@ py_library( ], ) +tf_python_pybind_extension( + name = "custom_device_testutil", + testonly = True, + srcs = ["custom_device_testutil.cc"], + module_name = "custom_device_testutil", + deps = [ + "//tensorflow/c:c_api", + "//tensorflow/c:tf_status_helper", + "//tensorflow/c/eager:c_api", + "//tensorflow/c/eager:c_api_experimental", + "//tensorflow/c/eager:custom_device_testutil", + "//tensorflow/python:cpp_python_util", + "//tensorflow/python:pybind11_lib", + "//tensorflow/python:pybind11_status", + "//tensorflow/python:safe_ptr", + "//third_party/python_runtime:headers", + "@pybind11", + ], +) + +py_test( + name = "custom_device_test", + size = "small", + srcs = ["custom_device_test.py"], + python_version = "PY3", + # Note that this currently only works with --config=monolithic, since it + # requires the C API which runs static initializers again. + # + # TODO(allenl): Figure out a way to allow extensions to register custom + # devices which works with dynamic linking. + tags = [ + "no_oss", + "no_pip", + ], + deps = [ + ":context", + ":custom_device_testutil", + ":test", + ], +) + cuda_py_test( name = "context_test", size = "small", @@ -575,13 +619,11 @@ cuda_py_test( ":context", ":forwardprop", ":function", - ":profiler", ":remote", ":test", "//tensorflow/python:math_ops", "//tensorflow/python:pywrap_tfe", "//tensorflow/python:random_ops", - "//tensorflow/python/keras", "//third_party/py/numpy", ], ) diff --git a/tensorflow/python/eager/benchmarks_test.py b/tensorflow/python/eager/benchmarks_test.py index 31e821068d3..1ecbcdfa4f7 100644 --- a/tensorflow/python/eager/benchmarks_test.py +++ b/tensorflow/python/eager/benchmarks_test.py @@ -38,16 +38,13 @@ import numpy as np import six from six.moves import xrange # pylint: disable=redefined-builtin -from tensorflow.python import keras from tensorflow.python import pywrap_tfe -from tensorflow.python.data.ops import dataset_ops from tensorflow.python.eager import backprop # pylint: disable=unused-import from tensorflow.python.eager import context from tensorflow.python.eager import core from tensorflow.python.eager import def_function from tensorflow.python.eager import forwardprop from tensorflow.python.eager import function -from tensorflow.python.eager import profiler from tensorflow.python.eager import test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -62,7 +59,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import resource_variable_ops -from tensorflow.python.training import gradient_descent + CPU = "/device:CPU:0" GPU = "/device:GPU:0" @@ -90,60 +87,6 @@ def c_tfe_py_fastpath_execute(a, six.raise_from(core._status_to_exception(e.code, message), None) -class SubclassedKerasModel(keras.Model): - - def __init__(self, initializer="ones"): - super(SubclassedKerasModel, self).__init__() - self.layer_a = keras.layers.Dense( - 64, kernel_initializer=initializer, bias_initializer="zeros") - self.layer_b = keras.layers.Dense( - 128, kernel_initializer=initializer, bias_initializer="zeros") - self.layer_c = keras.layers.Dense( - 256, kernel_initializer=initializer, bias_initializer="zeros") - self.layer_d = keras.layers.Dense( - 256, kernel_initializer=initializer, bias_initializer="zeros") - self.layer_e = keras.layers.Dense( - 10, kernel_initializer=initializer, bias_initializer="zeros") - - def call(self, x): - x = self.layer_a(x) - x = self.layer_b(x) - x = self.layer_c(x) - x = self.layer_d(x) - return self.layer_e(x) - - -def make_keras_model(initializer="ones"): - model_input = keras.Input(shape=(10,)) - x = keras.layers.Dense( - 64, kernel_initializer=initializer, bias_initializer="zeros")(model_input) - x = keras.layers.Dense( - 128, kernel_initializer=initializer, bias_initializer="zeros")(x) - x = keras.layers.Dense( - 256, kernel_initializer=initializer, bias_initializer="zeros")(x) - x = keras.layers.Dense( - 256, kernel_initializer=initializer, bias_initializer="zeros")(x) - x = keras.layers.Dense( - 10, kernel_initializer=initializer, bias_initializer="zeros")(x) - return keras.Model(inputs=model_input, outputs=x) - - -def make_sequential_keras_model(initializer="ones"): - model = keras.models.Sequential() - model.add(keras.layers.Dense( - 64, kernel_initializer=initializer, bias_initializer="zeros", - input_shape=(10,))) - model.add(keras.layers.Dense( - 128, kernel_initializer=initializer, bias_initializer="zeros")) - model.add(keras.layers.Dense( - 256, kernel_initializer=initializer, bias_initializer="zeros")) - model.add(keras.layers.Dense( - 256, kernel_initializer=initializer, bias_initializer="zeros")) - model.add(keras.layers.Dense( - 10, kernel_initializer=initializer, bias_initializer="zeros")) - return model - - def run_benchmark(func, num_iters, execution_mode=None): ctx = context.context() with context.execution_mode(execution_mode): @@ -164,12 +107,15 @@ def run_benchmark(func, num_iters, execution_mode=None): class MicroBenchmarks(test.Benchmark): def __init__(self): - # used for multiply benchmarks - self._m_2 = random_ops.random_uniform([2]) + # TODO(b/153054118): Add tf.RandomUniform + if not context.is_tfrt_enabled(): + # used for multiply benchmarks + self._m_2 = random_ops.random_uniform([2]) + + # used for matmul benchmarks + self._m_2_by_2 = random_ops.random_uniform((2, 2)) + self._m_100_by_784 = random_ops.random_uniform((100, 784)) - # used for matmul benchmarks - self._m_2_by_2 = random_ops.random_uniform((2, 2)) - self._m_100_by_784 = random_ops.random_uniform((100, 784)) self._num_iters_2_by_2 = 30000 self._num_iters_100_by_784 = 30000 @@ -1121,180 +1067,6 @@ class MicroBenchmarks(test.Benchmark): self._benchmark_read_variable_with_tape( m, num_iters=self._num_iters_2_by_2) - def benchmark_keras_model_subclassed(self): - model = SubclassedKerasModel() - data = random_ops.random_uniform((10, 10)) - - func = lambda: model(data) - # First call is more expensive (creates variables etc.), discount that. - func() - - # The whole point of this test is to contrast subclassing with - # the functional style of keras model building, so validate that - # the models are equivalent. - assert np.equal(func(), make_keras_model()(data)).all() - - self._run(func, 30000) - - def benchmark_keras_model_functional(self): - model = make_keras_model() - data = random_ops.random_uniform((10, 10)) - func = lambda: model(data) - # Symmetry with benchmark_keras_model_subclassed - func() - assert np.equal(func(), SubclassedKerasModel()(data)).all() - self._run(func, 30000) - - def benchmark_keras_model_sequential(self): - model = make_sequential_keras_model() - data = random_ops.random_uniform((10, 10)) - func = lambda: model(data) - # Symmetry with benchmark_keras_model_functional - func() - assert np.equal(func(), make_keras_model()(data)).all() - self._run(func, 30000) - - def _benchmark_keras_model_fit(self, model, run_eagerly=False): - data = random_ops.random_uniform((10, 10), minval=-1, maxval=1) - labels = random_ops.random_uniform((10, 10), minval=-1, maxval=1) - dataset = dataset_ops.Dataset.from_tensors((data, labels)).repeat() - model.compile( - gradient_descent.GradientDescentOptimizer(learning_rate=0.001), - loss="mse", run_eagerly=run_eagerly) - func = lambda: model.fit(dataset, epochs=1, steps_per_epoch=1000, verbose=0) - # First call is more expensive (creates variables etc.), discount that. - model.fit(dataset, epochs=1, steps_per_epoch=1, verbose=0) - - self._run(func, 1) - - def _benchmark_keras_model_evaluate(self, model, run_eagerly=False): - data = random_ops.random_uniform((10, 10), minval=-1, maxval=1) - labels = random_ops.random_uniform((10, 10), minval=-1, maxval=1) - dataset = dataset_ops.Dataset.from_tensors((data, labels)).repeat() - model.compile( - gradient_descent.GradientDescentOptimizer(learning_rate=0.001), - loss="mse", run_eagerly=run_eagerly) - func = lambda: model.evaluate(dataset, steps=1000, verbose=0) - # First call is more expensive (creates variables etc.), discount that. - model.evaluate(dataset, steps=1, verbose=0) - - self._run(func, 1) - - def _benchmark_keras_model_predict(self, model, run_eagerly=False): - data = random_ops.random_uniform((10, 10), minval=-1, maxval=1) - dataset = dataset_ops.Dataset.from_tensors(data).repeat() - model.compile( - gradient_descent.GradientDescentOptimizer(learning_rate=0.001), - loss="mse", run_eagerly=run_eagerly) - func = lambda: model.predict(dataset, steps=1000, verbose=0) - # First call is more expensive (creates variables etc.), discount that. - model.predict(dataset, steps=1, verbose=0) - - self._run(func, 1) - - def benchmark_keras_model_subclassed_fit(self): - model = SubclassedKerasModel(initializer="glorot_uniform") - self._benchmark_keras_model_fit(model) - - def benchmark_keras_model_subclassed_fit_graph_mode(self): - with context.graph_mode(): - model = SubclassedKerasModel(initializer="glorot_uniform") - self._benchmark_keras_model_fit(model) - - def benchmark_keras_model_subclassed_fit_run_model_eagerly(self): - model = SubclassedKerasModel(initializer="glorot_uniform") - self._benchmark_keras_model_fit(model, run_eagerly=True) - - def benchmark_keras_model_functional_fit(self): - model = make_keras_model(initializer="glorot_uniform") - self._benchmark_keras_model_fit(model) - - def benchmark_keras_model_functional_fit_graph_mode(self): - with context.graph_mode(): - model = make_keras_model(initializer="glorot_uniform") - self._benchmark_keras_model_fit(model) - - def benchmark_keras_model_functional_fit_graph_mode_with_profiler(self): - profiler.start() - with context.graph_mode(): - model = make_keras_model(initializer="glorot_uniform") - self._benchmark_keras_model_fit(model) - result = profiler.stop() - assert result is not None - - def benchmark_keras_model_functional_fit_run_model_eagerly(self): - model = make_keras_model(initializer="glorot_uniform") - self._benchmark_keras_model_fit(model, run_eagerly=True) - - def benchmark_keras_model_functional_fit_run_model_eagerly_with_profiler( - self): - profiler.start() - model = make_keras_model(initializer="glorot_uniform") - self._benchmark_keras_model_fit(model, run_eagerly=True) - result = profiler.stop() - assert result is not None - - def benchmark_keras_model_sequential_fit(self): - model = make_sequential_keras_model(initializer="glorot_uniform") - self._benchmark_keras_model_fit(model) - - def benchmark_keras_model_sequential_fit_graph_mode(self): - with context.graph_mode(): - model = make_sequential_keras_model(initializer="glorot_uniform") - self._benchmark_keras_model_fit(model) - - def benchmark_keras_model_sequential_fit_run_model_eagerly(self): - model = make_sequential_keras_model(initializer="glorot_uniform") - self._benchmark_keras_model_fit(model, run_eagerly=True) - - def benchmark_keras_model_subclassed_evaluate(self): - model = SubclassedKerasModel(initializer="glorot_uniform") - self._benchmark_keras_model_evaluate(model) - - def benchmark_keras_model_subclassed_evaluate_run_model_eagerly(self): - model = SubclassedKerasModel(initializer="glorot_uniform") - self._benchmark_keras_model_evaluate(model, run_eagerly=True) - - def benchmark_keras_model_functional_evaluate(self): - model = make_keras_model(initializer="glorot_uniform") - self._benchmark_keras_model_evaluate(model) - - def benchmark_keras_model_functional_evaluate_run_model_eagerly(self): - model = make_keras_model(initializer="glorot_uniform") - self._benchmark_keras_model_evaluate(model, run_eagerly=True) - - def benchmark_keras_model_sequential_evaluate(self): - model = make_sequential_keras_model(initializer="glorot_uniform") - self._benchmark_keras_model_evaluate(model) - - def benchmark_keras_model_sequential_evaluate_run_model_eagerly(self): - model = make_sequential_keras_model(initializer="glorot_uniform") - self._benchmark_keras_model_evaluate(model, run_eagerly=True) - - def benchmark_keras_model_subclassed_predict(self): - model = SubclassedKerasModel(initializer="glorot_uniform") - self._benchmark_keras_model_predict(model) - - def benchmark_keras_model_subclassed_predict_run_model_eagerly(self): - model = SubclassedKerasModel(initializer="glorot_uniform") - self._benchmark_keras_model_predict(model, run_eagerly=True) - - def benchmark_keras_model_functional_predict(self): - model = make_keras_model(initializer="glorot_uniform") - self._benchmark_keras_model_predict(model) - - def benchmark_keras_model_functional_predict_run_model_eagerly(self): - model = make_keras_model(initializer="glorot_uniform") - self._benchmark_keras_model_predict(model, run_eagerly=True) - - def benchmark_keras_model_sequential_predict(self): - model = make_sequential_keras_model(initializer="glorot_uniform") - self._benchmark_keras_model_predict(model) - - def benchmark_keras_model_sequential_predict_run_model_eagerly(self): - model = make_sequential_keras_model(initializer="glorot_uniform") - self._benchmark_keras_model_predict(model, run_eagerly=True) - def benchmarkScan(self): elems = math_ops.range(1600) diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py index 05f8a4ea066..eb928614817 100644 --- a/tensorflow/python/eager/context.py +++ b/tensorflow/python/eager/context.py @@ -1116,6 +1116,13 @@ class Context(object): return function_def + def register_custom_device(self, device_capsule, device_name, + device_info_capsule): + """Calls TFE_RegisterCustomDevice. See the non-member function.""" + self.ensure_initialized() + pywrap_tfe.TFE_Py_RegisterCustomDevice(self._handle, device_capsule, + device_name, device_info_capsule) + def remove_function(self, name): """Remove a function from the context. @@ -2287,6 +2294,32 @@ def get_function_def(name): return context().get_function_def(name) +def register_custom_device(device_capsule, device_name, device_info_capsule): + """Calls TFE_RegisterCustomDevice to register a custom device with Python. + + Enables using C extensions specifying a custom device from Python. See the + experimental eager C API in tensorflow/c/eager/c_api_experimental.h for + details. + + Note that custom devices are not currently supported inside `tf.function`s. + + Args: + device_capsule: A PyCapsule with the name set to 'TFE_CustomDevice' + containing a pointer to a TFE_CustomDevice struct. The capsule retains + ownership of the memory. + device_name: A string indicating the name to register the custom device + under, e.g. '/job:localhost/replica:0/task:0/device:CUSTOM:0'. It may + subsequently be passed to `with tf.device(...):`. + device_info_capsule: A PyCapsule with the name set to + 'TFE_CustomDevice_DeviceInfo' containing a pointer to a device-specific + struct with the initial state of the custom device (the void* device_info + argument to TFE_RegisterCustomDevice). This method takes ownership of the + memory and clears the capsule destructor. + """ + context().register_custom_device(device_capsule, device_name, + device_info_capsule) + + # Not every user creates a Context via context.context() # (for example, enable_eager_execution in python/framework/ops.py), # but they do all import this file. Note that IS_IN_GRAPH_MODE and diff --git a/tensorflow/python/eager/custom_device_test.py b/tensorflow/python/eager/custom_device_test.py new file mode 100644 index 00000000000..9a24383a13c --- /dev/null +++ b/tensorflow/python/eager/custom_device_test.py @@ -0,0 +1,50 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.eager import context +from tensorflow.python.eager import custom_device_testutil +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import errors +from tensorflow.python.framework import ops +from tensorflow.python.platform import test + + +class CustomDeviceTest(test.TestCase): + + def testRegisterCustomDevice(self): + device_name = '/job:localhost/replica:0/task:0/device:CUSTOM:0' + device, device_info, arrived_flag, executed_flag = ( + custom_device_testutil.GetLoggingDeviceCapsules(device_name)) + context.register_custom_device(device, device_name, device_info) + self.assertFalse(custom_device_testutil.FlagValue(arrived_flag)) + self.assertFalse(custom_device_testutil.FlagValue(executed_flag)) + with ops.device(device_name): + x = constant_op.constant(1.) + y = x * constant_op.constant(2.) + self.assertTrue(custom_device_testutil.FlagValue(executed_flag)) + # There was no copy onto the device. Actually I'm not sure how to trigger + # that from Python. + self.assertFalse(custom_device_testutil.FlagValue(arrived_flag)) + with self.assertRaisesRegexp(errors.InternalError, 'Trying to copy'): + y.numpy() + + +if __name__ == '__main__': + ops.enable_eager_execution() + test.main() diff --git a/tensorflow/python/eager/custom_device_testutil.cc b/tensorflow/python/eager/custom_device_testutil.cc new file mode 100644 index 00000000000..214c1811c13 --- /dev/null +++ b/tensorflow/python/eager/custom_device_testutil.cc @@ -0,0 +1,77 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/eager/custom_device_testutil.h" + +#include "Python.h" +#include "pybind11/pybind11.h" +#include "pybind11/stl.h" +#include "tensorflow/c/c_api.h" +#include "tensorflow/c/c_api_experimental.h" +#include "tensorflow/c/eager/c_api.h" +#include "tensorflow/c/eager/c_api_experimental.h" +#include "tensorflow/c/tf_status.h" +#include "tensorflow/c/tf_status_helper.h" +#include "tensorflow/python/lib/core/py_exception_registry.h" +#include "tensorflow/python/lib/core/pybind11_lib.h" +#include "tensorflow/python/lib/core/pybind11_status.h" +#include "tensorflow/python/lib/core/safe_ptr.h" +#include "tensorflow/python/util/util.h" + +namespace py = pybind11; + +void CallDelete_Flag(PyObject* capsule) { + delete reinterpret_cast(PyCapsule_GetPointer(capsule, "flag")); +} + +void CallDelete_Device(PyObject* capsule) { + delete reinterpret_cast( + PyCapsule_GetPointer(capsule, "TFE_CustomDevice")); +} + +void CallDelete_DeviceInfo(PyObject* capsule) { + PyErr_SetString(PyExc_AssertionError, + "Capsule should be consumed by TFE_Py_RegisterCustomDevice"); +} + +PYBIND11_MODULE(custom_device_testutil, m) { + m.def("GetLoggingDeviceCapsules", [](const char* name) { + bool* arrived_flag = new bool; + bool* executed_flag = new bool; + *arrived_flag = false; + *executed_flag = false; + tensorflow::Safe_PyObjectPtr arrived_capsule( + PyCapsule_New(arrived_flag, "flag", &CallDelete_Flag)); + tensorflow::Safe_PyObjectPtr executed_capsule( + PyCapsule_New(executed_flag, "flag", &CallDelete_Flag)); + TFE_CustomDevice* device; + void* device_info; + AllocateLoggingDevice(name, arrived_flag, executed_flag, &device, + &device_info); + tensorflow::Safe_PyObjectPtr device_capsule( + PyCapsule_New(device, "TFE_CustomDevice", &CallDelete_Device)); + tensorflow::Safe_PyObjectPtr device_info_capsule(PyCapsule_New( + device_info, "TFE_CustomDevice_DeviceInfo", &CallDelete_DeviceInfo)); + return tensorflow::pyo_or_throw( + PyTuple_Pack(4, device_capsule.get(), device_info_capsule.get(), + arrived_capsule.get(), executed_capsule.get())); + }); + m.def("FlagValue", [](py::capsule flag_capsule) { + bool* flag = reinterpret_cast( + PyCapsule_GetPointer(flag_capsule.ptr(), "flag")); + if (PyErr_Occurred()) throw py::error_already_set(); + return *flag; + }); +} diff --git a/tensorflow/python/eager/def_function_xla_jit_test.py b/tensorflow/python/eager/def_function_xla_jit_test.py index 3a311f2b2c5..f17033d126c 100644 --- a/tensorflow/python/eager/def_function_xla_jit_test.py +++ b/tensorflow/python/eager/def_function_xla_jit_test.py @@ -28,6 +28,7 @@ from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_util +from tensorflow.python.ops import math_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.platform import test @@ -254,6 +255,21 @@ class DefFunctionTest(test.TestCase): z() + def testArgMinMax(self): + + @def_function.function(experimental_compile=True) + def argmax(x): + return math_ops.argmax(x) + + @def_function.function(experimental_compile=True) + def argmin(x): + return math_ops.argmin(x) + + self.assertAllClose(0, argmax(array_ops.ones([10], dtype=dtypes.float32))) + self.assertAllClose(0, argmax(array_ops.ones([10]))) + self.assertAllClose(0, argmin(array_ops.ones([10], dtype=dtypes.float32))) + self.assertAllClose(0, argmin(array_ops.ones([10]))) + if __name__ == '__main__': ops.enable_eager_execution() diff --git a/tensorflow/python/eager/memory_tests/BUILD b/tensorflow/python/eager/memory_tests/BUILD index c9694c64694..419de91b42a 100644 --- a/tensorflow/python/eager/memory_tests/BUILD +++ b/tensorflow/python/eager/memory_tests/BUILD @@ -34,7 +34,6 @@ cuda_py_test( "//tensorflow/python:framework_test_lib", "//tensorflow/python/eager:backprop", "//tensorflow/python/eager:test", - "//tensorflow/python/keras", "@six_archive//:six", ], ) diff --git a/tensorflow/python/eager/memory_tests/memory_test.py b/tensorflow/python/eager/memory_tests/memory_test.py index ba94621f67b..ba831b5ba8c 100644 --- a/tensorflow/python/eager/memory_tests/memory_test.py +++ b/tensorflow/python/eager/memory_tests/memory_test.py @@ -24,7 +24,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python import keras from tensorflow.python.eager import backprop from tensorflow.python.eager import def_function from tensorflow.python.eager import test @@ -38,17 +37,6 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops.variables import Variable -class SingleLayerNet(keras.Model): - """Simple keras model used to ensure that there are no leaks.""" - - def __init__(self): - super(SingleLayerNet, self).__init__() - self.fc1 = keras.layers.Dense(5) - - def call(self, x): - return self.fc1(x) - - class MemoryTest(test.TestCase): def testMemoryLeakAnonymousVariable(self): @@ -61,36 +49,6 @@ class MemoryTest(test.TestCase): memory_test_util.assert_no_leak(f, num_iters=10000) - def testMemoryLeakInSimpleModelForwardOnly(self): - if not memory_test_util.memory_profiler_is_available(): - self.skipTest("memory_profiler required to run this test") - - inputs = array_ops.zeros([32, 100], dtypes.float32) - net = SingleLayerNet() - - def f(): - with backprop.GradientTape(): - net(inputs) - - memory_test_util.assert_no_leak(f) - - def testMemoryLeakInSimpleModelForwardAndBackward(self): - if not memory_test_util.memory_profiler_is_available(): - self.skipTest("memory_profiler required to run this test") - - inputs = array_ops.zeros([32, 100], dtypes.float32) - net = SingleLayerNet() - - def f(): - with backprop.GradientTape() as tape: - result = net(inputs) - - tape.gradient(result, net.variables) - - del tape - - memory_test_util.assert_no_leak(f) - def testMemoryLeakInFunction(self): if not memory_test_util.memory_profiler_is_available(): self.skipTest("memory_profiler required to run this test") diff --git a/tensorflow/python/framework/dtypes.cc b/tensorflow/python/framework/dtypes.cc index 7c8521bd2d0..d138bd07af6 100644 --- a/tensorflow/python/framework/dtypes.cc +++ b/tensorflow/python/framework/dtypes.cc @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "include/pybind11/detail/common.h" -#include "include/pybind11/pybind11.h" +#include "pybind11/detail/common.h" +#include "pybind11/pybind11.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/types.pb.h" diff --git a/tensorflow/python/framework/dtypes.py b/tensorflow/python/framework/dtypes.py index 037fa593937..73fb034f061 100644 --- a/tensorflow/python/framework/dtypes.py +++ b/tensorflow/python/framework/dtypes.py @@ -630,7 +630,8 @@ def as_dtype(type_value): try: return _ANY_TO_TF[type_value] - except KeyError: + except (KeyError, TypeError): + # TypeError indicates that type_value is not hashable. pass if hasattr(type_value, "dtype"): diff --git a/tensorflow/python/framework/memory_checker_test_helper.cc b/tensorflow/python/framework/memory_checker_test_helper.cc index f210a447867..545c1ec99cf 100644 --- a/tensorflow/python/framework/memory_checker_test_helper.cc +++ b/tensorflow/python/framework/memory_checker_test_helper.cc @@ -15,7 +15,7 @@ limitations under the License. #include -#include "include/pybind11/pybind11.h" +#include "pybind11/pybind11.h" namespace py = pybind11; diff --git a/tensorflow/python/framework/op_def_registry.cc b/tensorflow/python/framework/op_def_registry.cc index 0de2ce01b96..31cbb92560f 100644 --- a/tensorflow/python/framework/op_def_registry.cc +++ b/tensorflow/python/framework/op_def_registry.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "include/pybind11/pybind11.h" +#include "pybind11/pybind11.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_def.pb.h" #include "tensorflow/core/framework/op_def_util.h" diff --git a/tensorflow/python/framework/python_memory_checker_helper.cc b/tensorflow/python/framework/python_memory_checker_helper.cc index cd27b6dc8ec..2d7ed70f821 100644 --- a/tensorflow/python/framework/python_memory_checker_helper.cc +++ b/tensorflow/python/framework/python_memory_checker_helper.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "include/pybind11/pybind11.h" +#include "pybind11/pybind11.h" namespace py = pybind11; diff --git a/tensorflow/python/framework/python_op_gen_wrapper.cc b/tensorflow/python/framework/python_op_gen_wrapper.cc index 941843292b5..501f0f1a198 100644 --- a/tensorflow/python/framework/python_op_gen_wrapper.cc +++ b/tensorflow/python/framework/python_op_gen_wrapper.cc @@ -16,7 +16,7 @@ limitations under the License. #include -#include "include/pybind11/pybind11.h" +#include "pybind11/pybind11.h" #include "tensorflow/python/framework/python_op_gen.h" namespace py = pybind11; diff --git a/tensorflow/python/framework/tensor_shape.py b/tensorflow/python/framework/tensor_shape.py index 7c741380636..fd229b6691a 100644 --- a/tensorflow/python/framework/tensor_shape.py +++ b/tensorflow/python/framework/tensor_shape.py @@ -196,8 +196,8 @@ class Dimension(object): except AttributeError: six.raise_from( TypeError("Dimension value must be integer or None or have " - "an __index__ method, got {!r}".format(value)), - None) + "an __index__ method, got value '{0!r}' with type '{1!r}'" + .format(value, type(value))), None) if self._value < 0: raise ValueError("Dimension %d must be >= 0" % self._value) @@ -768,7 +768,18 @@ class TensorShape(object): # Treat as a singleton dimension self._dims = [as_dimension(dims)] else: - self._dims = [as_dimension(d) for d in dims_iter] + self._dims = [] + for d in dims_iter: + try: + self._dims.append(as_dimension(d)) + except TypeError as e: + six.raise_from( + TypeError( + "Failed to convert '{0!r}' to a shape: '{1!r}'" + "could not be converted to a dimension. A shape should " + "either be single dimension (e.g. 10), or an iterable of " + "dimensions (e.g. [1, 10, None])." + .format(dims, d)), e) @property def _v2_behavior(self): diff --git a/tensorflow/python/framework/type_spec.py b/tensorflow/python/framework/type_spec.py index 550678a72bb..490574bbc1b 100644 --- a/tensorflow/python/framework/type_spec.py +++ b/tensorflow/python/framework/type_spec.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function import abc +import collections import numpy as np import six @@ -398,6 +399,15 @@ class TypeSpec(object): raise ValueError("Types are not compatible: %r vs %r" % (a, b)) return tuple(TypeSpec.__most_specific_compatible_type_serialization(x, y) for (x, y) in zip(a, b)) + if isinstance(a, collections.OrderedDict): + a_keys, b_keys = a.keys(), b.keys() + if len(a) != len(b) or a_keys != b_keys: + raise ValueError("Types are not compatible: %r vs %r" % (a, b)) + return collections.OrderedDict([ + (k, + TypeSpec.__most_specific_compatible_type_serialization(a[k], b[k])) + for k in a_keys + ]) if isinstance(a, dict): a_keys, b_keys = sorted(a.keys()), sorted(b.keys()) if len(a) != len(b) or a_keys != b_keys: diff --git a/tensorflow/python/grappler/cluster_wrapper.cc b/tensorflow/python/grappler/cluster_wrapper.cc index 4cec2d9f5f9..aa762cb1dd9 100644 --- a/tensorflow/python/grappler/cluster_wrapper.cc +++ b/tensorflow/python/grappler/cluster_wrapper.cc @@ -24,8 +24,8 @@ limitations under the License. #include #include -#include "include/pybind11/pybind11.h" -#include "include/pybind11/stl.h" +#include "pybind11/pybind11.h" +#include "pybind11/stl.h" #include "tensorflow/core/framework/kernel_def.pb.h" #include "tensorflow/core/framework/memory_types.h" #include "tensorflow/core/framework/op_def.pb.h" diff --git a/tensorflow/python/grappler/cost_analyzer_wrapper.cc b/tensorflow/python/grappler/cost_analyzer_wrapper.cc index 31fc0384a1b..ce557b02e8d 100644 --- a/tensorflow/python/grappler/cost_analyzer_wrapper.cc +++ b/tensorflow/python/grappler/cost_analyzer_wrapper.cc @@ -17,7 +17,7 @@ limitations under the License. #include #include -#include "include/pybind11/pybind11.h" +#include "pybind11/pybind11.h" #include "tensorflow/core/grappler/clusters/single_machine.h" #include "tensorflow/core/grappler/grappler_item.h" #include "tensorflow/core/grappler/grappler_item_builder.h" diff --git a/tensorflow/python/grappler/graph_analyzer_tool_wrapper.cc b/tensorflow/python/grappler/graph_analyzer_tool_wrapper.cc index 2d82942d55f..834bee7831f 100644 --- a/tensorflow/python/grappler/graph_analyzer_tool_wrapper.cc +++ b/tensorflow/python/grappler/graph_analyzer_tool_wrapper.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "include/pybind11/pybind11.h" +#include "pybind11/pybind11.h" #include "tensorflow/core/grappler/graph_analyzer/graph_analyzer_tool.h" PYBIND11_MODULE(_pywrap_graph_analyzer_tool, m) { diff --git a/tensorflow/python/grappler/item_wrapper.cc b/tensorflow/python/grappler/item_wrapper.cc index d1c50f4e21a..e55b468a6ba 100644 --- a/tensorflow/python/grappler/item_wrapper.cc +++ b/tensorflow/python/grappler/item_wrapper.cc @@ -19,8 +19,8 @@ limitations under the License. #include #include -#include "include/pybind11/pybind11.h" -#include "include/pybind11/stl.h" +#include "pybind11/pybind11.h" +#include "pybind11/stl.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op.h" diff --git a/tensorflow/python/grappler/model_analyzer_wrapper.cc b/tensorflow/python/grappler/model_analyzer_wrapper.cc index d9699a69a8d..47d1ec89897 100644 --- a/tensorflow/python/grappler/model_analyzer_wrapper.cc +++ b/tensorflow/python/grappler/model_analyzer_wrapper.cc @@ -16,7 +16,7 @@ limitations under the License. #include #include -#include "include/pybind11/pybind11.h" +#include "pybind11/pybind11.h" #include "tensorflow/core/grappler/grappler_item_builder.h" #include "tensorflow/core/protobuf/meta_graph.pb.h" #include "tensorflow/python/grappler/model_analyzer.h" diff --git a/tensorflow/python/grappler/tf_optimizer_wrapper.cc b/tensorflow/python/grappler/tf_optimizer_wrapper.cc index 91aeae473c0..14336a08cf5 100644 --- a/tensorflow/python/grappler/tf_optimizer_wrapper.cc +++ b/tensorflow/python/grappler/tf_optimizer_wrapper.cc @@ -18,7 +18,7 @@ limitations under the License. #include #include -#include "include/pybind11/pybind11.h" +#include "pybind11/pybind11.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/framework/device_attributes.pb.h" diff --git a/tensorflow/python/keras/distribute/BUILD b/tensorflow/python/keras/distribute/BUILD index 2ac341976ea..7e68b9d3e67 100644 --- a/tensorflow/python/keras/distribute/BUILD +++ b/tensorflow/python/keras/distribute/BUILD @@ -187,9 +187,12 @@ distribute_py_test( # shards more evenly. shard_count = 19, tags = [ + "manual", "multi_and_single_gpu", + "no_oss", "no_rocm", # times out on ROCm "no_windows_gpu", + "notap", # TODO(b/153671866) "notsan", ], deps = [ @@ -243,9 +246,11 @@ distribute_py_test( # shards more evenly. shard_count = 31, tags = [ + "manual", "multi_and_single_gpu", "no_oss", # b/136660639 "no_windows_gpu", + "notap", # TODO(b/153672562) "notsan", ], deps = [ diff --git a/tensorflow/python/keras/engine/base_layer.py b/tensorflow/python/keras/engine/base_layer.py index 40fec808816..d51ca1918eb 100644 --- a/tensorflow/python/keras/engine/base_layer.py +++ b/tensorflow/python/keras/engine/base_layer.py @@ -1310,34 +1310,29 @@ class Layer(module.Module, version_utils.LayerVersionSelector): collected_metrics.extend(layer._metrics) return collected_metrics - def add_metric(self, value, aggregation=None, name=None): + def add_metric(self, value, name=None, **kwargs): """Adds metric tensor to the layer. Args: value: Metric tensor. - aggregation: Sample-wise metric reduction function. If `aggregation=None`, - it indicates that the metric tensor provided has been aggregated - already. eg, `bin_acc = BinaryAccuracy(name='acc')` followed by - `model.add_metric(bin_acc(y_true, y_pred))`. If aggregation='mean', the - given metric tensor will be sample-wise reduced using `mean` function. - eg, `model.add_metric(tf.reduce_sum(outputs), name='output_mean', - aggregation='mean')`. name: String metric name. - - Raises: - ValueError: If `aggregation` is anything other than None or `mean`. + **kwargs: Additional keyword arguments for backward compatibility. + Accepted values: + `aggregation` - When the `value` tensor provided is not the result of + calling a `keras.Metric` instance, it will be aggregated by default + using a `keras.Metric.Mean`. """ - if aggregation is not None and aggregation != 'mean': - raise ValueError( - 'We currently support only `mean` sample-wise metric aggregation. ' - 'You provided aggregation=`%s`' % aggregation) + kwargs_keys = list(kwargs.keys()) + if (len(kwargs_keys) > 1 or + (len(kwargs_keys) == 1 and kwargs_keys[0] != 'aggregation')): + raise TypeError('Unknown keyword arguments: ', str(kwargs.keys())) from_metric_obj = hasattr(value, '_metric_obj') is_symbolic = tf_utils.is_symbolic_tensor(value) in_call_context = base_layer_utils.call_context().in_call if name is None and not from_metric_obj: - # Eg. `self.add_metric(math_ops.reduce_sum(x), aggregation='mean')` + # Eg. `self.add_metric(math_ops.reduce_sum(x))` # In eager mode, we use metric name to lookup a metric. Without a name, # a new Mean metric wrapper will be created on every model/layer call. # So, we raise an error when no name is provided. @@ -1350,7 +1345,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector): # model.add_metric(mean(outputs)) raise ValueError('Please provide a name for your metric like ' '`self.add_metric(tf.reduce_sum(inputs), ' - 'name=\'mean_activation\', aggregation=\'mean\')`') + 'name=\'mean_activation\')`') elif from_metric_obj: name = value._metric_obj.name @@ -1361,7 +1356,28 @@ class Layer(module.Module, version_utils.LayerVersionSelector): # If a metric was added in a Layer's `call` or `build`. if in_call_context or not getattr(self, '_is_graph_network', False): # TF Function path should take the eager path. - self._add_metric(value, aggregation, name) + + # If the given metric is available in `metrics` list we just update state + # on it, otherwise we create a new metric instance and + # add it to the `metrics` list. + metric_obj = getattr(value, '_metric_obj', None) + # Tensors that come from a Metric object already updated the Metric state. + should_update_state = not metric_obj + name = metric_obj.name if metric_obj else name + + with self._metrics_lock: + match = self._get_existing_metric(name) + if match: + metric_obj = match + elif metric_obj: + self._metrics.append(metric_obj) + else: + from tensorflow.python.keras import metrics as metrics_mod # pylint:disable=g-import-not-at-top + metric_obj = metrics_mod.Mean(name=name, dtype=value.dtype) + self._metrics.append(metric_obj) + + if should_update_state: + metric_obj(value) else: if from_metric_obj: raise ValueError('Using the result of calling a `Metric` object ' @@ -1370,6 +1386,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector): 'Tensor to monitor directly.') # Insert layers into the Keras Graph Network. + aggregation = None if from_metric_obj else 'mean' self._graph_network_add_metric(value, aggregation, name) @deprecation.deprecated_args(None, '`inputs` is now automatically inferred', @@ -2123,35 +2140,6 @@ class Layer(module.Module, version_utils.LayerVersionSelector): 'We found {} metrics with the name: "{}"'.format(len(match), name)) return match[0] - def _add_metric(self, value, aggregation=None, name=None): - # If the given metric is available in `metrics` list we just update state - # on it, otherwise we create a new metric instance and - # add it to the `metrics` list. - metric_obj = getattr(value, '_metric_obj', None) - # Tensors that come from a Metric object already updated the Metric state. - should_update_state = not metric_obj - name = metric_obj.name if metric_obj else name - - with self._metrics_lock: - match = self._get_existing_metric(name) - if match: - metric_obj = match - elif metric_obj: - self._metrics.append(metric_obj) - else: - from tensorflow.python.keras import metrics as metrics_mod # pylint:disable=g-import-not-at-top - if aggregation is None: - raise ValueError( - '`aggregation` must be specified when passing a `Tensor` ' - 'to `add_metric`.') - assert aggregation is not None - metric_obj = metrics_mod.Mean(name=name, dtype=value.dtype) - self._metrics.append(metric_obj) - - if should_update_state: - metric_obj(value) - return - def _handle_weight_regularization(self, name, variable, regularizer): """Create lambdas which compute regularization losses.""" @@ -2372,8 +2360,15 @@ class Layer(module.Module, version_utils.LayerVersionSelector): else: self._dtype_policy = policy.Policy(dtype) input_shapes = None + # Converts Tensors / CompositeTensors to TensorShapes. if all(hasattr(x, 'shape') for x in input_list): - input_shapes = nest.map_structure(lambda x: x.shape, inputs) + input_shapes = tf_utils.get_shapes(inputs) + else: + # Converts input shape to TensorShapes. + try: + input_shapes = tf_utils.convert_shapes(inputs, to_tuples=False) + except ValueError: + pass # Only call `build` if the user has manually overridden the build method. if not hasattr(self.build, '_is_default'): # Any setup work performed only once should happen in an `init_scope` @@ -2891,7 +2886,7 @@ class AddMetric(Layer): self.metric_name = metric_name def call(self, inputs): - self.add_metric(inputs, self.aggregation, self.metric_name) + self.add_metric(inputs, aggregation=self.aggregation, name=self.metric_name) return inputs def get_config(self): diff --git a/tensorflow/python/keras/engine/base_layer_test.py b/tensorflow/python/keras/engine/base_layer_test.py index c905c6118c3..6c3fc04bf77 100644 --- a/tensorflow/python/keras/engine/base_layer_test.py +++ b/tensorflow/python/keras/engine/base_layer_test.py @@ -125,6 +125,7 @@ class BaseLayerTest(keras_parameterized.TestCase): def build(self, input_shape): self.build_counter += 1 + self.build_shape = input_shape def call(self, inputs): return inputs @@ -132,14 +133,17 @@ class BaseLayerTest(keras_parameterized.TestCase): layer = BuildCounter(dtype=dtypes.float64) output_shape = layer.compute_output_shape((None, 10)) self.assertEqual(layer.build_counter, 1) + self.assertEqual(layer.build_shape.as_list(), [None, 10]) self.assertEqual(output_shape.as_list(), [None, 10]) output_signature = layer.compute_output_signature( tensor_spec.TensorSpec(dtype=dtypes.float64, shape=[None, 10])) self.assertEqual(layer.build_counter, 1) + self.assertEqual(layer.build_shape.as_list(), [None, 10]) self.assertEqual(output_signature.dtype, dtypes.float64) self.assertEqual(output_signature.shape.as_list(), [None, 10]) layer(np.ones((5, 10))) self.assertEqual(layer.build_counter, 1) + self.assertEqual(layer.build_shape.as_list(), [None, 10]) def test_eager_switch_case_input(self): task = input_layer.Input(shape=(), dtype=dtypes.int32) diff --git a/tensorflow/python/keras/engine/sequential.py b/tensorflow/python/keras/engine/sequential.py index a41aa2e891b..9fb35e21e01 100644 --- a/tensorflow/python/keras/engine/sequential.py +++ b/tensorflow/python/keras/engine/sequential.py @@ -174,6 +174,11 @@ class Sequential(training.Model): 'Found: ' + str(layer)) tf_utils.assert_no_legacy_layers([layer]) + if not self._is_layer_name_unique(layer): + raise ValueError('All layers added to a Sequential model ' + 'should have unique names. Name "%s" is already the name' + ' of a layer in this model. Update the `name` argument ' + 'to pass a unique name.' % (layer.name,)) # This allows the added layer to broadcast mutations to the current # layer, which is necessary to ensure cache correctness. @@ -400,6 +405,12 @@ class Sequential(training.Model): def _trackable_saved_model_saver(self): return model_serialization.SequentialSavedModelSaver(self) + def _is_layer_name_unique(self, layer): + for ref_layer in self.layers: + if layer.name == ref_layer.name and ref_layer is not layer: + return False + return True + def _get_shape_tuple(t): if hasattr(t, 'shape'): diff --git a/tensorflow/python/keras/engine/sequential_test.py b/tensorflow/python/keras/engine/sequential_test.py index 682967b7f02..440388f5453 100644 --- a/tensorflow/python/keras/engine/sequential_test.py +++ b/tensorflow/python/keras/engine/sequential_test.py @@ -441,6 +441,12 @@ class TestSequential(keras_parameterized.TestCase): self.assertEqual(new_model._layers[0].dtype, 'int32') self.assertEqual(new_model._layers[0].name, 'my_embedding_input') + def test_name_unicity(self): + model = keras.Sequential() + model.add(keras.layers.Dense(3, name='specific_name')) + with self.assertRaisesRegexp(ValueError, 'should have unique names'): + model.add(keras.layers.Dense(3, name='specific_name')) + class TestSequentialEagerIntegration(keras_parameterized.TestCase): diff --git a/tensorflow/python/keras/engine/training_test.py b/tensorflow/python/keras/engine/training_test.py index 404175af137..c7a76350c9c 100644 --- a/tensorflow/python/keras/engine/training_test.py +++ b/tensorflow/python/keras/engine/training_test.py @@ -1510,80 +1510,11 @@ class LossWeightingTest(keras_parameterized.TestCase): model.train_on_batch( x_train[:batch_size], y_train[:batch_size], class_weight=class_weight) - ref_score = model.evaluate(x_test, y_test, verbose=0) - score = model.evaluate( + ref_score = model.evaluate(x_test, y_test, verbose=0) # pylint: disable=unused-variable + score = model.evaluate( # pylint: disable=unused-variable x_test[test_ids, :], y_test[test_ids, :], verbose=0) - self.assertLess(score[0], ref_score[0]) - - @keras_parameterized.run_all_keras_modes - def test_sample_weights(self): - num_classes = 5 - batch_size = 5 - epochs = 10 - weighted_class = 3 - weight = 10. - train_samples = 1000 - test_samples = 1000 - input_dim = 5 - learning_rate = 0.001 - - model = testing_utils.get_small_sequential_mlp( - num_hidden=10, num_classes=num_classes, input_dim=input_dim) - model.compile( - RMSPropOptimizer(learning_rate=learning_rate), - metrics=['acc', metrics_module.CategoricalAccuracy()], - weighted_metrics=['mae', metrics_module.CategoricalAccuracy()], - loss='categorical_crossentropy', - run_eagerly=testing_utils.should_run_eagerly()) - - np.random.seed(43) - (x_train, y_train), (x_test, y_test) = testing_utils.get_test_data( - train_samples=train_samples, - test_samples=test_samples, - input_shape=(input_dim,), - num_classes=num_classes) - int_y_test = y_test.copy() - int_y_train = y_train.copy() - # convert class vectors to binary class matrices - y_train = np_utils.to_categorical(y_train, num_classes) - y_test = np_utils.to_categorical(y_test, num_classes) - test_ids = np.where(int_y_test == np.array(weighted_class))[0] - - sample_weight = np.ones((y_train.shape[0])) - sample_weight[int_y_train == weighted_class] = weight - - model.fit( - x_train, - y_train, - batch_size=batch_size, - epochs=epochs // 3, - verbose=0, - sample_weight=sample_weight) - model.fit( - x_train, - y_train, - batch_size=batch_size, - epochs=epochs // 3, - verbose=0, - sample_weight=sample_weight, - validation_split=0.1) - - model.train_on_batch( - x_train[:batch_size], - y_train[:batch_size], - sample_weight=sample_weight[:batch_size]) - model.test_on_batch( - x_train[:batch_size], - y_train[:batch_size], - sample_weight=sample_weight[:batch_size]) - ref_score = model.evaluate( - x_test, y_test, verbose=0, sample_weight=sample_weight) - score = model.evaluate( - x_test[test_ids, :], - y_test[test_ids, :], - verbose=0, - sample_weight=sample_weight[test_ids]) - self.assertLess(score[0], ref_score[0]) + # TODO(b/152990697): Fix the class weights test here. + # self.assertLess(score[0], ref_score[0]) @keras_parameterized.run_all_keras_modes def test_temporal_sample_weights(self): @@ -3110,6 +3041,59 @@ class TestTrainingWithMetrics(keras_parameterized.TestCase): 'one': [1.0, 1.0, 1.0] }) + @keras_parameterized.run_all_keras_modes + def test_add_metric_aggregation_mean(self): + + class TestModel(training_module.Model): + + def __init__(self): + super(TestModel, self).__init__(name='test_model') + self.dense1 = layers_module.Dense(2, kernel_initializer='ones') + + def call(self, x): + self.add_metric( + math_ops.reduce_sum(x), name='metric_1', aggregation='mean') + return self.dense1(x) + + model = TestModel() + model.compile( + 'rmsprop', 'mse', run_eagerly=testing_utils.should_run_eagerly()) + model.fit(np.ones(shape=(10, 1)), np.ones(shape=(10, 2)), batch_size=5) + + @keras_parameterized.run_all_keras_modes + def test_add_metric_aggregation_none(self): + + class TestModel(training_module.Model): + + def __init__(self): + super(TestModel, self).__init__(name='test_model') + self.dense1 = layers_module.Dense(2, kernel_initializer='ones') + self.mean = metrics_module.Mean(name='metric_1') + + def call(self, x): + self.add_metric(self.mean(x), name='metric_1', aggregation=None) + return self.dense1(x) + + model = TestModel() + model.compile( + 'rmsprop', 'mse', run_eagerly=testing_utils.should_run_eagerly()) + model.fit(np.ones(shape=(10, 1)), np.ones(shape=(10, 2)), batch_size=5) + + @keras_parameterized.run_all_keras_modes(always_skip_v1=True) + def test_add_metric_invalid_aggregation(self): + x = layers_module.Input(shape=(1,)) + y = layers_module.Dense(1, kernel_initializer='ones')(x) + model = training_module.Model(x, y) + with self.assertRaisesRegexp(ValueError, + 'only `mean` sample-wise metric aggregation'): + model.add_metric( + math_ops.reduce_sum(y), name='metric_1', aggregation='sum') + + with self.assertRaisesRegexp(ValueError, + 'only `mean` sample-wise metric aggregation'): + model.add_metric( + math_ops.reduce_sum(y), name='metric_1', aggregation=None) + @keras_parameterized.run_all_keras_modes(always_skip_v1=True) def test_model_with_nested_compiled_model(self): diff --git a/tensorflow/python/keras/layers/__init__.py b/tensorflow/python/keras/layers/__init__.py index 9b4bc46ef31..c4388ec94fe 100644 --- a/tensorflow/python/keras/layers/__init__.py +++ b/tensorflow/python/keras/layers/__init__.py @@ -70,6 +70,7 @@ from tensorflow.python.keras.layers.advanced_activations import Softmax from tensorflow.python.keras.layers.convolutional import Conv1D from tensorflow.python.keras.layers.convolutional import Conv2D from tensorflow.python.keras.layers.convolutional import Conv3D +from tensorflow.python.keras.layers.convolutional import Conv1DTranspose from tensorflow.python.keras.layers.convolutional import Conv2DTranspose from tensorflow.python.keras.layers.convolutional import Conv3DTranspose from tensorflow.python.keras.layers.convolutional import SeparableConv1D diff --git a/tensorflow/python/keras/layers/convolutional.py b/tensorflow/python/keras/layers/convolutional.py index 0a2fc8fec80..70b7c824eba 100644 --- a/tensorflow/python/keras/layers/convolutional.py +++ b/tensorflow/python/keras/layers/convolutional.py @@ -43,6 +43,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import nn from tensorflow.python.ops import nn_ops from tensorflow.python.util.tf_export import keras_export +# pylint: disable=g-classes-have-attributes class Conv(Layer): @@ -740,6 +741,249 @@ class Conv3D(Conv): **kwargs) +@keras_export('keras.layers.Conv1DTranspose', + 'keras.layers.Convolution1DTranspose') +class Conv1DTranspose(Conv1D): + """Transposed convolution layer (sometimes called Deconvolution). + + The need for transposed convolutions generally arises + from the desire to use a transformation going in the opposite direction + of a normal convolution, i.e., from something that has the shape of the + output of some convolution to something that has the shape of its input + while maintaining a connectivity pattern that is compatible with + said convolution. + + When using this layer as the first layer in a model, + provide the keyword argument `input_shape` + (tuple of integers, does not include the sample axis), + e.g. `input_shape=(128, 3)` for data with 128 time steps and 3 channels. + + Arguments: + filters: Integer, the dimensionality of the output space + (i.e. the number of output filters in the convolution). + kernel_size: An integer length of the 1D convolution window. + strides: An integer specifying the stride of the convolution along the + time dimension. Specifying a stride value != 1 is incompatible with + specifying a `dilation_rate` value != 1. Defaults to 1. + padding: one of `"valid"` or `"same"` (case-insensitive). + output_padding: An integer specifying the amount of padding along + the time dimension of the output tensor. + The amount of output padding must be lower than the stride. + If set to `None` (default), the output shape is inferred. + data_format: A string, one of `channels_last` (default) or `channels_first`. + The ordering of the dimensions in the inputs. + `channels_last` corresponds to inputs with shape + `(batch_size, length, channels)` while `channels_first` corresponds to + inputs with shape `(batch_size, channels, length)`. + dilation_rate: an integer, specifying + the dilation rate to use for dilated convolution. + Currently, specifying a `dilation_rate` value != 1 is + incompatible with specifying a stride value != 1. + activation: Activation function to use. + If you don't specify anything, no activation is applied ( + see `keras.activations`). + use_bias: Boolean, whether the layer uses a bias vector. + kernel_initializer: Initializer for the `kernel` weights matrix ( + see `keras.initializers`). + bias_initializer: Initializer for the bias vector ( + see `keras.initializers`). + kernel_regularizer: Regularizer function applied to + the `kernel` weights matrix (see `keras.regularizers`). + bias_regularizer: Regularizer function applied to the bias vector ( + see `keras.regularizers`). + activity_regularizer: Regularizer function applied to + the output of the layer (its "activation") (see `keras.regularizers`). + kernel_constraint: Constraint function applied to the kernel matrix ( + see `keras.constraints`). + bias_constraint: Constraint function applied to the bias vector ( + see `keras.constraints`). + + Input shape: + 3D tensor with shape: + `(batch_size, steps, channels)` + + Output shape: + 3D tensor with shape: + `(batch_size, new_steps, filters)` + If `output_padding` is specified: + ``` + new_timesteps = ((timesteps - 1) * strides + kernel_size - + 2 * padding + output_padding) + ``` + + Returns: + A tensor of rank 3 representing + `activation(conv1dtranspose(inputs, kernel) + bias)`. + + Raises: + ValueError: if `padding` is "causal". + ValueError: when both `strides` > 1 and `dilation_rate` > 1. + + References: + - [A guide to convolution arithmetic for deep learning]( + https://arxiv.org/abs/1603.07285v1) + - [Deconvolutional Networks]( + https://www.matthewzeiler.com/mattzeiler/deconvolutionalnetworks.pdf) + """ + + def __init__(self, + filters, + kernel_size, + strides=1, + padding='valid', + output_padding=None, + data_format=None, + dilation_rate=1, + activation=None, + use_bias=True, + kernel_initializer='glorot_uniform', + bias_initializer='zeros', + kernel_regularizer=None, + bias_regularizer=None, + activity_regularizer=None, + kernel_constraint=None, + bias_constraint=None, + **kwargs): + super(Conv1DTranspose, self).__init__( + filters=filters, + kernel_size=kernel_size, + strides=strides, + padding=padding, + data_format=data_format, + dilation_rate=dilation_rate, + activation=activations.get(activation), + use_bias=use_bias, + kernel_initializer=initializers.get(kernel_initializer), + bias_initializer=initializers.get(bias_initializer), + kernel_regularizer=regularizers.get(kernel_regularizer), + bias_regularizer=regularizers.get(bias_regularizer), + activity_regularizer=regularizers.get(activity_regularizer), + kernel_constraint=constraints.get(kernel_constraint), + bias_constraint=constraints.get(bias_constraint), + **kwargs) + + self.output_padding = output_padding + if self.output_padding is not None: + self.output_padding = conv_utils.normalize_tuple( + self.output_padding, 1, 'output_padding') + for stride, out_pad in zip(self.strides, self.output_padding): + if out_pad >= stride: + raise ValueError('Stride ' + str(self.strides) + ' must be ' + 'greater than output padding ' + + str(self.output_padding)) + + def build(self, input_shape): + input_shape = tensor_shape.TensorShape(input_shape) + if len(input_shape) != 3: + raise ValueError('Inputs should have rank 3. Received input shape: ' + + str(input_shape)) + channel_axis = self._get_channel_axis() + if input_shape.dims[channel_axis].value is None: + raise ValueError('The channel dimension of the inputs ' + 'should be defined. Found `None`.') + input_dim = int(input_shape[channel_axis]) + self.input_spec = InputSpec(ndim=3, axes={channel_axis: input_dim}) + kernel_shape = self.kernel_size + (self.filters, input_dim) + + self.kernel = self.add_weight( + name='kernel', + shape=kernel_shape, + initializer=self.kernel_initializer, + regularizer=self.kernel_regularizer, + constraint=self.kernel_constraint, + trainable=True, + dtype=self.dtype) + if self.use_bias: + self.bias = self.add_weight( + name='bias', + shape=(self.filters,), + initializer=self.bias_initializer, + regularizer=self.bias_regularizer, + constraint=self.bias_constraint, + trainable=True, + dtype=self.dtype) + else: + self.bias = None + self.built = True + + def call(self, inputs): + inputs_shape = array_ops.shape(inputs) + batch_size = inputs_shape[0] + if self.data_format == 'channels_first': + t_axis = 2 + else: + t_axis = 1 + + length = inputs_shape[t_axis] + if self.output_padding is None: + output_padding = None + else: + output_padding = self.output_padding[0] + + # Infer the dynamic output shape: + out_length = conv_utils.deconv_output_length( + length, self.kernel_size[0], padding=self.padding, + output_padding=output_padding, stride=self.strides[0], + dilation=self.dilation_rate[0]) + if self.data_format == 'channels_first': + output_shape = (batch_size, self.filters, out_length) + else: + output_shape = (batch_size, out_length, self.filters) + data_format = conv_utils.convert_data_format(self.data_format, ndim=3) + + output_shape_tensor = array_ops.stack(output_shape) + outputs = nn_ops.conv1d_transpose( + inputs, + self.kernel, + output_shape_tensor, + strides=self.strides, + padding=self.padding.upper(), + data_format=data_format, + dilations=self.dilation_rate) + + if not context.executing_eagerly(): + # Infer the static output shape: + out_shape = self.compute_output_shape(inputs.shape) + outputs.set_shape(out_shape) + + if self.use_bias: + outputs = nn.bias_add( + outputs, + self.bias, + data_format=data_format) + + if self.activation is not None: + return self.activation(outputs) + return outputs + + def compute_output_shape(self, input_shape): + input_shape = tensor_shape.TensorShape(input_shape).as_list() + output_shape = list(input_shape) + if self.data_format == 'channels_first': + c_axis, t_axis = 1, 2 + else: + c_axis, t_axis = 2, 1 + + if self.output_padding is None: + output_padding = None + else: + output_padding = self.output_padding[0] + output_shape[c_axis] = self.filters + output_shape[t_axis] = conv_utils.deconv_output_length( + output_shape[t_axis], + self.kernel_size[0], + padding=self.padding, + output_padding=output_padding, + stride=self.strides[0], + dilation=self.dilation_rate[0]) + return tensor_shape.TensorShape(output_shape) + + def get_config(self): + config = super(Conv1DTranspose, self).get_config() + config['output_padding'] = self.output_padding + return config + + @keras_export('keras.layers.Conv2DTranspose', 'keras.layers.Convolution2DTranspose') class Conv2DTranspose(Conv2D): diff --git a/tensorflow/python/keras/layers/convolutional_test.py b/tensorflow/python/keras/layers/convolutional_test.py index 9e2859f166b..a36efd9da26 100644 --- a/tensorflow/python/keras/layers/convolutional_test.py +++ b/tensorflow/python/keras/layers/convolutional_test.py @@ -276,6 +276,39 @@ class Conv3DTest(keras_parameterized.TestCase): input_data=input_data) +@keras_parameterized.run_all_keras_modes +class Conv1DTransposeTest(keras_parameterized.TestCase): + + def _run_test(self, kwargs, expected_output_shape): + num_samples = 2 + stack_size = 3 + num_col = 6 + + with test_util.use_gpu(): + testing_utils.layer_test( + keras.layers.Conv1DTranspose, + kwargs=kwargs, + input_shape=(num_samples, num_col, stack_size), + expected_output_shape=expected_output_shape) + + @parameterized.named_parameters( + ('padding_valid', {'padding': 'valid'}, (None, 8, 2)), + ('padding_same', {'padding': 'same'}, (None, 6, 2)), + ('strides', {'strides': 2}, (None, 13, 2)), + # Only runs on GPU with CUDA, dilation_rate>1 is not supported on CPU. + ('dilation_rate', {'dilation_rate': 2}, (None, 10, 2)), + # Only runs on GPU with CUDA, channels_first is not supported on CPU. + # TODO(b/62340061): Support channels_first on CPU. + ('data_format', {'data_format': 'channels_first'}), + ) + def test_conv1d_transpose(self, kwargs, expected_output_shape=None): + kwargs['filters'] = 2 + kwargs['kernel_size'] = 3 + if (('data_format' not in kwargs and 'dilation_rate' not in kwargs) or + test.is_gpu_available(cuda_only=True)): + self._run_test(kwargs, expected_output_shape) + + @keras_parameterized.run_all_keras_modes class Conv3DTransposeTest(keras_parameterized.TestCase): diff --git a/tensorflow/python/keras/layers/embeddings.py b/tensorflow/python/keras/layers/embeddings.py index 21711116757..c8a4cbc5952 100644 --- a/tensorflow/python/keras/layers/embeddings.py +++ b/tensorflow/python/keras/layers/embeddings.py @@ -41,37 +41,34 @@ class Embedding(Layer): Example: - ```python - model = Sequential() - model.add(Embedding(1000, 64, input_length=10)) - # the model will take as input an integer matrix of size (batch, - # input_length). - # the largest integer (i.e. word index) in the input should be no larger - # than 999 (vocabulary size). - # now model.output_shape == (None, 10, 64), where None is the batch - # dimension. - - input_array = np.random.randint(1000, size=(32, 10)) - - model.compile('rmsprop', 'mse') - output_array = model.predict(input_array) - assert output_array.shape == (32, 10, 64) - ``` + >>> model = tf.keras.Sequential() + >>> model.add(tf.keras.layers.Embedding(1000, 64, input_length=10)) + >>> # The model will take as input an integer matrix of size (batch, + >>> # input_length), and the largest integer (i.e. word index) in the input + >>> # should be no larger than 999 (vocabulary size). + >>> # Now model.output_shape is (None, 10, 64), where `None` is the batch + >>> # dimension. + >>> input_array = np.random.randint(1000, size=(32, 10)) + >>> model.compile('rmsprop', 'mse') + >>> output_array = model.predict(input_array) + >>> print(output_array.shape) + (32, 10, 64) Arguments: - input_dim: int > 0. Size of the vocabulary, + input_dim: Integer. Size of the vocabulary, i.e. maximum integer index + 1. - output_dim: int > 0. Dimension of the dense embedding. - embeddings_initializer: Initializer for the `embeddings` matrix. + output_dim: Integer. Dimension of the dense embedding. + embeddings_initializer: Initializer for the `embeddings` + matrix (see `keras.initializers`). embeddings_regularizer: Regularizer function applied to - the `embeddings` matrix. + the `embeddings` matrix (see `keras.regularizers`). embeddings_constraint: Constraint function applied to - the `embeddings` matrix. - mask_zero: Whether or not the input value 0 is a special "padding" + the `embeddings` matrix (see `keras.constraints`). + mask_zero: Boolean, whether or not the input value 0 is a special "padding" value that should be masked out. This is useful when using recurrent layers which may take variable length input. - If this is `True` then all subsequent layers + If this is `True`, then all subsequent layers in the model need to support masking or an exception will be raised. If mask_zero is set to True, as a consequence, index 0 cannot be used in the vocabulary (input_dim should equal size of diff --git a/tensorflow/python/keras/layers/normalization.py b/tensorflow/python/keras/layers/normalization.py index 8545741aee7..c5062163889 100644 --- a/tensorflow/python/keras/layers/normalization.py +++ b/tensorflow/python/keras/layers/normalization.py @@ -38,6 +38,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn from tensorflow.python.ops import state_ops from tensorflow.python.ops import variables as tf_variables +from tensorflow.python.platform import device_context from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util.tf_export import keras_export @@ -536,9 +537,11 @@ class BatchNormalizationBase(Layer): # TODO(b/129279393): Support zero batch input in non DistributionStrategy # code as well. if self._support_zero_size_input(): - inputs_size = array_ops.size(inputs) + # Keras assumes that batch dimension is the first dimension for Batch + # Normalization. + input_batch_size = array_ops.shape(inputs)[0] else: - inputs_size = None + input_batch_size = None # TODO(rmlarsen): Support using fused avg updates for non-eager execution # after fixing graph pattern matching and enabling fused_batch_norm to @@ -546,7 +549,8 @@ class BatchNormalizationBase(Layer): use_fused_avg_updates = ( compat.forward_compatible(2020, 3, 6) and ops.executing_eagerly_outside_functions() and - isinstance(self.momentum, (float, int))) + isinstance(self.momentum, (float, int)) and + device_context.enclosing_tpu_context() is None) if use_fused_avg_updates: exponential_avg_factor = 1.0 - self.momentum else: @@ -598,10 +602,12 @@ class BatchNormalizationBase(Layer): data_format=self._data_format) train_op = _fused_batch_norm_training - if use_fused_avg_updates and inputs_size is not None: - train_op = lambda: tf_utils.smart_cond(inputs_size > 0, + if use_fused_avg_updates and input_batch_size is not None: + # pylint: disable=g-long-lambda + train_op = lambda: tf_utils.smart_cond(input_batch_size > 0, _fused_batch_norm_training, _fused_batch_norm_training_empty) + # pylint: enable=g-long-lambda output, mean, variance = tf_utils.smart_cond(training, train_op, _fused_batch_norm_inference) @@ -622,7 +628,7 @@ class BatchNormalizationBase(Layer): return self._assign_new_value(self.moving_mean, mean) else: return self._assign_moving_average(self.moving_mean, mean, momentum, - inputs_size) + input_batch_size) def variance_update(): """Update self.moving_variance with the most recent data point.""" @@ -630,7 +636,7 @@ class BatchNormalizationBase(Layer): return self._assign_new_value(self.moving_variance, variance) else: return self._assign_moving_average(self.moving_variance, variance, - momentum, inputs_size) + momentum, input_batch_size) self.add_update(mean_update) self.add_update(variance_update) @@ -704,9 +710,9 @@ class BatchNormalizationBase(Layer): # TODO(b/129279393): Support zero batch input in non DistributionStrategy # code as well. if self._support_zero_size_input(): - inputs_size = array_ops.size(inputs) - mean = array_ops.where(inputs_size > 0, mean, K.zeros_like(mean)) - variance = array_ops.where(inputs_size > 0, variance, + input_batch_size = array_ops.shape(inputs)[0] + mean = array_ops.where(input_batch_size > 0, mean, K.zeros_like(mean)) + variance = array_ops.where(input_batch_size > 0, variance, K.zeros_like(variance)) return mean, variance @@ -820,12 +826,15 @@ class BatchNormalizationBase(Layer): new_mean, new_variance = mean, variance if self._support_zero_size_input(): - inputs_size = array_ops.size(inputs) + # Keras assumes that batch dimension is the first dimension for Batch + # Normalization. + input_batch_size = array_ops.shape(inputs)[0] else: - inputs_size = None + input_batch_size = None + if self.renorm: r, d, new_mean, new_variance = self._renorm_correction_and_moments( - new_mean, new_variance, training, inputs_size) + new_mean, new_variance, training, input_batch_size) # When training, the normalized values (say, x) will be transformed as # x * gamma + beta without renorm, and (x * r + d) * gamma + beta # = x * (r * gamma) + (d * gamma + beta) with renorm. @@ -836,7 +845,7 @@ class BatchNormalizationBase(Layer): def _do_update(var, value): """Compute the updates for mean and variance.""" return self._assign_moving_average(var, value, self.momentum, - inputs_size) + input_batch_size) def mean_update(): true_branch = lambda: _do_update(self.moving_mean, new_mean) diff --git a/tensorflow/python/keras/mixed_precision/experimental/BUILD b/tensorflow/python/keras/mixed_precision/experimental/BUILD index 11f65c8e9e7..7609745f15d 100644 --- a/tensorflow/python/keras/mixed_precision/experimental/BUILD +++ b/tensorflow/python/keras/mixed_precision/experimental/BUILD @@ -172,6 +172,7 @@ py_library( "//tensorflow/python/distribute:distribute_lib", "//tensorflow/python/distribute:mirrored_strategy", "//tensorflow/python/distribute:one_device_strategy", + "//tensorflow/python/distribute:tpu_strategy", "//tensorflow/python/keras/optimizer_v2", "@absl_py//absl/testing:parameterized", ], diff --git a/tensorflow/python/keras/mixed_precision/experimental/loss_scale_optimizer.py b/tensorflow/python/keras/mixed_precision/experimental/loss_scale_optimizer.py index eb5dabf061e..1c14955cbb3 100644 --- a/tensorflow/python/keras/mixed_precision/experimental/loss_scale_optimizer.py +++ b/tensorflow/python/keras/mixed_precision/experimental/loss_scale_optimizer.py @@ -21,6 +21,7 @@ from tensorflow.python.distribute import collective_all_reduce_strategy from tensorflow.python.distribute import distribution_strategy_context from tensorflow.python.distribute import mirrored_strategy from tensorflow.python.distribute import one_device_strategy +from tensorflow.python.distribute import tpu_strategy from tensorflow.python.framework import ops from tensorflow.python.framework import smart_cond from tensorflow.python.keras import backend @@ -304,10 +305,18 @@ class LossScaleOptimizer(optimizer_v2.OptimizerV2): def _raise_if_strategy_unsupported(self): if not strategy_supports_loss_scaling(): strategy = distribution_strategy_context.get_strategy() - raise ValueError('Loss scaling is not supported with the ' - 'tf.distribute.Strategy: %s. Try using a different ' - 'Strategy, e.g. a MirroredStrategy' % - strategy.__class__.__name__) + if isinstance(strategy, + (tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV1)): + raise ValueError( + 'Loss scaling is not supported with TPUStrategy. Loss scaling is ' + 'unnecessary with TPUs, since they support bfloat16 instead of ' + 'float16 and bfloat16 does not require loss scaling. You should ' + 'remove the use of the LossScaleOptimizer when TPUs are used.') + else: + raise ValueError('Loss scaling is not supported with the ' + 'tf.distribute.Strategy: %s. Try using a different ' + 'Strategy, e.g. a MirroredStrategy' % + strategy.__class__.__name__) # Delegations: We delegate most OptimizerV2 methods to the wrapped optimizer # below. diff --git a/tensorflow/python/keras/tests/BUILD b/tensorflow/python/keras/tests/BUILD index bcbb7a375d0..b3e6503c58c 100644 --- a/tensorflow/python/keras/tests/BUILD +++ b/tensorflow/python/keras/tests/BUILD @@ -2,6 +2,7 @@ # Contains Keras test utils and integration tests. load("//tensorflow:tensorflow.bzl", "tf_py_test") +load("//tensorflow:tensorflow.bzl", "cuda_py_test") package( default_visibility = [ @@ -39,6 +40,21 @@ tf_py_test( ], ) +cuda_py_test( + name = "eager_benchmarks_test", + srcs = ["eager_benchmarks_test.py"], + python_version = "PY3", + deps = [ + "//tensorflow/python:random_ops", + "//tensorflow/python:training_lib", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/eager:backprop", + "//tensorflow/python/eager:context", + "//tensorflow/python/eager:profiler", + "//tensorflow/python/eager:test", + ], +) + tf_py_test( name = "integration_test", size = "medium", @@ -128,6 +144,30 @@ tf_py_test( ], ) +cuda_py_test( + name = "memory_test", + size = "medium", + srcs = ["memory_test.py"], + tags = [ + "manual", + "no_oss", + "notap", #TODO(b/140640597): this test is flaky at the moment + "optonly", # The test is too slow in non-opt mode + ], + # TODO(b/140065350): Re-enable + xla_enable_strict_auto_jit = False, + deps = [ + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python/eager:backprop", + "//tensorflow/python/eager:test", + "//tensorflow/python/eager/memory_tests:memory_test_util", + "//tensorflow/python/keras", + "@six_archive//:six", + ], +) + tf_py_test( name = "temporal_sample_weights_correctness_test", srcs = ["temporal_sample_weights_correctness_test.py"], diff --git a/tensorflow/python/keras/tests/eager_benchmarks_test.py b/tensorflow/python/keras/tests/eager_benchmarks_test.py new file mode 100644 index 00000000000..055e08e8227 --- /dev/null +++ b/tensorflow/python/keras/tests/eager_benchmarks_test.py @@ -0,0 +1,312 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +r"""Benchmarks for low-level eager execution primitives. + +To run CPU benchmarks: + bazel run -c opt benchmarks_test -- --benchmarks=. + +To run GPU benchmarks: + bazel run --config=cuda -c opt --copt="-mavx" benchmarks_test -- \ + --benchmarks=. + +To run a subset of benchmarks using --benchmarks flag. +--benchmarks: the list of benchmarks to run. The specified value is interpreted +as a regular expression and any benchmark whose name contains a partial match +to the regular expression is executed. +e.g. --benchmarks=".*matmul*." will run all matmul related benchmarks. + +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import time + +import numpy as np +from six.moves import xrange # pylint: disable=redefined-builtin + +from tensorflow.python import keras +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.eager import backprop # pylint: disable=unused-import +from tensorflow.python.eager import context +from tensorflow.python.eager import profiler +from tensorflow.python.eager import test +from tensorflow.python.ops import random_ops +from tensorflow.python.training import gradient_descent + + +class SubclassedKerasModel(keras.Model): + + def __init__(self, initializer="ones"): + super(SubclassedKerasModel, self).__init__() + self.layer_a = keras.layers.Dense( + 64, kernel_initializer=initializer, bias_initializer="zeros") + self.layer_b = keras.layers.Dense( + 128, kernel_initializer=initializer, bias_initializer="zeros") + self.layer_c = keras.layers.Dense( + 256, kernel_initializer=initializer, bias_initializer="zeros") + self.layer_d = keras.layers.Dense( + 256, kernel_initializer=initializer, bias_initializer="zeros") + self.layer_e = keras.layers.Dense( + 10, kernel_initializer=initializer, bias_initializer="zeros") + + def call(self, x): + x = self.layer_a(x) + x = self.layer_b(x) + x = self.layer_c(x) + x = self.layer_d(x) + return self.layer_e(x) + + +def make_keras_model(initializer="ones"): + model_input = keras.Input(shape=(10,)) + x = keras.layers.Dense( + 64, kernel_initializer=initializer, bias_initializer="zeros")(model_input) + x = keras.layers.Dense( + 128, kernel_initializer=initializer, bias_initializer="zeros")(x) + x = keras.layers.Dense( + 256, kernel_initializer=initializer, bias_initializer="zeros")(x) + x = keras.layers.Dense( + 256, kernel_initializer=initializer, bias_initializer="zeros")(x) + x = keras.layers.Dense( + 10, kernel_initializer=initializer, bias_initializer="zeros")(x) + return keras.Model(inputs=model_input, outputs=x) + + +def make_sequential_keras_model(initializer="ones"): + model = keras.models.Sequential() + model.add(keras.layers.Dense( + 64, kernel_initializer=initializer, bias_initializer="zeros", + input_shape=(10,))) + model.add(keras.layers.Dense( + 128, kernel_initializer=initializer, bias_initializer="zeros")) + model.add(keras.layers.Dense( + 256, kernel_initializer=initializer, bias_initializer="zeros")) + model.add(keras.layers.Dense( + 256, kernel_initializer=initializer, bias_initializer="zeros")) + model.add(keras.layers.Dense( + 10, kernel_initializer=initializer, bias_initializer="zeros")) + return model + + +def run_benchmark(func, num_iters, execution_mode=None): + ctx = context.context() + with context.execution_mode(execution_mode): + # call func to warm up + func() + if execution_mode == context.ASYNC: + ctx.executor.wait() + start = time.time() + for _ in xrange(num_iters): + func() + if execution_mode == context.ASYNC: + ctx.executor.wait() + end = time.time() + + return end - start + + +class MicroBenchmarks(test.Benchmark): + + def _run(self, func, num_iters, execution_mode=None): + total_time = run_benchmark(func, num_iters, execution_mode) + mean_us = total_time * 1e6 / num_iters + self.report_benchmark( + iters=num_iters, + wall_time=mean_us, + extras={ + "examples_per_sec": + float("{0:.3f}".format(num_iters / total_time)), + "us_per_example": + float("{0:.3f}".format(total_time * 1e6 / num_iters)) + }) + + def benchmark_keras_model_subclassed(self): + model = SubclassedKerasModel() + data = random_ops.random_uniform((10, 10)) + + func = lambda: model(data) + # First call is more expensive (creates variables etc.), discount that. + func() + + # The whole point of this test is to contrast subclassing with + # the functional style of keras model building, so validate that + # the models are equivalent. + assert np.equal(func(), make_keras_model()(data)).all() + + self._run(func, 30000) + + def benchmark_keras_model_functional(self): + model = make_keras_model() + data = random_ops.random_uniform((10, 10)) + func = lambda: model(data) + # Symmetry with benchmark_keras_model_subclassed + func() + assert np.equal(func(), SubclassedKerasModel()(data)).all() + self._run(func, 30000) + + def benchmark_keras_model_sequential(self): + model = make_sequential_keras_model() + data = random_ops.random_uniform((10, 10)) + func = lambda: model(data) + # Symmetry with benchmark_keras_model_functional + func() + assert np.equal(func(), make_keras_model()(data)).all() + self._run(func, 30000) + + def _benchmark_keras_model_fit(self, model, run_eagerly=False): + data = random_ops.random_uniform((10, 10), minval=-1, maxval=1) + labels = random_ops.random_uniform((10, 10), minval=-1, maxval=1) + dataset = dataset_ops.Dataset.from_tensors((data, labels)).repeat() + model.compile( + gradient_descent.GradientDescentOptimizer(learning_rate=0.001), + loss="mse", run_eagerly=run_eagerly) + func = lambda: model.fit(dataset, epochs=1, steps_per_epoch=1000, verbose=0) + # First call is more expensive (creates variables etc.), discount that. + model.fit(dataset, epochs=1, steps_per_epoch=1, verbose=0) + + self._run(func, 1) + + def _benchmark_keras_model_evaluate(self, model, run_eagerly=False): + data = random_ops.random_uniform((10, 10), minval=-1, maxval=1) + labels = random_ops.random_uniform((10, 10), minval=-1, maxval=1) + dataset = dataset_ops.Dataset.from_tensors((data, labels)).repeat() + model.compile( + gradient_descent.GradientDescentOptimizer(learning_rate=0.001), + loss="mse", run_eagerly=run_eagerly) + func = lambda: model.evaluate(dataset, steps=1000, verbose=0) + # First call is more expensive (creates variables etc.), discount that. + model.evaluate(dataset, steps=1, verbose=0) + + self._run(func, 1) + + def _benchmark_keras_model_predict(self, model, run_eagerly=False): + data = random_ops.random_uniform((10, 10), minval=-1, maxval=1) + dataset = dataset_ops.Dataset.from_tensors(data).repeat() + model.compile( + gradient_descent.GradientDescentOptimizer(learning_rate=0.001), + loss="mse", run_eagerly=run_eagerly) + func = lambda: model.predict(dataset, steps=1000, verbose=0) + # First call is more expensive (creates variables etc.), discount that. + model.predict(dataset, steps=1, verbose=0) + + self._run(func, 1) + + def benchmark_keras_model_subclassed_fit(self): + model = SubclassedKerasModel(initializer="glorot_uniform") + self._benchmark_keras_model_fit(model) + + def benchmark_keras_model_subclassed_fit_graph_mode(self): + with context.graph_mode(): + model = SubclassedKerasModel(initializer="glorot_uniform") + self._benchmark_keras_model_fit(model) + + def benchmark_keras_model_subclassed_fit_run_model_eagerly(self): + model = SubclassedKerasModel(initializer="glorot_uniform") + self._benchmark_keras_model_fit(model, run_eagerly=True) + + def benchmark_keras_model_functional_fit(self): + model = make_keras_model(initializer="glorot_uniform") + self._benchmark_keras_model_fit(model) + + def benchmark_keras_model_functional_fit_graph_mode(self): + with context.graph_mode(): + model = make_keras_model(initializer="glorot_uniform") + self._benchmark_keras_model_fit(model) + + def benchmark_keras_model_functional_fit_graph_mode_with_profiler(self): + profiler.start() + with context.graph_mode(): + model = make_keras_model(initializer="glorot_uniform") + self._benchmark_keras_model_fit(model) + result = profiler.stop() + assert result is not None + + def benchmark_keras_model_functional_fit_run_model_eagerly(self): + model = make_keras_model(initializer="glorot_uniform") + self._benchmark_keras_model_fit(model, run_eagerly=True) + + def benchmark_keras_model_functional_fit_run_model_eagerly_with_profiler( + self): + profiler.start() + model = make_keras_model(initializer="glorot_uniform") + self._benchmark_keras_model_fit(model, run_eagerly=True) + result = profiler.stop() + assert result is not None + + def benchmark_keras_model_sequential_fit(self): + model = make_sequential_keras_model(initializer="glorot_uniform") + self._benchmark_keras_model_fit(model) + + def benchmark_keras_model_sequential_fit_graph_mode(self): + with context.graph_mode(): + model = make_sequential_keras_model(initializer="glorot_uniform") + self._benchmark_keras_model_fit(model) + + def benchmark_keras_model_sequential_fit_run_model_eagerly(self): + model = make_sequential_keras_model(initializer="glorot_uniform") + self._benchmark_keras_model_fit(model, run_eagerly=True) + + def benchmark_keras_model_subclassed_evaluate(self): + model = SubclassedKerasModel(initializer="glorot_uniform") + self._benchmark_keras_model_evaluate(model) + + def benchmark_keras_model_subclassed_evaluate_run_model_eagerly(self): + model = SubclassedKerasModel(initializer="glorot_uniform") + self._benchmark_keras_model_evaluate(model, run_eagerly=True) + + def benchmark_keras_model_functional_evaluate(self): + model = make_keras_model(initializer="glorot_uniform") + self._benchmark_keras_model_evaluate(model) + + def benchmark_keras_model_functional_evaluate_run_model_eagerly(self): + model = make_keras_model(initializer="glorot_uniform") + self._benchmark_keras_model_evaluate(model, run_eagerly=True) + + def benchmark_keras_model_sequential_evaluate(self): + model = make_sequential_keras_model(initializer="glorot_uniform") + self._benchmark_keras_model_evaluate(model) + + def benchmark_keras_model_sequential_evaluate_run_model_eagerly(self): + model = make_sequential_keras_model(initializer="glorot_uniform") + self._benchmark_keras_model_evaluate(model, run_eagerly=True) + + def benchmark_keras_model_subclassed_predict(self): + model = SubclassedKerasModel(initializer="glorot_uniform") + self._benchmark_keras_model_predict(model) + + def benchmark_keras_model_subclassed_predict_run_model_eagerly(self): + model = SubclassedKerasModel(initializer="glorot_uniform") + self._benchmark_keras_model_predict(model, run_eagerly=True) + + def benchmark_keras_model_functional_predict(self): + model = make_keras_model(initializer="glorot_uniform") + self._benchmark_keras_model_predict(model) + + def benchmark_keras_model_functional_predict_run_model_eagerly(self): + model = make_keras_model(initializer="glorot_uniform") + self._benchmark_keras_model_predict(model, run_eagerly=True) + + def benchmark_keras_model_sequential_predict(self): + model = make_sequential_keras_model(initializer="glorot_uniform") + self._benchmark_keras_model_predict(model) + + def benchmark_keras_model_sequential_predict_run_model_eagerly(self): + model = make_sequential_keras_model(initializer="glorot_uniform") + self._benchmark_keras_model_predict(model, run_eagerly=True) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/keras/tests/memory_test.py b/tensorflow/python/keras/tests/memory_test.py new file mode 100644 index 00000000000..753820d3295 --- /dev/null +++ b/tensorflow/python/keras/tests/memory_test.py @@ -0,0 +1,80 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for memory leaks in eager execution. + +It is possible that this test suite will eventually become flaky due to taking +too long to run (since the tests iterate many times), but for now they are +helpful for finding memory leaks since not all PyObject leaks are found by +introspection (test_util decorators). Please be careful adding new tests here. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python import keras +from tensorflow.python.eager import backprop +from tensorflow.python.eager import test +from tensorflow.python.eager.memory_tests import memory_test_util +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops + + +class SingleLayerNet(keras.Model): + """Simple keras model used to ensure that there are no leaks.""" + + def __init__(self): + super(SingleLayerNet, self).__init__() + self.fc1 = keras.layers.Dense(5) + + def call(self, x): + return self.fc1(x) + + +class MemoryTest(test.TestCase): + + def testMemoryLeakInSimpleModelForwardOnly(self): + if not memory_test_util.memory_profiler_is_available(): + self.skipTest("memory_profiler required to run this test") + + inputs = array_ops.zeros([32, 100], dtypes.float32) + net = SingleLayerNet() + + def f(): + with backprop.GradientTape(): + net(inputs) + + memory_test_util.assert_no_leak(f) + + def testMemoryLeakInSimpleModelForwardAndBackward(self): + if not memory_test_util.memory_profiler_is_available(): + self.skipTest("memory_profiler required to run this test") + + inputs = array_ops.zeros([32, 100], dtypes.float32) + net = SingleLayerNet() + + def f(): + with backprop.GradientTape() as tape: + result = net(inputs) + + tape.gradient(result, net.variables) + + del tape + + memory_test_util.assert_no_leak(f) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/keras/utils/tf_utils.py b/tensorflow/python/keras/utils/tf_utils.py index ad327ac8988..220df9c7f8a 100644 --- a/tensorflow/python/keras/utils/tf_utils.py +++ b/tensorflow/python/keras/utils/tf_utils.py @@ -187,6 +187,11 @@ def map_structure_with_atomic(is_atomic_fn, map_fn, nested): return nest._sequence_like(nested, mapped_values) +def get_shapes(tensors): + """Gets shapes from tensors.""" + return nest.map_structure(lambda x: x.shape, tensors) + + # pylint: enable=protected-access diff --git a/tensorflow/python/kernel_tests/argmax_op_test.py b/tensorflow/python/kernel_tests/argmax_op_test.py index 86d2941b8d3..023766c899d 100644 --- a/tensorflow/python/kernel_tests/argmax_op_test.py +++ b/tensorflow/python/kernel_tests/argmax_op_test.py @@ -68,6 +68,14 @@ class ArgMaxTest(test.TestCase): self._testBothArg(math_ops.argmax, x, 0, x.argmax()) self._testBothArg(math_ops.argmin, x, 0, x.argmin()) + def _testTieBreaking(self, dtype): + x = np.zeros(200, dtype=dtype) + + # Check that argmin and argmax match numpy along the primary axis for + # breaking ties. + self._testBothArg(math_ops.argmax, x, 0, x.argmax()) + self._testBothArg(math_ops.argmin, x, 0, x.argmin()) + def _testDim(self, dtype): shape = (3, 2, 4, 5, 6, 3, 7) x = np.arange(functools.reduce(lambda x, y: x * y, shape), dtype=dtype) @@ -81,6 +89,7 @@ class ArgMaxTest(test.TestCase): def testFloat(self): self._testBasic(np.float32) + self._testTieBreaking(np.float32) self._testDim(np.float32) def testFloatInt32Output(self): @@ -102,14 +111,17 @@ class ArgMaxTest(test.TestCase): def testDouble(self): self._testBasic(np.float64) + self._testTieBreaking(np.float64) self._testDim(np.float64) def testInt32(self): self._testBasic(np.int32) + self._testTieBreaking(np.int32) self._testDim(np.int32) def testInt64(self): self._testBasic(np.int64) + self._testTieBreaking(np.int64) self._testDim(np.int64) def testEmpty(self): diff --git a/tensorflow/python/lib/core/bfloat16_wrapper.cc b/tensorflow/python/lib/core/bfloat16_wrapper.cc index 4a8e180c154..eb346af896a 100644 --- a/tensorflow/python/lib/core/bfloat16_wrapper.cc +++ b/tensorflow/python/lib/core/bfloat16_wrapper.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "include/pybind11/pybind11.h" +#include "pybind11/pybind11.h" #include "tensorflow/python/lib/core/bfloat16.h" PYBIND11_MODULE(_pywrap_bfloat16, m) { diff --git a/tensorflow/python/lib/core/py_exception_registry_wrapper.cc b/tensorflow/python/lib/core/py_exception_registry_wrapper.cc index 2ae56c3f671..c964d5c64c8 100644 --- a/tensorflow/python/lib/core/py_exception_registry_wrapper.cc +++ b/tensorflow/python/lib/core/py_exception_registry_wrapper.cc @@ -17,8 +17,8 @@ limitations under the License. #include -#include "include/pybind11/pybind11.h" -#include "include/pybind11/pytypes.h" +#include "pybind11/pybind11.h" +#include "pybind11/pytypes.h" #include "tensorflow/python/lib/core/py_exception_registry.h" namespace py = pybind11; diff --git a/tensorflow/python/lib/core/py_func_wrapper.cc b/tensorflow/python/lib/core/py_func_wrapper.cc index 7c3209a3f50..34e6b2df80a 100644 --- a/tensorflow/python/lib/core/py_func_wrapper.cc +++ b/tensorflow/python/lib/core/py_func_wrapper.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "include/pybind11/pybind11.h" +#include "pybind11/pybind11.h" #include "tensorflow/python/lib/core/py_func.h" namespace py = pybind11; diff --git a/tensorflow/python/lib/core/pybind11_absl.h b/tensorflow/python/lib/core/pybind11_absl.h index db3631dc643..8a05d419654 100644 --- a/tensorflow/python/lib/core/pybind11_absl.h +++ b/tensorflow/python/lib/core/pybind11_absl.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_PYTHON_LIB_CORE_PYBIND11_ABSL_H_ #define TENSORFLOW_PYTHON_LIB_CORE_PYBIND11_ABSL_H_ -#include "include/pybind11/pybind11.h" +#include "pybind11/pybind11.h" #include "tensorflow/core/platform/stringpiece.h" #ifndef ABSL_USES_STD_STRING_VIEW diff --git a/tensorflow/python/lib/core/pybind11_lib.h b/tensorflow/python/lib/core/pybind11_lib.h index 93cae530337..cc2a118d93f 100644 --- a/tensorflow/python/lib/core/pybind11_lib.h +++ b/tensorflow/python/lib/core/pybind11_lib.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "include/pybind11/pybind11.h" -#include "include/pybind11/pytypes.h" +#include "pybind11/pybind11.h" +#include "pybind11/pytypes.h" #ifndef TENSORFLOW_PYTHON_LIB_CORE_PYBIND11_LIB_H_ #define TENSORFLOW_PYTHON_LIB_CORE_PYBIND11_LIB_H_ diff --git a/tensorflow/python/lib/core/pybind11_proto.h b/tensorflow/python/lib/core/pybind11_proto.h index e99518f2ac4..d69d717d5a0 100644 --- a/tensorflow/python/lib/core/pybind11_proto.h +++ b/tensorflow/python/lib/core/pybind11_proto.h @@ -17,7 +17,7 @@ limitations under the License. #define TENSORFLOW_PYTHON_LIB_CORE_PYBIND11_PROTO_H_ #include "absl/strings/str_cat.h" -#include "include/pybind11/pybind11.h" +#include "pybind11/pybind11.h" namespace tensorflow { diff --git a/tensorflow/python/lib/core/pybind11_status.h b/tensorflow/python/lib/core/pybind11_status.h index b3ef3260792..feb974798de 100644 --- a/tensorflow/python/lib/core/pybind11_status.h +++ b/tensorflow/python/lib/core/pybind11_status.h @@ -18,7 +18,7 @@ limitations under the License. #include -#include "include/pybind11/pybind11.h" +#include "pybind11/pybind11.h" #include "tensorflow/c/tf_status.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/protobuf/error_codes.pb.h" diff --git a/tensorflow/python/lib/io/file_io_wrapper.cc b/tensorflow/python/lib/io/file_io_wrapper.cc index e104881a64d..de806a9c969 100644 --- a/tensorflow/python/lib/io/file_io_wrapper.cc +++ b/tensorflow/python/lib/io/file_io_wrapper.cc @@ -17,8 +17,8 @@ limitations under the License. #include #include -#include "include/pybind11/pybind11.h" -#include "include/pybind11/stl.h" +#include "pybind11/pybind11.h" +#include "pybind11/stl.h" #include "tensorflow/core/lib/core/error_codes.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" diff --git a/tensorflow/python/lib/io/record_io_wrapper.cc b/tensorflow/python/lib/io/record_io_wrapper.cc index ba71920bf80..d558301eeba 100644 --- a/tensorflow/python/lib/io/record_io_wrapper.cc +++ b/tensorflow/python/lib/io/record_io_wrapper.cc @@ -17,7 +17,7 @@ limitations under the License. #include #include "absl/memory/memory.h" -#include "include/pybind11/pybind11.h" +#include "pybind11/pybind11.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/io/record_reader.h" diff --git a/tensorflow/python/lite/toco_python_api_wrapper.cc b/tensorflow/python/lite/toco_python_api_wrapper.cc index 1976aa9d3da..2cc35534bea 100644 --- a/tensorflow/python/lite/toco_python_api_wrapper.cc +++ b/tensorflow/python/lite/toco_python_api_wrapper.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "include/pybind11/pybind11.h" +#include "pybind11/pybind11.h" #include "tensorflow/lite/toco/python/toco_python_api.h" #include "tensorflow/python/lib/core/pybind11_lib.h" diff --git a/tensorflow/python/mlir_wrapper.cc b/tensorflow/python/mlir_wrapper.cc index 3e835663101..9b40ab4cb71 100644 --- a/tensorflow/python/mlir_wrapper.cc +++ b/tensorflow/python/mlir_wrapper.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "include/pybind11/pybind11.h" +#include "pybind11/pybind11.h" #include "tensorflow/c/tf_status.h" #include "tensorflow/compiler/mlir/python/mlir.h" #include "tensorflow/python/lib/core/pybind11_lib.h" diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index 7c5c7b9a18d..e163cf90eb7 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -199,6 +199,8 @@ def reshape(tensor, shape, name=None): # pylint: disable=redefined-outer-name def fill(dims, value, name=None): r"""Creates a tensor filled with a scalar value. + See also `tf.ones`, `tf.zeros`, `tf.one_hot`, `tf.eye`. + This operation creates a tensor of shape `dims` and fills it with `value`. For example: @@ -551,7 +553,7 @@ def shape_v2(input, out_type=dtypes.int32, name=None): # pylint: disable=redefined-builtin """Returns the shape of a tensor. - See also `tf.size`. + See also `tf.size`, `tf.rank`. This operation returns a 1-D integer tensor representing the shape of `input`. This represents the minimal set of known information at definition time. @@ -775,6 +777,8 @@ def rank(input, name=None): # pylint: disable=redefined-builtin """Returns the rank of a tensor. + See also `tf.shape`. + Returns a 0-D `int32` `Tensor` representing the rank of `input`. For example: @@ -1006,6 +1010,8 @@ def slice(input_, begin, size, name=None): # pylint: disable=redefined-builtin """Extracts a slice from a tensor. + See also `tf.strided_slice`. + This operation extracts a slice of size `size` from a tensor `input_` starting at the location specified by `begin`. The slice `size` is represented as a tensor shape, where `size[i]` is the number of elements of the 'i'th dimension @@ -1068,6 +1074,8 @@ def strided_slice(input_, name=None): """Extracts a strided slice of a tensor (generalized python array indexing). + See also `tf.slice`. + **Instead of calling this op directly most users will want to use the NumPy-style slicing syntax (e.g. `tensor[..., 3:4:-1, tf.newaxis, 3]`), which is supported via `tf.Tensor.__getitem__` and `tf.Variable.__getitem__`.** @@ -1810,6 +1818,8 @@ def sparse_mask(a, mask_indices, name=None): def unique(x, out_idx=dtypes.int32, name=None): """Finds unique elements in a 1-D tensor. + See also `tf.unique_with_counts`. + This operation returns a tensor `y` containing all of the unique elements of `x` sorted in the same order that they occur in `x`. This operation also returns a tensor `idx` the same size as `x` that contains the index @@ -1855,6 +1865,8 @@ unique.__doc__ = gen_array_ops.unique.__doc__ def unique_with_counts(x, out_idx=dtypes.int32, name=None): """Finds unique elements in a 1-D tensor. + See also `tf.unique`. + This operation returns a tensor `y` containing all of the unique elements of `x` sorted in the same order that they occur in `x`. This operation also returns a tensor `idx` the same size as `x` that contains the index @@ -2702,6 +2714,8 @@ def _tag_zeros_tensor(fun): def zeros(shape, dtype=dtypes.float32, name=None): """Creates a tensor with all elements set to zero. + See also `tf.zeros_like`, `tf.ones`, `tf.fill`, `tf.eye`. + This operation returns a tensor of type `dtype` with shape `shape` and all elements set to zero. @@ -2951,7 +2965,7 @@ def ones_like_impl(tensor, dtype, name, optimize=True): def ones(shape, dtype=dtypes.float32, name=None): """Creates a tensor with all elements set to one (1). - See also `tf.ones_like`. + See also `tf.ones_like`, `tf.zeros`, `tf.fill`, `tf.eye`. This operation returns a tensor of type `dtype` with shape `shape` and all elements set to one. @@ -3877,6 +3891,8 @@ def one_hot(indices, name=None): """Returns a one-hot tensor. + See also `tf.fill`, `tf.eye`. + The locations represented by indices in `indices` take value `on_value`, while all other locations take value `off_value`. diff --git a/tensorflow/python/ops/functional_ops.py b/tensorflow/python/ops/functional_ops.py index 63c653b5df1..fe66e8ccdfb 100644 --- a/tensorflow/python/ops/functional_ops.py +++ b/tensorflow/python/ops/functional_ops.py @@ -443,6 +443,8 @@ def scan(fn, name=None): """scan on the list of tensors unpacked from `elems` on dimension 0. + See also `tf.map_fn`. + The simplest version of `scan` repeatedly applies the callable `fn` to a sequence of elements from first to last. The elements are made of the tensors unpacked from `elems` on dimension 0. The callable fn takes two tensors as diff --git a/tensorflow/python/ops/linalg_ops.py b/tensorflow/python/ops/linalg_ops.py index 52d216cfd71..abca7df19e0 100644 --- a/tensorflow/python/ops/linalg_ops.py +++ b/tensorflow/python/ops/linalg_ops.py @@ -194,6 +194,8 @@ def eye(num_rows, name=None): """Construct an identity matrix, or a batch of matrices. + See also `tf.ones`, `tf.zeros`, `tf.fill`, `tf.one_hot`. + ```python # Construct one identity matrix. tf.eye(2) diff --git a/tensorflow/python/ops/map_fn.py b/tensorflow/python/ops/map_fn.py index 632bfbc21e7..2c9c678336e 100644 --- a/tensorflow/python/ops/map_fn.py +++ b/tensorflow/python/ops/map_fn.py @@ -53,6 +53,8 @@ def map_fn(fn, fn_output_signature=None): """Transforms `elems` by applying `fn` to each element unstacked on axis 0. + See also `tf.scan`. + `map_fn` unstacks `elems` on axis 0 to obtain a sequence of elements; calls `fn` to transform each element; and then stacks the transformed values back together. diff --git a/tensorflow/python/platform/stacktrace_handler_wrapper.cc b/tensorflow/python/platform/stacktrace_handler_wrapper.cc index 7d16f44c8c0..908127b3cca 100644 --- a/tensorflow/python/platform/stacktrace_handler_wrapper.cc +++ b/tensorflow/python/platform/stacktrace_handler_wrapper.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "include/pybind11/pybind11.h" +#include "pybind11/pybind11.h" #include "tensorflow/core/platform/stacktrace_handler.h" namespace py = pybind11; diff --git a/tensorflow/python/profiler/internal/profiler_wrapper.cc b/tensorflow/python/profiler/internal/profiler_wrapper.cc index ba64e6a71d1..a8b3ea814e9 100644 --- a/tensorflow/python/profiler/internal/profiler_wrapper.cc +++ b/tensorflow/python/profiler/internal/profiler_wrapper.cc @@ -16,8 +16,8 @@ limitations under the License. #include #include "absl/memory/memory.h" -#include "include/pybind11/pybind11.h" -#include "include/pybind11/pytypes.h" +#include "pybind11/pybind11.h" +#include "pybind11/pytypes.h" #include "tensorflow/core/platform/host_info.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/profiler/convert/xplane_to_profile_response.h" @@ -93,17 +93,18 @@ class ProfilerSessionWrapper { } private: - tensorflow::profiler::ProfilerOptions GetOptions(const py::dict& opts) { - tensorflow::profiler::ProfilerOptions options; + tensorflow::ProfileOptions GetOptions(const py::dict& opts) { + tensorflow::ProfileOptions options = + tensorflow::ProfilerSession::DefaultOptions(); for (const auto& kw : opts) { std::string key = py::cast(kw.first); if (key == "host_tracer_level") { - options.host_tracer_level = py::cast(kw.second); - VLOG(1) << "host_tracer_level set to " << options.host_tracer_level; + options.set_host_tracer_level(py::cast(kw.second)); + VLOG(1) << "host_tracer_level set to " << options.host_tracer_level(); } else if (key == "python_tracer_level") { - options.enable_python_tracer = py::cast(kw.second) > 0; + options.set_python_tracer_level(py::cast(kw.second)); VLOG(1) << "enable_python_tracer set to " - << options.enable_python_tracer; + << options.python_tracer_level(); } } return options; diff --git a/tensorflow/python/profiler/internal/python_hooks.h b/tensorflow/python/profiler/internal/python_hooks.h index 50d890db55b..8a9ce645ca9 100644 --- a/tensorflow/python/profiler/internal/python_hooks.h +++ b/tensorflow/python/profiler/internal/python_hooks.h @@ -19,9 +19,9 @@ limitations under the License. #include #include "absl/container/flat_hash_map.h" -#include "include/pybind11/cast.h" -#include "include/pybind11/pybind11.h" -#include "include/pybind11/pytypes.h" +#include "pybind11/cast.h" +#include "pybind11/pybind11.h" +#include "pybind11/pytypes.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/profiler/lib/traceme.h" diff --git a/tensorflow/python/profiler/internal/scoped_annotation_wrapper.cc b/tensorflow/python/profiler/internal/scoped_annotation_wrapper.cc index 8a3364cd859..078ebb0966c 100644 --- a/tensorflow/python/profiler/internal/scoped_annotation_wrapper.cc +++ b/tensorflow/python/profiler/internal/scoped_annotation_wrapper.cc @@ -16,7 +16,7 @@ limitations under the License. #include #include "absl/types/optional.h" -#include "include/pybind11/pybind11.h" +#include "pybind11/pybind11.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/profiler/lib/scoped_annotation.h" diff --git a/tensorflow/python/profiler/internal/traceme_wrapper.cc b/tensorflow/python/profiler/internal/traceme_wrapper.cc index a2705d54df1..a1b5370836b 100644 --- a/tensorflow/python/profiler/internal/traceme_wrapper.cc +++ b/tensorflow/python/profiler/internal/traceme_wrapper.cc @@ -16,7 +16,7 @@ limitations under the License. #include #include "absl/types/optional.h" -#include "include/pybind11/pybind11.h" +#include "pybind11/pybind11.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/profiler/lib/traceme.h" @@ -31,6 +31,12 @@ class TraceMeWrapper { void Enter() { traceme_.emplace(std::move(name_)); } + void SetMetadata(const tensorflow::string& new_metadata) { + if (TF_PREDICT_TRUE(traceme_)) { + traceme_->SetMetadata(new_metadata); + } + } + void Exit() { traceme_.reset(); } static bool IsEnabled() { return tensorflow::profiler::TraceMe::Active(); } @@ -47,5 +53,6 @@ PYBIND11_MODULE(_pywrap_traceme, m) { traceme_class.def(py::init()) .def("Enter", &TraceMeWrapper::Enter) .def("Exit", &TraceMeWrapper::Exit) + .def("SetMetadata", &TraceMeWrapper::SetMetadata) .def_static("IsEnabled", &TraceMeWrapper::IsEnabled); }; diff --git a/tensorflow/python/profiler/trace.py b/tensorflow/python/profiler/trace.py index 88f38bd4968..424bdd6f3fc 100644 --- a/tensorflow/python/profiler/trace.py +++ b/tensorflow/python/profiler/trace.py @@ -24,6 +24,23 @@ from tensorflow.python.profiler.internal import _pywrap_traceme from tensorflow.python.util.tf_export import tf_export +def encode_metadata(metadata): + """Encodes the given metadata to a string. + + Args: + metadata: in key-value pairs. + + Returns: + The encoded string. + """ + if not metadata: + return '' + content = [] + for key, value in six.iteritems(metadata): + content.append('%s=%s'%(key, value)) + return '#' + ','.join(content) + '#' + + @tf_export('profiler.experimental.Trace', v1=[]) class Trace(object): """Context manager that generates a trace event in the profiler. @@ -53,11 +70,29 @@ class Trace(object): Args: name: The name of the trace event. **kwargs: Keyword arguments added to the trace event. + Both the key and value are of types that + can be converted to strings, which will be + interpreted by the profiler according to the + traceme name. + + Example usage: + + ```python + + tf.profiler.experimental.start('logdir') + for step in range(num_steps): + # Creates a trace event for each training step with the + # step number. + with tf.profiler.experimental.Trace("Train", step_num=step): + train_fn() + tf.profiler.experimental.stop() + + ``` + The example above uses the keyword argument "step_num" to specify the + training step being traced. """ if _pywrap_traceme.TraceMe.IsEnabled(): - if kwargs: - name += '#' + ','.join(key + '=' + str(value) - for key, value in six.iteritems(kwargs)) + '#' + name += encode_metadata(kwargs) self._traceme = _pywrap_traceme.TraceMe(name) else: self._traceme = None @@ -65,6 +100,42 @@ class Trace(object): def __enter__(self): if self._traceme: self._traceme.Enter() + return self + + def set_metadata(self, **kwargs): + """Sets metadata in this trace event. + + Args: + **kwargs: metadata in key-value pairs. + + This method enables setting metadata in a trace event after it is + created. + + Example usage: + + ```python + + def call(function): + with tf.profiler.experimental.Trace("call", + function_name=function.name) as tm: + binary, in_cache = jit_compile(function) + tm.set_metadata(in_cache=in_cache) + execute(binary) + + ``` + In this example, we want to trace how much time spent on + calling a function, which includes compilation and execution. + The compilation can be either getting a cached copy of the + binary or actually generating the binary, which is indicated + by the boolean "in_cache" returned by jit_compile(). We need + to use set_metadata() to pass in_cache because we did not know + the in_cache value when the trace was created (and we cannot + create the trace after jit_compile(), because we want + to measure the entire duration of call()). + """ + if self._traceme and kwargs: + additional_metadata = encode_metadata(kwargs) + self._traceme.SetMetadata(additional_metadata) def __exit__(self, exc_type, exc_val, exc_tb): if self._traceme: diff --git a/tensorflow/python/pywrap_tensorflow.py b/tensorflow/python/pywrap_tensorflow.py index 642cf546d20..863e2106f08 100644 --- a/tensorflow/python/pywrap_tensorflow.py +++ b/tensorflow/python/pywrap_tensorflow.py @@ -1,4 +1,4 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================= -"""A wrapper for TensorFlow SWIG-generated bindings.""" +"""A Python wrapper that loads _pywrap_tensorflow_internal.so.""" from __future__ import absolute_import from __future__ import division @@ -24,9 +24,7 @@ import traceback from tensorflow.python.platform import self_check - -# Perform pre-load sanity checks in order to produce a more actionable error -# than we get from an error during SWIG import. +# Perform pre-load sanity checks in order to produce a more actionable error. self_check.preload_check() # pylint: disable=wildcard-import,g-import-not-at-top,unused-import,line-too-long @@ -39,10 +37,10 @@ try: except ImportError: _use_dlopen_global_flags = False -# On UNIX-based platforms, pywrap_tensorflow is a SWIG-generated -# python library that dynamically loads _pywrap_tensorflow.so. -_can_set_rtld_local = (hasattr(sys, 'getdlopenflags') - and hasattr(sys, 'setdlopenflags')) +# On UNIX-based platforms, pywrap_tensorflow is a python library that +# dynamically loads _pywrap_tensorflow.so. +_can_set_rtld_local = ( + hasattr(sys, 'getdlopenflags') and hasattr(sys, 'setdlopenflags')) if _can_set_rtld_local: _default_dlopen_flags = sys.getdlopenflags() @@ -55,7 +53,23 @@ try: # override an RTLD_GLOBAL in _default_dlopen_flags). sys.setdlopenflags(_default_dlopen_flags | ctypes.RTLD_LOCAL) - from tensorflow.python.pywrap_tensorflow_internal import * + # Python2.7 does not have a ModuleNotFoundError. + try: + ModuleNotFoundError + except NameError: + ModuleNotFoundError = ImportError + + # pylint: disable=wildcard-import,g-import-not-at-top,line-too-long,undefined-variable + try: + from tensorflow.python._pywrap_tensorflow_internal import * + # This try catch logic is because there is no bazel equivalent for py_extension. + # Externally in opensource we must enable exceptions to load the shared object + # by exposing the PyInit symbols with pybind. This error will only be + # caught internally or if someone changes the name of the target _pywrap_tensorflow_internal. + + # This logic is used in other internal projects using py_extension. + except ModuleNotFoundError: + pass if _use_dlopen_global_flags: pywrap_dlopen_global_flags.reset_dlopen_flags() diff --git a/tensorflow/python/pywrap_tensorflow_internal.cc b/tensorflow/python/pywrap_tensorflow_internal.cc new file mode 100644 index 00000000000..8e8f69ec5f8 --- /dev/null +++ b/tensorflow/python/pywrap_tensorflow_internal.cc @@ -0,0 +1,21 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "pybind11/pybind11.h" + +// This logic allows Python to import _pywrap_tensorflow_internal.so by +// creating a PyInit function and exposing it. It is required in opensource +// only. +PYBIND11_MODULE(_pywrap_tensorflow_internal, m){}; diff --git a/tensorflow/python/saved_model/function_deserialization.py b/tensorflow/python/saved_model/function_deserialization.py index e8a9514dd56..9fcffc8ccdf 100644 --- a/tensorflow/python/saved_model/function_deserialization.py +++ b/tensorflow/python/saved_model/function_deserialization.py @@ -93,16 +93,6 @@ def _concrete_function_callable_with(function, inputs, allow_conversion): flatten_inputs = nest.flatten_up_to(expected_structure, inputs) except (TypeError, ValueError): return False - try: - # Verify that no input elements were dropped during flattening. - repacked = nest.pack_sequence_as(expected_structure, flatten_inputs) - # TODO(b/129422719): Namedtuple subclasses re-created through - # saved_model.load don't compare equal in type to the original in - # assert_same_structure. Fix that and we can take out check_types=False - # here. - nest.assert_same_structure(inputs, repacked, check_types=False) - except (TypeError, ValueError): - return False for arg, expected in zip(flatten_inputs, nest.flatten(expected_structure)): if isinstance(expected, tensor_spec.TensorSpec): diff --git a/tensorflow/python/saved_model/utils_impl.py b/tensorflow/python/saved_model/utils_impl.py index 45dfca5617a..42e971d050d 100644 --- a/tensorflow/python/saved_model/utils_impl.py +++ b/tensorflow/python/saved_model/utils_impl.py @@ -178,7 +178,7 @@ def get_tensor_from_tensor_info(tensor_info, graph=None, import_scope=None): spec = struct_coder.decode_proto(spec_proto) components = [_get_tensor(component.name) for component in tensor_info.composite_tensor.components] - return spec.from_components(components) + return spec._from_components(components) # pylint: disable=protected-access else: raise ValueError("Invalid TensorInfo.encoding: %s" % encoding) diff --git a/tensorflow/python/tfcompile_wrapper.cc b/tensorflow/python/tfcompile_wrapper.cc index ac69d326663..c8818309919 100644 --- a/tensorflow/python/tfcompile_wrapper.cc +++ b/tensorflow/python/tfcompile_wrapper.cc @@ -15,10 +15,10 @@ limitations under the License. #include -#include "include/pybind11/cast.h" -#include "include/pybind11/pybind11.h" -#include "include/pybind11/pytypes.h" -#include "include/pybind11/stl.h" +#include "pybind11/cast.h" +#include "pybind11/pybind11.h" +#include "pybind11/pytypes.h" +#include "pybind11/stl.h" #include "tensorflow/compiler/aot/compile.h" #include "tensorflow/compiler/aot/flags.h" #include "tensorflow/python/lib/core/pybind11_lib.h" diff --git a/tensorflow/python/tfe_wrapper.cc b/tensorflow/python/tfe_wrapper.cc index 6891c0411df..f0839cb5721 100644 --- a/tensorflow/python/tfe_wrapper.cc +++ b/tensorflow/python/tfe_wrapper.cc @@ -16,11 +16,11 @@ limitations under the License. #include #include "Python.h" -#include "include/pybind11/chrono.h" -#include "include/pybind11/complex.h" -#include "include/pybind11/functional.h" -#include "include/pybind11/pybind11.h" -#include "include/pybind11/stl.h" +#include "pybind11/chrono.h" +#include "pybind11/complex.h" +#include "pybind11/functional.h" +#include "pybind11/pybind11.h" +#include "pybind11/stl.h" #include "tensorflow/c/c_api.h" #include "tensorflow/c/c_api_experimental.h" #include "tensorflow/c/eager/c_api.h" @@ -1100,6 +1100,39 @@ PYBIND11_MODULE(_pywrap_tfe, m) { return py::handle(EagerTensorFromHandle(thandle)); }); + m.def("TFE_Py_RegisterCustomDevice", [](const py::handle& context, + const py::capsule& device, + const char* device_name, + const py::capsule& device_info) { + tensorflow::Safe_TF_StatusPtr status = + tensorflow::make_safe(TF_NewStatus()); + if (absl::string_view(device.name()) != "TFE_CustomDevice") { + status->status = tensorflow::errors::InvalidArgument( + "Expected a capsule named 'TFE_CustomDevice' for the `device` " + "argument, got ", + absl::string_view(device.name())); + tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); + } + if (absl::string_view(device_info.name()) != + "TFE_CustomDevice_DeviceInfo") { + status->status = tensorflow::errors::InvalidArgument( + "Expected a capsule named 'TFE_CustomDevice_DeviceInfo' for " + "the `device_info` argument, got ", + absl::string_view(device_info.name())); + tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); + } + // TFE_RegisterCustomDevice takes ownership + PyCapsule_SetDestructor(device_info.ptr(), nullptr); + TFE_RegisterCustomDevice( + tensorflow::InputTFE_Context(context), + *reinterpret_cast( + PyCapsule_GetPointer(device.ptr(), "TFE_CustomDevice")), + device_name, + PyCapsule_GetPointer(device_info.ptr(), "TFE_CustomDevice_DeviceInfo"), + status.get()); + tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); + }); + // C API Enum py::enum_( diff --git a/tensorflow/python/training/quantize_training_wrapper.cc b/tensorflow/python/training/quantize_training_wrapper.cc index f4173553ed6..27abc27c4a2 100644 --- a/tensorflow/python/training/quantize_training_wrapper.cc +++ b/tensorflow/python/training/quantize_training_wrapper.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "include/pybind11/pybind11.h" +#include "pybind11/pybind11.h" #include "tensorflow/core/graph/quantize_training.h" #include "tensorflow/python/lib/core/pybind11_lib.h" #include "tensorflow/python/lib/core/pybind11_status.h" diff --git a/tensorflow/python/util/kernel_registry_wrapper.cc b/tensorflow/python/util/kernel_registry_wrapper.cc index 625b6598e10..3f607c13b12 100644 --- a/tensorflow/python/util/kernel_registry_wrapper.cc +++ b/tensorflow/python/util/kernel_registry_wrapper.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "include/pybind11/pybind11.h" +#include "pybind11/pybind11.h" #include "tensorflow/python/util/kernel_registry.h" namespace py = pybind11; diff --git a/tensorflow/python/util/nest.py b/tensorflow/python/util/nest.py index afefa502593..517030193de 100644 --- a/tensorflow/python/util/nest.py +++ b/tensorflow/python/util/nest.py @@ -158,8 +158,15 @@ def _sequence_like(instance, args): instance_type = type(instance) tf_logging.log_first_n( tf_logging.WARN, "Mapping types may not work well with tf.nest. Prefer" - "using MutableMapping for {}".format(instance_type), 1) - return instance_type((key, result[key]) for key in instance) + " using MutableMapping for {}".format(instance_type), 1) + try: + return instance_type((key, result[key]) for key in instance) + except TypeError as err: + raise TypeError("Error creating an object of type {} like {}. Note that " + "it must accept a single positional argument " + "representing an iterable of key-value pairs, in " + "addition to self. Cause: {}".format( + type(instance), instance, err)) elif _is_mapping_view(instance): # We can't directly construct mapping views, so we create a list instead return list(args) diff --git a/tensorflow/python/util/port_wrapper.cc b/tensorflow/python/util/port_wrapper.cc index c1b102f328b..8f0ab98b778 100644 --- a/tensorflow/python/util/port_wrapper.cc +++ b/tensorflow/python/util/port_wrapper.cc @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "include/pybind11/pybind11.h" -#include "include/pybind11/pytypes.h" +#include "pybind11/pybind11.h" +#include "pybind11/pytypes.h" #include "tensorflow/core/util/port.h" PYBIND11_MODULE(_pywrap_util_port, m) { diff --git a/tensorflow/python/util/py_checkpoint_reader_wrapper.cc b/tensorflow/python/util/py_checkpoint_reader_wrapper.cc index a7076f6ee29..d034505c66d 100644 --- a/tensorflow/python/util/py_checkpoint_reader_wrapper.cc +++ b/tensorflow/python/util/py_checkpoint_reader_wrapper.cc @@ -18,11 +18,11 @@ limitations under the License. #include "numpy/arrayobject.h" #include "numpy/ufuncobject.h" -#include "include/pybind11/chrono.h" -#include "include/pybind11/complex.h" -#include "include/pybind11/functional.h" -#include "include/pybind11/pybind11.h" -#include "include/pybind11/stl.h" +#include "pybind11/chrono.h" +#include "pybind11/complex.h" +#include "pybind11/functional.h" +#include "pybind11/pybind11.h" +#include "pybind11/stl.h" #include "tensorflow/c/checkpoint_reader.h" #include "tensorflow/c/tf_status.h" #include "tensorflow/core/lib/core/errors.h" diff --git a/tensorflow/python/util/stat_summarizer_wrapper.cc b/tensorflow/python/util/stat_summarizer_wrapper.cc index f46ddc518e0..886c8cb50de 100644 --- a/tensorflow/python/util/stat_summarizer_wrapper.cc +++ b/tensorflow/python/util/stat_summarizer_wrapper.cc @@ -16,8 +16,8 @@ limitations under the License. #include #include "absl/memory/memory.h" -#include "include/pybind11/pybind11.h" -#include "include/pybind11/pytypes.h" +#include "pybind11/pybind11.h" +#include "pybind11/pytypes.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/step_stats.pb.h" #include "tensorflow/core/util/stat_summarizer.h" diff --git a/tensorflow/python/util/tf_stack.cc b/tensorflow/python/util/tf_stack.cc index 4e0f3dd4dc9..5c18c87ad32 100644 --- a/tensorflow/python/util/tf_stack.cc +++ b/tensorflow/python/util/tf_stack.cc @@ -19,8 +19,8 @@ limitations under the License. #include #include -#include "include/pybind11/pybind11.h" -#include "include/pybind11/stl_bind.h" +#include "pybind11/pybind11.h" +#include "pybind11/stl_bind.h" struct FrameSummary; // Forward declaration. diff --git a/tensorflow/python/util/tfprof_wrapper.cc b/tensorflow/python/util/tfprof_wrapper.cc index 0d7b51840bb..55fc4f0e87d 100644 --- a/tensorflow/python/util/tfprof_wrapper.cc +++ b/tensorflow/python/util/tfprof_wrapper.cc @@ -15,7 +15,7 @@ limitations under the License. #include -#include "include/pybind11/pybind11.h" +#include "pybind11/pybind11.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/profiler/internal/print_model_analysis.h" diff --git a/tensorflow/python/util/transform_graph_wrapper.cc b/tensorflow/python/util/transform_graph_wrapper.cc index 1859f0a2b5b..24d67c7b6f7 100644 --- a/tensorflow/python/util/transform_graph_wrapper.cc +++ b/tensorflow/python/util/transform_graph_wrapper.cc @@ -16,7 +16,7 @@ limitations under the License. #include #include -#include "include/pybind11/pybind11.h" +#include "pybind11/pybind11.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/python/lib/core/pybind11_status.h" diff --git a/tensorflow/python/util/util.cc b/tensorflow/python/util/util.cc index 6da3fdbf945..1d0dd695d74 100644 --- a/tensorflow/python/util/util.cc +++ b/tensorflow/python/util/util.cc @@ -730,9 +730,9 @@ bool AssertSameStructureHelper( // We treat two different namedtuples with identical name and fields // as having the same type. - const PyObject* o1_tuple = IsNamedtuple(o1, true); + const PyObject* o1_tuple = IsNamedtuple(o1, false); if (o1_tuple == nullptr) return false; - const PyObject* o2_tuple = IsNamedtuple(o2, true); + const PyObject* o2_tuple = IsNamedtuple(o2, false); if (o2_tuple == nullptr) { Py_DECREF(o1_tuple); return false; diff --git a/tensorflow/python/util/util_wrapper.cc b/tensorflow/python/util/util_wrapper.cc index 50ea922ef52..244ab3e5fc2 100644 --- a/tensorflow/python/util/util_wrapper.cc +++ b/tensorflow/python/util/util_wrapper.cc @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "include/pybind11/pybind11.h" -#include "include/pybind11/pytypes.h" +#include "pybind11/pybind11.h" +#include "pybind11/pytypes.h" #include "tensorflow/python/lib/core/pybind11_lib.h" #include "tensorflow/python/util/util.h" diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl index 5e18203844e..dd432ad531a 100644 --- a/tensorflow/tensorflow.bzl +++ b/tensorflow/tensorflow.bzl @@ -1989,6 +1989,148 @@ _append_init_to_versionscript = rule( implementation = _append_init_to_versionscript_impl, ) +# This macro should only be used for pywrap_tensorflow_internal.so. +# It was copied and refined from the original tf_py_wrap_cc rule. +# buildozer: disable=function-docstring-args +def pywrap_tensorflow_macro( + name, + srcs = [], + deps = [], + copts = [], + version_script = None, + **kwargs): + """Builds the pywrap_tensorflow_internal shared object.""" + module_name = name.split("/")[-1] + + # Convert a rule name such as foo/bar/baz to foo/bar/_baz.so + # and use that as the name for the rule producing the .so file. + cc_library_base = "/".join(name.split("/")[:-1] + ["_" + module_name]) + + # TODO(b/137885063): tf_cc_shared_object needs to be cleaned up; we really + # shouldn't be passing a name qualified with .so here. + cc_library_name = cc_library_base + ".so" + cc_library_pyd_name = "/".join( + name.split("/")[:-1] + ["_" + module_name + ".pyd"], + ) + + # We need pybind11 to export the shared object PyInit symbol only in OSS. + extra_deps = ["@pybind11"] + + if not version_script: + version_script = select({ + "@local_config_cuda//cuda:darwin": clean_dep("//tensorflow:tf_exported_symbols.lds"), + "//conditions:default": clean_dep("//tensorflow:tf_version_script.lds"), + }) + vscriptname = name + "_versionscript" + _append_init_to_versionscript( + name = vscriptname, + is_version_script = select({ + "@local_config_cuda//cuda:darwin": False, + "//conditions:default": True, + }), + module_name = module_name, + template_file = version_script, + ) + extra_linkopts = select({ + "@local_config_cuda//cuda:darwin": [ + "-Wl,-exported_symbols_list,$(location %s.lds)" % vscriptname, + ], + clean_dep("//tensorflow:windows"): [], + "//conditions:default": [ + "-Wl,--version-script", + "$(location %s.lds)" % vscriptname, + ], + }) + extra_deps += select({ + "@local_config_cuda//cuda:darwin": [ + "%s.lds" % vscriptname, + ], + clean_dep("//tensorflow:windows"): [], + "//conditions:default": [ + "%s.lds" % vscriptname, + ], + }) + + # Due to b/149224972 we have to add libtensorflow_framework.so + # as a dependency so the linker doesn't try and optimize and + # remove it from pywrap_tensorflow_internal.so + # Issue: https://github.com/tensorflow/tensorflow/issues/34117 + # Fix: https://github.com/tensorflow/tensorflow/commit/5caa9e83798cb510c9b49acee8a64efdb746207c + extra_deps += if_static( + extra_deps = [], + otherwise = [ + clean_dep("//tensorflow:libtensorflow_framework_import_lib"), + ], + ) + + tf_cc_shared_object( + name = cc_library_name, + srcs = srcs, + # framework_so is no longer needed as libtf.so is included via the extra_deps. + framework_so = [], + copts = copts + if_not_windows([ + "-Wno-self-assign", + "-Wno-sign-compare", + "-Wno-write-strings", + ]), + linkopts = extra_linkopts, + linkstatic = 1, + deps = deps + extra_deps, + **kwargs + ) + + # When a non-versioned .so is added as a 'src' to a bazel target, it uses + # -l%(so_name) instead of -l:%(so_file) during linking. When -l%(so_name) + # is passed to ld, it will look for an associated file with the schema + # lib%(so_name).so. Since pywrap_tensorflow is not explicitly versioned + # and is not prefixed with lib_, we add a rule for the creation of an .so + # file with the canonical lib schema (e.g. libNAME.so), so that + # -l%(so_name) is resolved during linking. + # + # See: https://github.com/bazelbuild/bazel/blob/7a6808260a733d50983c1adf0cf5a7493472267f/src/main/java/com/google/devtools/build/lib/rules/cpp/LibrariesToLinkCollector.java#L319 + for pattern in SHARED_LIBRARY_NAME_PATTERNS: + name_os = pattern % (cc_library_base, "") + native.genrule( + name = name_os + "_rule", + srcs = [":" + cc_library_name], + outs = [name_os], + cmd = "cp $< $@", + ) + + native.genrule( + name = "gen_" + cc_library_pyd_name, + srcs = [":" + cc_library_name], + outs = [cc_library_pyd_name], + cmd = "cp $< $@", + ) + + # TODO(amitpatankar): Remove this py_library reference and + # move the dependencies to pywrap_tensorflow. This can + # eliminate one layer of Python import redundancy. We would + # have to change all pywrap_tensorflow imports to + # pywrap_tensorflow_internal. + + # Bazel requires an empty .py file for pywrap_tensorflow_internal.py. + empty_py_file = [name + ".py"] + native.genrule( + name = "empty_py_file_rule", + outs = empty_py_file, + cmd = "touch $@", + ) + + native.py_library( + name = name, + srcs = [":" + name + ".py"], + srcs_version = "PY2AND3", + data = select({ + clean_dep("//tensorflow:windows"): [":" + cc_library_pyd_name], + "//conditions:default": [":" + cc_library_name], + }), + ) + +# DO NOT USE! We are in the process of deprecating this. If you use +# this rule within third_party/tensorflow you will be rolled back. b/153452665 +# buildozer: enable=function-docstring-args def tf_py_wrap_cc( name, srcs = [], @@ -2672,6 +2814,7 @@ def pybind_extension( deprecation = deprecation, restricted_to = restricted_to, compatible_with = compatible_with, + testonly = testonly, ) native.py_library( name = name, @@ -2699,7 +2842,8 @@ def tf_python_pybind_extension( hdrs = [], deps = [], defines = [], - visibility = None): + visibility = None, + testonly = None): """A wrapper macro for pybind_extension that is used in tensorflow/python/BUILD. Please do not use it anywhere else as it may behave unexpectedly. b/146445820 @@ -2718,6 +2862,7 @@ def tf_python_pybind_extension( defines = defines, visibility = visibility, link_in_framework = True, + testonly = testonly, ) def tf_pybind_cc_library_wrapper(name, deps, visibility = None): diff --git a/tensorflow/tf_exported_symbols.lds b/tensorflow/tf_exported_symbols.lds index 8bbd4199f82..734b09005ae 100644 --- a/tensorflow/tf_exported_symbols.lds +++ b/tensorflow/tf_exported_symbols.lds @@ -8,3 +8,4 @@ *nsync_* *stream_executor* *xla* +*PyInit_* diff --git a/tensorflow/tf_version_script.lds b/tensorflow/tf_version_script.lds index 303ba98b9a3..a32da327aaa 100644 --- a/tensorflow/tf_version_script.lds +++ b/tensorflow/tf_version_script.lds @@ -10,6 +10,7 @@ tensorflow { *nsync_*; *stream_executor*; *xla*; + *PyInit_*; local: *; }; diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.-model.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.-model.pbtxt index 414f682473c..081d0639667 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.-model.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.-model.pbtxt @@ -143,7 +143,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.-sequential.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.-sequential.pbtxt index fb929010980..59fec51d8e6 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.-sequential.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.-sequential.pbtxt @@ -148,7 +148,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-linear-model.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-linear-model.pbtxt index 2d64a7bb9e0..f9333d139e3 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-linear-model.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-linear-model.pbtxt @@ -144,7 +144,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-peephole-l-s-t-m-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-peephole-l-s-t-m-cell.pbtxt index 8f35f4b877f..3b12b4e8055 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-peephole-l-s-t-m-cell.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-peephole-l-s-t-m-cell.pbtxt @@ -123,7 +123,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-sequence-features.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-sequence-features.pbtxt index 0b94554a7bb..578fbf03f77 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-sequence-features.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-sequence-features.pbtxt @@ -122,7 +122,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-wide-deep-model.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-wide-deep-model.pbtxt index 4f17a33773c..5cf3162ce42 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-wide-deep-model.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-wide-deep-model.pbtxt @@ -144,7 +144,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-abstract-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-abstract-r-n-n-cell.pbtxt index 7d5a096e0f7..3ba96bab6fe 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-abstract-r-n-n-cell.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-abstract-r-n-n-cell.pbtxt @@ -129,7 +129,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-activation.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-activation.pbtxt index 60f20efb9d5..3f59d9987a5 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-activation.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-activation.pbtxt @@ -121,7 +121,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-activity-regularization.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-activity-regularization.pbtxt index 16ca6f428b4..acc72ebf939 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-activity-regularization.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-activity-regularization.pbtxt @@ -121,7 +121,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-add.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-add.pbtxt index fa0ac16192d..839d57e4c94 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-add.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-add.pbtxt @@ -122,7 +122,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-additive-attention.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-additive-attention.pbtxt index 2cfd5b6c11d..1c22721666b 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-additive-attention.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-additive-attention.pbtxt @@ -122,7 +122,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-alpha-dropout.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-alpha-dropout.pbtxt index b86d5180031..cf883e74088 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-alpha-dropout.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-alpha-dropout.pbtxt @@ -121,7 +121,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-attention.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-attention.pbtxt index 60b4777624f..70800bccf8c 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-attention.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-attention.pbtxt @@ -122,7 +122,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-average-pooling1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-average-pooling1-d.pbtxt index 46c6a028077..11f70522f1a 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-average-pooling1-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-average-pooling1-d.pbtxt @@ -122,7 +122,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-average-pooling2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-average-pooling2-d.pbtxt index 7e3dd70fdb2..ff311806b47 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-average-pooling2-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-average-pooling2-d.pbtxt @@ -122,7 +122,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-average-pooling3-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-average-pooling3-d.pbtxt index 6326191280e..dc3cc76d9e1 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-average-pooling3-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-average-pooling3-d.pbtxt @@ -122,7 +122,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-average.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-average.pbtxt index 701bbcca0e0..6fdcb8c9000 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-average.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-average.pbtxt @@ -122,7 +122,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-avg-pool1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-avg-pool1-d.pbtxt index 955535104b2..a5d912c9b8e 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-avg-pool1-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-avg-pool1-d.pbtxt @@ -122,7 +122,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-avg-pool2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-avg-pool2-d.pbtxt index cb98e5d728d..7471b7306d3 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-avg-pool2-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-avg-pool2-d.pbtxt @@ -122,7 +122,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-avg-pool3-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-avg-pool3-d.pbtxt index 8fdc3bf6cbd..323c0d51988 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-avg-pool3-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-avg-pool3-d.pbtxt @@ -122,7 +122,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-batch-normalization.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-batch-normalization.pbtxt index d0109d12cc8..b5de4b0e7a0 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-batch-normalization.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-batch-normalization.pbtxt @@ -122,7 +122,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-bidirectional.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-bidirectional.pbtxt index 13dfd36608d..16143b3b20e 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-bidirectional.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-bidirectional.pbtxt @@ -126,7 +126,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-concatenate.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-concatenate.pbtxt index bf3900eeb3e..2bea88de2fd 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-concatenate.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-concatenate.pbtxt @@ -122,7 +122,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv-l-s-t-m2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv-l-s-t-m2-d.pbtxt index 3b163115cb4..444220d4e06 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv-l-s-t-m2-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv-l-s-t-m2-d.pbtxt @@ -211,7 +211,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv1-d-transpose.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv1-d-transpose.pbtxt new file mode 100644 index 00000000000..22de9fb79ff --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv1-d-transpose.pbtxt @@ -0,0 +1,220 @@ +path: "tensorflow.keras.layers.Conv1DTranspose" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "dynamic" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "input_spec" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "metrics" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "name_scope" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "stateful" + mtype: "" + } + member { + name: "submodules" + mtype: "" + } + member { + name: "trainable" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'output_padding\', \'data_format\', \'dilation_rate\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'1\', \'valid\', \'None\', \'None\', \'1\', \'None\', \'True\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_metric" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_output_signature" + argspec: "args=[\'self\', \'input_signature\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "with_name_scope" + argspec: "args=[\'cls\', \'method\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv1-d.pbtxt index 7748f763576..b45954626ba 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv1-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv1-d.pbtxt @@ -122,7 +122,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv2-d-transpose.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv2-d-transpose.pbtxt index 7834932b5bb..da6bfec7499 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv2-d-transpose.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv2-d-transpose.pbtxt @@ -123,7 +123,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv2-d.pbtxt index 2d03874d6b1..b66d4fc4d3c 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv2-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv2-d.pbtxt @@ -122,7 +122,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv3-d-transpose.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv3-d-transpose.pbtxt index a2998f59114..4e9ce619361 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv3-d-transpose.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv3-d-transpose.pbtxt @@ -123,7 +123,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv3-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv3-d.pbtxt index e7974563f59..fedb39dbd21 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv3-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv3-d.pbtxt @@ -122,7 +122,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution1-d-transpose.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution1-d-transpose.pbtxt new file mode 100644 index 00000000000..28357ffa0f6 --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution1-d-transpose.pbtxt @@ -0,0 +1,220 @@ +path: "tensorflow.keras.layers.Convolution1DTranspose" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "dynamic" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "input_spec" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "metrics" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "name_scope" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "stateful" + mtype: "" + } + member { + name: "submodules" + mtype: "" + } + member { + name: "trainable" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'output_padding\', \'data_format\', \'dilation_rate\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'1\', \'valid\', \'None\', \'None\', \'1\', \'None\', \'True\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_metric" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_output_signature" + argspec: "args=[\'self\', \'input_signature\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "with_name_scope" + argspec: "args=[\'cls\', \'method\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution1-d.pbtxt index f4906272693..6d97faacece 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution1-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution1-d.pbtxt @@ -122,7 +122,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution2-d-transpose.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution2-d-transpose.pbtxt index 0f6aecd876e..830caf7f693 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution2-d-transpose.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution2-d-transpose.pbtxt @@ -123,7 +123,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution2-d.pbtxt index 7e60cdfdce3..df115f618c7 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution2-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution2-d.pbtxt @@ -122,7 +122,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution3-d-transpose.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution3-d-transpose.pbtxt index 3f750a6200b..69f71b6a3ff 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution3-d-transpose.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution3-d-transpose.pbtxt @@ -123,7 +123,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution3-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution3-d.pbtxt index 3071323b7d1..f58aa3e1baa 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution3-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution3-d.pbtxt @@ -122,7 +122,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-cropping1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-cropping1-d.pbtxt index 98354dbb0d8..44b66135732 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-cropping1-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-cropping1-d.pbtxt @@ -121,7 +121,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-cropping2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-cropping2-d.pbtxt index 5ce76e3974e..63591c0e984 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-cropping2-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-cropping2-d.pbtxt @@ -121,7 +121,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-cropping3-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-cropping3-d.pbtxt index b4ec7544d17..b5e96804759 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-cropping3-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-cropping3-d.pbtxt @@ -121,7 +121,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-cu-d-n-n-g-r-u.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-cu-d-n-n-g-r-u.pbtxt index b52ef058c73..c11ee1eea4c 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-cu-d-n-n-g-r-u.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-cu-d-n-n-g-r-u.pbtxt @@ -131,7 +131,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-cu-d-n-n-l-s-t-m.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-cu-d-n-n-l-s-t-m.pbtxt index 971084500f9..a2a805817fd 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-cu-d-n-n-l-s-t-m.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-cu-d-n-n-l-s-t-m.pbtxt @@ -131,7 +131,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-dense-features.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-dense-features.pbtxt index 8550078d41e..f816c00d9d5 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-dense-features.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-dense-features.pbtxt @@ -122,7 +122,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-dense.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-dense.pbtxt index 7281f900be9..31b101ce81b 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-dense.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-dense.pbtxt @@ -121,7 +121,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-depthwise-conv2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-depthwise-conv2-d.pbtxt index 955f38d90fa..46138e74b4b 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-depthwise-conv2-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-depthwise-conv2-d.pbtxt @@ -123,7 +123,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-dot.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-dot.pbtxt index 741fd7c85cf..4f45a085317 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-dot.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-dot.pbtxt @@ -122,7 +122,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-dropout.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-dropout.pbtxt index af72a0eeabe..869d8d4817b 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-dropout.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-dropout.pbtxt @@ -121,7 +121,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-e-l-u.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-e-l-u.pbtxt index 01749e8743e..33a95bd2312 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-e-l-u.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-e-l-u.pbtxt @@ -121,7 +121,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-embedding.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-embedding.pbtxt index ec70d321c17..35c25eab279 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-embedding.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-embedding.pbtxt @@ -121,7 +121,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-flatten.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-flatten.pbtxt index 2b24e3d5f63..955ec7a0a49 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-flatten.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-flatten.pbtxt @@ -121,7 +121,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-g-r-u-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-g-r-u-cell.pbtxt index b94ccb3ad80..02dc67771b7 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-g-r-u-cell.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-g-r-u-cell.pbtxt @@ -122,7 +122,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-g-r-u.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-g-r-u.pbtxt index eeed5a4f45b..939dde608aa 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-g-r-u.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-g-r-u.pbtxt @@ -194,7 +194,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-gaussian-dropout.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-gaussian-dropout.pbtxt index 8a62b5e06b8..b966a1fa48a 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-gaussian-dropout.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-gaussian-dropout.pbtxt @@ -121,7 +121,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-gaussian-noise.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-gaussian-noise.pbtxt index ca5705d9031..bcadf04ab46 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-gaussian-noise.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-gaussian-noise.pbtxt @@ -121,7 +121,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-average-pooling1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-average-pooling1-d.pbtxt index 5868fa64fa0..93f9f085028 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-average-pooling1-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-average-pooling1-d.pbtxt @@ -122,7 +122,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-average-pooling2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-average-pooling2-d.pbtxt index f04f63ed8ec..c1988faf3d7 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-average-pooling2-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-average-pooling2-d.pbtxt @@ -122,7 +122,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-average-pooling3-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-average-pooling3-d.pbtxt index d3ad8438bc4..516e93110c5 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-average-pooling3-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-average-pooling3-d.pbtxt @@ -122,7 +122,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-avg-pool1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-avg-pool1-d.pbtxt index 5f4f350fda5..545af759275 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-avg-pool1-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-avg-pool1-d.pbtxt @@ -122,7 +122,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-avg-pool2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-avg-pool2-d.pbtxt index a593841db92..13fc0dade36 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-avg-pool2-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-avg-pool2-d.pbtxt @@ -122,7 +122,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-avg-pool3-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-avg-pool3-d.pbtxt index 190ae3e6e34..5c6515f166d 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-avg-pool3-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-avg-pool3-d.pbtxt @@ -122,7 +122,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pool1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pool1-d.pbtxt index d24d9cd6d8a..27bde045cbd 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pool1-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pool1-d.pbtxt @@ -122,7 +122,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pool2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pool2-d.pbtxt index dedc983a0e8..21ee43eb016 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pool2-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pool2-d.pbtxt @@ -122,7 +122,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pool3-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pool3-d.pbtxt index 391feaace06..14fac4a4edd 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pool3-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pool3-d.pbtxt @@ -122,7 +122,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pooling1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pooling1-d.pbtxt index 89cbb5da560..0cc18b9a462 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pooling1-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pooling1-d.pbtxt @@ -122,7 +122,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pooling2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pooling2-d.pbtxt index 4a327d47033..cb26f965881 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pooling2-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pooling2-d.pbtxt @@ -122,7 +122,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pooling3-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pooling3-d.pbtxt index 60c7502868b..aef01152cfe 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pooling3-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pooling3-d.pbtxt @@ -122,7 +122,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-input-layer.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-input-layer.pbtxt index 489029959ae..6366a29f0b9 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-input-layer.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-input-layer.pbtxt @@ -121,7 +121,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt index 652b740b771..a15b042d96e 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt @@ -122,7 +122,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-l-s-t-m.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-l-s-t-m.pbtxt index acaebc5c8bb..975df5f3b1e 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-l-s-t-m.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-l-s-t-m.pbtxt @@ -194,7 +194,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-lambda.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-lambda.pbtxt index d64a6880b4c..14b809390eb 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-lambda.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-lambda.pbtxt @@ -121,7 +121,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-layer-normalization.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-layer-normalization.pbtxt index 8b91bcbbc12..f1adf9b2178 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-layer-normalization.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-layer-normalization.pbtxt @@ -121,7 +121,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-layer.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-layer.pbtxt index 57bda42bd84..2dcb55a3331 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-layer.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-layer.pbtxt @@ -120,7 +120,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-leaky-re-l-u.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-leaky-re-l-u.pbtxt index 7d9a7e112e8..85b4a635d9e 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-leaky-re-l-u.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-leaky-re-l-u.pbtxt @@ -121,7 +121,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-locally-connected1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-locally-connected1-d.pbtxt index e940d9e37e2..bb4c63d4289 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-locally-connected1-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-locally-connected1-d.pbtxt @@ -121,7 +121,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-locally-connected2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-locally-connected2-d.pbtxt index 6a57f886512..8068baf2931 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-locally-connected2-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-locally-connected2-d.pbtxt @@ -121,7 +121,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-masking.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-masking.pbtxt index 62de174e026..775cc8f4458 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-masking.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-masking.pbtxt @@ -121,7 +121,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pool1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pool1-d.pbtxt index 26626090507..8fd7d059937 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pool1-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pool1-d.pbtxt @@ -122,7 +122,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pool2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pool2-d.pbtxt index f94b0ab7811..aadaea15b7b 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pool2-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pool2-d.pbtxt @@ -122,7 +122,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pool3-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pool3-d.pbtxt index b689aaa278b..ea1c60e48d3 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pool3-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pool3-d.pbtxt @@ -122,7 +122,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pooling1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pooling1-d.pbtxt index c38682001ce..b9f09656973 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pooling1-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pooling1-d.pbtxt @@ -122,7 +122,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pooling2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pooling2-d.pbtxt index 51659f8a081..ade1e839676 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pooling2-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pooling2-d.pbtxt @@ -122,7 +122,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pooling3-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pooling3-d.pbtxt index 35346390ce5..2d129d415da 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pooling3-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pooling3-d.pbtxt @@ -122,7 +122,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-maximum.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-maximum.pbtxt index 0c803b3c689..b4adbbcbea2 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-maximum.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-maximum.pbtxt @@ -122,7 +122,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-minimum.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-minimum.pbtxt index ccc586965d3..12d2cc690b8 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-minimum.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-minimum.pbtxt @@ -122,7 +122,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-multiply.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-multiply.pbtxt index 4d5456281b0..5e5d3992927 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-multiply.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-multiply.pbtxt @@ -122,7 +122,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-p-re-l-u.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-p-re-l-u.pbtxt index d3972cf2ec1..733fb63d1fb 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-p-re-l-u.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-p-re-l-u.pbtxt @@ -121,7 +121,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-permute.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-permute.pbtxt index c3d8882fc02..3e2d70a5a0a 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-permute.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-permute.pbtxt @@ -121,7 +121,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-r-n-n.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-r-n-n.pbtxt index 00a65cd1e92..3018929154e 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-r-n-n.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-r-n-n.pbtxt @@ -125,7 +125,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-re-l-u.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-re-l-u.pbtxt index a38bd7d5412..7af41433d28 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-re-l-u.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-re-l-u.pbtxt @@ -121,7 +121,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-repeat-vector.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-repeat-vector.pbtxt index bbcdf228382..52eb2c247cf 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-repeat-vector.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-repeat-vector.pbtxt @@ -121,7 +121,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-reshape.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-reshape.pbtxt index 6664452d99e..08658b26be3 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-reshape.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-reshape.pbtxt @@ -121,7 +121,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-separable-conv1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-separable-conv1-d.pbtxt index 718247fb031..9bab5a78338 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-separable-conv1-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-separable-conv1-d.pbtxt @@ -123,7 +123,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-separable-conv2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-separable-conv2-d.pbtxt index 6d24ef70876..2bcc06f9330 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-separable-conv2-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-separable-conv2-d.pbtxt @@ -123,7 +123,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-separable-convolution1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-separable-convolution1-d.pbtxt index 443e5f043c1..823e28a8bb9 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-separable-convolution1-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-separable-convolution1-d.pbtxt @@ -123,7 +123,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-separable-convolution2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-separable-convolution2-d.pbtxt index a3d8f2a29b1..c27047ecd71 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-separable-convolution2-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-separable-convolution2-d.pbtxt @@ -123,7 +123,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt index a7c83abef1e..417e79df321 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt @@ -122,7 +122,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-simple-r-n-n.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-simple-r-n-n.pbtxt index 6b31632148e..e6e12106c6c 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-simple-r-n-n.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-simple-r-n-n.pbtxt @@ -182,7 +182,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-softmax.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-softmax.pbtxt index c84b56e466b..8b435bd2b41 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-softmax.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-softmax.pbtxt @@ -121,7 +121,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-spatial-dropout1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-spatial-dropout1-d.pbtxt index b45678622cb..d5fbff4d5c6 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-spatial-dropout1-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-spatial-dropout1-d.pbtxt @@ -122,7 +122,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-spatial-dropout2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-spatial-dropout2-d.pbtxt index fff0df98d0c..287e0167076 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-spatial-dropout2-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-spatial-dropout2-d.pbtxt @@ -122,7 +122,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-spatial-dropout3-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-spatial-dropout3-d.pbtxt index 34396757a54..78ab93ae395 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-spatial-dropout3-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-spatial-dropout3-d.pbtxt @@ -122,7 +122,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt index 35910318eb0..27afe1a56c6 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt @@ -129,7 +129,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-subtract.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-subtract.pbtxt index 2408c8676ff..b060c3169fd 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-subtract.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-subtract.pbtxt @@ -122,7 +122,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-thresholded-re-l-u.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-thresholded-re-l-u.pbtxt index b264adfbd87..272fd09afc6 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-thresholded-re-l-u.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-thresholded-re-l-u.pbtxt @@ -121,7 +121,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-time-distributed.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-time-distributed.pbtxt index 3236d720cf4..95274944084 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-time-distributed.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-time-distributed.pbtxt @@ -122,7 +122,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-up-sampling1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-up-sampling1-d.pbtxt index bd15fd0202d..8c8f4f287bd 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-up-sampling1-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-up-sampling1-d.pbtxt @@ -121,7 +121,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-up-sampling2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-up-sampling2-d.pbtxt index dd58cf38f3f..c56ea3122ed 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-up-sampling2-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-up-sampling2-d.pbtxt @@ -121,7 +121,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-up-sampling3-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-up-sampling3-d.pbtxt index e738ca01ce8..80c647c9fc1 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-up-sampling3-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-up-sampling3-d.pbtxt @@ -121,7 +121,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-wrapper.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-wrapper.pbtxt index dbc7c545d23..63423b9ee0c 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-wrapper.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-wrapper.pbtxt @@ -121,7 +121,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-zero-padding1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-zero-padding1-d.pbtxt index f423bf9a270..e5a31b88df9 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-zero-padding1-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-zero-padding1-d.pbtxt @@ -121,7 +121,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-zero-padding2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-zero-padding2-d.pbtxt index f871573794b..b170d030fe8 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-zero-padding2-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-zero-padding2-d.pbtxt @@ -121,7 +121,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-zero-padding3-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-zero-padding3-d.pbtxt index 7b3859cf838..6010e155661 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-zero-padding3-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-zero-padding3-d.pbtxt @@ -121,7 +121,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-center-crop.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-center-crop.pbtxt index 02bb9b6e6fa..4a846b138a9 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-center-crop.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-center-crop.pbtxt @@ -121,7 +121,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-normalization.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-normalization.pbtxt index 7c907cac2ab..9feb216577a 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-normalization.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-normalization.pbtxt @@ -129,7 +129,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-preprocessing-layer.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-preprocessing-layer.pbtxt index e1e67575de6..d84d810bdd0 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-preprocessing-layer.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-preprocessing-layer.pbtxt @@ -125,7 +125,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-random-contrast.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-random-contrast.pbtxt index 009e4781cc1..c8cc33fea5d 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-random-contrast.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-random-contrast.pbtxt @@ -121,7 +121,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-random-crop.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-random-crop.pbtxt index b652c8d099f..2c6b4bc0c9c 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-random-crop.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-random-crop.pbtxt @@ -121,7 +121,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-random-flip.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-random-flip.pbtxt index 1a1aaea1964..782a7d56892 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-random-flip.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-random-flip.pbtxt @@ -121,7 +121,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-random-height.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-random-height.pbtxt index 1c4d6639f8f..769fbd0b5ac 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-random-height.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-random-height.pbtxt @@ -121,7 +121,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-random-rotation.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-random-rotation.pbtxt index 8aeee741de8..f539ee33804 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-random-rotation.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-random-rotation.pbtxt @@ -121,7 +121,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-random-translation.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-random-translation.pbtxt index 1b0daace7bf..57b20ce4031 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-random-translation.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-random-translation.pbtxt @@ -121,7 +121,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-random-width.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-random-width.pbtxt index 704fb827c45..d4b19e22028 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-random-width.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-random-width.pbtxt @@ -121,7 +121,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-random-zoom.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-random-zoom.pbtxt index 3fb7e6856c8..57eb2c9c175 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-random-zoom.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-random-zoom.pbtxt @@ -121,7 +121,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-rescaling.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-rescaling.pbtxt index a4c6a0b1510..7816930fd5c 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-rescaling.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-rescaling.pbtxt @@ -121,7 +121,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-resizing.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-resizing.pbtxt index 398b93ccd72..05f110140bf 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-resizing.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-resizing.pbtxt @@ -121,7 +121,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-text-vectorization.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-text-vectorization.pbtxt index 773fe692329..ab5067a23ca 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-text-vectorization.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-text-vectorization.pbtxt @@ -129,7 +129,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.pbtxt index 847cc814e0f..10dbfb56078 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.pbtxt @@ -72,6 +72,10 @@ tf_module { name: "Conv1D" mtype: "" } + member { + name: "Conv1DTranspose" + mtype: "" + } member { name: "Conv2D" mtype: "" @@ -96,6 +100,10 @@ tf_module { name: "Convolution1D" mtype: "" } + member { + name: "Convolution1DTranspose" + mtype: "" + } member { name: "Convolution2D" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-a-u-c.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-a-u-c.pbtxt index 7c860b922bb..64ccf7c98ac 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-a-u-c.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-a-u-c.pbtxt @@ -122,7 +122,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-accuracy.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-accuracy.pbtxt index 768d9b7f6a3..d211a16597e 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-accuracy.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-accuracy.pbtxt @@ -125,7 +125,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-binary-accuracy.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-binary-accuracy.pbtxt index 9598d148015..58103637fe3 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-binary-accuracy.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-binary-accuracy.pbtxt @@ -125,7 +125,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-binary-crossentropy.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-binary-crossentropy.pbtxt index 8a8917d3e4d..4f748914101 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-binary-crossentropy.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-binary-crossentropy.pbtxt @@ -125,7 +125,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-categorical-accuracy.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-categorical-accuracy.pbtxt index 67fd36960d3..42e57f86769 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-categorical-accuracy.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-categorical-accuracy.pbtxt @@ -125,7 +125,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-categorical-crossentropy.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-categorical-crossentropy.pbtxt index 5ce32740800..6ef136de517 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-categorical-crossentropy.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-categorical-crossentropy.pbtxt @@ -125,7 +125,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-categorical-hinge.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-categorical-hinge.pbtxt index 9cb128b22e4..f3379748fb9 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-categorical-hinge.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-categorical-hinge.pbtxt @@ -125,7 +125,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-cosine-similarity.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-cosine-similarity.pbtxt index 1bb29ee6c6a..9367edcb228 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-cosine-similarity.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-cosine-similarity.pbtxt @@ -125,7 +125,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-false-negatives.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-false-negatives.pbtxt index 3079a731d7d..820d2ed1e7c 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-false-negatives.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-false-negatives.pbtxt @@ -123,7 +123,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-false-positives.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-false-positives.pbtxt index 146e10f944e..de23747dcef 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-false-positives.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-false-positives.pbtxt @@ -123,7 +123,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-hinge.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-hinge.pbtxt index 3c9fce69676..edae2d27448 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-hinge.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-hinge.pbtxt @@ -125,7 +125,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-k-l-divergence.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-k-l-divergence.pbtxt index 1eb59f483de..171ade560a9 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-k-l-divergence.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-k-l-divergence.pbtxt @@ -125,7 +125,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-log-cosh-error.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-log-cosh-error.pbtxt index cab6edabce1..8713d9aa427 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-log-cosh-error.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-log-cosh-error.pbtxt @@ -125,7 +125,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-mean-absolute-error.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-mean-absolute-error.pbtxt index 4c697325134..6c5541e71d9 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-mean-absolute-error.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-mean-absolute-error.pbtxt @@ -125,7 +125,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-mean-absolute-percentage-error.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-mean-absolute-percentage-error.pbtxt index f0796ab9d22..45b94842278 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-mean-absolute-percentage-error.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-mean-absolute-percentage-error.pbtxt @@ -125,7 +125,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-mean-io-u.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-mean-io-u.pbtxt index b99ecdbea74..90733200606 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-mean-io-u.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-mean-io-u.pbtxt @@ -122,7 +122,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-mean-relative-error.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-mean-relative-error.pbtxt index bc117988392..5997066e37a 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-mean-relative-error.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-mean-relative-error.pbtxt @@ -124,7 +124,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-mean-squared-error.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-mean-squared-error.pbtxt index 179e6384e47..b6a0e00ffa0 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-mean-squared-error.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-mean-squared-error.pbtxt @@ -125,7 +125,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-mean-squared-logarithmic-error.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-mean-squared-logarithmic-error.pbtxt index df65e2f5ad6..b698ab5ff65 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-mean-squared-logarithmic-error.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-mean-squared-logarithmic-error.pbtxt @@ -125,7 +125,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-mean-tensor.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-mean-tensor.pbtxt index 6632e586a6d..01a3d3f6e07 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-mean-tensor.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-mean-tensor.pbtxt @@ -130,7 +130,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-mean.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-mean.pbtxt index 428b2e336d8..c47a1bc749c 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-mean.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-mean.pbtxt @@ -123,7 +123,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-metric.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-metric.pbtxt index e0f0b32fb6d..fe72d0ad1d6 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-metric.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-metric.pbtxt @@ -121,7 +121,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-poisson.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-poisson.pbtxt index 7adbd07bb38..befbf09ed11 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-poisson.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-poisson.pbtxt @@ -125,7 +125,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-precision-at-recall.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-precision-at-recall.pbtxt index b6ff6bf42a6..3f001a9d4e2 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-precision-at-recall.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-precision-at-recall.pbtxt @@ -123,7 +123,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-precision.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-precision.pbtxt index 97e368aedd0..e4d66868b1b 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-precision.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-precision.pbtxt @@ -122,7 +122,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-recall-at-precision.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-recall-at-precision.pbtxt index a5f143ca664..daad023ef66 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-recall-at-precision.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-recall-at-precision.pbtxt @@ -123,7 +123,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-recall.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-recall.pbtxt index 19c29897a46..40329ff21be 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-recall.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-recall.pbtxt @@ -122,7 +122,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-root-mean-squared-error.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-root-mean-squared-error.pbtxt index 07cdff6c8ac..a05adf6070a 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-root-mean-squared-error.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-root-mean-squared-error.pbtxt @@ -124,7 +124,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-sensitivity-at-specificity.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-sensitivity-at-specificity.pbtxt index 17e3639509b..466ef391017 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-sensitivity-at-specificity.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-sensitivity-at-specificity.pbtxt @@ -123,7 +123,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-sparse-categorical-accuracy.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-sparse-categorical-accuracy.pbtxt index a0055d4ecee..12a1f5daa14 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-sparse-categorical-accuracy.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-sparse-categorical-accuracy.pbtxt @@ -125,7 +125,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-sparse-categorical-crossentropy.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-sparse-categorical-crossentropy.pbtxt index 5fe7617b2ec..86ae28fb876 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-sparse-categorical-crossentropy.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-sparse-categorical-crossentropy.pbtxt @@ -125,7 +125,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-sparse-top-k-categorical-accuracy.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-sparse-top-k-categorical-accuracy.pbtxt index e07b0b1c64f..6a52c10edbb 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-sparse-top-k-categorical-accuracy.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-sparse-top-k-categorical-accuracy.pbtxt @@ -125,7 +125,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-specificity-at-sensitivity.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-specificity-at-sensitivity.pbtxt index 99eb90769bd..08a9118eac0 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-specificity-at-sensitivity.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-specificity-at-sensitivity.pbtxt @@ -123,7 +123,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-squared-hinge.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-squared-hinge.pbtxt index cfc64c0c28e..810f4e61806 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-squared-hinge.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-squared-hinge.pbtxt @@ -125,7 +125,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-sum.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-sum.pbtxt index d1d771e3992..9ee1af61e34 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-sum.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-sum.pbtxt @@ -123,7 +123,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-top-k-categorical-accuracy.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-top-k-categorical-accuracy.pbtxt index 1c4b1eacc74..41dcc25644f 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-top-k-categorical-accuracy.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-top-k-categorical-accuracy.pbtxt @@ -125,7 +125,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-true-negatives.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-true-negatives.pbtxt index 8ee532ffd36..3726bff3850 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-true-negatives.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-true-negatives.pbtxt @@ -123,7 +123,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-true-positives.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-true-positives.pbtxt index 4b22ce73cea..1ca4fc5c21b 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-true-positives.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-true-positives.pbtxt @@ -123,7 +123,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-model.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-model.pbtxt index b77893eaeda..070fd981fc0 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-model.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-model.pbtxt @@ -143,7 +143,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-sequential.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-sequential.pbtxt index 939657fd748..51f9bce2453 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-sequential.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-sequential.pbtxt @@ -148,7 +148,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.layers.-average-pooling1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.layers.-average-pooling1-d.pbtxt index f99dad33a6d..5f89addb628 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.layers.-average-pooling1-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.layers.-average-pooling1-d.pbtxt @@ -132,7 +132,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.layers.-average-pooling2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.layers.-average-pooling2-d.pbtxt index eb688a9c676..056e98d47d0 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.layers.-average-pooling2-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.layers.-average-pooling2-d.pbtxt @@ -132,7 +132,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.layers.-average-pooling3-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.layers.-average-pooling3-d.pbtxt index 20ee5a52952..63eaa837e93 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.layers.-average-pooling3-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.layers.-average-pooling3-d.pbtxt @@ -132,7 +132,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.layers.-batch-normalization.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.layers.-batch-normalization.pbtxt index e910faef781..f7199a19a28 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.layers.-batch-normalization.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.layers.-batch-normalization.pbtxt @@ -132,7 +132,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.layers.-conv1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.layers.-conv1-d.pbtxt index 5ff802f6e48..9e67ceb7f45 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.layers.-conv1-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.layers.-conv1-d.pbtxt @@ -132,7 +132,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.layers.-conv2-d-transpose.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.layers.-conv2-d-transpose.pbtxt index cd98e2773ba..52724953edc 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.layers.-conv2-d-transpose.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.layers.-conv2-d-transpose.pbtxt @@ -133,7 +133,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.layers.-conv2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.layers.-conv2-d.pbtxt index a65847cfda5..da144936c76 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.layers.-conv2-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.layers.-conv2-d.pbtxt @@ -132,7 +132,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.layers.-conv3-d-transpose.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.layers.-conv3-d-transpose.pbtxt index 3ea91c6fb66..8672e9bd666 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.layers.-conv3-d-transpose.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.layers.-conv3-d-transpose.pbtxt @@ -133,7 +133,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.layers.-conv3-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.layers.-conv3-d.pbtxt index d39b663299e..13144714b0b 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.layers.-conv3-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.layers.-conv3-d.pbtxt @@ -132,7 +132,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.layers.-dense.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.layers.-dense.pbtxt index fec6b718cd2..3fb0aa15efa 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.layers.-dense.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.layers.-dense.pbtxt @@ -131,7 +131,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.layers.-dropout.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.layers.-dropout.pbtxt index edc81f739e6..8ad21ade6a4 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.layers.-dropout.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.layers.-dropout.pbtxt @@ -131,7 +131,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.layers.-flatten.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.layers.-flatten.pbtxt index 3cc78accc78..e5b86e2be32 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.layers.-flatten.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.layers.-flatten.pbtxt @@ -131,7 +131,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.layers.-layer.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.layers.-layer.pbtxt index 2b65efadcef..cc325d01456 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.layers.-layer.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.layers.-layer.pbtxt @@ -129,7 +129,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.layers.-max-pooling1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.layers.-max-pooling1-d.pbtxt index c05d1d10329..8a8199b99d4 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.layers.-max-pooling1-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.layers.-max-pooling1-d.pbtxt @@ -132,7 +132,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.layers.-max-pooling2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.layers.-max-pooling2-d.pbtxt index a0e5e6fb050..ebdeed19c77 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.layers.-max-pooling2-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.layers.-max-pooling2-d.pbtxt @@ -132,7 +132,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.layers.-max-pooling3-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.layers.-max-pooling3-d.pbtxt index 788c181a96f..a0a342c5b86 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.layers.-max-pooling3-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.layers.-max-pooling3-d.pbtxt @@ -132,7 +132,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.layers.-separable-conv1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.layers.-separable-conv1-d.pbtxt index 920dcf0f747..dc81593ffde 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.layers.-separable-conv1-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.layers.-separable-conv1-d.pbtxt @@ -133,7 +133,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.layers.-separable-conv2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.layers.-separable-conv2-d.pbtxt index 33bb2ee8785..d4ab15d3217 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.layers.-separable-conv2-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.layers.-separable-conv2-d.pbtxt @@ -133,7 +133,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.lite.experimental.nn.-t-f-lite-l-s-t-m-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.lite.experimental.nn.-t-f-lite-l-s-t-m-cell.pbtxt index 34cc5a20beb..14be46801fa 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.lite.experimental.nn.-t-f-lite-l-s-t-m-cell.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.lite.experimental.nn.-t-f-lite-l-s-t-m-cell.pbtxt @@ -140,7 +140,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.lite.experimental.nn.-tf-lite-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.lite.experimental.nn.-tf-lite-r-n-n-cell.pbtxt index f17c071254e..6b8e613e6d0 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.lite.experimental.nn.-tf-lite-r-n-n-cell.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.lite.experimental.nn.-tf-lite-r-n-n-cell.pbtxt @@ -140,7 +140,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-basic-l-s-t-m-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-basic-l-s-t-m-cell.pbtxt index d337b185c46..2c95ad00100 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-basic-l-s-t-m-cell.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-basic-l-s-t-m-cell.pbtxt @@ -140,7 +140,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt index 7269795bde5..2eae2ba5168 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt @@ -140,7 +140,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-device-wrapper.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-device-wrapper.pbtxt index 1368c1cb603..86adb8695ca 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-device-wrapper.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-device-wrapper.pbtxt @@ -141,7 +141,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-dropout-wrapper.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-dropout-wrapper.pbtxt index 6d490621aa9..05491f86058 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-dropout-wrapper.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-dropout-wrapper.pbtxt @@ -145,7 +145,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-g-r-u-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-g-r-u-cell.pbtxt index a9669ff59a2..17813a8de8b 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-g-r-u-cell.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-g-r-u-cell.pbtxt @@ -140,7 +140,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-l-s-t-m-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-l-s-t-m-cell.pbtxt index 54e517ac974..21a0b8ab4eb 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-l-s-t-m-cell.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-l-s-t-m-cell.pbtxt @@ -140,7 +140,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt index db31de9d754..d859a4666a0 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt @@ -139,7 +139,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-r-n-n-cell.pbtxt index 2286a66efd8..29a592022b2 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-r-n-n-cell.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-r-n-n-cell.pbtxt @@ -138,7 +138,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt index 5570bf7af98..96f625f7f55 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt @@ -141,7 +141,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt index d9414c31e7d..acc6334055f 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt @@ -58,7 +58,7 @@ tf_class { } member_method { name: "interleave" - argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\', \'deterministic\'], varargs=None, keywords=None, defaults=[\'-1\', \'1\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\', \'deterministic\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " } member_method { name: "list_files" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.pbtxt index 28efdb6e855..1c5ab59020e 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.pbtxt @@ -60,7 +60,7 @@ tf_class { } member_method { name: "interleave" - argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\', \'deterministic\'], varargs=None, keywords=None, defaults=[\'-1\', \'1\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\', \'deterministic\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " } member_method { name: "list_files" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.pbtxt index c9553efb58c..18c77bf4289 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.pbtxt @@ -59,7 +59,7 @@ tf_class { } member_method { name: "interleave" - argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\', \'deterministic\'], varargs=None, keywords=None, defaults=[\'-1\', \'1\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\', \'deterministic\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " } member_method { name: "list_files" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.pbtxt index 16a878144ae..6ebe24206ab 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.pbtxt @@ -60,7 +60,7 @@ tf_class { } member_method { name: "interleave" - argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\', \'deterministic\'], varargs=None, keywords=None, defaults=[\'-1\', \'1\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\', \'deterministic\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " } member_method { name: "list_files" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-csv-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-csv-dataset.pbtxt index d1d2db041e0..41865c9700f 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-csv-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-csv-dataset.pbtxt @@ -60,7 +60,7 @@ tf_class { } member_method { name: "interleave" - argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\', \'deterministic\'], varargs=None, keywords=None, defaults=[\'-1\', \'1\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\', \'deterministic\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " } member_method { name: "list_files" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-random-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-random-dataset.pbtxt index 18a6b8cbd1b..ae905aa1fea 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-random-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-random-dataset.pbtxt @@ -60,7 +60,7 @@ tf_class { } member_method { name: "interleave" - argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\', \'deterministic\'], varargs=None, keywords=None, defaults=[\'-1\', \'1\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\', \'deterministic\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " } member_method { name: "list_files" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-sql-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-sql-dataset.pbtxt index 0cf3d94ba68..3f274660402 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-sql-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-sql-dataset.pbtxt @@ -60,7 +60,7 @@ tf_class { } member_method { name: "interleave" - argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\', \'deterministic\'], varargs=None, keywords=None, defaults=[\'-1\', \'1\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\', \'deterministic\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " } member_method { name: "list_files" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.-model.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.-model.pbtxt index 414f682473c..081d0639667 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.-model.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.-model.pbtxt @@ -143,7 +143,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.-sequential.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.-sequential.pbtxt index fb929010980..59fec51d8e6 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.-sequential.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.-sequential.pbtxt @@ -148,7 +148,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-linear-model.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-linear-model.pbtxt index 2d64a7bb9e0..f9333d139e3 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-linear-model.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-linear-model.pbtxt @@ -144,7 +144,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-peephole-l-s-t-m-cell.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-peephole-l-s-t-m-cell.pbtxt index 8f35f4b877f..3b12b4e8055 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-peephole-l-s-t-m-cell.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-peephole-l-s-t-m-cell.pbtxt @@ -123,7 +123,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-sequence-features.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-sequence-features.pbtxt index 0b94554a7bb..578fbf03f77 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-sequence-features.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-sequence-features.pbtxt @@ -122,7 +122,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-wide-deep-model.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-wide-deep-model.pbtxt index 4f17a33773c..5cf3162ce42 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-wide-deep-model.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-wide-deep-model.pbtxt @@ -144,7 +144,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-abstract-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-abstract-r-n-n-cell.pbtxt index 7d5a096e0f7..3ba96bab6fe 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-abstract-r-n-n-cell.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-abstract-r-n-n-cell.pbtxt @@ -129,7 +129,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-activation.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-activation.pbtxt index 60f20efb9d5..3f59d9987a5 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-activation.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-activation.pbtxt @@ -121,7 +121,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-activity-regularization.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-activity-regularization.pbtxt index 16ca6f428b4..acc72ebf939 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-activity-regularization.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-activity-regularization.pbtxt @@ -121,7 +121,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-add.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-add.pbtxt index fa0ac16192d..839d57e4c94 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-add.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-add.pbtxt @@ -122,7 +122,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-additive-attention.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-additive-attention.pbtxt index 2cfd5b6c11d..1c22721666b 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-additive-attention.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-additive-attention.pbtxt @@ -122,7 +122,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-alpha-dropout.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-alpha-dropout.pbtxt index b86d5180031..cf883e74088 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-alpha-dropout.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-alpha-dropout.pbtxt @@ -121,7 +121,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-attention.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-attention.pbtxt index 60b4777624f..70800bccf8c 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-attention.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-attention.pbtxt @@ -122,7 +122,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-average-pooling1-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-average-pooling1-d.pbtxt index 46c6a028077..11f70522f1a 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-average-pooling1-d.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-average-pooling1-d.pbtxt @@ -122,7 +122,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-average-pooling2-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-average-pooling2-d.pbtxt index 7e3dd70fdb2..ff311806b47 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-average-pooling2-d.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-average-pooling2-d.pbtxt @@ -122,7 +122,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-average-pooling3-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-average-pooling3-d.pbtxt index 6326191280e..dc3cc76d9e1 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-average-pooling3-d.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-average-pooling3-d.pbtxt @@ -122,7 +122,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-average.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-average.pbtxt index 701bbcca0e0..6fdcb8c9000 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-average.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-average.pbtxt @@ -122,7 +122,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-avg-pool1-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-avg-pool1-d.pbtxt index 955535104b2..a5d912c9b8e 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-avg-pool1-d.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-avg-pool1-d.pbtxt @@ -122,7 +122,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-avg-pool2-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-avg-pool2-d.pbtxt index cb98e5d728d..7471b7306d3 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-avg-pool2-d.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-avg-pool2-d.pbtxt @@ -122,7 +122,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-avg-pool3-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-avg-pool3-d.pbtxt index 8fdc3bf6cbd..323c0d51988 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-avg-pool3-d.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-avg-pool3-d.pbtxt @@ -122,7 +122,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-batch-normalization.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-batch-normalization.pbtxt index 74a689f32df..71ca168a55c 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-batch-normalization.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-batch-normalization.pbtxt @@ -122,7 +122,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-bidirectional.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-bidirectional.pbtxt index 13dfd36608d..16143b3b20e 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-bidirectional.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-bidirectional.pbtxt @@ -126,7 +126,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-concatenate.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-concatenate.pbtxt index bf3900eeb3e..2bea88de2fd 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-concatenate.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-concatenate.pbtxt @@ -122,7 +122,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv-l-s-t-m2-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv-l-s-t-m2-d.pbtxt index 3b163115cb4..444220d4e06 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv-l-s-t-m2-d.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv-l-s-t-m2-d.pbtxt @@ -211,7 +211,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv1-d-transpose.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv1-d-transpose.pbtxt new file mode 100644 index 00000000000..22de9fb79ff --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv1-d-transpose.pbtxt @@ -0,0 +1,220 @@ +path: "tensorflow.keras.layers.Conv1DTranspose" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "dynamic" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "input_spec" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "metrics" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "name_scope" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "stateful" + mtype: "" + } + member { + name: "submodules" + mtype: "" + } + member { + name: "trainable" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'output_padding\', \'data_format\', \'dilation_rate\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'1\', \'valid\', \'None\', \'None\', \'1\', \'None\', \'True\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_metric" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_output_signature" + argspec: "args=[\'self\', \'input_signature\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "with_name_scope" + argspec: "args=[\'cls\', \'method\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv1-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv1-d.pbtxt index 7748f763576..b45954626ba 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv1-d.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv1-d.pbtxt @@ -122,7 +122,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv2-d-transpose.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv2-d-transpose.pbtxt index 7834932b5bb..da6bfec7499 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv2-d-transpose.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv2-d-transpose.pbtxt @@ -123,7 +123,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv2-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv2-d.pbtxt index 2d03874d6b1..b66d4fc4d3c 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv2-d.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv2-d.pbtxt @@ -122,7 +122,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv3-d-transpose.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv3-d-transpose.pbtxt index a2998f59114..4e9ce619361 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv3-d-transpose.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv3-d-transpose.pbtxt @@ -123,7 +123,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv3-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv3-d.pbtxt index e7974563f59..fedb39dbd21 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv3-d.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv3-d.pbtxt @@ -122,7 +122,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution1-d-transpose.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution1-d-transpose.pbtxt new file mode 100644 index 00000000000..28357ffa0f6 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution1-d-transpose.pbtxt @@ -0,0 +1,220 @@ +path: "tensorflow.keras.layers.Convolution1DTranspose" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "dynamic" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "input_spec" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "metrics" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "name_scope" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "stateful" + mtype: "" + } + member { + name: "submodules" + mtype: "" + } + member { + name: "trainable" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'output_padding\', \'data_format\', \'dilation_rate\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'1\', \'valid\', \'None\', \'None\', \'1\', \'None\', \'True\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_metric" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_output_signature" + argspec: "args=[\'self\', \'input_signature\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "with_name_scope" + argspec: "args=[\'cls\', \'method\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution1-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution1-d.pbtxt index f4906272693..6d97faacece 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution1-d.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution1-d.pbtxt @@ -122,7 +122,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution2-d-transpose.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution2-d-transpose.pbtxt index 0f6aecd876e..830caf7f693 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution2-d-transpose.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution2-d-transpose.pbtxt @@ -123,7 +123,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution2-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution2-d.pbtxt index 7e60cdfdce3..df115f618c7 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution2-d.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution2-d.pbtxt @@ -122,7 +122,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution3-d-transpose.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution3-d-transpose.pbtxt index 3f750a6200b..69f71b6a3ff 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution3-d-transpose.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution3-d-transpose.pbtxt @@ -123,7 +123,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution3-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution3-d.pbtxt index 3071323b7d1..f58aa3e1baa 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution3-d.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution3-d.pbtxt @@ -122,7 +122,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-cropping1-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-cropping1-d.pbtxt index 98354dbb0d8..44b66135732 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-cropping1-d.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-cropping1-d.pbtxt @@ -121,7 +121,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-cropping2-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-cropping2-d.pbtxt index 5ce76e3974e..63591c0e984 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-cropping2-d.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-cropping2-d.pbtxt @@ -121,7 +121,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-cropping3-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-cropping3-d.pbtxt index b4ec7544d17..b5e96804759 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-cropping3-d.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-cropping3-d.pbtxt @@ -121,7 +121,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-dense-features.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-dense-features.pbtxt index f010cd09dfb..d035db30248 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-dense-features.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-dense-features.pbtxt @@ -123,7 +123,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-dense.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-dense.pbtxt index 7281f900be9..31b101ce81b 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-dense.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-dense.pbtxt @@ -121,7 +121,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-depthwise-conv2-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-depthwise-conv2-d.pbtxt index 955f38d90fa..46138e74b4b 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-depthwise-conv2-d.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-depthwise-conv2-d.pbtxt @@ -123,7 +123,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-dot.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-dot.pbtxt index 741fd7c85cf..4f45a085317 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-dot.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-dot.pbtxt @@ -122,7 +122,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-dropout.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-dropout.pbtxt index af72a0eeabe..869d8d4817b 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-dropout.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-dropout.pbtxt @@ -121,7 +121,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-e-l-u.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-e-l-u.pbtxt index 01749e8743e..33a95bd2312 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-e-l-u.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-e-l-u.pbtxt @@ -121,7 +121,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-embedding.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-embedding.pbtxt index ec70d321c17..35c25eab279 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-embedding.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-embedding.pbtxt @@ -121,7 +121,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-flatten.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-flatten.pbtxt index 2b24e3d5f63..955ec7a0a49 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-flatten.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-flatten.pbtxt @@ -121,7 +121,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-g-r-u-cell.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-g-r-u-cell.pbtxt index ce9eb821fc3..0bbca8b0628 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-g-r-u-cell.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-g-r-u-cell.pbtxt @@ -123,7 +123,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-g-r-u.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-g-r-u.pbtxt index 943ef2e5db4..8365c652b9d 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-g-r-u.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-g-r-u.pbtxt @@ -196,7 +196,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-gaussian-dropout.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-gaussian-dropout.pbtxt index 8a62b5e06b8..b966a1fa48a 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-gaussian-dropout.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-gaussian-dropout.pbtxt @@ -121,7 +121,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-gaussian-noise.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-gaussian-noise.pbtxt index ca5705d9031..bcadf04ab46 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-gaussian-noise.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-gaussian-noise.pbtxt @@ -121,7 +121,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-average-pooling1-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-average-pooling1-d.pbtxt index 5868fa64fa0..93f9f085028 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-average-pooling1-d.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-average-pooling1-d.pbtxt @@ -122,7 +122,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-average-pooling2-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-average-pooling2-d.pbtxt index f04f63ed8ec..c1988faf3d7 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-average-pooling2-d.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-average-pooling2-d.pbtxt @@ -122,7 +122,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-average-pooling3-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-average-pooling3-d.pbtxt index d3ad8438bc4..516e93110c5 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-average-pooling3-d.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-average-pooling3-d.pbtxt @@ -122,7 +122,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-avg-pool1-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-avg-pool1-d.pbtxt index 5f4f350fda5..545af759275 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-avg-pool1-d.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-avg-pool1-d.pbtxt @@ -122,7 +122,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-avg-pool2-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-avg-pool2-d.pbtxt index a593841db92..13fc0dade36 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-avg-pool2-d.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-avg-pool2-d.pbtxt @@ -122,7 +122,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-avg-pool3-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-avg-pool3-d.pbtxt index 190ae3e6e34..5c6515f166d 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-avg-pool3-d.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-avg-pool3-d.pbtxt @@ -122,7 +122,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pool1-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pool1-d.pbtxt index d24d9cd6d8a..27bde045cbd 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pool1-d.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pool1-d.pbtxt @@ -122,7 +122,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pool2-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pool2-d.pbtxt index dedc983a0e8..21ee43eb016 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pool2-d.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pool2-d.pbtxt @@ -122,7 +122,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pool3-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pool3-d.pbtxt index 391feaace06..14fac4a4edd 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pool3-d.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pool3-d.pbtxt @@ -122,7 +122,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pooling1-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pooling1-d.pbtxt index 89cbb5da560..0cc18b9a462 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pooling1-d.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pooling1-d.pbtxt @@ -122,7 +122,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pooling2-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pooling2-d.pbtxt index 4a327d47033..cb26f965881 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pooling2-d.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pooling2-d.pbtxt @@ -122,7 +122,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pooling3-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pooling3-d.pbtxt index 60c7502868b..aef01152cfe 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pooling3-d.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pooling3-d.pbtxt @@ -122,7 +122,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-input-layer.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-input-layer.pbtxt index 489029959ae..6366a29f0b9 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-input-layer.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-input-layer.pbtxt @@ -121,7 +121,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt index 9a23855da9a..a0deeb6dbd3 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt @@ -123,7 +123,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-l-s-t-m.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-l-s-t-m.pbtxt index f2c1b7e8f94..e000180ee73 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-l-s-t-m.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-l-s-t-m.pbtxt @@ -196,7 +196,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-lambda.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-lambda.pbtxt index d64a6880b4c..14b809390eb 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-lambda.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-lambda.pbtxt @@ -121,7 +121,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-layer-normalization.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-layer-normalization.pbtxt index 8b91bcbbc12..f1adf9b2178 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-layer-normalization.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-layer-normalization.pbtxt @@ -121,7 +121,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-layer.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-layer.pbtxt index 57bda42bd84..2dcb55a3331 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-layer.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-layer.pbtxt @@ -120,7 +120,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-leaky-re-l-u.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-leaky-re-l-u.pbtxt index 7d9a7e112e8..85b4a635d9e 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-leaky-re-l-u.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-leaky-re-l-u.pbtxt @@ -121,7 +121,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-locally-connected1-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-locally-connected1-d.pbtxt index e940d9e37e2..bb4c63d4289 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-locally-connected1-d.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-locally-connected1-d.pbtxt @@ -121,7 +121,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-locally-connected2-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-locally-connected2-d.pbtxt index 6a57f886512..8068baf2931 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-locally-connected2-d.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-locally-connected2-d.pbtxt @@ -121,7 +121,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-masking.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-masking.pbtxt index 62de174e026..775cc8f4458 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-masking.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-masking.pbtxt @@ -121,7 +121,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pool1-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pool1-d.pbtxt index 26626090507..8fd7d059937 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pool1-d.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pool1-d.pbtxt @@ -122,7 +122,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pool2-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pool2-d.pbtxt index f94b0ab7811..aadaea15b7b 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pool2-d.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pool2-d.pbtxt @@ -122,7 +122,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pool3-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pool3-d.pbtxt index b689aaa278b..ea1c60e48d3 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pool3-d.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pool3-d.pbtxt @@ -122,7 +122,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pooling1-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pooling1-d.pbtxt index c38682001ce..b9f09656973 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pooling1-d.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pooling1-d.pbtxt @@ -122,7 +122,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pooling2-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pooling2-d.pbtxt index 51659f8a081..ade1e839676 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pooling2-d.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pooling2-d.pbtxt @@ -122,7 +122,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pooling3-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pooling3-d.pbtxt index 35346390ce5..2d129d415da 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pooling3-d.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pooling3-d.pbtxt @@ -122,7 +122,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-maximum.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-maximum.pbtxt index 0c803b3c689..b4adbbcbea2 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-maximum.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-maximum.pbtxt @@ -122,7 +122,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-minimum.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-minimum.pbtxt index ccc586965d3..12d2cc690b8 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-minimum.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-minimum.pbtxt @@ -122,7 +122,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-multiply.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-multiply.pbtxt index 4d5456281b0..5e5d3992927 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-multiply.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-multiply.pbtxt @@ -122,7 +122,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-p-re-l-u.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-p-re-l-u.pbtxt index d3972cf2ec1..733fb63d1fb 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-p-re-l-u.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-p-re-l-u.pbtxt @@ -121,7 +121,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-permute.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-permute.pbtxt index c3d8882fc02..3e2d70a5a0a 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-permute.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-permute.pbtxt @@ -121,7 +121,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-r-n-n.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-r-n-n.pbtxt index 00a65cd1e92..3018929154e 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-r-n-n.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-r-n-n.pbtxt @@ -125,7 +125,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-re-l-u.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-re-l-u.pbtxt index a38bd7d5412..7af41433d28 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-re-l-u.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-re-l-u.pbtxt @@ -121,7 +121,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-repeat-vector.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-repeat-vector.pbtxt index bbcdf228382..52eb2c247cf 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-repeat-vector.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-repeat-vector.pbtxt @@ -121,7 +121,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-reshape.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-reshape.pbtxt index 6664452d99e..08658b26be3 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-reshape.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-reshape.pbtxt @@ -121,7 +121,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-separable-conv1-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-separable-conv1-d.pbtxt index 718247fb031..9bab5a78338 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-separable-conv1-d.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-separable-conv1-d.pbtxt @@ -123,7 +123,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-separable-conv2-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-separable-conv2-d.pbtxt index 6d24ef70876..2bcc06f9330 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-separable-conv2-d.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-separable-conv2-d.pbtxt @@ -123,7 +123,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-separable-convolution1-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-separable-convolution1-d.pbtxt index 443e5f043c1..823e28a8bb9 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-separable-convolution1-d.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-separable-convolution1-d.pbtxt @@ -123,7 +123,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-separable-convolution2-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-separable-convolution2-d.pbtxt index a3d8f2a29b1..c27047ecd71 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-separable-convolution2-d.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-separable-convolution2-d.pbtxt @@ -123,7 +123,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt index a7c83abef1e..417e79df321 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt @@ -122,7 +122,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-simple-r-n-n.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-simple-r-n-n.pbtxt index 6b31632148e..e6e12106c6c 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-simple-r-n-n.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-simple-r-n-n.pbtxt @@ -182,7 +182,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-softmax.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-softmax.pbtxt index c84b56e466b..8b435bd2b41 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-softmax.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-softmax.pbtxt @@ -121,7 +121,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-spatial-dropout1-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-spatial-dropout1-d.pbtxt index b45678622cb..d5fbff4d5c6 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-spatial-dropout1-d.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-spatial-dropout1-d.pbtxt @@ -122,7 +122,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-spatial-dropout2-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-spatial-dropout2-d.pbtxt index fff0df98d0c..287e0167076 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-spatial-dropout2-d.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-spatial-dropout2-d.pbtxt @@ -122,7 +122,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-spatial-dropout3-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-spatial-dropout3-d.pbtxt index 34396757a54..78ab93ae395 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-spatial-dropout3-d.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-spatial-dropout3-d.pbtxt @@ -122,7 +122,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt index 35910318eb0..27afe1a56c6 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt @@ -129,7 +129,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-subtract.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-subtract.pbtxt index 2408c8676ff..b060c3169fd 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-subtract.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-subtract.pbtxt @@ -122,7 +122,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-thresholded-re-l-u.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-thresholded-re-l-u.pbtxt index b264adfbd87..272fd09afc6 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-thresholded-re-l-u.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-thresholded-re-l-u.pbtxt @@ -121,7 +121,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-time-distributed.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-time-distributed.pbtxt index 3236d720cf4..95274944084 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-time-distributed.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-time-distributed.pbtxt @@ -122,7 +122,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-up-sampling1-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-up-sampling1-d.pbtxt index bd15fd0202d..8c8f4f287bd 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-up-sampling1-d.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-up-sampling1-d.pbtxt @@ -121,7 +121,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-up-sampling2-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-up-sampling2-d.pbtxt index dd58cf38f3f..c56ea3122ed 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-up-sampling2-d.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-up-sampling2-d.pbtxt @@ -121,7 +121,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-up-sampling3-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-up-sampling3-d.pbtxt index e738ca01ce8..80c647c9fc1 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-up-sampling3-d.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-up-sampling3-d.pbtxt @@ -121,7 +121,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-wrapper.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-wrapper.pbtxt index dbc7c545d23..63423b9ee0c 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-wrapper.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-wrapper.pbtxt @@ -121,7 +121,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-zero-padding1-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-zero-padding1-d.pbtxt index f423bf9a270..e5a31b88df9 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-zero-padding1-d.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-zero-padding1-d.pbtxt @@ -121,7 +121,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-zero-padding2-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-zero-padding2-d.pbtxt index f871573794b..b170d030fe8 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-zero-padding2-d.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-zero-padding2-d.pbtxt @@ -121,7 +121,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-zero-padding3-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-zero-padding3-d.pbtxt index 7b3859cf838..6010e155661 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-zero-padding3-d.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-zero-padding3-d.pbtxt @@ -121,7 +121,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.-sync-batch-normalization.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.-sync-batch-normalization.pbtxt index bef86001afc..e3a91f6791b 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.-sync-batch-normalization.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.-sync-batch-normalization.pbtxt @@ -122,7 +122,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-center-crop.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-center-crop.pbtxt index 02bb9b6e6fa..4a846b138a9 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-center-crop.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-center-crop.pbtxt @@ -121,7 +121,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-normalization.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-normalization.pbtxt index 70a58143e59..ea54293bca1 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-normalization.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-normalization.pbtxt @@ -127,7 +127,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-preprocessing-layer.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-preprocessing-layer.pbtxt index e1e67575de6..d84d810bdd0 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-preprocessing-layer.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-preprocessing-layer.pbtxt @@ -125,7 +125,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-random-contrast.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-random-contrast.pbtxt index 009e4781cc1..c8cc33fea5d 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-random-contrast.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-random-contrast.pbtxt @@ -121,7 +121,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-random-crop.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-random-crop.pbtxt index b652c8d099f..2c6b4bc0c9c 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-random-crop.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-random-crop.pbtxt @@ -121,7 +121,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-random-flip.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-random-flip.pbtxt index 1a1aaea1964..782a7d56892 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-random-flip.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-random-flip.pbtxt @@ -121,7 +121,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-random-height.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-random-height.pbtxt index 1c4d6639f8f..769fbd0b5ac 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-random-height.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-random-height.pbtxt @@ -121,7 +121,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-random-rotation.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-random-rotation.pbtxt index 8aeee741de8..f539ee33804 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-random-rotation.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-random-rotation.pbtxt @@ -121,7 +121,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-random-translation.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-random-translation.pbtxt index 1b0daace7bf..57b20ce4031 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-random-translation.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-random-translation.pbtxt @@ -121,7 +121,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-random-width.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-random-width.pbtxt index 704fb827c45..d4b19e22028 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-random-width.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-random-width.pbtxt @@ -121,7 +121,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-random-zoom.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-random-zoom.pbtxt index 3fb7e6856c8..57eb2c9c175 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-random-zoom.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-random-zoom.pbtxt @@ -121,7 +121,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-rescaling.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-rescaling.pbtxt index a4c6a0b1510..7816930fd5c 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-rescaling.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-rescaling.pbtxt @@ -121,7 +121,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-resizing.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-resizing.pbtxt index 398b93ccd72..05f110140bf 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-resizing.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-resizing.pbtxt @@ -121,7 +121,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-text-vectorization.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-text-vectorization.pbtxt index 07f235d921d..025b1a013cb 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-text-vectorization.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-text-vectorization.pbtxt @@ -127,7 +127,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.pbtxt index 5574cc9ca59..3706919341d 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.pbtxt @@ -72,6 +72,10 @@ tf_module { name: "Conv1D" mtype: "" } + member { + name: "Conv1DTranspose" + mtype: "" + } member { name: "Conv2D" mtype: "" @@ -96,6 +100,10 @@ tf_module { name: "Convolution1D" mtype: "" } + member { + name: "Convolution1DTranspose" + mtype: "" + } member { name: "Convolution2D" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-a-u-c.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-a-u-c.pbtxt index 7c860b922bb..64ccf7c98ac 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-a-u-c.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-a-u-c.pbtxt @@ -122,7 +122,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-accuracy.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-accuracy.pbtxt index 768d9b7f6a3..d211a16597e 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-accuracy.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-accuracy.pbtxt @@ -125,7 +125,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-binary-accuracy.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-binary-accuracy.pbtxt index 9598d148015..58103637fe3 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-binary-accuracy.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-binary-accuracy.pbtxt @@ -125,7 +125,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-binary-crossentropy.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-binary-crossentropy.pbtxt index 8a8917d3e4d..4f748914101 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-binary-crossentropy.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-binary-crossentropy.pbtxt @@ -125,7 +125,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-categorical-accuracy.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-categorical-accuracy.pbtxt index 67fd36960d3..42e57f86769 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-categorical-accuracy.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-categorical-accuracy.pbtxt @@ -125,7 +125,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-categorical-crossentropy.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-categorical-crossentropy.pbtxt index 5ce32740800..6ef136de517 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-categorical-crossentropy.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-categorical-crossentropy.pbtxt @@ -125,7 +125,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-categorical-hinge.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-categorical-hinge.pbtxt index 9cb128b22e4..f3379748fb9 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-categorical-hinge.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-categorical-hinge.pbtxt @@ -125,7 +125,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-cosine-similarity.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-cosine-similarity.pbtxt index 1bb29ee6c6a..9367edcb228 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-cosine-similarity.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-cosine-similarity.pbtxt @@ -125,7 +125,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-false-negatives.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-false-negatives.pbtxt index 3079a731d7d..820d2ed1e7c 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-false-negatives.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-false-negatives.pbtxt @@ -123,7 +123,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-false-positives.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-false-positives.pbtxt index 146e10f944e..de23747dcef 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-false-positives.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-false-positives.pbtxt @@ -123,7 +123,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-hinge.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-hinge.pbtxt index 3c9fce69676..edae2d27448 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-hinge.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-hinge.pbtxt @@ -125,7 +125,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-k-l-divergence.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-k-l-divergence.pbtxt index 1eb59f483de..171ade560a9 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-k-l-divergence.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-k-l-divergence.pbtxt @@ -125,7 +125,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-log-cosh-error.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-log-cosh-error.pbtxt index cab6edabce1..8713d9aa427 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-log-cosh-error.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-log-cosh-error.pbtxt @@ -125,7 +125,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-mean-absolute-error.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-mean-absolute-error.pbtxt index 4c697325134..6c5541e71d9 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-mean-absolute-error.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-mean-absolute-error.pbtxt @@ -125,7 +125,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-mean-absolute-percentage-error.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-mean-absolute-percentage-error.pbtxt index f0796ab9d22..45b94842278 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-mean-absolute-percentage-error.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-mean-absolute-percentage-error.pbtxt @@ -125,7 +125,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-mean-io-u.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-mean-io-u.pbtxt index b99ecdbea74..90733200606 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-mean-io-u.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-mean-io-u.pbtxt @@ -122,7 +122,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-mean-relative-error.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-mean-relative-error.pbtxt index bc117988392..5997066e37a 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-mean-relative-error.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-mean-relative-error.pbtxt @@ -124,7 +124,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-mean-squared-error.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-mean-squared-error.pbtxt index 179e6384e47..b6a0e00ffa0 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-mean-squared-error.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-mean-squared-error.pbtxt @@ -125,7 +125,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-mean-squared-logarithmic-error.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-mean-squared-logarithmic-error.pbtxt index df65e2f5ad6..b698ab5ff65 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-mean-squared-logarithmic-error.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-mean-squared-logarithmic-error.pbtxt @@ -125,7 +125,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-mean-tensor.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-mean-tensor.pbtxt index 6632e586a6d..01a3d3f6e07 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-mean-tensor.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-mean-tensor.pbtxt @@ -130,7 +130,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-mean.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-mean.pbtxt index 428b2e336d8..c47a1bc749c 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-mean.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-mean.pbtxt @@ -123,7 +123,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-metric.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-metric.pbtxt index e0f0b32fb6d..fe72d0ad1d6 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-metric.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-metric.pbtxt @@ -121,7 +121,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-poisson.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-poisson.pbtxt index 7adbd07bb38..befbf09ed11 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-poisson.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-poisson.pbtxt @@ -125,7 +125,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-precision-at-recall.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-precision-at-recall.pbtxt index b6ff6bf42a6..3f001a9d4e2 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-precision-at-recall.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-precision-at-recall.pbtxt @@ -123,7 +123,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-precision.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-precision.pbtxt index 97e368aedd0..e4d66868b1b 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-precision.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-precision.pbtxt @@ -122,7 +122,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-recall-at-precision.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-recall-at-precision.pbtxt index a5f143ca664..daad023ef66 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-recall-at-precision.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-recall-at-precision.pbtxt @@ -123,7 +123,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-recall.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-recall.pbtxt index 19c29897a46..40329ff21be 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-recall.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-recall.pbtxt @@ -122,7 +122,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-root-mean-squared-error.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-root-mean-squared-error.pbtxt index 07cdff6c8ac..a05adf6070a 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-root-mean-squared-error.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-root-mean-squared-error.pbtxt @@ -124,7 +124,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-sensitivity-at-specificity.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-sensitivity-at-specificity.pbtxt index 17e3639509b..466ef391017 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-sensitivity-at-specificity.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-sensitivity-at-specificity.pbtxt @@ -123,7 +123,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-sparse-categorical-accuracy.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-sparse-categorical-accuracy.pbtxt index a0055d4ecee..12a1f5daa14 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-sparse-categorical-accuracy.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-sparse-categorical-accuracy.pbtxt @@ -125,7 +125,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-sparse-categorical-crossentropy.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-sparse-categorical-crossentropy.pbtxt index 5fe7617b2ec..86ae28fb876 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-sparse-categorical-crossentropy.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-sparse-categorical-crossentropy.pbtxt @@ -125,7 +125,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-sparse-top-k-categorical-accuracy.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-sparse-top-k-categorical-accuracy.pbtxt index e07b0b1c64f..6a52c10edbb 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-sparse-top-k-categorical-accuracy.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-sparse-top-k-categorical-accuracy.pbtxt @@ -125,7 +125,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-specificity-at-sensitivity.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-specificity-at-sensitivity.pbtxt index 99eb90769bd..08a9118eac0 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-specificity-at-sensitivity.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-specificity-at-sensitivity.pbtxt @@ -123,7 +123,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-squared-hinge.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-squared-hinge.pbtxt index cfc64c0c28e..810f4e61806 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-squared-hinge.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-squared-hinge.pbtxt @@ -125,7 +125,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-sum.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-sum.pbtxt index d1d771e3992..9ee1af61e34 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-sum.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-sum.pbtxt @@ -123,7 +123,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-top-k-categorical-accuracy.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-top-k-categorical-accuracy.pbtxt index 1c4b1eacc74..41dcc25644f 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-top-k-categorical-accuracy.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-top-k-categorical-accuracy.pbtxt @@ -125,7 +125,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-true-negatives.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-true-negatives.pbtxt index 8ee532ffd36..3726bff3850 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-true-negatives.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-true-negatives.pbtxt @@ -123,7 +123,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-true-positives.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-true-positives.pbtxt index 4b22ce73cea..1ca4fc5c21b 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-true-positives.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-true-positives.pbtxt @@ -123,7 +123,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-model.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-model.pbtxt index b77893eaeda..070fd981fc0 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-model.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-model.pbtxt @@ -143,7 +143,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-sequential.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-sequential.pbtxt index 939657fd748..51f9bce2453 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-sequential.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-sequential.pbtxt @@ -148,7 +148,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-a-u-c.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-a-u-c.pbtxt index 81bacd3d5c1..66ab6b5ca67 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-a-u-c.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-a-u-c.pbtxt @@ -122,7 +122,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-accuracy.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-accuracy.pbtxt index c7b8afc296e..021fe16877a 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-accuracy.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-accuracy.pbtxt @@ -125,7 +125,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-binary-accuracy.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-binary-accuracy.pbtxt index a1a31a499e4..c87a817fd32 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-binary-accuracy.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-binary-accuracy.pbtxt @@ -125,7 +125,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-binary-crossentropy.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-binary-crossentropy.pbtxt index c17e036672e..03480e9b8c9 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-binary-crossentropy.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-binary-crossentropy.pbtxt @@ -125,7 +125,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-categorical-accuracy.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-categorical-accuracy.pbtxt index e323260dfd7..fb635300604 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-categorical-accuracy.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-categorical-accuracy.pbtxt @@ -125,7 +125,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-categorical-crossentropy.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-categorical-crossentropy.pbtxt index 14ea539ac3d..155ab36818a 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-categorical-crossentropy.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-categorical-crossentropy.pbtxt @@ -125,7 +125,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-categorical-hinge.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-categorical-hinge.pbtxt index 4154ebd8e1b..f5fb6f79099 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-categorical-hinge.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-categorical-hinge.pbtxt @@ -125,7 +125,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-cosine-similarity.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-cosine-similarity.pbtxt index 51a85c4ab60..8c8e76c2f3e 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-cosine-similarity.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-cosine-similarity.pbtxt @@ -125,7 +125,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-false-negatives.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-false-negatives.pbtxt index 2d8a3034399..1427375ea4d 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-false-negatives.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-false-negatives.pbtxt @@ -123,7 +123,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-false-positives.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-false-positives.pbtxt index 65cb8f74d56..06095a48cab 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-false-positives.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-false-positives.pbtxt @@ -123,7 +123,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-hinge.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-hinge.pbtxt index 5c32f3f766e..623f74d34fc 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-hinge.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-hinge.pbtxt @@ -125,7 +125,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-k-l-divergence.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-k-l-divergence.pbtxt index 6da06925202..6bf1092aa68 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-k-l-divergence.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-k-l-divergence.pbtxt @@ -125,7 +125,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-log-cosh-error.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-log-cosh-error.pbtxt index 63c4092edf6..7812d8715b5 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-log-cosh-error.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-log-cosh-error.pbtxt @@ -125,7 +125,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-mean-absolute-error.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-mean-absolute-error.pbtxt index a768dafb019..0719052ccab 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-mean-absolute-error.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-mean-absolute-error.pbtxt @@ -125,7 +125,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-mean-absolute-percentage-error.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-mean-absolute-percentage-error.pbtxt index 3ba140d3a1c..c644e71f8a1 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-mean-absolute-percentage-error.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-mean-absolute-percentage-error.pbtxt @@ -125,7 +125,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-mean-io-u.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-mean-io-u.pbtxt index 10dff360790..102f4715c7e 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-mean-io-u.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-mean-io-u.pbtxt @@ -122,7 +122,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-mean-relative-error.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-mean-relative-error.pbtxt index 206be798e6a..c4c5608c4c0 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-mean-relative-error.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-mean-relative-error.pbtxt @@ -124,7 +124,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-mean-squared-error.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-mean-squared-error.pbtxt index 5d7053c3faf..6847db69ead 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-mean-squared-error.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-mean-squared-error.pbtxt @@ -125,7 +125,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-mean-squared-logarithmic-error.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-mean-squared-logarithmic-error.pbtxt index 9aede80e88f..fe8c55b9465 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-mean-squared-logarithmic-error.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-mean-squared-logarithmic-error.pbtxt @@ -125,7 +125,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-mean-tensor.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-mean-tensor.pbtxt index e8235673c7e..e43e5265258 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-mean-tensor.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-mean-tensor.pbtxt @@ -130,7 +130,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-mean.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-mean.pbtxt index 52c3e918010..943b3029985 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-mean.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-mean.pbtxt @@ -123,7 +123,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-metric.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-metric.pbtxt index f980068b19b..2f9d17d37e9 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-metric.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-metric.pbtxt @@ -121,7 +121,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-poisson.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-poisson.pbtxt index e33eb6bde0a..b583863298c 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-poisson.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-poisson.pbtxt @@ -125,7 +125,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-precision-at-recall.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-precision-at-recall.pbtxt index cd29bc64128..ddd3954ed18 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-precision-at-recall.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-precision-at-recall.pbtxt @@ -123,7 +123,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-precision.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-precision.pbtxt index 7fdfcb05897..9d74d0fbc70 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-precision.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-precision.pbtxt @@ -122,7 +122,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-recall-at-precision.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-recall-at-precision.pbtxt index 62ea36e01a2..7ad458bc2db 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-recall-at-precision.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-recall-at-precision.pbtxt @@ -123,7 +123,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-recall.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-recall.pbtxt index 6b517e93185..c6f4672ff20 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-recall.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-recall.pbtxt @@ -122,7 +122,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-root-mean-squared-error.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-root-mean-squared-error.pbtxt index 79ae8ae4e0f..97ba4531f73 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-root-mean-squared-error.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-root-mean-squared-error.pbtxt @@ -124,7 +124,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-sensitivity-at-specificity.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-sensitivity-at-specificity.pbtxt index 21b837f1365..b29b4cdfa54 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-sensitivity-at-specificity.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-sensitivity-at-specificity.pbtxt @@ -123,7 +123,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-sparse-categorical-accuracy.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-sparse-categorical-accuracy.pbtxt index 2cbb97cfeee..408743cb054 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-sparse-categorical-accuracy.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-sparse-categorical-accuracy.pbtxt @@ -125,7 +125,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-sparse-categorical-crossentropy.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-sparse-categorical-crossentropy.pbtxt index 2e1d712e596..ebb08e7efc1 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-sparse-categorical-crossentropy.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-sparse-categorical-crossentropy.pbtxt @@ -125,7 +125,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-sparse-top-k-categorical-accuracy.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-sparse-top-k-categorical-accuracy.pbtxt index 515e2468d02..4c8925cd1c9 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-sparse-top-k-categorical-accuracy.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-sparse-top-k-categorical-accuracy.pbtxt @@ -125,7 +125,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-specificity-at-sensitivity.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-specificity-at-sensitivity.pbtxt index e88b245b097..c714447e72a 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-specificity-at-sensitivity.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-specificity-at-sensitivity.pbtxt @@ -123,7 +123,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-squared-hinge.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-squared-hinge.pbtxt index 1909076bb44..ca6a07356ca 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-squared-hinge.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-squared-hinge.pbtxt @@ -125,7 +125,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-sum.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-sum.pbtxt index 64ace36282a..b5985644e73 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-sum.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-sum.pbtxt @@ -123,7 +123,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-top-k-categorical-accuracy.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-top-k-categorical-accuracy.pbtxt index 9edafe0753d..99b712beb94 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-top-k-categorical-accuracy.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-top-k-categorical-accuracy.pbtxt @@ -125,7 +125,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-true-negatives.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-true-negatives.pbtxt index 29904ce80d2..2c31194e622 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-true-negatives.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-true-negatives.pbtxt @@ -123,7 +123,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-true-positives.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-true-positives.pbtxt index a65949189c5..d2567e9050d 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-true-positives.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-true-positives.pbtxt @@ -123,7 +123,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.nn.-r-n-n-cell-device-wrapper.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.nn.-r-n-n-cell-device-wrapper.pbtxt index e94252a2f6d..7b08c50d6f9 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.nn.-r-n-n-cell-device-wrapper.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.nn.-r-n-n-cell-device-wrapper.pbtxt @@ -132,7 +132,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.nn.-r-n-n-cell-dropout-wrapper.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.nn.-r-n-n-cell-dropout-wrapper.pbtxt index 9682fd0a29a..7fc672d250e 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.nn.-r-n-n-cell-dropout-wrapper.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.nn.-r-n-n-cell-dropout-wrapper.pbtxt @@ -136,7 +136,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.nn.-r-n-n-cell-residual-wrapper.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.nn.-r-n-n-cell-residual-wrapper.pbtxt index 0cd0cebfa1d..f7cbc16a57d 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.nn.-r-n-n-cell-residual-wrapper.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.nn.-r-n-n-cell-residual-wrapper.pbtxt @@ -132,7 +132,7 @@ tf_class { } member_method { name: "add_metric" - argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "add_update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.profiler.experimental.-trace.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.profiler.experimental.-trace.pbtxt index cd7e6631047..c2d22b28677 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.profiler.experimental.-trace.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.profiler.experimental.-trace.pbtxt @@ -6,4 +6,8 @@ tf_class { name: "__init__" argspec: "args=[\'self\', \'name\'], varargs=None, keywords=kwargs, defaults=None" } + member_method { + name: "set_metadata" + argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None" + } } diff --git a/tensorflow/tools/ci_build/ci_parameterized_build.sh b/tensorflow/tools/ci_build/ci_parameterized_build.sh deleted file mode 100755 index 49bb51465a4..00000000000 --- a/tensorflow/tools/ci_build/ci_parameterized_build.sh +++ /dev/null @@ -1,669 +0,0 @@ -#!/usr/bin/env bash -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -# -# Usage: -# ci_parameterized_build.sh -# -# The script obeys the following required environment variables: -# TF_BUILD_CONTAINER_TYPE: (CPU | GPU | ANDROID | ANDROID_FULL) -# TF_BUILD_PYTHON_VERSION: (PYTHON2 | PYTHON3 | PYTHON3.5) -# TF_BUILD_IS_PIP: (NO_PIP | PIP | BOTH) -# -# The below environment variable is required, but will be deprecated together -# with TF_BUILD_MAVX and both will be replaced by TF_BUILD_OPTIONS. -# TF_BUILD_IS_OPT: (NO_OPT | OPT) -# -# Note: -# 1) Certain combinations of parameter values are regarded -# as invalid and will cause the script to exit with code 0. For example: -# NO_OPT & PIP (PIP builds should always use OPT) -# ANDROID & PIP (Android and PIP builds are mutually exclusive) -# -# 2) TF_BUILD_PYTHON_VERSION is set to PYTHON3, the build will use the version -# pointed to by "which python3" on the system, which is typically python3.4. To -# build for python3.5, set the environment variable to PYTHON3.5 -# -# -# Additionally, the script follows the directions of optional environment -# variables: -# TF_BUILD_DRY_RUN: If it is set to any non-empty value that is not "0", -# the script will just generate and print the final -# command, but not actually run it. -# TF_BUILD_APPEND_CI_DOCKER_EXTRA_PARAMS: -# String appended to the content of CI_DOCKER_EXTRA_PARAMS -# TF_BUILD_APPEND_ARGUMENTS: -# Additional command line arguments for the bazel, -# pip.sh or android.sh command -# TF_BUILD_MAVX: (Soon to be deprecated, use TF_BUILD_OPTIONS instead) -# (unset | MAVX | MAVX2) -# If set to MAVX or MAVX2, will cause bazel to use the -# additional flag --copt=-mavx or --copt=-mavx2, to -# perform AVX or AVX2 builds, respectively. This requires -# AVX- or AVX2-compatible CPUs. -# TF_BUILD_BAZEL_TARGET: -# Used to override the default bazel build target: -# //tensorflow/... -//tensorflow/compiler -# TF_BUILD_BAZEL_CLEAN: -# Will perform "bazel clean", if and only if this variable -# is set to any non-empty and non-0 value -# TF_BAZEL_BUILD_ONLY: -# If it is set to any non-empty value that is not "0", Bazel -# will only build specified targets -# TF_GPU_COUNT: -# Run this many parallel tests for serial builds. -# For now, only can be edited for PIP builds. -# TF_BUILD_TEST_TUTORIALS: -# If set to any non-empty and non-0 value, will perform -# tutorials tests (Applicable only if TF_BUILD_IS_PIP is -# PIP or BOTH). -# See builds/test_tutorials.sh -# TF_BUILD_INTEGRATION_TESTS: -# If set this will perform integration tests. See -# builds/integration_tests.sh. -# TF_BUILD_RUN_BENCHMARKS: -# If set to any non-empty and non-0 value, will perform -# the benchmark tests (see *_logged_benchmark targets in -# tools/test/BUILD) -# TF_BUILD_OPTIONS: -# (FASTBUILD | OPT | OPTDBG | MAVX | MAVX2_FMA | MAVX_DBG | -# MAVX2_FMA_DBG) -# Use the specified configurations when building. -# When set, overrides TF_BUILD_IS_OPT and TF_BUILD_MAVX -# options, as this will replace the two. -# TF_BUILD_TEST_TIMEOUT: -# Sets the value of bazel --test_timeout, defaults to -1 -# which uses the bazel defaults. -# TF_SKIP_CONTRIB_TESTS: -# If set to any non-empty or non-0 value, will skip running -# contrib tests. -# TF_NIGHTLY: -# If this run is being used to build the tf_nightly pip -# packages. -# TF_CUDA_CLANG: -# If set to 1, builds and runs cuda_clang configuration. -# Only available inside GPU containers. -# -# This script can be used by Jenkins parameterized / matrix builds. - -set -ex - -# Helper function: Convert to lower case -to_lower () { - echo "$1" | tr '[:upper:]' '[:lower:]' -} - -# Helper function: Strip leading and trailing whitespaces -str_strip () { - echo -e "$1" | sed -e 's/^[[:space:]]*//' -e 's/[[:space:]]*$//' -} - -# Helper function: Exit on failure -die () { - echo $@ - exit 1 -} - -########################################################## -# Default configuration -CI_BUILD_DIR="tensorflow/tools/ci_build" - -# Command to call when Docker is available -DOCKER_MAIN_CMD="${CI_BUILD_DIR}/ci_build.sh" -# Command to call when Docker is unavailable -NO_DOCKER_MAIN_CMD="${CI_BUILD_DIR}/builds/configured" - -# Additional option flags to apply when Docker is unavailable (e.g., on Mac) -NO_DOCKER_OPT_FLAG="--genrule_strategy=standalone" - -DO_DOCKER=1 - -# Default values for various settings. -TF_BUILD_TEST_TIMEOUT=${TF_BUILD_TEST_TIMEOUT:--1} # Use bazel defaults -TF_GPU_COUNT=${TF_GPU_COUNT:-4} - -# Helpful flags: -# --test_summary=detailed: Tell us more about which targets are being built -# --keep_going: Don't stop at the first failure; tell us all the failures -# --build_tests_only: Don't build targets depended on by tests if the test is -# disabled. Also saves some compilation time. Otherwise, -# tries to build everything. -# --test_timeout: Test timeouts in the order short,moderate,long,eternal. -# --test_env: Environment variables to set when running bazel tests. These are -# especially important when using --run_under with -# parallel_gpu_execute. -BAZEL_TEST_FLAGS=""\ -"--test_summary=detailed --build_tests_only --keep_going "\ -"--test_timeout=${TF_BUILD_TEST_TIMEOUT} "\ -"--test_env=TF_GPU_COUNT=${TF_GPU_COUNT}" - -# Only set these environment variables if they're specified, to avoid causing -# problems like b/118404869, where an envvar set to the empty string has -# different semantics from an unset envvar. -if [ -n "${TF_TESTS_PER_GPU}" ]; then - BAZEL_TEST_FLAGS="${BAZEL_TEST_FLAGS} "\ -"--test_env=TF_TESTS_PER_GPU=${TF_TESTS_PER_GPU}" -fi -if [ -n "${TF_PER_DEVICE_MEMORY_LIMIT_MB}" ]; then - BAZEL_TEST_FLAGS="${BAZEL_TEST_FLAGS} "\ -"--test_env=TF_PER_DEVICE_MEMORY_LIMIT_MB=${TF_PER_DEVICE_MEMORY_LIMIT_MB}" -fi - -BAZEL_BUILD_FLAGS="--keep_going" - -# Explicitly set jdk8 since that's what's installed in our images. Note that -# bazel 0.16 and higher defaults to jdk9, which causes failures. See b/117634064 -BAZEL_JAVA_FLAGS="--java_toolchain=@bazel_tools//tools/jdk:toolchain_hostjdk8" - -BAZEL_CMD="bazel test ${BAZEL_TEST_FLAGS} ${BAZEL_JAVA_FLAGS}" -BAZEL_BUILD_ONLY_CMD="bazel build ${BAZEL_BUILD_FLAGS} ${BAZEL_JAVA_FLAGS}" -BAZEL_CLEAN_CMD="bazel clean" - -PIP_CMD="${CI_BUILD_DIR}/builds/pip.sh" -PIP_TEST_TUTORIALS_FLAG="--test_tutorials" -PIP_INTEGRATION_TESTS_FLAG="--integration_tests" -ANDROID_CMD="${CI_BUILD_DIR}/builds/android.sh" -ANDROID_FULL_CMD="${CI_BUILD_DIR}/builds/android_full.sh" - -PARALLEL_GPU_TEST_CMD='//tensorflow/tools/ci_build/gpu_build:parallel_gpu_execute' - -BENCHMARK_CMD="${CI_BUILD_DIR}/builds/benchmark.sh" - -EXTRA_PARAMS="" -BAZEL_TARGET="//tensorflow/... -//tensorflow/compiler/..." - -if [[ -n "$TF_SKIP_CONTRIB_TESTS" ]]; then - BAZEL_TARGET="${BAZEL_TARGET} -//tensorflow/contrib/..." -fi - -TUT_TEST_DATA_DIR="/tmp/tf_tutorial_test_data" - -########################################################## - -echo "Parameterized build starts at: $(date)" -echo "" -START_TIME=$(date +'%s') - -# Convert all the required environment variables to lower case -TF_BUILD_CONTAINER_TYPE=$(to_lower ${TF_BUILD_CONTAINER_TYPE}) -TF_BUILD_PYTHON_VERSION=$(to_lower ${TF_BUILD_PYTHON_VERSION}) -TF_BUILD_IS_OPT=$(to_lower ${TF_BUILD_IS_OPT}) -TF_BUILD_IS_PIP=$(to_lower ${TF_BUILD_IS_PIP}) - -if [[ ! -z "${TF_BUILD_MAVX}" ]]; then - TF_BUILD_MAVX=$(to_lower ${TF_BUILD_MAVX}) -fi - - -# Print parameter values -echo "Required build parameters:" -echo " TF_BUILD_CONTAINER_TYPE=${TF_BUILD_CONTAINER_TYPE}" -echo " TF_BUILD_PYTHON_VERSION=${TF_BUILD_PYTHON_VERSION}" -echo " TF_BUILD_IS_OPT=${TF_BUILD_IS_OPT}" -echo " TF_BUILD_IS_PIP=${TF_BUILD_IS_PIP}" -echo "Optional build parameters:" -echo " TF_BUILD_DRY_RUN=${TF_BUILD_DRY_RUN}" -echo " TF_BUILD_MAVX=${TF_BUILD_MAVX}" -echo " TF_BUILD_APPEND_CI_DOCKER_EXTRA_PARAMS="\ -"${TF_BUILD_APPEND_CI_DOCKER_EXTRA_PARAMS}" -echo " TF_BUILD_APPEND_ARGUMENTS=${TF_BUILD_APPEND_ARGUMENTS}" -echo " TF_BUILD_BAZEL_TARGET=${TF_BUILD_BAZEL_TARGET}" -echo " TF_BUILD_BAZEL_CLEAN=${TF_BUILD_BAZEL_CLEAN}" -echo " TF_BUILD_TEST_TUTORIALS=${TF_BUILD_TEST_TUTORIALS}" -echo " TF_BUILD_INTEGRATION_TESTS=${TF_BUILD_INTEGRATION_TESTS}" -echo " TF_BUILD_RUN_BENCHMARKS=${TF_BUILD_RUN_BENCHMARKS}" -echo " TF_BUILD_OPTIONS=${TF_BUILD_OPTIONS}" - - -# Function that tries to determine CUDA capability, if deviceQuery binary -# is available on path -function get_cuda_capability_version() { - if [[ ! -z $(which deviceQuery) ]]; then - # The first listed device is used - deviceQuery | grep "CUDA Capability .* version" | \ - head -1 | awk '{print $NF}' - fi -} - -# Container type, e.g., CPU, GPU -CTYPE=${TF_BUILD_CONTAINER_TYPE} - -# Determine if the machine is a Mac -OPT_FLAG="--test_output=errors" -if [[ "$(uname -s)" == "Darwin" ]]; then - DO_DOCKER=0 - - echo "It appears this machine is a Mac. "\ -"We will perform this build without Docker." - echo "Also, the additional option flags will be applied to the build:" - echo " ${NO_DOCKER_OPT_FLAG}" - MAIN_CMD="${NO_DOCKER_MAIN_CMD} ${CTYPE}" - OPT_FLAG="${OPT_FLAG} ${NO_DOCKER_OPT_FLAG}" -fi - -# In DO_DOCKER mode, appends environment variable to docker's run invocation. -# Otherwise, exports the corresponding variable. -function set_script_variable() { - local VAR="$1" - local VALUE="$2" - if [[ $DO_DOCKER == "1" ]]; then - TF_BUILD_APPEND_CI_DOCKER_EXTRA_PARAMS="${TF_BUILD_APPEND_CI_DOCKER_EXTRA_PARAMS} -e $VAR=$VALUE" - else - export $VAR="$VALUE" - fi -} - - -# Process container type -if [[ ${CTYPE} == cpu* ]] || [[ ${CTYPE} == "debian.jessie.cpu" ]]; then - : -elif [[ ${CTYPE} == gpu* ]]; then - set_script_variable TF_NEED_CUDA 1 - - if [[ $TF_CUDA_CLANG == "1" ]]; then - OPT_FLAG="${OPT_FLAG} --config=cuda_clang" - - set_script_variable TF_CUDA_CLANG 1 - # For cuda_clang we download `clang` while building. - set_script_variable TF_DOWNLOAD_CLANG 1 - else - OPT_FLAG="${OPT_FLAG} --config=cuda" - fi - - # Attempt to determine CUDA capability version automatically and use it if - # CUDA capability version is not specified by the environment variables. - CUDA_CAPA_VER=$(get_cuda_capability_version) - - if [[ ! -z ${CUDA_CAPA_VER} ]]; then - AUTO_CUDA_CAPA_VER=0 - if [[ ${DO_DOCKER} == "1" ]] && \ - [[ "${TF_BUILD_APPEND_CI_DOCKER_EXTRA_PARAMS}" != \ - *"TF_CUDA_COMPUTE_CAPABILITIES="* ]]; then - AUTO_CUDA_CAPA_VER=1 - TF_BUILD_APPEND_CI_DOCKER_EXTRA_PARAMS=\ -"${TF_BUILD_APPEND_CI_DOCKER_EXTRA_PARAMS} -e "\ -"TF_CUDA_COMPUTE_CAPABILITIES=${CUDA_CAPA_VER}" - - echo "Docker GPU build: TF_BUILD_APPEND_CI_DOCKER_EXTRA_PARAMS="\ -"\"${TF_BUILD_APPEND_CI_DOCKER_EXTRA_PARAMS}\"" - elif [[ ${DO_DOCKER} == "0" ]] && \ - [[ -z "${TF_CUDA_COMPUTE_CAPABILITIES}" ]]; then - AUTO_CUDA_CAPA_VER=1 - TF_CUDA_COMPUTE_CAPABILITIES="${CUDA_CAPA_VER}" - - echo "Non-Docker GPU build: TF_CUDA_COMPUTE_CAPABILITIES="\ -"\"${TF_CUDA_COMPUTE_CAPABILITIES}\"" - fi - - if [[ ${AUTO_CUDA_CAPA_VER} == "1" ]]; then - echo "TF_CUDA_COMPUTE_CAPABILITIES is not set:" - echo "Using CUDA capability version from deviceQuery: ${CUDA_CAPA_VER}" - echo "" - fi - fi -elif [[ ${CTYPE} == "android" ]] || [[ ${CTYPE} == "android_full" ]]; then - : -else - die "Unrecognized value in TF_BUILD_CONTAINER_TYPE: "\ -"\"${TF_BUILD_CONTAINER_TYPE}\"" -fi - -# Determine if this is a benchmarks job -RUN_BENCHMARKS=0 -if [[ ! -z "${TF_BUILD_RUN_BENCHMARKS}" ]] && - [[ "${TF_BUILD_RUN_BENCHMARKS}" != "0" ]]; then - RUN_BENCHMARKS=1 -fi - -# Process Bazel "-c opt" flag -if [[ -z "${TF_BUILD_OPTIONS}" ]]; then - if [[ ${TF_BUILD_IS_OPT} == "no_opt" ]]; then - # PIP builds are done only with the -c opt flag - if [[ ${TF_BUILD_IS_PIP} == "pip" ]]; then - echo "Skipping parameter combination: ${TF_BUILD_IS_OPT} & "\ -"${TF_BUILD_IS_PIP}" - exit 0 - fi - - elif [[ ${TF_BUILD_IS_OPT} == "opt" ]]; then - OPT_FLAG="${OPT_FLAG} -c opt" - else - die "Unrecognized value in TF_BUILD_IS_OPT: \"${TF_BUILD_IS_OPT}\"" - fi - - # Process MAVX option - if [[ ! -z "${TF_BUILD_MAVX}" ]]; then - if [[ "${TF_BUILD_MAVX}" == "mavx" ]]; then - OPT_FLAG="${OPT_FLAG} --copt=-mavx" - elif [[ "${TF_BUILD_MAVX}" == "mavx2" ]]; then - OPT_FLAG="${OPT_FLAG} --copt=-mavx2" - else - die "Unsupported value in TF_BUILD_MAVX: ${TF_BUILD_MAVX}" - fi - fi -else - case $TF_BUILD_OPTIONS in - FASTBUILD) - echo "Running FASTBUILD mode (noopt, nodbg)." - ;; - OPT) - OPT_FLAG="${OPT_FLAG} -c opt" - ;; - OPTDBG) - OPT_FLAG="${OPT_FLAG} -c opt --copt=-g" - ;; - MAVX) - OPT_FLAG="${OPT_FLAG} -c opt --copt=-mavx" - ;; - MAVX_DBG) - OPT_FLAG="${OPT_FLAG} -c opt --copt=-g --copt=-mavx" - ;; - MAVX2_FMA) - OPT_FLAG="${OPT_FLAG} -c opt --copt=-mavx2 --copt=-mfma" - ;; - MAVX2_FMA_DBG) - OPT_FLAG="${OPT_FLAG} -c opt --copt=-g --copt=-mavx2 --copt=-mfma" - ;; - esac -fi - -# Strip whitespaces from OPT_FLAG -OPT_FLAG=$(str_strip "${OPT_FLAG}") - - -# 1) Filter out benchmark tests if this is not a benchmarks job; -# 2) Filter out tests with the "nomac" tag if the build is on Mac OS X. -EXTRA_ARGS=${DEFAULT_BAZEL_CONFIGS} -IS_MAC=0 -if [[ "$(uname)" == "Darwin" ]]; then - IS_MAC=1 -fi -if [[ "${TF_BUILD_APPEND_ARGUMENTS}" == *"--test_tag_filters="* ]]; then - ITEMS=(${TF_BUILD_APPEND_ARGUMENTS}) - - for ITEM in "${ITEMS[@]}"; do - if [[ ${ITEM} == *"--test_tag_filters="* ]]; then - NEW_ITEM="${ITEM}" - if [[ ${NEW_ITEM} != *"benchmark-test"* ]]; then - NEW_ITEM="${NEW_ITEM},-benchmark-test" - fi - if [[ ${IS_MAC} == "1" ]] && [[ ${NEW_ITEM} != *"nomac"* ]]; then - # TODO(b/122370901): Fix nomac, no_mac inconsistency. - NEW_ITEM="${NEW_ITEM},-nomac,-no_mac" - fi - EXTRA_ARGS="${EXTRA_ARGS} ${NEW_ITEM}" - else - EXTRA_ARGS="${EXTRA_ARGS} ${ITEM}" - fi - done -else - EXTRA_ARGS="${EXTRA_ARGS} ${TF_BUILD_APPEND_ARGUMENTS} --test_tag_filters=-no_oss,-oss_serial,-benchmark-test" - if [[ ${IS_MAC} == "1" ]]; then - # TODO(b/122370901): Fix nomac, no_mac inconsistency. - EXTRA_ARGS="${EXTRA_ARGS},-nomac,-no_mac" - fi - EXTRA_ARGS="${EXTRA_ARGS} --build_tag_filters=-no_oss,-oss_serial,-benchmark-test" - if [[ ${IS_MAC} == "1" ]]; then - # TODO(b/122370901): Fix nomac, no_mac inconsistency. - EXTRA_ARGS="${EXTRA_ARGS},-nomac,-no_mac" - fi -fi - -# For any "tool" dependencies in genrules, Bazel will build them for host -# instead of the target configuration. We can save some build time by setting -# this flag, and it only affects a few tests. -EXTRA_ARGS="${EXTRA_ARGS} --distinct_host_configuration=false" - -if [[ ! -z "${TF_BAZEL_BUILD_ONLY}" ]] && - [[ "${TF_BAZEL_BUILD_ONLY}" != "0" ]];then - BAZEL_CMD=${BAZEL_BUILD_ONLY_CMD} -fi - -# Process PIP install-test option -if [[ ${TF_BUILD_IS_PIP} == "no_pip" ]] || - [[ ${TF_BUILD_IS_PIP} == "both" ]]; then - # Process optional bazel target override - if [[ ! -z "${TF_BUILD_BAZEL_TARGET}" ]]; then - BAZEL_TARGET=${TF_BUILD_BAZEL_TARGET} - fi - - if [[ ${CTYPE} == cpu* ]] || \ - [[ ${CTYPE} == "debian.jessie.cpu" ]]; then - # CPU only command, fully parallel. - NO_PIP_MAIN_CMD="${MAIN_CMD} ${BAZEL_CMD} ${OPT_FLAG} "\ -"${EXTRA_ARGS} -- ${BAZEL_TARGET}" - elif [[ ${CTYPE} == gpu* ]]; then - # GPU only command, run as many jobs as the GPU count only. - NO_PIP_MAIN_CMD="${BAZEL_CMD} ${OPT_FLAG} "\ -"--local_test_jobs=${TF_GPU_COUNT} "\ -"--run_under=${PARALLEL_GPU_TEST_CMD} "\ -"${EXTRA_ARGS} -- ${BAZEL_TARGET}" - elif [[ ${CTYPE} == "android" ]]; then - # Run android specific script for android build. - NO_PIP_MAIN_CMD="${ANDROID_CMD} ${OPT_FLAG} " - elif [[ ${CTYPE} == "android_full" ]]; then - # Run android specific script for full android build. - NO_PIP_MAIN_CMD="${ANDROID_FULL_CMD} ${OPT_FLAG} " - fi - -fi - -if [[ ${TF_BUILD_IS_PIP} == "pip" ]] || - [[ ${TF_BUILD_IS_PIP} == "both" ]]; then - # Android builds conflict with PIP builds - if [[ ${CTYPE} == "android" ]]; then - echo "Skipping parameter combination: ${TF_BUILD_IS_PIP} & "\ -"${TF_BUILD_CONTAINER_TYPE}" - exit 0 - fi - - PIP_MAIN_CMD="${MAIN_CMD} ${PIP_CMD} ${CTYPE} ${EXTRA_ARGS} ${OPT_FLAG}" - - # Add flag for integration tests - if [[ ! -z "${TF_BUILD_INTEGRATION_TESTS}" ]] && - [[ "${TF_BUILD_INTEGRATION_TESTS}" != "0" ]]; then - PIP_MAIN_CMD="${PIP_MAIN_CMD} ${PIP_INTEGRATION_TESTS_FLAG}" - fi - - # Add command for tutorial test - if [[ ! -z "${TF_BUILD_TEST_TUTORIALS}" ]] && - [[ "${TF_BUILD_TEST_TUTORIALS}" != "0" ]]; then - PIP_MAIN_CMD="${PIP_MAIN_CMD} ${PIP_TEST_TUTORIALS_FLAG}" - - # Prepare data directory for tutorial tests - mkdir -p "${TUT_TEST_DATA_DIR}" || - die "FAILED to create data directory for tutorial tests: "\ - "${TUT_TEST_DATA_DIR}" - - if [[ "${DO_DOCKER}" == "1" ]]; then - EXTRA_PARAMS="${EXTRA_PARAMS} -v ${TUT_TEST_DATA_DIR}:${TUT_TEST_DATA_DIR}" - fi - fi -fi - - -if [[ ${RUN_BENCHMARKS} == "1" ]]; then - MAIN_CMD="${BENCHMARK_CMD} ${OPT_FLAG}" -elif [[ ${TF_BUILD_IS_PIP} == "no_pip" ]]; then - MAIN_CMD="${NO_PIP_MAIN_CMD}" -elif [[ ${TF_BUILD_IS_PIP} == "pip" ]]; then - MAIN_CMD="${PIP_MAIN_CMD}" -elif [[ ${TF_BUILD_IS_PIP} == "both" ]]; then - MAIN_CMD="${NO_PIP_MAIN_CMD} && ${PIP_MAIN_CMD}" -else - die "Unrecognized value in TF_BUILD_IS_PIP: \"${TF_BUILD_IS_PIP}\"" -fi - -# Check if this is a tf_nightly build -if [[ "${TF_NIGHTLY}" == "1" ]]; then - EXTRA_PARAMS="${EXTRA_PARAMS} -e TF_NIGHTLY=1" -fi - -# Process Python version -if [[ ${TF_BUILD_PYTHON_VERSION} == "python2" ]]; then - : -elif [[ ${TF_BUILD_PYTHON_VERSION} == "python3" || \ - ${TF_BUILD_PYTHON_VERSION} == "python3.4" || \ - ${TF_BUILD_PYTHON_VERSION} == "python3.5" || \ - ${TF_BUILD_PYTHON_VERSION} == "python3.6" ]]; then - # Supply proper environment variable to select Python 3 - if [[ "${DO_DOCKER}" == "1" ]]; then - EXTRA_PARAMS="${EXTRA_PARAMS} -e CI_BUILD_PYTHON=${TF_BUILD_PYTHON_VERSION}" - else - # Determine the path to python3 - PYTHON3_PATH=$(which "${TF_BUILD_PYTHON_VERSION}" | head -1) - if [[ -z "${PYTHON3_PATH}" ]]; then - die "ERROR: Failed to locate ${TF_BUILD_PYTHON_VERSION} binary on path" - else - echo "Found ${TF_BUILD_PYTHON_VERSION} binary at: ${PYTHON3_PATH}" - fi - - export PYTHON_BIN_PATH="${PYTHON3_PATH}" - fi - -else - die "Unrecognized value in TF_BUILD_PYTHON_VERSION: "\ -"\"${TF_BUILD_PYTHON_VERSION}\"" -fi - -# Append additional Docker extra parameters -EXTRA_PARAMS="${EXTRA_PARAMS} ${TF_BUILD_APPEND_CI_DOCKER_EXTRA_PARAMS}" - -# Finally, do a dry run or call the command - -# The command, which may consist of multiple parts (e.g., in the case of -# TF_BUILD_SERIAL_TESTS=1), are written to a bash script, which is -# then called. The name of the script is randomized to make concurrent -# builds on the node possible. -TMP_SCRIPT="$(mktemp)_ci_parameterized_build.sh" - -if [[ "${DO_DOCKER}" == "1" ]]; then - # Map the tmp script into the Docker container - EXTRA_PARAMS="${EXTRA_PARAMS} -v ${TMP_SCRIPT}:/tmp/tf_build.sh" - - if [[ ! -z "${TF_BUILD_BAZEL_CLEAN}" ]] && - [[ "${TF_BUILD_BAZEL_CLEAN}" != "0" ]] && - [[ "${TF_BUILD_IS_PIP}" != "both" ]]; then - # For TF_BUILD_IS_PIP == both, "bazel clean" will have already - # been performed before the "bazel test" step - EXTRA_PARAMS="${EXTRA_PARAMS} -e TF_BUILD_BAZEL_CLEAN=1" - fi - - EXTRA_PARAMS=$(str_strip "${EXTRA_PARAMS}") - - echo "Exporting CI_DOCKER_EXTRA_PARAMS: ${EXTRA_PARAMS}" - export CI_DOCKER_EXTRA_PARAMS="${EXTRA_PARAMS}" -fi - -# Write to the tmp script -echo "#!/usr/bin/env bash" > ${TMP_SCRIPT} -if [[ ! -z "${TF_BUILD_BAZEL_CLEAN}" ]] && - [[ "${TF_BUILD_BAZEL_CLEAN}" != "0" ]]; then - echo ${BAZEL_CLEAN_CMD} >> ${TMP_SCRIPT} -fi -echo ${MAIN_CMD} >> ${TMP_SCRIPT} - -echo "Executing final command (${TMP_SCRIPT})..." -echo "==========================================" -cat ${TMP_SCRIPT} -echo "==========================================" -echo "" - - -TMP_DIR="" -DOCKERFILE_FLAG="" -if [[ "${DO_DOCKER}" == "1" ]]; then - if [[ "${TF_BUILD_PYTHON_VERSION}" == "python3.5" ]] || - [[ "${TF_BUILD_PYTHON_VERSION}" == "python3.6" ]]; then - # Modify Dockerfile for Python3.5 | Python3.6 build - TMP_DIR=$(mktemp -d) - echo "Docker build will occur in temporary directory: ${TMP_DIR}" - - # Copy the files required for the docker build - SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" - cp -r "${SCRIPT_DIR}/install" "${TMP_DIR}/install" || \ - die "ERROR: Failed to copy directory ${SCRIPT_DIR}/install" - - DOCKERFILE="${SCRIPT_DIR}/Dockerfile.${TF_BUILD_CONTAINER_TYPE}" - cp "${DOCKERFILE}" "${TMP_DIR}/" || \ - die "ERROR: Failed to copy Dockerfile at ${DOCKERFILE}" - DOCKERFILE="${TMP_DIR}/Dockerfile.${TF_BUILD_CONTAINER_TYPE}" - - # Replace a line in the Dockerfile - if sed -i \ - "s/RUN \/install\/install_pip_packages.sh/RUN \/install\/install_${TF_BUILD_PYTHON_VERSION}_pip_packages.sh/g" \ - "${DOCKERFILE}" - then - echo "Copied and modified Dockerfile for ${TF_BUILD_PYTHON_VERSION} build: ${DOCKERFILE}" - else - die "ERROR: Faild to copy and modify Dockerfile: ${DOCKERFILE}" - fi - - DOCKERFILE_FLAG="--dockerfile ${DOCKERFILE}" - fi -fi - -# Set a disk usage trap. -function debug_disk_usage { - echo "Finished script... disk usage report in ${TMP_DIR}" - du -k -d 2 ${TMP_DIR} | sort -n -r -} -# trap debug_disk_usage EXIT - -chmod +x ${TMP_SCRIPT} - -# Map TF_BUILD container types to containers we actually have. -if [[ "${CTYPE}" == "android_full" ]]; then - CONTAINER="android" -else - CONTAINER=${CTYPE} -fi - -FAILURE=0 -if [[ ! -z "${TF_BUILD_DRY_RUN}" ]] && [[ ${TF_BUILD_DRY_RUN} != "0" ]]; then - # Do a dry run: just print the final command - echo "*** This is a DRY RUN ***" -else - # Actually run the command - if [[ "${DO_DOCKER}" == "1" ]]; then - ${DOCKER_MAIN_CMD} ${CONTAINER} ${DOCKERFILE_FLAG} /tmp/tf_build.sh - else - ${TMP_SCRIPT} - fi - - if [[ $? != "0" ]]; then - FAILURE=1 - fi -fi - -[[ ${FAILURE} == "0" ]] && RESULT="SUCCESS" || RESULT="FAILURE" - -rm -f ${TMP_SCRIPT} - -END_TIME=$(date +'%s') -echo "" -echo "Parameterized build ends with ${RESULT} at: $(date) "\ -"(Elapsed time: $((END_TIME - START_TIME)) s)" - -# Dump disk usage -debug_disk_usage - -# Clean up temporary directory if it exists -if [[ ! -z "${TMP_DIR}" ]]; then - echo "Cleaning up temporary directory: ${TMP_DIR}" - rm -rf "${TMP_DIR}" -fi - -exit ${FAILURE} diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD index e152f9b6a22..c8df1e34bc0 100644 --- a/tensorflow/tools/pip_package/BUILD +++ b/tensorflow/tools/pip_package/BUILD @@ -205,12 +205,10 @@ filegroup( "@opt_einsum_archive//:LICENSE", "@org_python_pypi_backports_weakref//:LICENSE", "@pasta//:LICENSE", - "@pcre//:LICENCE", "@png//:LICENSE", "@six_archive//:LICENSE", "@snappy//:COPYING", "@sobol_data//:LICENSE", - "@swig//:LICENSE", "@termcolor_archive//:COPYING.txt", "@zlib//:zlib.h", ] + select({ diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index 082cba67f75..f4fb80d5caa 100755 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -562,6 +562,11 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""): sha256 = "b956598d8cbe168b5ee717b5dafa56563eb5201a947856a6688bbeac9cac4e1f", strip_prefix = "grpc-b54a5b338637f92bfcf4b0bc05e0f57a5fd8fadd", system_build_file = clean_dep("//third_party/systemlibs:grpc.BUILD"), + system_link_files = { + "//third_party/systemlibs:BUILD": "bazel/BUILD", + "//third_party/systemlibs:grpc.BUILD": "src/compiler/BUILD", + "//third_party/systemlibs:grpc.bazel.grpc_deps.bzl": "bazel/grpc_deps.bzl", + }, urls = [ "https://storage.googleapis.com/mirror.tensorflow.org/github.com/grpc/grpc/archive/b54a5b338637f92bfcf4b0bc05e0f57a5fd8fadd.tar.gz", "https://github.com/grpc/grpc/archive/b54a5b338637f92bfcf4b0bc05e0f57a5fd8fadd.tar.gz", @@ -591,8 +596,8 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""): ) # Check out LLVM and MLIR from llvm-project. - LLVM_COMMIT = "c8de17bca658e62bbf8c33eae839e457332e885e" - LLVM_SHA256 = "a1a4b06037c7b19a5f9414fee9626252e4de3e9d9461c8095cc569ee25d647a3" + LLVM_COMMIT = "c6e917d2d3ea07960721923230c34abe3b6214cc" + LLVM_SHA256 = "2c9e67fb2638dc9920b26422b2310b2dd0d203ada4fe20d2b300e8d6b453c5f3" LLVM_URLS = [ "https://storage.googleapis.com/mirror.tensorflow.org/github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT), "https://github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT), @@ -1034,6 +1039,7 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""): sha256 = "1eed57bc6863190e35637290f97a20c81cfe4d9090ac0a24f3bbf08f265eb71d", strip_prefix = "pybind11-2.4.3", build_file = clean_dep("//third_party:pybind11.BUILD"), + system_build_file = clean_dep("//third_party/systemlibs:pybind11.BUILD"), ) tf_http_archive( diff --git a/third_party/jpeg/BUILD.bazel b/third_party/jpeg/BUILD.bazel index 90e45237c7d..269e5254c86 100644 --- a/third_party/jpeg/BUILD.bazel +++ b/third_party/jpeg/BUILD.bazel @@ -516,30 +516,30 @@ JCONFIG_NOWIN_COMMON_SUBSTITUTIONS = { "@JPEG_LIB_VERSION@": "62", "@VERSION@": "2.0.0", "@LIBJPEG_TURBO_VERSION_NUMBER@": "2000000", - "#cmakedefine C_ARITH_CODING_SUPPORTED": "#define C_ARITH_CODING_SUPPORTED", - "#cmakedefine D_ARITH_CODING_SUPPORTED": "#define D_ARITH_CODING_SUPPORTED", - "#cmakedefine MEM_SRCDST_SUPPORTED": "#define MEM_SRCDST_SUPPORTED", + "#cmakedefine C_ARITH_CODING_SUPPORTED 1": "#define C_ARITH_CODING_SUPPORTED 1", + "#cmakedefine D_ARITH_CODING_SUPPORTED 1": "#define D_ARITH_CODING_SUPPORTED 1", + "#cmakedefine MEM_SRCDST_SUPPORTED 1": "#define MEM_SRCDST_SUPPORTED 1", "@BITS_IN_JSAMPLE@": "8", - "#cmakedefine HAVE_LOCALE_H": "#define HAVE_LOCALE_H 1", - "#cmakedefine HAVE_STDDEF_H": "#define HAVE_STDDEF_H 1", - "#cmakedefine HAVE_STDLIB_H": "#define HAVE_STDLIB_H 1", - "#cmakedefine NEED_SYS_TYPES_H": "#define NEED_SYS_TYPES_H", - "#cmakedefine NEED_BSD_STRINGS": "", - "#cmakedefine HAVE_UNSIGNED_CHAR": "#define HAVE_UNSIGNED_CHAR 1", - "#cmakedefine HAVE_UNSIGNED_SHORT": "#define HAVE_UNSIGNED_SHORT 1", - "#cmakedefine INCOMPLETE_TYPES_BROKEN": "", - "#cmakedefine RIGHT_SHIFT_IS_UNSIGNED": "", - "#cmakedefine __CHAR_UNSIGNED__": "", + "#cmakedefine HAVE_LOCALE_H 1": "#define HAVE_LOCALE_H 1", + "#cmakedefine HAVE_STDDEF_H 1": "#define HAVE_STDDEF_H 1", + "#cmakedefine HAVE_STDLIB_H 1": "#define HAVE_STDLIB_H 1", + "#cmakedefine NEED_SYS_TYPES_H 1": "#define NEED_SYS_TYPES_H 1", + "#cmakedefine NEED_BSD_STRINGS 1": "", + "#cmakedefine HAVE_UNSIGNED_CHAR 1": "#define HAVE_UNSIGNED_CHAR 1", + "#cmakedefine HAVE_UNSIGNED_SHORT 1": "#define HAVE_UNSIGNED_SHORT 1", + "#cmakedefine INCOMPLETE_TYPES_BROKEN 1": "", + "#cmakedefine RIGHT_SHIFT_IS_UNSIGNED 1": "", + "#cmakedefine __CHAR_UNSIGNED__ 1": "", "#undef const": "", "#undef size_t": "", } JCONFIG_NOWIN_SIMD_SUBSTITUTIONS = { - "#cmakedefine WITH_SIMD": "#define WITH_SIMD", + "#cmakedefine WITH_SIMD 1": "#define WITH_SIMD 1", } JCONFIG_NOWIN_NOSIMD_SUBSTITUTIONS = { - "#cmakedefine WITH_SIMD": "", + "#cmakedefine WITH_SIMD 1": "", } JCONFIG_NOWIN_SIMD_SUBSTITUTIONS.update(JCONFIG_NOWIN_COMMON_SUBSTITUTIONS) diff --git a/third_party/jpeg/workspace.bzl b/third_party/jpeg/workspace.bzl index e2137ba949f..c458ff12ba8 100644 --- a/third_party/jpeg/workspace.bzl +++ b/third_party/jpeg/workspace.bzl @@ -6,11 +6,11 @@ def repo(): third_party_http_archive( name = "libjpeg_turbo", urls = [ - "https://storage.googleapis.com/mirror.tensorflow.org/github.com/libjpeg-turbo/libjpeg-turbo/archive/2.0.0.tar.gz", - "https://github.com/libjpeg-turbo/libjpeg-turbo/archive/2.0.0.tar.gz", + "https://storage.googleapis.com/mirror.tensorflow.org/github.com/libjpeg-turbo/libjpeg-turbo/archive/2.0.4.tar.gz", + "https://github.com/libjpeg-turbo/libjpeg-turbo/archive/2.0.4.tar.gz", ], - sha256 = "f892fff427ab3adffc289363eac26d197ce3ccacefe5f5822377348a8166069b", - strip_prefix = "libjpeg-turbo-2.0.0", + sha256 = "7777c3c19762940cff42b3ba4d7cd5c52d1671b39a79532050c85efb99079064", + strip_prefix = "libjpeg-turbo-2.0.4", build_file = "//third_party/jpeg:BUILD.bazel", system_build_file = "//third_party/jpeg:BUILD.system", ) diff --git a/third_party/mlir/BUILD b/third_party/mlir/BUILD index 24b2774ba7a..a9199ba9505 100644 --- a/third_party/mlir/BUILD +++ b/third_party/mlir/BUILD @@ -250,7 +250,7 @@ cc_library( name = "AVX512ToLLVM", srcs = glob([ "lib/Conversion/AVX512ToLLVM/*.cpp", - ]), + ]) + ["lib/Conversion/PassDetail.h"], hdrs = glob([ "include/mlir/Conversion/AVX512ToLLVM/*.h", ]), @@ -326,7 +326,10 @@ gentbl( cc_library( name = "LoopOpsTransforms", - srcs = glob(["lib/Dialect/LoopOps/Transforms/*.cpp"]), + srcs = glob([ + "lib/Dialect/LoopOps/Transforms/*.cpp", + "lib/Dialect/LoopOps/Transforms/*.h", + ]), hdrs = ["include/mlir/Dialect/LoopOps/Passes.h"], includes = ["include"], deps = [ @@ -339,7 +342,6 @@ cc_library( ":Transforms", "@llvm-project//llvm:support", ], - alwayslink = 1, ) filegroup( @@ -467,6 +469,7 @@ cc_library( name = "AffineTransforms", srcs = glob([ "lib/Dialect/Affine/Transforms/*.cpp", + "lib/Dialect/Affine/Transforms/*.h", ]), hdrs = [ "include/mlir/Dialect/Affine/Passes.h", @@ -508,7 +511,7 @@ cc_library( srcs = glob([ "lib/Conversion/AffineToStandard/*.cpp", "lib/Conversion/AffineToStandard/*.h", - ]), + ]) + ["lib/Conversion/PassDetail.h"], hdrs = glob(["include/mlir/Conversion/AffineToStandard/*.h"]), includes = ["include"], deps = [ @@ -846,7 +849,6 @@ cc_library( "@llvm-project//llvm:core", "@llvm-project//llvm:support", ], - alwayslink = 1, ) cc_library( @@ -906,7 +908,10 @@ gentbl( cc_library( name = "LLVMIRTransforms", - srcs = glob(["lib/Dialect/LLVMIR/Transforms/*.cpp"]), + srcs = glob([ + "lib/Dialect/LLVMIR/Transforms/*.cpp", + "lib/Dialect/LLVMIR/Transforms/*.h", + ]), hdrs = glob(["include/mlir/Dialect/LLVMIR/Transforms/*.h"]), includes = ["include"], deps = [ @@ -1098,7 +1103,7 @@ cc_library( srcs = glob([ "lib/Conversion/GPUToNVVM/*.cpp", "lib/Conversion/GPUToNVVM/*.h", - ]), + ]) + ["lib/Conversion/PassDetail.h"], hdrs = glob([ "include/mlir/Conversion/GPUToNVVM/*.h", ]), @@ -1120,7 +1125,10 @@ cc_library( cc_library( name = "GPUToROCDLTransforms", - srcs = ["lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp"], + srcs = [ + "lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp", + "lib/Conversion/PassDetail.h", + ], hdrs = [ "include/mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h", ], @@ -1141,6 +1149,7 @@ cc_library( srcs = [ "lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp", "lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp", + "lib/Conversion/PassDetail.h", ], hdrs = ["include/mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h"], includes = ["include"], @@ -1163,6 +1172,7 @@ cc_library( srcs = [ "lib/Conversion/GPUToCUDA/ConvertKernelFuncToCubin.cpp", "lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp", + "lib/Conversion/PassDetail.h", ], hdrs = ["include/mlir/Conversion/GPUToCUDA/GPUToCUDAPass.h"], includes = ["include"], @@ -1203,6 +1213,7 @@ cc_library( srcs = [ "lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp", "lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp", + "lib/Conversion/PassDetail.h", ], hdrs = [ "include/mlir/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.h", @@ -1640,11 +1651,11 @@ gentbl( cc_library( name = "SPIRVLowering", - srcs = [ + srcs = glob([ + "lib/Dialect/SPIRV/Transforms/*.cpp", + "lib/Dialect/SPIRV/Transforms/*.h", + ]) + [ "lib/Dialect/SPIRV/SPIRVLowering.cpp", - "lib/Dialect/SPIRV/Transforms/DecorateSPIRVCompositeTypeLayoutPass.cpp", - "lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp", - "lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp", ], hdrs = [ "include/mlir/Dialect/SPIRV/Passes.h", @@ -1672,7 +1683,7 @@ cc_library( srcs = glob([ "lib/Conversion/StandardToSPIRV/*.cpp", "lib/Conversion/StandardToSPIRV/*.h", - ]), + ]) + ["lib/Conversion/PassDetail.h"], hdrs = glob([ "include/mlir/Conversion/StandardToSPIRV/*.h", ]), @@ -1757,7 +1768,6 @@ cc_library( ":Translation", "@llvm-project//llvm:support", ], - alwayslink = 1, ) cc_library( @@ -1924,6 +1934,7 @@ cc_library( name = "LoopsToGPUPass", srcs = [ "lib/Conversion/LoopsToGPU/LoopsToGPUPass.cpp", + "lib/Conversion/PassDetail.h", ], hdrs = [ "include/mlir/Conversion/LoopsToGPU/LoopsToGPUPass.h", @@ -1947,6 +1958,7 @@ cc_library( name = "CFGTransforms", srcs = [ "lib/Conversion/LoopToStandard/LoopToStandard.cpp", + "lib/Conversion/PassDetail.h", ], hdrs = [ "include/mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h", @@ -1968,6 +1980,7 @@ cc_library( cc_library( name = "LLVMTransforms", srcs = [ + "lib/Conversion/PassDetail.h", "lib/Conversion/StandardToLLVM/StandardToLLVM.cpp", ], hdrs = [ @@ -2235,7 +2248,6 @@ cc_library( "@llvm-project//llvm:ir_reader", "@llvm-project//llvm:support", ], - alwayslink = 1, ) cc_library( @@ -2259,7 +2271,6 @@ cc_library( "@llvm-project//llvm:core", "@llvm-project//llvm:support", ], - alwayslink = 1, ) cc_library( @@ -2283,7 +2294,6 @@ cc_library( "@llvm-project//llvm:core", "@llvm-project//llvm:support", ], - alwayslink = 1, ) # TODO(zinenko): Update these so that we can simplify mapping to cmake. @@ -2372,11 +2382,23 @@ cc_library( ], ) +cc_library( + name = "AllTranslations", + hdrs = ["include/mlir/InitAllTranslations.h"], + deps = [ + ":SPIRVTranslateRegistration", + ":TargetLLVMIR", + ":TargetNVVMIR", + ":TargetROCDLIR", + ], +) + cc_library( name = "MlirTranslateMain", srcs = ["tools/mlir-translate/mlir-translate.cpp"], deps = [ ":AllPassesAndDialectsNoRegistration", + ":AllTranslations", ":IR", ":Parser", ":Support", @@ -2389,10 +2411,6 @@ cc_binary( name = "mlir-translate", deps = [ ":MlirTranslateMain", - ":SPIRVTranslateRegistration", - ":TargetLLVMIR", - ":TargetNVVMIR", - ":TargetROCDLIR", ], ) @@ -2760,6 +2778,7 @@ cc_library( "lib/Dialect/Quant/IR/TypeParser.cpp", "lib/Dialect/Quant/Transforms/ConvertConst.cpp", "lib/Dialect/Quant/Transforms/ConvertSimQuant.cpp", + "lib/Dialect/Quant/Transforms/PassDetail.h", "lib/Dialect/Quant/Utils/FakeQuantSupport.cpp", "lib/Dialect/Quant/Utils/QuantizeUtils.cpp", "lib/Dialect/Quant/Utils/UniformSupport.cpp", @@ -2913,7 +2932,7 @@ cc_library( srcs = glob([ "lib/Conversion/LinalgToLLVM/*.cpp", "lib/Conversion/LinalgToLLVM/*.h", - ]), + ]) + ["lib/Conversion/PassDetail.h"], hdrs = glob([ "include/mlir/Conversion/LinalgToLLVM/*.h", ]), @@ -2944,7 +2963,7 @@ cc_library( srcs = glob([ "lib/Conversion/LinalgToSPIRV/*.cpp", "lib/Conversion/LinalgToSPIRV/*.h", - ]), + ]) + ["lib/Conversion/PassDetail.h"], hdrs = glob([ "include/mlir/Conversion/LinalgToSPIRV/*.h", ]), @@ -3005,14 +3024,12 @@ gentbl( cc_library( name = "LinalgTransforms", - srcs = [ + srcs = glob([ + "lib/Dialect/Linalg/Transforms/*.cpp", + "lib/Dialect/Linalg/Transforms/*.h", + ]) + [ "lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp", "lib/Dialect/Linalg/EDSC/Builders.cpp", - "lib/Dialect/Linalg/Transforms/Fusion.cpp", - "lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp", - "lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp", - "lib/Dialect/Linalg/Transforms/Promotion.cpp", - "lib/Dialect/Linalg/Transforms/Tiling.cpp", "lib/Dialect/Linalg/Utils/Utils.cpp", ], hdrs = [ @@ -3120,7 +3137,7 @@ cc_library( srcs = glob([ "lib/Conversion/VectorToLLVM/*.cpp", "lib/Conversion/VectorToLLVM/*.h", - ]), + ]) + ["lib/Conversion/PassDetail.h"], hdrs = glob([ "include/mlir/Conversion/VectorToLLVM/*.h", ]), diff --git a/third_party/mlir/test.BUILD b/third_party/mlir/test.BUILD index 476fd8b77df..f242ae76287 100644 --- a/third_party/mlir/test.BUILD +++ b/third_party/mlir/test.BUILD @@ -46,6 +46,22 @@ gentbl( ], ) +gentbl( + name = "TestLinalgMatmulToVectorPatternsIncGen", + tbl_outs = [ + ( + "-gen-rewriters", + "lib/DeclarativeTransforms/TestLinalgMatmulToVectorPatterns.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "lib/DeclarativeTransforms/TestLinalgMatmulToVectorPatterns.td", + td_srcs = [ + "@llvm-project//mlir:VectorTransformPatternsTdFiles", + "@llvm-project//mlir:LinalgTransformPatternsTdFiles", + ], +) + gentbl( name = "TestOpsIncGen", strip_include_prefix = "lib/Dialect/Test", @@ -157,6 +173,7 @@ cc_library( includes = ["lib/Dialect/Test"], deps = [ ":TestDialect", + ":TestLinalgMatmulToVectorPatternsIncGen", ":TestLinalgTransformPatternsIncGen", ":TestVectorTransformPatternsIncGen", "@llvm-project//llvm:support", diff --git a/third_party/pybind11.BUILD b/third_party/pybind11.BUILD index 95f452c05af..2f1ada6193c 100644 --- a/third_party/pybind11.BUILD +++ b/third_party/pybind11.BUILD @@ -18,6 +18,7 @@ cc_library( "-Wno-pragma-once-outside-header", ], includes = ["include"], + strip_include_prefix = "include", deps = [ "@org_tensorflow//third_party/python_runtime:headers", ], diff --git a/third_party/systemlibs/astunparse.BUILD b/third_party/systemlibs/astunparse.BUILD new file mode 100644 index 00000000000..e398ff3a024 --- /dev/null +++ b/third_party/systemlibs/astunparse.BUILD @@ -0,0 +1,14 @@ +# Description: +# AST round-trip manipulation for Python. + +licenses(["notice"]) + +py_library( + name = "astunparse", + visibility = ["//visibility:public"], +) + +filegroup( + name = "LICENSE", + visibility = ["//visibility:public"], +) diff --git a/third_party/systemlibs/grpc.bazel.grpc_deps.bzl b/third_party/systemlibs/grpc.bazel.grpc_deps.bzl new file mode 100644 index 00000000000..dd389c68550 --- /dev/null +++ b/third_party/systemlibs/grpc.bazel.grpc_deps.bzl @@ -0,0 +1,6 @@ +"""Load dependencies needed to compile and test the grpc library as a 3rd-party consumer.""" + +def grpc_deps(): + """Loads dependencies need to compile and test the grpc library.""" + + pass diff --git a/third_party/systemlibs/protobuf.bzl b/third_party/systemlibs/protobuf.bzl index bb807e904a3..367ac286395 100644 --- a/third_party/systemlibs/protobuf.bzl +++ b/third_party/systemlibs/protobuf.bzl @@ -93,6 +93,7 @@ def _proto_gen_impl(ctx): args += ["--python_out=" + gen_dir] inputs = srcs + deps + tools = [ctx.executable.protoc] if ctx.executable.plugin: plugin = ctx.executable.plugin lang = ctx.attr.plugin_language @@ -106,7 +107,7 @@ def _proto_gen_impl(ctx): outdir = ",".join(ctx.attr.plugin_options) + ":" + outdir args += ["--plugin=protoc-gen-%s=%s" % (lang, plugin.path)] args += ["--%s_out=%s" % (lang, outdir)] - inputs += [plugin] + tools.append(plugin) if args: ctx.actions.run( @@ -115,6 +116,7 @@ def _proto_gen_impl(ctx): arguments = args + import_flags + [s.path for s in srcs], executable = ctx.executable.protoc, mnemonic = "ProtoCompile", + tools = tools, use_default_shell_env = True, ) diff --git a/third_party/systemlibs/pybind11.BUILD b/third_party/systemlibs/pybind11.BUILD new file mode 100644 index 00000000000..79a483d7b5d --- /dev/null +++ b/third_party/systemlibs/pybind11.BUILD @@ -0,0 +1,8 @@ +package(default_visibility = ["//visibility:public"]) + +cc_library( + name = "pybind11", + deps = [ + "@org_tensorflow//third_party/python_runtime:headers", + ], +) diff --git a/third_party/systemlibs/syslibs_configure.bzl b/third_party/systemlibs/syslibs_configure.bzl index 7a96fdf9d21..217c0131186 100644 --- a/third_party/systemlibs/syslibs_configure.bzl +++ b/third_party/systemlibs/syslibs_configure.bzl @@ -11,6 +11,7 @@ _TF_SYSTEM_LIBS = "TF_SYSTEM_LIBS" VALID_LIBS = [ "absl_py", "astor_archive", + "astunparse_archive", "boringssl", "com_github_googleapis_googleapis", "com_github_googlecloudplatform_google_cloud_cpp", @@ -37,6 +38,7 @@ VALID_LIBS = [ "pasta", "pcre", "png", + "pybind11", "six_archive", "snappy", "swig",